diff options
Diffstat (limited to 'activerecord')
31 files changed, 452 insertions, 174 deletions
diff --git a/activerecord/lib/active_record/associations/association_scope.rb b/activerecord/lib/active_record/associations/association_scope.rb index 27fd9e35db..bb889a8f3b 100644 --- a/activerecord/lib/active_record/associations/association_scope.rb +++ b/activerecord/lib/active_record/associations/association_scope.rb @@ -120,6 +120,7 @@ module ActiveRecord end scope.where_values += item.where_values + scope.bind_values += item.bind_values scope.order_values |= item.order_values end end diff --git a/activerecord/lib/active_record/associations/join_dependency.rb b/activerecord/lib/active_record/associations/join_dependency.rb index 94f69d4c2d..b7dc037a65 100644 --- a/activerecord/lib/active_record/associations/join_dependency.rb +++ b/activerecord/lib/active_record/associations/join_dependency.rb @@ -164,17 +164,17 @@ module ActiveRecord def make_outer_joins(parent, child) tables = table_aliases_for(parent, child) join_type = Arel::Nodes::OuterJoin - joins = make_constraints parent, child, tables, join_type + info = make_constraints parent, child, tables, join_type - joins.concat child.children.flat_map { |c| make_outer_joins(child, c) } + [info] + child.children.flat_map { |c| make_outer_joins(child, c) } end def make_inner_joins(parent, child) tables = child.tables join_type = Arel::Nodes::InnerJoin - joins = make_constraints parent, child, tables, join_type + info = make_constraints parent, child, tables, join_type - joins.concat child.children.flat_map { |c| make_inner_joins(child, c) } + [info] + child.children.flat_map { |c| make_inner_joins(child, c) } end def table_aliases_for(parent, node) diff --git a/activerecord/lib/active_record/associations/join_dependency/join_association.rb b/activerecord/lib/active_record/associations/join_dependency/join_association.rb index 1d923ecc09..a0e83c0a02 100644 --- a/activerecord/lib/active_record/associations/join_dependency/join_association.rb +++ b/activerecord/lib/active_record/associations/join_dependency/join_association.rb @@ -21,8 +21,11 @@ module ActiveRecord super && reflection == other.reflection end + JoinInformation = Struct.new :joins, :binds + def join_constraints(foreign_table, foreign_klass, node, join_type, tables, scope_chain, chain) joins = [] + bind_values = [] tables = tables.reverse scope_chain_index = 0 @@ -60,21 +63,27 @@ module ActiveRecord left.merge right end - if reflection.type - constraint = constraint.and table[reflection.type].eq foreign_klass.base_class.name - end - if rel && !rel.arel.constraints.empty? + bind_values.concat rel.bind_values constraint = constraint.and rel.arel.constraints end + if reflection.type + value = foreign_klass.base_class.name + column = klass.columns_hash[column.to_s] + + substitute = klass.connection.substitute_at(column, bind_values.length) + bind_values.push [column, value] + constraint = constraint.and table[reflection.type].eq substitute + end + joins << table.create_join(table, table.create_on(constraint), join_type) # The current table in this iteration becomes the foreign table in the next foreign_table, foreign_klass = table, klass end - joins + JoinInformation.new joins, bind_values end # Builds equality condition. diff --git a/activerecord/lib/active_record/associations/preloader.rb b/activerecord/lib/active_record/associations/preloader.rb index 31ddf4e0fc..311684d886 100644 --- a/activerecord/lib/active_record/associations/preloader.rb +++ b/activerecord/lib/active_record/associations/preloader.rb @@ -80,7 +80,7 @@ module ActiveRecord # { author: :avatar } # [ :books, { author: :avatar } ] - NULL_RELATION = Struct.new(:values).new({}) + NULL_RELATION = Struct.new(:values, :bind_values).new({}, []) def preload(records, associations, preload_scope = nil) records = Array.wrap(records).compact.uniq diff --git a/activerecord/lib/active_record/associations/preloader/association.rb b/activerecord/lib/active_record/associations/preloader/association.rb index 69b65982b3..83c69586e6 100644 --- a/activerecord/lib/active_record/associations/preloader/association.rb +++ b/activerecord/lib/active_record/associations/preloader/association.rb @@ -111,10 +111,13 @@ module ActiveRecord scope = klass.unscoped values = reflection_scope.values + reflection_binds = reflection_scope.bind_values preload_values = preload_scope.values + preload_binds = preload_scope.bind_values scope.where_values = Array(values[:where]) + Array(preload_values[:where]) scope.references_values = Array(values[:references]) + Array(preload_values[:references]) + scope.bind_values = (reflection_binds + preload_binds) scope.select! preload_values[:select] || values[:select] || table[Arel.star] scope.includes! preload_values[:includes] || values[:includes] diff --git a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb index 47f2ad9b10..b7b9a4363e 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb @@ -9,15 +9,19 @@ module ActiveRecord # Converts an arel AST to SQL def to_sql(arel, binds = []) if arel.respond_to?(:ast) - binds = binds.dup - visitor.accept(arel.ast) do - quote(*binds.shift.reverse) - end + collected = visitor.accept(arel.ast, collector) + collected.compile(binds.dup, self) else arel end end + # This is used in the StatementCache object. It returns an object that + # can be used to query the database repeatedly. + def cacheable_query(arel) # :nodoc: + ActiveRecord::StatementCache.query visitor, arel.ast + end + # Returns an ActiveRecord::Result instance. def select_all(arel, name = nil, binds = []) arel, binds = binds_from_relation arel, binds diff --git a/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb b/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb index ffd5055dec..78343cf4f5 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb @@ -6,6 +6,8 @@ require 'active_record/connection_adapters/schema_cache' require 'active_record/connection_adapters/abstract/schema_dumper' require 'active_record/connection_adapters/abstract/schema_creation' require 'monitor' +require 'arel/collectors/bind' +require 'arel/collectors/sql_string' module ActiveRecord module ConnectionAdapters # :nodoc: @@ -90,6 +92,8 @@ module ActiveRecord end end + attr_reader :prepared_statements + def initialize(connection, logger = nil, pool = nil) #:nodoc: super() @@ -103,6 +107,26 @@ module ActiveRecord @prepared_statements = false end + class BindCollector < Arel::Collectors::Bind + def compile(bvs, conn) + super(bvs.map { |bv| conn.quote(*bv.reverse) }) + end + end + + class SQLString < Arel::Collectors::SQLString + def compile(bvs, conn) + super(bvs) + end + end + + def collector + if @prepared_statements + SQLString.new + else + BindCollector.new + end + end + def valid_type?(type) true end @@ -128,16 +152,11 @@ module ActiveRecord @owner = nil end - def unprepared_visitor - self.class::BindSubstitution.new self - end - def unprepared_statement old_prepared_statements, @prepared_statements = @prepared_statements, false - old_visitor, @visitor = @visitor, unprepared_visitor yield ensure - @visitor, @prepared_statements = old_visitor, old_prepared_statements + @prepared_statements = old_prepared_statements end # Returns the human-readable name of the adapter. Use mixed case - one @@ -318,13 +337,14 @@ module ActiveRecord def release_savepoint(name = nil) end - def case_sensitive_modifier(node) + def case_sensitive_modifier(node, table_attribute) node end def case_sensitive_comparison(table, attribute, column, value) - value = case_sensitive_modifier(value) unless value.nil? - table[attribute].eq(value) + table_attr = table[attribute] + value = case_sensitive_modifier(value, table_attr) unless value.nil? + table_attr.eq(value) end def case_insensitive_comparison(table, attribute, column, value) diff --git a/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb b/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb index 20eea208ec..75c58ac7d9 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb @@ -183,21 +183,18 @@ module ActiveRecord INDEX_TYPES = [:fulltext, :spatial] INDEX_USINGS = [:btree, :hash] - class BindSubstitution < Arel::Visitors::MySQL # :nodoc: - include Arel::Visitors::BindVisitor - end - # FIXME: Make the first parameter more similar for the two adapters def initialize(connection, logger, connection_options, config) super(connection, logger) @connection_options, @config = connection_options, config @quoted_column_names, @quoted_table_names = {}, {} + @visitor = Arel::Visitors::MySQL.new self + if self.class.type_cast_config_to_boolean(config.fetch(:prepared_statements) { true }) @prepared_statements = true - @visitor = Arel::Visitors::MySQL.new self else - @visitor = unprepared_visitor + @prepared_statements = false end end @@ -610,7 +607,8 @@ module ActiveRecord pk_and_sequence && pk_and_sequence.first end - def case_sensitive_modifier(node) + def case_sensitive_modifier(node, table_attribute) + node = Arel::Nodes.build_quoted node, table_attribute Arel::Nodes::Bin.new(node) end diff --git a/activerecord/lib/active_record/connection_adapters/mysql2_adapter.rb b/activerecord/lib/active_record/connection_adapters/mysql2_adapter.rb index 5e82fdcbe0..a9d260b98c 100644 --- a/activerecord/lib/active_record/connection_adapters/mysql2_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/mysql2_adapter.rb @@ -40,10 +40,14 @@ module ActiveRecord def initialize(connection, logger, connection_options, config) super - @visitor = BindSubstitution.new self + @prepared_statements = false configure_connection end + def cacheable_query(arel) + ActiveRecord::StatementCache.partial_query visitor, arel.ast, collector + end + MAX_INDEX_LENGTH_FOR_UTF8MB4 = 191 def initialize_schema_migrations_table if @config[:encoding] == 'utf8mb4' diff --git a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb index 0485093123..56dd2da249 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb @@ -338,19 +338,15 @@ module ActiveRecord end end - class BindSubstitution < Arel::Visitors::PostgreSQL # :nodoc: - include Arel::Visitors::BindVisitor - end - # Initializes and connects a PostgreSQL adapter. def initialize(connection, logger, connection_parameters, config) super(connection, logger) + @visitor = Arel::Visitors::PostgreSQL.new self if self.class.type_cast_config_to_boolean(config.fetch(:prepared_statements) { true }) @prepared_statements = true - @visitor = Arel::Visitors::PostgreSQL.new self else - @visitor = unprepared_visitor + @prepared_statements = false end @connection_parameters, @config = connection_parameters, config diff --git a/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb b/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb index cd1f7a16c6..f59c2432dd 100644 --- a/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb @@ -123,10 +123,6 @@ module ActiveRecord end end - class BindSubstitution < Arel::Visitors::SQLite # :nodoc: - include Arel::Visitors::BindVisitor - end - def initialize(connection, logger, config) super(connection, logger) @@ -135,11 +131,12 @@ module ActiveRecord self.class.type_cast_config_to_integer(config.fetch(:statement_limit) { 1000 })) @config = config + @visitor = Arel::Visitors::SQLite.new self + if self.class.type_cast_config_to_boolean(config.fetch(:prepared_statements) { true }) @prepared_statements = true - @visitor = Arel::Visitors::SQLite.new self else - @visitor = unprepared_visitor + @prepared_statements = false end end @@ -273,7 +270,7 @@ module ActiveRecord def explain(arel, binds = []) sql = "EXPLAIN QUERY PLAN #{to_sql(arel, binds)}" - ExplainPrettyPrinter.new.pp(exec_query(sql, 'EXPLAIN', binds)) + ExplainPrettyPrinter.new.pp(exec_query(sql, 'EXPLAIN', [])) end class ExplainPrettyPrinter diff --git a/activerecord/lib/active_record/core.rb b/activerecord/lib/active_record/core.rb index 4e53f66005..3be9c7695f 100644 --- a/activerecord/lib/active_record/core.rb +++ b/activerecord/lib/active_record/core.rb @@ -94,6 +94,7 @@ module ActiveRecord end class_attribute :default_connection_handler, instance_writer: false + class_attribute :find_by_statement_cache def self.connection_handler ActiveRecord::RuntimeRegistry.connection_handler || default_connection_handler @@ -107,6 +108,71 @@ module ActiveRecord end module ClassMethods + def initialize_find_by_cache + self.find_by_statement_cache = {}.extend(Mutex_m) + end + + def inherited(child_class) + child_class.initialize_find_by_cache + super + end + + def find(*ids) + # We don't have cache keys for this stuff yet + return super unless ids.length == 1 + return super if block_given? || + primary_key.nil? || + default_scopes.any? || + columns_hash.include?(inheritance_column) || + ids.first.kind_of?(Array) + + id = ids.first + if ActiveRecord::Base === id + id = id.id + ActiveSupport::Deprecation.warn "You are passing an instance of ActiveRecord::Base to `find`." \ + "Please pass the id of the object by calling `.id`" + end + key = primary_key + + s = find_by_statement_cache[key] || find_by_statement_cache.synchronize { + find_by_statement_cache[key] ||= StatementCache.create(connection) { |params| + where(key => params.bind).limit(1) + } + } + record = s.execute([id], self, connection).first + unless record + raise RecordNotFound, "Couldn't find #{name} with '#{primary_key}'=#{id}" + end + record + end + + def find_by(*args) + return super if current_scope || args.length > 1 || reflect_on_all_aggregations.any? + + hash = args.first + + return super if hash.values.any? { |v| + v.nil? || Array === v || Hash === v + } + + key = hash.keys + + klass = self + s = find_by_statement_cache[key] || find_by_statement_cache.synchronize { + find_by_statement_cache[key] ||= StatementCache.create(connection) { |params| + wheres = key.each_with_object({}) { |param,o| + o[param] = params.bind + } + klass.where(wheres).limit(1) + } + } + begin + s.execute(hash.values, self, connection).first + rescue TypeError => e + raise ActiveRecord::StatementInvalid.new(e.message, e) + end + end + def initialize_generated_modules super diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb index 709edbee88..24b33ab0a8 100644 --- a/activerecord/lib/active_record/relation.rb +++ b/activerecord/lib/active_record/relation.rb @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +require 'arel/collectors/bind' module ActiveRecord # = Active Record Relation @@ -223,6 +224,7 @@ module ActiveRecord # Please see further details in the # {Active Record Query Interface guide}[http://guides.rubyonrails.org/active_record_querying.html#running-explain]. def explain + #TODO: Fix for binds. exec_explain(collecting_queries_for_explain { exec_queries }) end @@ -323,7 +325,8 @@ module ActiveRecord stmt.wheres = arel.constraints end - @klass.connection.update stmt, 'SQL', bind_values + bvs = bind_values + arel.bind_values + @klass.connection.update stmt, 'SQL', bvs end # Updates an object (or multiple objects) and saves it to the database, if validations pass. @@ -516,11 +519,11 @@ module ActiveRecord find_with_associations { |rel| relation = rel } end - ast = relation.arel.ast - binds = relation.bind_values.dup - visitor.accept(ast) do - connection.quote(*binds.shift.reverse) - end + arel = relation.arel + binds = (arel.bind_values + relation.bind_values).dup + binds.map! { |bv| connection.quote(*bv.reverse) } + collect = visitor.accept(arel.ast, Arel::Collectors::Bind.new) + collect.substitute_binds(binds).join end end @@ -537,7 +540,13 @@ module ActiveRecord Hash[equalities.map { |where| name = where.left.name - [name, binds.fetch(name.to_s) { where.right }] + [name, binds.fetch(name.to_s) { + case where.right + when Array then where.right.map(&:val) + else + where.right.val + end + }] }] end @@ -601,7 +610,7 @@ module ActiveRecord private def exec_queries - @records = eager_loading? ? find_with_associations : @klass.find_by_sql(arel, bind_values) + @records = eager_loading? ? find_with_associations : @klass.find_by_sql(arel, arel.bind_values + bind_values) preload = preload_values preload += includes_values unless eager_loading? diff --git a/activerecord/lib/active_record/relation/calculations.rb b/activerecord/lib/active_record/relation/calculations.rb index 812e3e800a..5e525340e0 100644 --- a/activerecord/lib/active_record/relation/calculations.rb +++ b/activerecord/lib/active_record/relation/calculations.rb @@ -235,11 +235,14 @@ module ActiveRecord column_alias = column_name + bind_values = nil + if operation == "count" && (relation.limit_value || relation.offset_value) # Shortcut when limit is zero. return 0 if relation.limit_value == 0 query_builder = build_count_subquery(relation, column_name, distinct) + bind_values = relation.bind_values else column = aggregate_column(column_name) @@ -249,9 +252,10 @@ module ActiveRecord relation.select_values = [select_value] query_builder = relation.arel + bind_values = query_builder.bind_values + relation.bind_values end - result = @klass.connection.select_all(query_builder, nil, relation.bind_values) + result = @klass.connection.select_all(query_builder, nil, bind_values) row = result.first value = row && row.values.first column = result.column_types.fetch(column_alias) do diff --git a/activerecord/lib/active_record/relation/finder_methods.rb b/activerecord/lib/active_record/relation/finder_methods.rb index 7af4b29ebc..6f57441a66 100644 --- a/activerecord/lib/active_record/relation/finder_methods.rb +++ b/activerecord/lib/active_record/relation/finder_methods.rb @@ -299,11 +299,8 @@ module ActiveRecord when Array, Hash relation = relation.where(conditions) else - if conditions != :none - column = columns_hash[primary_key] - substitute = connection.substitute_at(column, bind_values.length) - relation = where(table[primary_key].eq(substitute)) - relation.bind_values += [[column, conditions]] + unless conditions == :none + relation = where(primary_key => conditions) end end @@ -351,7 +348,8 @@ module ActiveRecord if ActiveRecord::NullRelation === relation [] else - rows = connection.select_all(relation.arel, 'SQL', relation.bind_values.dup) + arel = relation.arel + rows = connection.select_all(arel, 'SQL', arel.bind_values + relation.bind_values) join_dependency.instantiate(rows, aliases) end end diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index 4287304945..8cd97a3715 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -843,7 +843,7 @@ module ActiveRecord build_joins(arel, joins_values.flatten) unless joins_values.empty? - collapse_wheres(arel, (where_values - ['']).uniq) + collapse_wheres(arel, (where_values - [''])) #TODO: Add uniq with real value comparison / ignore uniqs that have binds arel.having(*having_values.uniq.reject(&:blank?)) unless having_values.empty? @@ -860,6 +860,15 @@ module ActiveRecord arel.from(build_from) if from_value arel.lock(lock_value) if lock_value + # Reorder bind indexes if joins produced bind values + if arel.bind_values.any? + bvs = arel.bind_values + bind_values + arel.ast.grep(Arel::Nodes::BindParam).each_with_index do |bp, i| + column = bvs[i].first + bp.replace connection.substitute_at(column, i) + end + end + arel end @@ -874,6 +883,8 @@ module ActiveRecord case scope when :order result = [] + when :where + self.bind_values = [] else result = [] unless single_val_method end @@ -924,18 +935,16 @@ module ActiveRecord def build_where(opts, other = []) case opts when String, Array - #TODO: Remove duplication with: /activerecord/lib/active_record/sanitization.rb:113 - values = Hash === other.first ? other.first.values : other - - values.grep(ActiveRecord::Relation) do |rel| - self.bind_values += rel.bind_values - end - [@klass.send(:sanitize_sql, other.empty? ? opts : ([opts] + other))] when Hash opts = PredicateBuilder.resolve_column_aliases(klass, opts) attributes = @klass.send(:expand_hash_conditions_for_aggregates, opts) + bv_len = bind_values.length + tmp_opts, bind_values = create_binds(opts, bv_len) + self.bind_values += bind_values + + attributes = @klass.send(:expand_hash_conditions_for_aggregates, tmp_opts) attributes.values.grep(ActiveRecord::Relation) do |rel| self.bind_values += rel.bind_values end @@ -946,6 +955,29 @@ module ActiveRecord end end + def create_binds(opts, idx) + bindable, non_binds = opts.partition do |column, value| + case value + when String, Integer, ActiveRecord::StatementCache::Substitute + @klass.columns_hash.include? column.to_s + else + false + end + end + + new_opts = {} + binds = [] + + bindable.each_with_index do |(column,value), index| + binds.push [@klass.columns_hash[column.to_s], value] + new_opts[column] = connection.substitute_at(column, index + idx) + end + + non_binds.each { |column,value| new_opts[column] = value } + + [new_opts, binds] + end + def build_from opts, name = from_value case opts @@ -987,9 +1019,12 @@ module ActiveRecord join_list ) - joins = join_dependency.join_constraints stashed_association_joins + join_infos = join_dependency.join_constraints stashed_association_joins - joins.each { |join| manager.from(join) } + join_infos.each do |info| + info.joins.each { |join| manager.from(join) } + manager.bind_values.concat info.binds + end manager.join_sources.concat(join_list) diff --git a/activerecord/lib/active_record/relation/spawn_methods.rb b/activerecord/lib/active_record/relation/spawn_methods.rb index 2552cbd234..57d66bce4b 100644 --- a/activerecord/lib/active_record/relation/spawn_methods.rb +++ b/activerecord/lib/active_record/relation/spawn_methods.rb @@ -58,6 +58,9 @@ module ActiveRecord # Post.order('id asc').only(:where) # discards the order condition # Post.order('id asc').only(:where, :order) # uses the specified order def only(*onlies) + if onlies.any? { |o| o == :where } + onlies << :bind + end relation_with values.slice(*onlies) end diff --git a/activerecord/lib/active_record/sanitization.rb b/activerecord/lib/active_record/sanitization.rb index 5a71c13d91..936f8dba02 100644 --- a/activerecord/lib/active_record/sanitization.rb +++ b/activerecord/lib/active_record/sanitization.rb @@ -92,7 +92,7 @@ module ActiveRecord table = Arel::Table.new(table_name, arel_engine).alias(default_table_name) PredicateBuilder.build_from_hash(self, attrs, table).map { |b| - connection.visitor.accept b + connection.visitor.compile b }.join(' AND ') end alias_method :sanitize_sql_hash, :sanitize_sql_hash_for_conditions diff --git a/activerecord/lib/active_record/statement_cache.rb b/activerecord/lib/active_record/statement_cache.rb index dd4ee0c4a0..aece446384 100644 --- a/activerecord/lib/active_record/statement_cache.rb +++ b/activerecord/lib/active_record/statement_cache.rb @@ -14,13 +14,87 @@ module ActiveRecord # The relation returned by the block is cached, and for each +execute+ call the cached relation gets duped. # Database is queried when +to_a+ is called on the relation. class StatementCache - def initialize - @relation = yield - raise ArgumentError.new("Statement cannot be nil") if @relation.nil? + class Substitute; end + + class Query + def initialize(sql) + @sql = sql + end + + def sql_for(binds, connection) + @sql + end + end + + class PartialQuery < Query + def initialize values + @values = values + @indexes = values.each_with_index.find_all { |thing,i| + Arel::Nodes::BindParam === thing + }.map(&:last) + end + + def sql_for(binds, connection) + val = @values.dup + binds = binds.dup + @indexes.each { |i| val[i] = connection.quote(*binds.shift.reverse) } + val.join + end + end + + def self.query(visitor, ast) + Query.new visitor.accept(ast, Arel::Collectors::SQLString.new).value + end + + def self.partial_query(visitor, ast, collector) + collected = visitor.accept(ast, collector).value + PartialQuery.new collected + end + + class Params + def bind; Substitute.new; end end - def execute - @relation.dup.to_a + class BindMap + def initialize(bind_values) + @indexes = [] + @bind_values = bind_values + + bind_values.each_with_index do |(_, value), i| + if Substitute === value + @indexes << i + end + end + end + + def bind(values) + bvs = @bind_values.map { |pair| pair.dup } + @indexes.each_with_index { |offset,i| bvs[offset][1] = values[i] } + bvs + end + end + + attr_reader :bind_map, :query_builder + + def self.create(connection, block = Proc.new) + relation = block.call Params.new + bind_map = BindMap.new relation.bind_values + query_builder = connection.cacheable_query relation.arel + new query_builder, bind_map + end + + def initialize(query_builder, bind_map) + @query_builder = query_builder + @bind_map = bind_map + end + + def execute(params, klass, connection) + bind_values = bind_map.bind params + + sql = query_builder.sql_for bind_values, connection + + klass.find_by_sql sql, bind_values end + alias :call :execute end end diff --git a/activerecord/test/cases/adapters/mysql2/explain_test.rb b/activerecord/test/cases/adapters/mysql2/explain_test.rb index 1cd356e868..675703caa1 100644 --- a/activerecord/test/cases/adapters/mysql2/explain_test.rb +++ b/activerecord/test/cases/adapters/mysql2/explain_test.rb @@ -9,15 +9,15 @@ module ActiveRecord def test_explain_for_one_query explain = Developer.where(:id => 1).explain - assert_match %(EXPLAIN for: SELECT `developers`.* FROM `developers` WHERE `developers`.`id` = 1), explain + assert_match %(EXPLAIN for: SELECT `developers`.* FROM `developers` WHERE `developers`.`id` = 1), explain assert_match %r(developers |.* const), explain end def test_explain_with_eager_loading explain = Developer.where(:id => 1).includes(:audit_logs).explain - assert_match %(EXPLAIN for: SELECT `developers`.* FROM `developers` WHERE `developers`.`id` = 1), explain + assert_match %(EXPLAIN for: SELECT `developers`.* FROM `developers` WHERE `developers`.`id` = 1), explain assert_match %r(developers |.* const), explain - assert_match %(EXPLAIN for: SELECT `audit_logs`.* FROM `audit_logs` WHERE `audit_logs`.`developer_id` IN (1)), explain + assert_match %(EXPLAIN for: SELECT `audit_logs`.* FROM `audit_logs` WHERE `audit_logs`.`developer_id` IN (1)), explain assert_match %r(audit_logs |.* ALL), explain end end diff --git a/activerecord/test/cases/adapters/postgresql/explain_test.rb b/activerecord/test/cases/adapters/postgresql/explain_test.rb index 0b61f61572..416f84cb38 100644 --- a/activerecord/test/cases/adapters/postgresql/explain_test.rb +++ b/activerecord/test/cases/adapters/postgresql/explain_test.rb @@ -9,7 +9,7 @@ module ActiveRecord def test_explain_for_one_query explain = Developer.where(:id => 1).explain - assert_match %(EXPLAIN for: SELECT "developers".* FROM "developers" WHERE "developers"."id" = 1), explain + assert_match %(EXPLAIN for: SELECT "developers".* FROM "developers" WHERE "developers"."id" = $1), explain assert_match %(QUERY PLAN), explain assert_match %(Index Scan using developers_pkey on developers), explain end @@ -17,9 +17,9 @@ module ActiveRecord def test_explain_with_eager_loading explain = Developer.where(:id => 1).includes(:audit_logs).explain assert_match %(QUERY PLAN), explain - assert_match %(EXPLAIN for: SELECT "developers".* FROM "developers" WHERE "developers"."id" = 1), explain + assert_match %(EXPLAIN for: SELECT "developers".* FROM "developers" WHERE "developers"."id" = $1), explain assert_match %(Index Scan using developers_pkey on developers), explain - assert_match %(EXPLAIN for: SELECT "audit_logs".* FROM "audit_logs" WHERE "audit_logs"."developer_id" IN (1)), explain + assert_match %(EXPLAIN for: SELECT "audit_logs".* FROM "audit_logs" WHERE "audit_logs"."developer_id" IN (1)), explain assert_match %(Seq Scan on audit_logs), explain end end diff --git a/activerecord/test/cases/adapters/sqlite3/explain_test.rb b/activerecord/test/cases/adapters/sqlite3/explain_test.rb index b227bce680..f1d6119d2e 100644 --- a/activerecord/test/cases/adapters/sqlite3/explain_test.rb +++ b/activerecord/test/cases/adapters/sqlite3/explain_test.rb @@ -9,15 +9,15 @@ module ActiveRecord def test_explain_for_one_query explain = Developer.where(:id => 1).explain - assert_match %(EXPLAIN for: SELECT "developers".* FROM "developers" WHERE "developers"."id" = 1), explain + assert_match %(EXPLAIN for: SELECT "developers".* FROM "developers" WHERE "developers"."id" = ?), explain assert_match(/(SEARCH )?TABLE developers USING (INTEGER )?PRIMARY KEY/, explain) end def test_explain_with_eager_loading explain = Developer.where(:id => 1).includes(:audit_logs).explain - assert_match %(EXPLAIN for: SELECT "developers".* FROM "developers" WHERE "developers"."id" = 1), explain + assert_match %(EXPLAIN for: SELECT "developers".* FROM "developers" WHERE "developers"."id" = ?), explain assert_match(/(SEARCH )?TABLE developers USING (INTEGER )?PRIMARY KEY/, explain) - assert_match %(EXPLAIN for: SELECT "audit_logs".* FROM "audit_logs" WHERE "audit_logs"."developer_id" IN (1)), explain + assert_match %(EXPLAIN for: SELECT "audit_logs".* FROM "audit_logs" WHERE "audit_logs"."developer_id" IN (1)), explain assert_match(/(SCAN )?TABLE audit_logs/, explain) end end diff --git a/activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb b/activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb index 14aad61ce2..2630a0f3a4 100644 --- a/activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb +++ b/activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb @@ -182,7 +182,7 @@ module ActiveRecord def test_quote_binary_column_escapes_it DualEncoding.connection.execute(<<-eosql) - CREATE TABLE dual_encodings ( + CREATE TABLE IF NOT EXISTS dual_encodings ( id integer PRIMARY KEY AUTOINCREMENT, name varchar(255), data binary @@ -192,9 +192,8 @@ module ActiveRecord binary = DualEncoding.new name: 'いただきます!', data: str binary.save! assert_equal str, binary.data - ensure - DualEncoding.connection.drop_table('dual_encodings') + DualEncoding.connection.execute('DROP TABLE IF EXISTS dual_encodings') end def test_type_cast_should_not_mutate_encoding diff --git a/activerecord/test/cases/associations/association_scope_test.rb b/activerecord/test/cases/associations/association_scope_test.rb index c78b036f53..3e0032ec73 100644 --- a/activerecord/test/cases/associations/association_scope_test.rb +++ b/activerecord/test/cases/associations/association_scope_test.rb @@ -9,7 +9,12 @@ module ActiveRecord scope = AssociationScope.scope(Author.new.association(:welcome_posts), Author.connection) wheres = scope.where_values.map(&:right) + binds = scope.bind_values.map(&:last) + wheres = scope.where_values.map(&:right).reject { |node| + Arel::Nodes::BindParam === node + } assert_equal wheres.uniq, wheres + assert_equal binds.uniq, binds end end end diff --git a/activerecord/test/cases/explain_test.rb b/activerecord/test/cases/explain_test.rb index 6dac5db111..9d25bdd82a 100644 --- a/activerecord/test/cases/explain_test.rb +++ b/activerecord/test/cases/explain_test.rb @@ -26,8 +26,12 @@ if ActiveRecord::Base.connection.supports_explain? sql, binds = queries[0] assert_match "SELECT", sql - assert_match "honda", sql - assert_equal [], binds + if binds.any? + assert_equal 1, binds.length + assert_equal "honda", binds.flatten.last + else + assert_match 'honda', sql + end end def test_exec_explain_with_no_binds diff --git a/activerecord/test/cases/hot_compatibility_test.rb b/activerecord/test/cases/hot_compatibility_test.rb index 367d04a154..b4617cf6f9 100644 --- a/activerecord/test/cases/hot_compatibility_test.rb +++ b/activerecord/test/cases/hot_compatibility_test.rb @@ -15,7 +15,7 @@ class HotCompatibilityTest < ActiveRecord::TestCase end teardown do - @klass.connection.drop_table :hot_compatibilities + ActiveRecord::Base.connection.drop_table :hot_compatibilities end test "insert after remove_column" do diff --git a/activerecord/test/cases/relation/merging_test.rb b/activerecord/test/cases/relation/merging_test.rb index 23500bf5d8..ff1c2a0d82 100644 --- a/activerecord/test/cases/relation/merging_test.rb +++ b/activerecord/test/cases/relation/merging_test.rb @@ -17,9 +17,7 @@ class RelationMergingTest < ActiveRecord::TestCase end def test_relation_to_sql - sql = Post.connection.unprepared_statement do - Post.first.comments.to_sql - end + sql = Post.first.comments.to_sql assert_no_match(/\?/, sql) end @@ -81,31 +79,20 @@ class RelationMergingTest < ActiveRecord::TestCase left = Post.where(title: "omg").where(comments_count: 1) right = Post.where(title: "wtf").where(title: "bbq") - expected = [left.where_values[1]] + right.where_values + expected = [left.bind_values[1]] + right.bind_values merged = left.merge(right) - assert_equal expected, merged.where_values + assert_equal expected, merged.bind_values assert !merged.to_sql.include?("omg") assert merged.to_sql.include?("wtf") assert merged.to_sql.include?("bbq") end - def test_merging_removes_rhs_bind_parameters - left = Post.where(id: Arel::Nodes::BindParam.new('?')) - column = Post.columns_hash['id'] - left.bind_values += [[column, 20]] - right = Post.where(id: 10) - - merged = left.merge(right) - assert_equal [], merged.bind_values - end - def test_merging_keeps_lhs_bind_parameters column = Post.columns_hash['id'] binds = [[column, 20]] - right = Post.where(id: Arel::Nodes::BindParam.new('?')) - right.bind_values += binds + right = Post.where(id: 20) left = Post.where(id: 10) merged = left.merge(right) @@ -113,17 +100,9 @@ class RelationMergingTest < ActiveRecord::TestCase end def test_merging_reorders_bind_params - post = Post.first - id_column = Post.columns_hash['id'] - title_column = Post.columns_hash['title'] - - bv = Post.connection.substitute_at id_column, 0 - - right = Post.where(id: bv) - right.bind_values += [[id_column, post.id]] - - left = Post.where(title: bv) - left.bind_values += [[title_column, post.title]] + post = Post.first + right = Post.where(id: 1) + left = Post.where(title: post.title) merged = left.merge(right) assert_equal post, merged.first diff --git a/activerecord/test/cases/relation/where_chain_test.rb b/activerecord/test/cases/relation/where_chain_test.rb index c628ca44ff..c6decaad89 100644 --- a/activerecord/test/cases/relation/where_chain_test.rb +++ b/activerecord/test/cases/relation/where_chain_test.rb @@ -12,9 +12,15 @@ module ActiveRecord end def test_not_eq - expected = Post.arel_table[@name].not_eq('hello') relation = Post.where.not(title: 'hello') - assert_equal([expected], relation.where_values) + + assert_equal 1, relation.where_values.length + + value = relation.where_values.first + bind = relation.bind_values.first + + assert_bound_ast value, Post.arel_table[@name], Arel::Nodes::NotEqual + assert_equal 'hello', bind.last end def test_not_null @@ -44,21 +50,29 @@ module ActiveRecord def test_not_eq_with_preceding_where relation = Post.where(title: 'hello').where.not(title: 'world') - expected = Post.arel_table[@name].eq('hello') - assert_equal(expected, relation.where_values.first) + value = relation.where_values.first + bind = relation.bind_values.first + assert_bound_ast value, Post.arel_table[@name], Arel::Nodes::Equality + assert_equal 'hello', bind.last - expected = Post.arel_table[@name].not_eq('world') - assert_equal(expected, relation.where_values.last) + value = relation.where_values.last + bind = relation.bind_values.last + assert_bound_ast value, Post.arel_table[@name], Arel::Nodes::NotEqual + assert_equal 'world', bind.last end def test_not_eq_with_succeeding_where relation = Post.where.not(title: 'hello').where(title: 'world') - expected = Post.arel_table[@name].not_eq('hello') - assert_equal(expected, relation.where_values.first) + value = relation.where_values.first + bind = relation.bind_values.first + assert_bound_ast value, Post.arel_table[@name], Arel::Nodes::NotEqual + assert_equal 'hello', bind.last - expected = Post.arel_table[@name].eq('world') - assert_equal(expected, relation.where_values.last) + value = relation.where_values.last + bind = relation.bind_values.last + assert_bound_ast value, Post.arel_table[@name], Arel::Nodes::Equality + assert_equal 'world', bind.last end def test_not_eq_with_string_parameter @@ -79,38 +93,61 @@ module ActiveRecord expected = Post.arel_table['author_id'].not_in([1, 2]) assert_equal(expected, relation.where_values[0]) - expected = Post.arel_table[@name].not_eq('ruby on rails') - assert_equal(expected, relation.where_values[1]) + value = relation.where_values[1] + bind = relation.bind_values.first + + assert_bound_ast value, Post.arel_table[@name], Arel::Nodes::NotEqual + assert_equal 'ruby on rails', bind.last end def test_rewhere_with_one_condition relation = Post.where(title: 'hello').where(title: 'world').rewhere(title: 'alone') - expected = Post.arel_table[@name].eq('alone') assert_equal 1, relation.where_values.size - assert_equal expected, relation.where_values.first + value = relation.where_values.first + bind = relation.bind_values.first + assert_bound_ast value, Post.arel_table[@name], Arel::Nodes::Equality + assert_equal 'alone', bind.last end def test_rewhere_with_multiple_overwriting_conditions relation = Post.where(title: 'hello').where(body: 'world').rewhere(title: 'alone', body: 'again') - title_expected = Post.arel_table['title'].eq('alone') - body_expected = Post.arel_table['body'].eq('again') - assert_equal 2, relation.where_values.size - assert_equal title_expected, relation.where_values.first - assert_equal body_expected, relation.where_values.second + + value = relation.where_values.first + bind = relation.bind_values.first + assert_bound_ast value, Post.arel_table['title'], Arel::Nodes::Equality + assert_equal 'alone', bind.last + + value = relation.where_values[1] + bind = relation.bind_values[1] + assert_bound_ast value, Post.arel_table['body'], Arel::Nodes::Equality + assert_equal 'again', bind.last + end + + def assert_bound_ast value, table, type + assert_equal table, value.left + assert_kind_of type, value + assert_kind_of Arel::Nodes::BindParam, value.right end def test_rewhere_with_one_overwriting_condition_and_one_unrelated relation = Post.where(title: 'hello').where(body: 'world').rewhere(title: 'alone') - title_expected = Post.arel_table['title'].eq('alone') - body_expected = Post.arel_table['body'].eq('world') - assert_equal 2, relation.where_values.size - assert_equal body_expected, relation.where_values.first - assert_equal title_expected, relation.where_values.second + + value = relation.where_values.first + bind = relation.bind_values.first + + assert_bound_ast value, Post.arel_table['body'], Arel::Nodes::Equality + assert_equal 'world', bind.last + + value = relation.where_values.second + bind = relation.bind_values.second + + assert_bound_ast value, Post.arel_table['title'], Arel::Nodes::Equality + assert_equal 'alone', bind.last end end end diff --git a/activerecord/test/cases/relations_test.rb b/activerecord/test/cases/relations_test.rb index 562cfe6796..a2a2a79180 100644 --- a/activerecord/test/cases/relations_test.rb +++ b/activerecord/test/cases/relations_test.rb @@ -874,6 +874,14 @@ class RelationTest < ActiveRecord::TestCase assert_equal 11, posts.distinct(false).select(:comments_count).count end + def test_update_all_with_scope + tag = Tag.first + Post.tagged_with(tag.id).update_all title: "rofl" + list = Post.tagged_with(tag.id).all.to_a + assert_operator list.length, :>, 0 + list.each { |post| assert_equal 'rofl', post.title } + end + def test_count_explicit_columns Post.update_all(:comments_count => nil) posts = Post.all @@ -1621,10 +1629,8 @@ class RelationTest < ActiveRecord::TestCase end def test_merging_removes_rhs_bind_parameters - left = Post.where(id: Arel::Nodes::BindParam.new('?')) - column = Post.columns_hash['id'] - left.bind_values += [[column, 20]] - right = Post.where(id: 10) + left = Post.where(id: 20) + right = Post.where(id: [1,2,3,4]) merged = left.merge(right) assert_equal [], merged.bind_values @@ -1634,8 +1640,7 @@ class RelationTest < ActiveRecord::TestCase column = Post.columns_hash['id'] binds = [[column, 20]] - right = Post.where(id: Arel::Nodes::BindParam.new('?')) - right.bind_values += binds + right = Post.where(id: 20) left = Post.where(id: 10) merged = left.merge(right) @@ -1643,17 +1648,9 @@ class RelationTest < ActiveRecord::TestCase end def test_merging_reorders_bind_params - post = Post.first - id_column = Post.columns_hash['id'] - title_column = Post.columns_hash['title'] - - bv = Post.connection.substitute_at id_column, 0 - - right = Post.where(id: bv) - right.bind_values += [[id_column, post.id]] - - left = Post.where(title: bv) - left.bind_values += [[title_column, post.title]] + post = Post.first + right = Post.where(id: post.id) + left = Post.where(title: post.title) merged = left.merge(right) assert_equal post, merged.first diff --git a/activerecord/test/cases/statement_cache_test.rb b/activerecord/test/cases/statement_cache_test.rb index 76da49707f..a704b861cb 100644 --- a/activerecord/test/cases/statement_cache_test.rb +++ b/activerecord/test/cases/statement_cache_test.rb @@ -10,27 +10,61 @@ module ActiveRecord @connection = ActiveRecord::Base.connection end + #Cache v 1.1 tests + def test_statement_cache + Book.create(name: "my book") + Book.create(name: "my other book") + + cache = StatementCache.create(Book.connection) do |params| + Book.where(:name => params.bind) + end + + b = cache.execute([ "my book" ], Book, Book.connection) + assert_equal "my book", b[0].name + b = cache.execute([ "my other book" ], Book, Book.connection) + assert_equal "my other book", b[0].name + end + + + def test_statement_cache_id + b1 = Book.create(name: "my book") + b2 = Book.create(name: "my other book") + + cache = StatementCache.create(Book.connection) do |params| + Book.where(id: params.bind) + end + + b = cache.execute([ b1.id ], Book, Book.connection) + assert_equal b1.name, b[0].name + b = cache.execute([ b2.id ], Book, Book.connection) + assert_equal b2.name, b[0].name + end + + def test_find_or_create_by + Book.create(name: "my book") + + a = Book.find_or_create_by(name: "my book") + b = Book.find_or_create_by(name: "my other book") + + assert_equal("my book", a.name) + assert_equal("my other book", b.name) + end + + #End + def test_statement_cache_with_simple_statement - cache = ActiveRecord::StatementCache.new do + cache = ActiveRecord::StatementCache.create(Book.connection) do |params| Book.where(name: "my book").where("author_id > 3") end Book.create(name: "my book", author_id: 4) - books = cache.execute + books = cache.execute([], Book, Book.connection) assert_equal "my book", books[0].name end - def test_statement_cache_with_nil_statement_raises_error - assert_raise(ArgumentError) do - ActiveRecord::StatementCache.new do - nil - end - end - end - def test_statement_cache_with_complex_statement - cache = ActiveRecord::StatementCache.new do + cache = ActiveRecord::StatementCache.create(Book.connection) do |params| Liquid.joins(:molecules => :electrons).where('molecules.name' => 'dioxane', 'electrons.name' => 'lepton') end @@ -38,12 +72,12 @@ module ActiveRecord molecule = salty.molecules.create(name: 'dioxane') molecule.electrons.create(name: 'lepton') - liquids = cache.execute + liquids = cache.execute([], Book, Book.connection) assert_equal "salty", liquids[0].name end def test_statement_cache_values_differ - cache = ActiveRecord::StatementCache.new do + cache = ActiveRecord::StatementCache.create(Book.connection) do |params| Book.where(name: "my book") end @@ -51,13 +85,13 @@ module ActiveRecord Book.create(name: "my book") end - first_books = cache.execute + first_books = cache.execute([], Book, Book.connection) 3.times do Book.create(name: "my book") end - additional_books = cache.execute + additional_books = cache.execute([], Book, Book.connection) assert first_books != additional_books end end diff --git a/activerecord/test/models/post.rb b/activerecord/test/models/post.rb index 099e039255..d9ecaee1d9 100644 --- a/activerecord/test/models/post.rb +++ b/activerecord/test/models/post.rb @@ -40,6 +40,8 @@ class Post < ActiveRecord::Base scope :with_comments, -> { preload(:comments) } scope :with_tags, -> { preload(:taggings) } + scope :tagged_with, ->(id) { joins(:taggings).where(taggings: { tag_id: id }) } + has_many :comments do def find_most_recent order("id DESC").first |