diff options
Diffstat (limited to 'activerecord/lib/active_record')
39 files changed, 556 insertions, 424 deletions
diff --git a/activerecord/lib/active_record/associations/association_scope.rb b/activerecord/lib/active_record/associations/association_scope.rb index d06b7b3508..2416167834 100644 --- a/activerecord/lib/active_record/associations/association_scope.rb +++ b/activerecord/lib/active_record/associations/association_scope.rb @@ -2,42 +2,30 @@ module ActiveRecord module Associations class AssociationScope #:nodoc: def self.scope(association, connection) - INSTANCE.scope association, connection - end - - class BindSubstitution - def initialize(block) - @block = block - end - - def bind_value(scope, column, value, connection) - substitute = connection.substitute_at(column) - scope.bind_values += [[column, @block.call(value)]] - substitute - end + INSTANCE.scope(association, connection) end def self.create(&block) - block = block ? block : lambda { |val| val } - new BindSubstitution.new(block) + block ||= lambda { |val| val } + new(block) end - def initialize(bind_substitution) - @bind_substitution = bind_substitution + def initialize(value_transformation) + @value_transformation = value_transformation end INSTANCE = create def scope(association, connection) - klass = association.klass - reflection = association.reflection - scope = klass.unscoped - owner = association.owner + klass = association.klass + reflection = association.reflection + scope = klass.unscoped + owner = association.owner alias_tracker = AliasTracker.create connection, association.klass.table_name, klass.type_caster chain_head, chain_tail = get_chain(reflection, association, alias_tracker) scope.extending! Array(reflection.options[:extend]) - add_constraints(scope, owner, klass, reflection, connection, chain_head, chain_tail) + add_constraints(scope, owner, klass, reflection, chain_head, chain_tail) end def join_type @@ -61,43 +49,36 @@ module ActiveRecord binds end + protected + + attr_reader :value_transformation + private def join(table, constraint) table.create_join(table, table.create_on(constraint), join_type) end - def column_for(table_name, column_name, connection) - columns = connection.schema_cache.columns_hash(table_name) - columns[column_name] - end - - def bind_value(scope, column, value, connection) - @bind_substitution.bind_value scope, column, value, connection - end - - def bind(scope, table_name, column_name, value, connection) - column = column_for table_name, column_name, connection - bind_value scope, column, value, connection - end - - def last_chain_scope(scope, table, reflection, owner, connection, association_klass) + def last_chain_scope(scope, table, reflection, owner, association_klass) join_keys = reflection.join_keys(association_klass) key = join_keys.key foreign_key = join_keys.foreign_key - bind_val = bind scope, table.table_name, key.to_s, owner[foreign_key], connection - scope = scope.where(table[key].eq(bind_val)) + value = transform_value(owner[foreign_key]) + scope = scope.where(table.name => { key => value }) if reflection.type - value = owner.class.base_class.name - bind_val = bind scope, table.table_name, reflection.type, value, connection - scope = scope.where(table[reflection.type].eq(bind_val)) - else - scope + polymorphic_type = transform_value(owner.class.base_class.name) + scope = scope.where(table.name => { reflection.type => polymorphic_type }) end + + scope + end + + def transform_value(value) + value_transformation.call(value) end - def next_chain_scope(scope, table, reflection, connection, association_klass, foreign_table, next_reflection) + def next_chain_scope(scope, table, reflection, association_klass, foreign_table, next_reflection) join_keys = reflection.join_keys(association_klass) key = join_keys.key foreign_key = join_keys.foreign_key @@ -105,9 +86,8 @@ module ActiveRecord constraint = table[key].eq(foreign_table[foreign_key]) if reflection.type - value = next_reflection.klass.base_class.name - bind_val = bind scope, table.table_name, reflection.type, value, connection - scope = scope.where(table[reflection.type].eq(bind_val)) + value = transform_value(next_reflection.klass.base_class.name) + scope = scope.where(table.name => { reflection.type => value }) end scope = scope.joins(join(foreign_table, constraint)) @@ -138,10 +118,10 @@ module ActiveRecord [runtime_reflection, previous_reflection] end - def add_constraints(scope, owner, association_klass, refl, connection, chain_head, chain_tail) + def add_constraints(scope, owner, association_klass, refl, chain_head, chain_tail) owner_reflection = chain_tail table = owner_reflection.alias_name - scope = last_chain_scope(scope, table, owner_reflection, owner, connection, association_klass) + scope = last_chain_scope(scope, table, owner_reflection, owner, association_klass) reflection = chain_head loop do @@ -151,7 +131,7 @@ module ActiveRecord unless reflection == chain_tail next_reflection = reflection.next foreign_table = next_reflection.alias_name - scope = next_chain_scope(scope, table, reflection, connection, association_klass, foreign_table, next_reflection) + scope = next_chain_scope(scope, table, reflection, association_klass, foreign_table, next_reflection) end # Exclude the scope of the association itself, because that @@ -160,15 +140,14 @@ module ActiveRecord item = eval_scope(reflection.klass, scope_chain_item, owner) if scope_chain_item == refl.scope - scope.merge! item.except(:where, :includes, :bind) + scope.merge! item.except(:where, :includes) end reflection.all_includes do scope.includes! item.includes_values end - scope.where_values += item.where_values - scope.bind_values += item.bind_values + scope.where_clause += item.where_clause scope.order_values |= item.order_values end diff --git a/activerecord/lib/active_record/associations/belongs_to_association.rb b/activerecord/lib/active_record/associations/belongs_to_association.rb index c63b42e2a0..265a65c4c1 100644 --- a/activerecord/lib/active_record/associations/belongs_to_association.rb +++ b/activerecord/lib/active_record/associations/belongs_to_association.rb @@ -68,6 +68,9 @@ module ActiveRecord def increment_counter(counter_cache_name) if foreign_key_present? klass.increment_counter(counter_cache_name, target_id) + if target && !stale_target? + target.increment(counter_cache_name) + end end end diff --git a/activerecord/lib/active_record/associations/collection_association.rb b/activerecord/lib/active_record/associations/collection_association.rb index f2c96e9a2a..4b7591e15c 100644 --- a/activerecord/lib/active_record/associations/collection_association.rb +++ b/activerecord/lib/active_record/associations/collection_association.rb @@ -151,6 +151,7 @@ module ActiveRecord # be chained. Since << flattens its argument list and inserts each record, # +push+ and +concat+ behave identically. def concat(*records) + records = records.flatten if owner.new_record? load_target concat_records(records) @@ -549,7 +550,7 @@ module ActiveRecord def concat_records(records, should_raise = false) result = true - records.flatten.each do |record| + records.each do |record| raise_on_type_mismatch!(record) add_to_target(record) do |rec| result &&= insert_record(rec, true, should_raise) unless owner.new_record? diff --git a/activerecord/lib/active_record/associations/has_many_association.rb b/activerecord/lib/active_record/associations/has_many_association.rb index 2a782c06d0..b574185561 100644 --- a/activerecord/lib/active_record/associations/has_many_association.rb +++ b/activerecord/lib/active_record/associations/has_many_association.rb @@ -101,7 +101,7 @@ module ActiveRecord end def update_counter_in_memory(difference, reflection = reflection()) - if has_cached_counter?(reflection) + if counter_must_be_updated_by_has_many?(reflection) counter = cached_counter_attribute_name(reflection) owner[counter] += difference owner.send(:clear_attribute_changes, counter) # eww @@ -118,18 +118,28 @@ module ActiveRecord # it will be decremented twice. # # Hence this method. - def inverse_updates_counter_cache?(reflection = reflection()) + def inverse_which_updates_counter_cache(reflection = reflection()) counter_name = cached_counter_attribute_name(reflection) - inverse_updates_counter_named?(counter_name, reflection) + inverse_which_updates_counter_named(counter_name, reflection) end + alias inverse_updates_counter_cache? inverse_which_updates_counter_cache - def inverse_updates_counter_named?(counter_name, reflection = reflection()) - reflection.klass._reflections.values.any? { |inverse_reflection| + def inverse_which_updates_counter_named(counter_name, reflection) + reflection.klass._reflections.values.find { |inverse_reflection| inverse_reflection.belongs_to? && inverse_reflection.counter_cache_column == counter_name } end + def inverse_updates_counter_in_memory?(reflection) + inverse = inverse_which_updates_counter_cache(reflection) + inverse && inverse == reflection.inverse_of + end + + def counter_must_be_updated_by_has_many?(reflection) + !inverse_updates_counter_in_memory?(reflection) && has_cached_counter?(reflection) + end + def delete_count(method, scope) if method == :delete_all scope.delete_all diff --git a/activerecord/lib/active_record/associations/join_dependency/join_association.rb b/activerecord/lib/active_record/associations/join_dependency/join_association.rb index c1ef86a95b..3a184cee37 100644 --- a/activerecord/lib/active_record/associations/join_dependency/join_association.rb +++ b/activerecord/lib/active_record/associations/join_dependency/join_association.rb @@ -25,7 +25,7 @@ module ActiveRecord def join_constraints(foreign_table, foreign_klass, node, join_type, tables, scope_chain, chain) joins = [] - bind_values = [] + binds = [] tables = tables.reverse scope_chain_index = 0 @@ -66,7 +66,7 @@ module ActiveRecord end if rel && !rel.arel.constraints.empty? - bind_values.concat rel.bind_values + binds += rel.bound_attributes constraint = constraint.and rel.arel.constraints end @@ -75,7 +75,7 @@ module ActiveRecord column = klass.columns_hash[reflection.type.to_s] substitute = klass.connection.substitute_at(column) - bind_values.push [column, value] + binds << Attribute.with_cast_value(column.name, value, klass.type_for_attribute(column.name)) constraint = constraint.and table[reflection.type].eq substitute end @@ -85,7 +85,7 @@ module ActiveRecord foreign_table, foreign_klass = table, klass end - JoinInformation.new joins, bind_values + JoinInformation.new joins, binds end # Builds equality condition. diff --git a/activerecord/lib/active_record/associations/preloader.rb b/activerecord/lib/active_record/associations/preloader.rb index 4358f3b581..97f4bd3811 100644 --- a/activerecord/lib/active_record/associations/preloader.rb +++ b/activerecord/lib/active_record/associations/preloader.rb @@ -89,7 +89,7 @@ module ActiveRecord # { author: :avatar } # [ :books, { author: :avatar } ] - NULL_RELATION = Struct.new(:values, :bind_values).new({}, []) + NULL_RELATION = Struct.new(:values, :where_clause, :joins_values).new({}, Relation::WhereClause.empty, []) def preload(records, associations, preload_scope = nil) records = Array.wrap(records).compact.uniq diff --git a/activerecord/lib/active_record/associations/preloader/association.rb b/activerecord/lib/active_record/associations/preloader/association.rb index afcaa5d55a..9e4a2b925c 100644 --- a/activerecord/lib/active_record/associations/preloader/association.rb +++ b/activerecord/lib/active_record/associations/preloader/association.rb @@ -131,18 +131,19 @@ module ActiveRecord def build_scope scope = klass.unscoped - values = reflection_scope.values - reflection_binds = reflection_scope.bind_values + values = reflection_scope.values preload_values = preload_scope.values - preload_binds = preload_scope.bind_values - scope.where_values = Array(values[:where]) + Array(preload_values[:where]) + scope.where_clause = reflection_scope.where_clause + preload_scope.where_clause scope.references_values = Array(values[:references]) + Array(preload_values[:references]) - scope.bind_values = (reflection_binds + preload_binds) - scope._select! preload_values[:select] || values[:select] || table[Arel.star] + scope._select! preload_values[:select] || values[:select] || table[Arel.star] scope.includes! preload_values[:includes] || values[:includes] - scope.joins! preload_values[:joins] || values[:joins] + if preload_scope.joins_values.any? + scope.joins!(preload_scope.joins_values) + else + scope.joins!(reflection_scope.joins_values) + end scope.order! preload_values[:order] || values[:order] if preload_values[:readonly] || values[:readonly] diff --git a/activerecord/lib/active_record/associations/preloader/through_association.rb b/activerecord/lib/active_record/associations/preloader/through_association.rb index 12bf3ef138..56aa23b173 100644 --- a/activerecord/lib/active_record/associations/preloader/through_association.rb +++ b/activerecord/lib/active_record/associations/preloader/through_association.rb @@ -78,10 +78,9 @@ module ActiveRecord if options[:source_type] scope.where! reflection.foreign_type => options[:source_type] else - unless reflection_scope.where_values.empty? + unless reflection_scope.where_clause.empty? scope.includes_values = Array(reflection_scope.values[:includes] || options[:source]) - scope.where_values = reflection_scope.values[:where] - scope.bind_values = reflection_scope.bind_values + scope.where_clause = reflection_scope.where_clause end scope.references! reflection_scope.values[:references] diff --git a/activerecord/lib/active_record/associations/through_association.rb b/activerecord/lib/active_record/associations/through_association.rb index 09828dbd9b..3ce9cffdbc 100644 --- a/activerecord/lib/active_record/associations/through_association.rb +++ b/activerecord/lib/active_record/associations/through_association.rb @@ -18,7 +18,7 @@ module ActiveRecord reflection_scope = reflection.scope if reflection_scope && reflection_scope.arity.zero? - relation.merge!(reflection_scope) + relation = relation.merge(reflection_scope) end scope.merge!( diff --git a/activerecord/lib/active_record/attribute.rb b/activerecord/lib/active_record/attribute.rb index 51e4fae62e..eadc17035e 100644 --- a/activerecord/lib/active_record/attribute.rb +++ b/activerecord/lib/active_record/attribute.rb @@ -51,7 +51,7 @@ module ActiveRecord end def changed_in_place_from?(old_value) - type.changed_in_place?(old_value, value) + has_been_read? && type.changed_in_place?(old_value, value) end def with_value_from_user(value) @@ -78,6 +78,10 @@ module ActiveRecord false end + def has_been_read? + defined?(@value) + end + def ==(other) self.class == other.class && name == other.name && @@ -152,6 +156,6 @@ module ActiveRecord false end end - private_constant :FromDatabase, :FromUser, :Null, :Uninitialized + private_constant :FromDatabase, :FromUser, :Null, :Uninitialized, :WithCastValue end end diff --git a/activerecord/lib/active_record/attribute_assignment.rb b/activerecord/lib/active_record/attribute_assignment.rb index bf64830417..fdc90df5b6 100644 --- a/activerecord/lib/active_record/attribute_assignment.rb +++ b/activerecord/lib/active_record/attribute_assignment.rb @@ -3,63 +3,32 @@ require 'active_model/forbidden_attributes_protection' module ActiveRecord module AttributeAssignment extend ActiveSupport::Concern - include ActiveModel::ForbiddenAttributesProtection - - # Allows you to set all the attributes by passing in a hash of attributes with - # keys matching the attribute names (which again matches the column names). - # - # If the passed hash responds to <tt>permitted?</tt> method and the return value - # of this method is +false+ an <tt>ActiveModel::ForbiddenAttributesError</tt> - # exception is raised. - # - # cat = Cat.new(name: "Gorby", status: "yawning") - # cat.attributes # => { "name" => "Gorby", "status" => "yawning", "created_at" => nil, "updated_at" => nil} - # cat.assign_attributes(status: "sleeping") - # cat.attributes # => { "name" => "Gorby", "status" => "sleeping", "created_at" => nil, "updated_at" => nil } - # - # New attributes will be persisted in the database when the object is saved. - # - # Aliased to <tt>attributes=</tt>. - def assign_attributes(new_attributes) - if !new_attributes.respond_to?(:stringify_keys) - raise ArgumentError, "When assigning attributes, you must pass a hash as an argument." - end - return if new_attributes.blank? + include ActiveModel::AttributeAssignment + + # Alias for `assign_attributes`. See +ActiveModel::AttributeAssignment+. + def attributes=(attributes) + assign_attributes(attributes) + end - attributes = new_attributes.stringify_keys - multi_parameter_attributes = [] - nested_parameter_attributes = [] + private - attributes = sanitize_for_mass_assignment(attributes) + def _assign_attributes(attributes) # :nodoc: + multi_parameter_attributes = {} + nested_parameter_attributes = {} attributes.each do |k, v| if k.include?("(") - multi_parameter_attributes << [ k, v ] + multi_parameter_attributes[k] = attributes.delete(k) elsif v.is_a?(Hash) - nested_parameter_attributes << [ k, v ] - else - _assign_attribute(k, v) + nested_parameter_attributes[k] = attributes.delete(k) end end + super(attributes) assign_nested_parameter_attributes(nested_parameter_attributes) unless nested_parameter_attributes.empty? assign_multiparameter_attributes(multi_parameter_attributes) unless multi_parameter_attributes.empty? end - alias attributes= assign_attributes - - private - - def _assign_attribute(k, v) - public_send("#{k}=", v) - rescue NoMethodError - if respond_to?("#{k}=") - raise - else - raise UnknownAttributeError.new(self, k) - end - end - # Assign any deferred nested attributes after the base attributes have been set. def assign_nested_parameter_attributes(pairs) pairs.each { |k, v| _assign_attribute(k, v) } diff --git a/activerecord/lib/active_record/attribute_methods.rb b/activerecord/lib/active_record/attribute_methods.rb index 8f165fb1dc..6de71896ee 100644 --- a/activerecord/lib/active_record/attribute_methods.rb +++ b/activerecord/lib/active_record/attribute_methods.rb @@ -369,6 +369,39 @@ module ActiveRecord write_attribute(attr_name, value) end + # Returns the name of all database fields which have been read from this + # model. This can be useful in development mode to determine which fields + # need to be selected. For performance critical pages, selecting only the + # required fields can be an easy performance win (assuming you aren't using + # all of the fields on the model). + # + # For example: + # + # class PostsController < ActionController::Base + # after_action :print_accessed_fields, only: :index + # + # def index + # @posts = Post.all + # end + # + # private + # + # def print_accessed_fields + # p @posts.first.accessed_fields + # end + # end + # + # Which allows you to quickly change your code to: + # + # class PostsController < ActionController::Base + # def index + # @posts = Post.select(:id, :title, :author_id, :updated_at) + # end + # end + def accessed_fields + @attributes.accessed + end + protected def clone_attribute_value(reader_method, attribute_name) # :nodoc: diff --git a/activerecord/lib/active_record/attribute_methods/dirty.rb b/activerecord/lib/active_record/attribute_methods/dirty.rb index bce9c5e1e3..06d87ee01e 100644 --- a/activerecord/lib/active_record/attribute_methods/dirty.rb +++ b/activerecord/lib/active_record/attribute_methods/dirty.rb @@ -165,7 +165,7 @@ module ActiveRecord end def store_original_raw_attribute(attr_name) - original_raw_attributes[attr_name] = @attributes[attr_name].value_for_database + original_raw_attributes[attr_name] = @attributes[attr_name].value_for_database rescue nil end def store_original_raw_attributes diff --git a/activerecord/lib/active_record/attribute_set.rb b/activerecord/lib/active_record/attribute_set.rb index 66fcaf6945..fdce68ce45 100644 --- a/activerecord/lib/active_record/attribute_set.rb +++ b/activerecord/lib/active_record/attribute_set.rb @@ -64,6 +64,10 @@ module ActiveRecord end end + def accessed + attributes.select { |_, attr| attr.has_been_read? }.keys + end + protected attr_reader :attributes diff --git a/activerecord/lib/active_record/attributes.rb b/activerecord/lib/active_record/attributes.rb index b263a89d79..faf5d632ec 100644 --- a/activerecord/lib/active_record/attributes.rb +++ b/activerecord/lib/active_record/attributes.rb @@ -14,7 +14,7 @@ module ActiveRecord end module ClassMethods # :nodoc: - # Defines or overrides a attribute on this model. This allows customization of + # Defines or overrides an attribute on this model. This allows customization of # Active Record's type casting behavior, as well as adding support for user defined # types. # diff --git a/activerecord/lib/active_record/connection_adapters/abstract/schema_statements.rb b/activerecord/lib/active_record/connection_adapters/abstract/schema_statements.rb index 0f44c332ae..f905669a24 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract/schema_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract/schema_statements.rb @@ -380,7 +380,7 @@ module ActiveRecord # it can be helpful to provide these in a migration's +change+ method so it can be reverted. # In that case, +options+ and the block will be used by create_table. def drop_table(table_name, options = {}) - execute "DROP TABLE #{quote_table_name(table_name)}" + execute "DROP TABLE#{' IF EXISTS' if options[:if_exists]} #{quote_table_name(table_name)}" end # Adds a new column to the named table. diff --git a/activerecord/lib/active_record/connection_adapters/abstract/transaction.rb b/activerecord/lib/active_record/connection_adapters/abstract/transaction.rb index 7535e9147a..7f738e89c9 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract/transaction.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract/transaction.rb @@ -111,7 +111,6 @@ module ActiveRecord def rollback connection.rollback_to_savepoint(savepoint_name) super - rollback_records end def commit @@ -136,13 +135,11 @@ module ActiveRecord def rollback connection.rollback_db_transaction super - rollback_records end def commit connection.commit_db_transaction super - commit_records end end @@ -159,18 +156,28 @@ module ActiveRecord else SavepointTransaction.new(@connection, "active_record_#{@stack.size}", options) end + @stack.push(transaction) transaction end def commit_transaction - transaction = @stack.pop - transaction.commit - transaction.records.each { |r| current_transaction.add_record(r) } + inner_transaction = @stack.pop + inner_transaction.commit + + if current_transaction.joinable? + inner_transaction.records.each do |r| + current_transaction.add_record(r) + end + else + inner_transaction.commit_records + end end - def rollback_transaction - @stack.pop.rollback + def rollback_transaction(transaction = nil) + transaction ||= @stack.pop + transaction.rollback + transaction.rollback_records end def within_new_transaction(options = {}) @@ -187,7 +194,7 @@ module ActiveRecord begin commit_transaction rescue Exception - transaction.rollback unless transaction.state.completed? + rollback_transaction(transaction) unless transaction.state.completed? raise end end diff --git a/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb b/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb index e9a3c26c32..b61b717a61 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract_mysql_adapter.rb @@ -502,7 +502,7 @@ module ActiveRecord end def drop_table(table_name, options = {}) - execute "DROP#{' TEMPORARY' if options[:temporary]} TABLE #{quote_table_name(table_name)}#{' CASCADE' if options[:force] == :cascade}" + execute "DROP#{' TEMPORARY' if options[:temporary]} TABLE#{' IF EXISTS' if options[:if_exists]} #{quote_table_name(table_name)}#{' CASCADE' if options[:force] == :cascade}" end def rename_index(table_name, old_name, new_name) diff --git a/activerecord/lib/active_record/connection_adapters/postgresql/oid/json.rb b/activerecord/lib/active_record/connection_adapters/postgresql/oid/json.rb index e12ddd9901..7dadc09a44 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql/oid/json.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql/oid/json.rb @@ -11,7 +11,7 @@ module ActiveRecord def type_cast_from_database(value) if value.is_a?(::String) - ::ActiveSupport::JSON.decode(value) + ::ActiveSupport::JSON.decode(value) rescue nil else super end diff --git a/activerecord/lib/active_record/connection_adapters/postgresql/schema_statements.rb b/activerecord/lib/active_record/connection_adapters/postgresql/schema_statements.rb index a90adcf4aa..afd7a17c03 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql/schema_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql/schema_statements.rb @@ -112,7 +112,7 @@ module ActiveRecord end def drop_table(table_name, options = {}) - execute "DROP TABLE #{quote_table_name(table_name)}#{' CASCADE' if options[:force] == :cascade}" + execute "DROP TABLE#{' IF EXISTS' if options[:if_exists]} #{quote_table_name(table_name)}#{' CASCADE' if options[:force] == :cascade}" end # Returns true if schema exists. diff --git a/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb b/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb index 03dfd29a0a..52dce6291a 100644 --- a/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/sqlite3_adapter.rb @@ -53,7 +53,7 @@ module ActiveRecord class SQLite3String < Type::String # :nodoc: def type_cast_for_database(value) if value.is_a?(::String) && value.encoding == Encoding::ASCII_8BIT - value.encode(Encoding::UTF_8) + value.encode(Encoding::UTF_8, undef: :replace) else super end diff --git a/activerecord/lib/active_record/core.rb b/activerecord/lib/active_record/core.rb index a7aff9f724..4705f129f2 100644 --- a/activerecord/lib/active_record/core.rb +++ b/activerecord/lib/active_record/core.rb @@ -456,6 +456,7 @@ module ActiveRecord # Takes a PP and prettily prints this record to it, allowing you to get a nice result from `pp record` # when pp is required. def pretty_print(pp) + return super if custom_inspect_method_defined? pp.object_address_group(self) do if defined?(@attributes) && @attributes column_names = self.class.column_names.select { |name| has_attribute?(name) || new_record? } @@ -560,5 +561,9 @@ module ActiveRecord @attributes = @attributes.dup end end + + def custom_inspect_method_defined? + self.class.instance_method(:inspect).owner != ActiveRecord::Base.instance_method(:inspect).owner + end end end diff --git a/activerecord/lib/active_record/counter_cache.rb b/activerecord/lib/active_record/counter_cache.rb index 7d8e0a2063..82596b63df 100644 --- a/activerecord/lib/active_record/counter_cache.rb +++ b/activerecord/lib/active_record/counter_cache.rb @@ -156,7 +156,7 @@ module ActiveRecord def each_counter_cached_associations _reflections.each do |name, reflection| - yield association(name) if reflection.belongs_to? && reflection.counter_cache_column + yield association(name.to_sym) if reflection.belongs_to? && reflection.counter_cache_column end end diff --git a/activerecord/lib/active_record/errors.rb b/activerecord/lib/active_record/errors.rb index fc28ab585f..d710d96a9a 100644 --- a/activerecord/lib/active_record/errors.rb +++ b/activerecord/lib/active_record/errors.rb @@ -178,18 +178,10 @@ module ActiveRecord class DangerousAttributeError < ActiveRecordError end - # Raised when unknown attributes are supplied via mass assignment. - class UnknownAttributeError < NoMethodError - - attr_reader :record, :attribute - - def initialize(record, attribute) - @record = record - @attribute = attribute.to_s - super("unknown attribute '#{attribute}' for #{@record.class}.") - end - - end + UnknownAttributeError = ActiveSupport::Deprecation::DeprecatedConstantProxy.new( # :nodoc: + 'ActiveRecord::UnknownAttributeError', + 'ActiveModel::AttributeAssignment::UnknownAttributeError' + ) # Raised when an error occurred while doing a mass assignment to an attribute through the # +attributes=+ method. The exception has an +attribute+ property that is the name of the diff --git a/activerecord/lib/active_record/model_schema.rb b/activerecord/lib/active_record/model_schema.rb index 641512d323..af0b667262 100644 --- a/activerecord/lib/active_record/model_schema.rb +++ b/activerecord/lib/active_record/model_schema.rb @@ -111,17 +111,6 @@ module ActiveRecord # class Mouse < ActiveRecord::Base # self.table_name = "mice" # end - # - # Alternatively, you can override the table_name method to define your - # own computation. (Possibly using <tt>super</tt> to manipulate the default - # table name.) Example: - # - # class Post < ActiveRecord::Base - # def self.table_name - # "special_" + super - # end - # end - # Post.table_name # => "special_posts" def table_name reset_table_name unless defined?(@table_name) @table_name @@ -132,9 +121,6 @@ module ActiveRecord # class Project < ActiveRecord::Base # self.table_name = "project" # end - # - # You can also just define your own <tt>self.table_name</tt> method; see - # the documentation for ActiveRecord::Base#table_name. def table_name=(value) value = value && value.to_s diff --git a/activerecord/lib/active_record/persistence.rb b/activerecord/lib/active_record/persistence.rb index 6306a25745..22112fe8ff 100644 --- a/activerecord/lib/active_record/persistence.rb +++ b/activerecord/lib/active_record/persistence.rb @@ -245,8 +245,8 @@ module ActiveRecord def update_attribute(name, value) name = name.to_s verify_readonly_attribute(name) - send("#{name}=", value) - save(validate: false) + public_send("#{name}=", value) + save(validate: false) if changed? end # Updates the attributes of the model from the passed-in hash and saves the @@ -352,7 +352,7 @@ module ActiveRecord # method toggles directly the underlying value without calling any setter. # Returns +self+. def toggle(attribute) - self[attribute] = !send("#{attribute}?") + self[attribute] = !public_send("#{attribute}?") self end diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb index dd78814c6a..07e27fe59e 100644 --- a/activerecord/lib/active_record/relation.rb +++ b/activerecord/lib/active_record/relation.rb @@ -1,18 +1,19 @@ # -*- coding: utf-8 -*- -require 'arel/collectors/bind' +require "arel/collectors/bind" module ActiveRecord # = Active Record Relation class Relation MULTI_VALUE_METHODS = [:includes, :eager_load, :preload, :select, :group, - :order, :joins, :where, :having, :bind, :references, + :order, :joins, :references, :extending, :unscope] - SINGLE_VALUE_METHODS = [:limit, :offset, :lock, :readonly, :from, :reordering, + SINGLE_VALUE_METHODS = [:limit, :offset, :lock, :readonly, :reordering, :reverse_order, :distinct, :create_with, :uniq] + CLAUSE_METHODS = [:where, :having, :from] INVALID_METHODS_FOR_DELETE_ALL = [:limit, :distinct, :offset, :group, :having] - VALUE_METHODS = MULTI_VALUE_METHODS + SINGLE_VALUE_METHODS + VALUE_METHODS = MULTI_VALUE_METHODS + SINGLE_VALUE_METHODS + CLAUSE_METHODS include FinderMethods, Calculations, SpawnMethods, QueryMethods, Batches, Explain, Delegation @@ -33,7 +34,6 @@ module ActiveRecord # This method is a hot spot, so for now, use Hash[] to dup the hash. # https://bugs.ruby-lang.org/issues/7166 @values = Hash[@values] - @values[:bind] = @values[:bind].dup if @values.key? :bind reset end @@ -342,8 +342,7 @@ module ActiveRecord stmt.wheres = arel.constraints end - bvs = arel.bind_values + bind_values - @klass.connection.update stmt, 'SQL', bvs + @klass.connection.update stmt, 'SQL', bind_values end # Updates an object (or multiple objects) and saves it to the database, if validations pass. @@ -467,8 +466,10 @@ module ActiveRecord invalid_methods = INVALID_METHODS_FOR_DELETE_ALL.select { |method| if MULTI_VALUE_METHODS.include?(method) send("#{method}_values").any? - else + elsif SINGLE_VALUE_METHODS.include?(method) send("#{method}_value") + elsif CLAUSE_METHODS.include?(method) + send("#{method}_clause").any? end } if invalid_methods.any? @@ -557,11 +558,10 @@ module ActiveRecord find_with_associations { |rel| relation = rel } end - arel = relation.arel - binds = arel.bind_values + relation.bind_values + binds = relation.bind_values binds = connection.prepare_binds_for_database(binds) binds.map! { |_, value| connection.quote(value) } - collect = visitor.accept(arel.ast, Arel::Collectors::Bind.new) + collect = visitor.accept(relation.arel.ast, Arel::Collectors::Bind.new) collect.substitute_binds(binds).join end end @@ -571,22 +571,7 @@ module ActiveRecord # User.where(name: 'Oscar').where_values_hash # # => {name: "Oscar"} def where_values_hash(relation_table_name = table_name) - equalities = where_values.grep(Arel::Nodes::Equality).find_all { |node| - node.left.relation.name == relation_table_name - } - - binds = Hash[bind_values.find_all(&:first).map { |column, v| [column.name, v] }] - - Hash[equalities.map { |where| - name = where.left.name - [name, binds.fetch(name.to_s) { - case where.right - when Array then where.right.map(&:val) - when Arel::Nodes::Casted, Arel::Nodes::Quoted - where.right.val - end - }] - }] + where_clause.to_h(relation_table_name) end def scope_for_create @@ -649,7 +634,7 @@ module ActiveRecord private def exec_queries - @records = eager_loading? ? find_with_associations : @klass.find_by_sql(arel, arel.bind_values + bind_values) + @records = eager_loading? ? find_with_associations : @klass.find_by_sql(arel, bind_values) preload = preload_values preload += includes_values unless eager_loading? diff --git a/activerecord/lib/active_record/relation/calculations.rb b/activerecord/lib/active_record/relation/calculations.rb index 1d4cb1a83b..724bba5f87 100644 --- a/activerecord/lib/active_record/relation/calculations.rb +++ b/activerecord/lib/active_record/relation/calculations.rb @@ -166,7 +166,7 @@ module ActiveRecord relation.select_values = column_names.map { |cn| columns_hash.key?(cn) ? arel_table[cn] : cn } - result = klass.connection.select_all(relation.arel, nil, relation.arel.bind_values + bind_values) + result = klass.connection.select_all(relation.arel, nil, bind_values) result.cast_values(klass.column_types) end end @@ -227,14 +227,11 @@ module ActiveRecord column_alias = column_name - bind_values = nil - if operation == "count" && (relation.limit_value || relation.offset_value) # Shortcut when limit is zero. return 0 if relation.limit_value == 0 query_builder = build_count_subquery(relation, column_name, distinct) - bind_values = query_builder.bind_values + relation.bind_values else column = aggregate_column(column_name) @@ -245,7 +242,6 @@ module ActiveRecord relation.select_values = [select_value] query_builder = relation.arel - bind_values = query_builder.bind_values + relation.bind_values end result = @klass.connection.select_all(query_builder, nil, bind_values) @@ -290,7 +286,7 @@ module ActiveRecord operation, distinct).as(aggregate_alias) ] - select_values += select_values unless having_values.empty? + select_values += select_values unless having_clause.empty? select_values.concat group_fields.zip(group_aliases).map { |field,aliaz| if field.respond_to?(:as) @@ -304,7 +300,7 @@ module ActiveRecord relation.group_values = group relation.select_values = select_values - calculated_data = @klass.connection.select_all(relation, nil, relation.arel.bind_values + bind_values) + calculated_data = @klass.connection.select_all(relation, nil, relation.bind_values) if association key_ids = calculated_data.collect { |row| row[group_aliases.first] } @@ -378,11 +374,9 @@ module ActiveRecord aliased_column = aggregate_column(column_name == :all ? 1 : column_name).as(column_alias) relation.select_values = [aliased_column] - arel = relation.arel - subquery = arel.as(subquery_alias) + subquery = relation.arel.as(subquery_alias) sm = Arel::SelectManager.new relation.engine - sm.bind_values = arel.bind_values select_value = operation_over_aggregate_column(column_alias, 'count', distinct) sm.project(select_value).from(subquery) end diff --git a/activerecord/lib/active_record/relation/finder_methods.rb b/activerecord/lib/active_record/relation/finder_methods.rb index 088bc203b7..77bf1217e6 100644 --- a/activerecord/lib/active_record/relation/finder_methods.rb +++ b/activerecord/lib/active_record/relation/finder_methods.rb @@ -307,11 +307,11 @@ module ActiveRecord relation = relation.where(conditions) else unless conditions == :none - relation = where(primary_key => conditions) + relation = relation.where(primary_key => conditions) end end - connection.select_value(relation, "#{name} Exists", relation.arel.bind_values + relation.bind_values) ? true : false + connection.select_value(relation, "#{name} Exists", relation.bind_values) ? true : false end # This method is called whenever no records are found with either a single @@ -365,7 +365,7 @@ module ActiveRecord [] else arel = relation.arel - rows = connection.select_all(arel, 'SQL', arel.bind_values + relation.bind_values) + rows = connection.select_all(arel, 'SQL', relation.bind_values) join_dependency.instantiate(rows, aliases) end end @@ -410,7 +410,7 @@ module ActiveRecord relation = relation.except(:select).select(values).distinct! arel = relation.arel - id_rows = @klass.connection.select_all(arel, 'SQL', arel.bind_values + relation.bind_values) + id_rows = @klass.connection.select_all(arel, 'SQL', relation.bind_values) id_rows.map {|row| row[primary_key]} end diff --git a/activerecord/lib/active_record/relation/from_clause.rb b/activerecord/lib/active_record/relation/from_clause.rb new file mode 100644 index 0000000000..a93952fa30 --- /dev/null +++ b/activerecord/lib/active_record/relation/from_clause.rb @@ -0,0 +1,32 @@ +module ActiveRecord + class Relation + class FromClause + attr_reader :value, :name + + def initialize(value, name) + @value = value + @name = name + end + + def binds + if value.is_a?(Relation) + value.bound_attributes + else + [] + end + end + + def merge(other) + self + end + + def empty? + value.nil? + end + + def self.empty + new(nil, nil) + end + end + end +end diff --git a/activerecord/lib/active_record/relation/merger.rb b/activerecord/lib/active_record/relation/merger.rb index afb0b208c3..65b607ff1c 100644 --- a/activerecord/lib/active_record/relation/merger.rb +++ b/activerecord/lib/active_record/relation/merger.rb @@ -49,9 +49,9 @@ module ActiveRecord @other = other end - NORMAL_VALUES = Relation::SINGLE_VALUE_METHODS + - Relation::MULTI_VALUE_METHODS - - [:joins, :where, :order, :bind, :reverse_order, :lock, :create_with, :reordering, :from] # :nodoc: + NORMAL_VALUES = Relation::VALUE_METHODS - + Relation::CLAUSE_METHODS - + [:joins, :order, :reverse_order, :lock, :create_with, :reordering] # :nodoc: def normal_values NORMAL_VALUES @@ -75,6 +75,7 @@ module ActiveRecord merge_multi_values merge_single_values + merge_clauses merge_joins relation @@ -107,20 +108,6 @@ module ActiveRecord end def merge_multi_values - lhs_wheres = relation.where_values - rhs_wheres = other.where_values - - lhs_binds = relation.bind_values - rhs_binds = other.bind_values - - removed, kept = partition_overwrites(lhs_wheres, rhs_wheres) - - where_values = kept + rhs_wheres - bind_values = filter_binds(lhs_binds, removed) + rhs_binds - - relation.where_values = where_values - relation.bind_values = bind_values - if other.reordering_value # override any order specified in the original relation relation.reorder! other.order_values @@ -133,36 +120,18 @@ module ActiveRecord end def merge_single_values - relation.from_value = other.from_value unless relation.from_value - relation.lock_value = other.lock_value unless relation.lock_value + relation.lock_value ||= other.lock_value unless other.create_with_value.blank? relation.create_with_value = (relation.create_with_value || {}).merge(other.create_with_value) end end - def filter_binds(lhs_binds, removed_wheres) - return lhs_binds if removed_wheres.empty? - - set = Set.new removed_wheres.map { |x| x.left.name.to_s } - lhs_binds.dup.delete_if { |col,_| set.include? col.name } - end - - # Remove equalities from the existing relation with a LHS which is - # present in the relation being merged in. - # returns [things_to_remove, things_to_keep] - def partition_overwrites(lhs_wheres, rhs_wheres) - if lhs_wheres.empty? || rhs_wheres.empty? - return [[], lhs_wheres] - end - - nodes = rhs_wheres.find_all do |w| - w.respond_to?(:operator) && w.operator == :== - end - seen = Set.new(nodes) { |node| node.left } - - lhs_wheres.partition do |w| - w.respond_to?(:operator) && w.operator == :== && seen.include?(w.left) + def merge_clauses + CLAUSE_METHODS.each do |name| + clause = relation.send("#{name}_clause") + other_clause = other.send("#{name}_clause") + relation.send("#{name}_clause=", clause.merge(other_clause)) end end end diff --git a/activerecord/lib/active_record/relation/predicate_builder.rb b/activerecord/lib/active_record/relation/predicate_builder.rb index 567efce8ae..7b605db88c 100644 --- a/activerecord/lib/active_record/relation/predicate_builder.rb +++ b/activerecord/lib/active_record/relation/predicate_builder.rb @@ -28,6 +28,11 @@ module ActiveRecord expand_from_hash(attributes) end + def create_binds(attributes) + attributes = convert_dot_notation_to_hash(attributes.stringify_keys) + create_binds_for_hash(attributes) + end + def expand(column, value) # Find the foreign key when using queries such as: # Post.where(author: author) @@ -80,16 +85,43 @@ module ActiveRecord attributes.flat_map do |key, value| if value.is_a?(Hash) - builder = self.class.new(table.associated_table(key)) - builder.expand_from_hash(value) + associated_predicate_builder(key).expand_from_hash(value) else expand(key, value) end end end + + def create_binds_for_hash(attributes) + result = attributes.dup + binds = [] + + attributes.each do |column_name, value| + case value + when Hash + attrs, bvs = associated_predicate_builder(column_name).create_binds_for_hash(value) + result[column_name] = attrs + binds += bvs + when Relation + binds += value.bound_attributes + else + if can_be_bound?(column_name, value) + result[column_name] = Arel::Nodes::BindParam.new + binds << Attribute.with_cast_value(column_name.to_s, value, table.type(column_name)) + end + end + end + + [result, binds] + end + private + def associated_predicate_builder(association_name) + self.class.new(table.associated_table(association_name)) + end + def convert_dot_notation_to_hash(attributes) dot_notation = attributes.keys.select { |s| s.include?(".") } @@ -107,5 +139,11 @@ module ActiveRecord def handler_for(object) @handlers.detect { |klass, _| klass === object }.last end + + def can_be_bound?(column_name, value) + !value.nil? && + handler_for(value).is_a?(BasicObjectHandler) && + !table.associated_with?(column_name) + end end end diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index d6e6cb4d05..bc5f126a23 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -1,6 +1,8 @@ -require 'active_support/core_ext/array/wrap' -require 'active_support/core_ext/string/filters' +require "active_record/relation/from_clause" +require "active_record/relation/where_clause" +require "active_record/relation/where_clause_factory" require 'active_model/forbidden_attributes_protection' +require 'active_support/core_ext/string/filters' module ActiveRecord module QueryMethods @@ -39,23 +41,10 @@ module ActiveRecord # User.where.not(name: "Jon", role: "admin") # # SELECT * FROM users WHERE name != 'Jon' AND role != 'admin' def not(opts, *rest) - where_value = @scope.send(:build_where, opts, rest).map do |rel| - case rel - when NilClass - raise ArgumentError, 'Invalid argument for .where.not(), got nil.' - when Arel::Nodes::In - Arel::Nodes::NotIn.new(rel.left, rel.right) - when Arel::Nodes::Equality - Arel::Nodes::NotEqual.new(rel.left, rel.right) - when String - Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new(rel)) - else - Arel::Nodes::Not.new(rel) - end - end + where_clause = @scope.send(:where_clause_factory).build(opts, rest) @scope.references!(PredicateBuilder.references(opts)) if Hash === opts - @scope.where_values += where_value + @scope.where_clause += where_clause.invert @scope end end @@ -90,6 +79,33 @@ module ActiveRecord CODE end + Relation::CLAUSE_METHODS.each do |name| + class_eval <<-CODE, __FILE__, __LINE__ + 1 + def #{name}_clause # def where_clause + @values[:#{name}] || new_#{name}_clause # @values[:where] || new_where_clause + end # end + # + def #{name}_clause=(value) # def where_clause=(value) + assert_mutability! # assert_mutability! + @values[:#{name}] = value # @values[:where] = value + end # end + CODE + end + + def bound_attributes + from_clause.binds + arel.bind_values + where_clause.binds + having_clause.binds + end + + def bind_values + # convert to old style + bound_attributes.map do |attribute| + if attribute.name + column = ConnectionAdapters::Column.new(attribute.name, nil, attribute.type) + end + [column, attribute.value_before_type_cast] + end + end + def create_with_value # :nodoc: @values[:create_with] || {} end @@ -392,9 +408,8 @@ module ActiveRecord raise ArgumentError, "Hash arguments in .unscope(*args) must have :where as the key." end - Array(target_value).each do |val| - where_unscoping(val) - end + target_values = Array(target_value).map(&:to_s) + self.where_clause = where_clause.except(*target_values) end else raise ArgumentError, "Unrecognized scoping: #{args.inspect}. Use .unscope(where: :attribute_name) or .unscope(:order), for example." @@ -425,15 +440,6 @@ module ActiveRecord self end - def bind(value) # :nodoc: - spawn.bind!(value) - end - - def bind!(value) # :nodoc: - self.bind_values += [value] - self - end - # Returns a new relation, which is the result of filtering the current relation # according to the conditions in the arguments. # @@ -569,7 +575,7 @@ module ActiveRecord references!(PredicateBuilder.references(opts)) end - self.where_values += build_where(opts, rest) + self.where_clause += where_clause_factory.build(opts, rest) self end @@ -596,7 +602,7 @@ module ActiveRecord def having!(opts, *rest) # :nodoc: references!(PredicateBuilder.references(opts)) if Hash === opts - self.having_values += build_where(opts, rest) + self.having_clause += having_clause_factory.build(opts, rest) self end @@ -744,10 +750,7 @@ module ActiveRecord end def from!(value, subquery_name = nil) # :nodoc: - self.from_value = [value, subquery_name] - if value.is_a? Relation - self.bind_values = value.arel.bind_values + value.bind_values + bind_values - end + self.from_clause = Relation::FromClause.new(value, subquery_name) self end @@ -858,13 +861,10 @@ module ActiveRecord build_joins(arel, joins_values.flatten) unless joins_values.empty? - collapse_wheres(arel, (where_values - [''])) #TODO: Add uniq with real value comparison / ignore uniqs that have binds - - arel.having(*having_values.uniq.reject(&:blank?)) unless having_values.empty? - + arel.where(where_clause.ast) unless where_clause.empty? + arel.having(having_clause.ast) unless having_clause.empty? arel.take(connection.sanitize_limit(limit_value)) if limit_value arel.skip(offset_value.to_i) if offset_value - arel.group(*group_values.uniq.reject(&:blank?)) unless group_values.empty? build_order(arel) @@ -872,7 +872,7 @@ module ActiveRecord build_select(arel, select_values.uniq) arel.distinct(distinct_value) - arel.from(build_from) if from_value + arel.from(build_from) unless from_clause.empty? arel.lock(lock_value) if lock_value arel @@ -883,114 +883,24 @@ module ActiveRecord raise ArgumentError, "Called unscope() with invalid unscoping argument ':#{scope}'. Valid arguments are :#{VALID_UNSCOPING_VALUES.to_a.join(", :")}." end - single_val_method = Relation::SINGLE_VALUE_METHODS.include?(scope) - unscope_code = "#{scope}_value#{'s' unless single_val_method}=" + clause_method = Relation::CLAUSE_METHODS.include?(scope) + multi_val_method = Relation::MULTI_VALUE_METHODS.include?(scope) + if clause_method + unscope_code = "#{scope}_clause=" + else + unscope_code = "#{scope}_value#{'s' if multi_val_method}=" + end case scope when :order result = [] - when :where - self.bind_values = [] else - result = [] unless single_val_method + result = [] if multi_val_method end self.send(unscope_code, result) end - def where_unscoping(target_value) - target_value = target_value.to_s - - where_values.reject! do |rel| - case rel - when Arel::Nodes::Between, Arel::Nodes::In, Arel::Nodes::NotIn, Arel::Nodes::Equality, Arel::Nodes::NotEqual, Arel::Nodes::LessThanOrEqual, Arel::Nodes::GreaterThanOrEqual - subrelation = (rel.left.kind_of?(Arel::Attributes::Attribute) ? rel.left : rel.right) - subrelation.name.to_s == target_value - end - end - - bind_values.reject! { |col,_| col.name == target_value } - end - - def custom_join_ast(table, joins) - joins = joins.reject(&:blank?) - - return [] if joins.empty? - - joins.map! do |join| - case join - when Array - join = Arel.sql(join.join(' ')) if array_of_strings?(join) - when String - join = Arel.sql(join) - end - table.create_string_join(join) - end - end - - def collapse_wheres(arel, wheres) - predicates = wheres.map do |where| - next where if ::Arel::Nodes::Equality === where - where = Arel.sql(where) if String === where - Arel::Nodes::Grouping.new(where) - end - - arel.where(Arel::Nodes::And.new(predicates)) if predicates.present? - end - - def build_where(opts, other = []) - case opts - when String, Array - [@klass.send(:sanitize_sql, other.empty? ? opts : ([opts] + other))] - when Hash - opts = predicate_builder.resolve_column_aliases(opts) - - tmp_opts, bind_values = create_binds(opts) - self.bind_values += bind_values - - attributes = @klass.send(:expand_hash_conditions_for_aggregates, tmp_opts) - add_relations_to_bind_values(attributes) - - predicate_builder.build_from_hash(attributes) - else - [opts] - end - end - - def create_binds(opts) - bindable, non_binds = opts.partition do |column, value| - case value - when String, Integer, ActiveRecord::StatementCache::Substitute - @klass.columns_hash.include? column.to_s - else - false - end - end - - association_binds, non_binds = non_binds.partition do |column, value| - value.is_a?(Hash) && association_for_table(column) - end - - new_opts = {} - binds = [] - - bindable.each do |(column,value)| - binds.push [@klass.columns_hash[column.to_s], value] - new_opts[column] = connection.substitute_at(column) - end - - association_binds.each do |(column, value)| - association_relation = association_for_table(column).klass.send(:relation) - association_new_opts, association_bind = association_relation.send(:create_binds, value) - new_opts[column] = association_new_opts - binds += association_bind - end - - non_binds.each { |column,value| new_opts[column] = value } - - [new_opts, binds] - end - def association_for_table(table_name) table_name = table_name.to_s @klass._reflect_on_association(table_name) || @@ -998,7 +908,8 @@ module ActiveRecord end def build_from - opts, name = from_value + opts = from_clause.value + name = from_clause.name case opts when Relation name ||= 'subquery' @@ -1023,13 +934,14 @@ module ActiveRecord raise 'unknown class: %s' % join.class.name end end + buckets.default = [] - association_joins = buckets[:association_join] || [] - stashed_association_joins = buckets[:stashed_join] || [] - join_nodes = (buckets[:join_node] || []).uniq - string_joins = (buckets[:string_join] || []).map(&:strip).uniq + association_joins = buckets[:association_join] + stashed_association_joins = buckets[:stashed_join] + join_nodes = buckets[:join_node].uniq + string_joins = buckets[:string_join].map(&:strip).uniq - join_list = join_nodes + custom_join_ast(manager, string_joins) + join_list = join_nodes + convert_join_strings_to_ast(manager, string_joins) join_dependency = ActiveRecord::Associations::JoinDependency.new( @klass, @@ -1049,6 +961,13 @@ module ActiveRecord manager end + def convert_join_strings_to_ast(table, joins) + joins + .flatten + .reject(&:blank?) + .map { |join| table.create_string_join(Arel.sql(join)) } + end + def build_select(arel, selects) if !selects.empty? expanded_select = selects.map do |field| @@ -1083,10 +1002,6 @@ module ActiveRecord end end - def array_of_strings?(o) - o.is_a?(Array) && o.all? { |obj| obj.is_a?(String) } - end - def build_order(arel) orders = order_values.uniq orders.reject!(&:blank?) @@ -1154,16 +1069,18 @@ module ActiveRecord end end - def add_relations_to_bind_values(attributes) - if attributes.is_a?(Hash) - attributes.each_value do |value| - if value.is_a?(ActiveRecord::Relation) - self.bind_values += value.arel.bind_values + value.bind_values - else - add_relations_to_bind_values(value) - end - end - end + def new_where_clause + Relation::WhereClause.empty + end + alias new_having_clause new_where_clause + + def where_clause_factory + @where_clause_factory ||= Relation::WhereClauseFactory.new(klass, predicate_builder) + end + alias having_clause_factory where_clause_factory + + def new_from_clause + Relation::FromClause.empty end end end diff --git a/activerecord/lib/active_record/relation/spawn_methods.rb b/activerecord/lib/active_record/relation/spawn_methods.rb index 01bddea6c9..dd3610d7aa 100644 --- a/activerecord/lib/active_record/relation/spawn_methods.rb +++ b/activerecord/lib/active_record/relation/spawn_methods.rb @@ -58,9 +58,6 @@ module ActiveRecord # Post.order('id asc').only(:where) # discards the order condition # Post.order('id asc').only(:where, :order) # uses the specified order def only(*onlies) - if onlies.any? { |o| o == :where } - onlies << :bind - end relation_with values.slice(*onlies) end diff --git a/activerecord/lib/active_record/relation/where_clause.rb b/activerecord/lib/active_record/relation/where_clause.rb new file mode 100644 index 0000000000..ae5667dfd6 --- /dev/null +++ b/activerecord/lib/active_record/relation/where_clause.rb @@ -0,0 +1,160 @@ +module ActiveRecord + class Relation + class WhereClause # :nodoc: + attr_reader :binds + + delegate :any?, :empty?, to: :predicates + + def initialize(predicates, binds) + @predicates = predicates + @binds = binds + end + + def +(other) + WhereClause.new( + predicates + other.predicates, + binds + other.binds, + ) + end + + def merge(other) + WhereClause.new( + predicates_unreferenced_by(other) + other.predicates, + non_conflicting_binds(other) + other.binds, + ) + end + + def except(*columns) + WhereClause.new( + predicates_except(columns), + binds_except(columns), + ) + end + + def to_h(table_name = nil) + equalities = predicates.grep(Arel::Nodes::Equality) + if table_name + equalities = equalities.select do |node| + node.left.relation.name == table_name + end + end + + binds = self.binds.map { |attr| [attr.name, attr.value] }.to_h + + equalities.map { |node| + name = node.left.name + [name, binds.fetch(name.to_s) { + case node.right + when Array then node.right.map(&:val) + when Arel::Nodes::Casted, Arel::Nodes::Quoted + node.right.val + end + }] + }.to_h + end + + def ast + Arel::Nodes::And.new(predicates_with_wrapped_sql_literals) + end + + def ==(other) + other.is_a?(WhereClause) && + predicates == other.predicates && + binds == other.binds + end + + def invert + WhereClause.new(inverted_predicates, binds) + end + + def self.empty + new([], []) + end + + protected + + attr_reader :predicates + + def referenced_columns + @referenced_columns ||= begin + equality_nodes = predicates.select { |n| equality_node?(n) } + Set.new(equality_nodes, &:left) + end + end + + private + + def predicates_unreferenced_by(other) + predicates.reject do |n| + equality_node?(n) && other.referenced_columns.include?(n.left) + end + end + + def equality_node?(node) + node.respond_to?(:operator) && node.operator == :== + end + + def non_conflicting_binds(other) + conflicts = referenced_columns & other.referenced_columns + conflicts.map! { |node| node.name.to_s } + binds.reject { |attr| conflicts.include?(attr.name) } + end + + def inverted_predicates + predicates.map { |node| invert_predicate(node) } + end + + def invert_predicate(node) + case node + when NilClass + raise ArgumentError, 'Invalid argument for .where.not(), got nil.' + when Arel::Nodes::In + Arel::Nodes::NotIn.new(node.left, node.right) + when Arel::Nodes::Equality + Arel::Nodes::NotEqual.new(node.left, node.right) + when String + Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new(node)) + else + Arel::Nodes::Not.new(node) + end + end + + def predicates_except(columns) + predicates.reject do |node| + case node + when Arel::Nodes::Between, Arel::Nodes::In, Arel::Nodes::NotIn, Arel::Nodes::Equality, Arel::Nodes::NotEqual, Arel::Nodes::LessThanOrEqual, Arel::Nodes::GreaterThanOrEqual + subrelation = (node.left.kind_of?(Arel::Attributes::Attribute) ? node.left : node.right) + columns.include?(subrelation.name.to_s) + end + end + end + + def binds_except(columns) + binds.reject do |attr| + columns.include?(attr.name) + end + end + + def predicates_with_wrapped_sql_literals + non_empty_predicates.map do |node| + if Arel::Nodes::Equality === node + node + else + wrap_sql_literal(node) + end + end + end + + def non_empty_predicates + predicates - [''] + end + + def wrap_sql_literal(node) + if ::String === node + node = Arel.sql(node) + end + Arel::Nodes::Grouping.new(node) + end + end + end +end diff --git a/activerecord/lib/active_record/relation/where_clause_factory.rb b/activerecord/lib/active_record/relation/where_clause_factory.rb new file mode 100644 index 0000000000..0430922be3 --- /dev/null +++ b/activerecord/lib/active_record/relation/where_clause_factory.rb @@ -0,0 +1,34 @@ +module ActiveRecord + class Relation + class WhereClauseFactory + def initialize(klass, predicate_builder) + @klass = klass + @predicate_builder = predicate_builder + end + + def build(opts, other) + binds = [] + + case opts + when String, Array + parts = [klass.send(:sanitize_sql, other.empty? ? opts : ([opts] + other))] + when Hash + attributes = predicate_builder.resolve_column_aliases(opts) + attributes = klass.send(:expand_hash_conditions_for_aggregates, attributes) + + attributes, binds = predicate_builder.create_binds(attributes) + + parts = predicate_builder.build_from_hash(attributes) + else + parts = [opts] + end + + WhereClause.new(parts, binds) + end + + protected + + attr_reader :klass, :predicate_builder + end + end +end diff --git a/activerecord/lib/active_record/table_metadata.rb b/activerecord/lib/active_record/table_metadata.rb index 11e33e8dfe..6c8792ee80 100644 --- a/activerecord/lib/active_record/table_metadata.rb +++ b/activerecord/lib/active_record/table_metadata.rb @@ -22,6 +22,14 @@ module ActiveRecord arel_table[column_name] end + def type(column_name) + if klass + klass.type_for_attribute(column_name.to_s) + else + Type::Value.new + end + end + def associated_with?(association_name) klass && klass._reflect_on_association(association_name) end @@ -34,7 +42,7 @@ module ActiveRecord association_klass = association.klass arel_table = association_klass.arel_table else - type_caster = TypeCaster::Connection.new(klass.connection, table_name) + type_caster = TypeCaster::Connection.new(klass, table_name) association_klass = nil arel_table = Arel::Table.new(table_name, type_caster: type_caster) end diff --git a/activerecord/lib/active_record/type/integer.rb b/activerecord/lib/active_record/type/integer.rb index 394fe008f9..90ca9f88da 100644 --- a/activerecord/lib/active_record/type/integer.rb +++ b/activerecord/lib/active_record/type/integer.rb @@ -16,13 +16,19 @@ module ActiveRecord :integer end - alias type_cast_for_database type_cast - def type_cast_from_database(value) return if value.nil? value.to_i end + def type_cast_for_database(value) + result = type_cast(value) + if result + ensure_in_range(result) + end + result + end + protected attr_reader :range @@ -34,9 +40,7 @@ module ActiveRecord when true then 1 when false then 0 else - result = value.to_i rescue nil - ensure_in_range(result) if result - result + value.to_i rescue nil end end diff --git a/activerecord/lib/active_record/type_caster/connection.rb b/activerecord/lib/active_record/type_caster/connection.rb index 9e4a130b40..1d204edb76 100644 --- a/activerecord/lib/active_record/type_caster/connection.rb +++ b/activerecord/lib/active_record/type_caster/connection.rb @@ -1,8 +1,8 @@ module ActiveRecord module TypeCaster class Connection - def initialize(connection, table_name) - @connection = connection + def initialize(klass, table_name) + @klass = klass @table_name = table_name end @@ -14,7 +14,8 @@ module ActiveRecord protected - attr_reader :connection, :table_name + attr_reader :table_name + delegate :connection, to: :@klass private |