diff options
Diffstat (limited to 'activerecord/lib/active_record/associations/association_scope.rb')
-rw-r--r-- | activerecord/lib/active_record/associations/association_scope.rb | 176 |
1 files changed, 73 insertions, 103 deletions
diff --git a/activerecord/lib/active_record/associations/association_scope.rb b/activerecord/lib/active_record/associations/association_scope.rb index b965230e60..2416167834 100644 --- a/activerecord/lib/active_record/associations/association_scope.rb +++ b/activerecord/lib/active_record/associations/association_scope.rb @@ -2,42 +2,30 @@ module ActiveRecord module Associations class AssociationScope #:nodoc: def self.scope(association, connection) - INSTANCE.scope association, connection - end - - class BindSubstitution - def initialize(block) - @block = block - end - - def bind_value(scope, column, value, alias_tracker) - substitute = alias_tracker.connection.substitute_at( - column, scope.bind_values.length) - scope.bind_values += [[column, @block.call(value)]] - substitute - end + INSTANCE.scope(association, connection) end def self.create(&block) - block = block ? block : lambda { |val| val } - new BindSubstitution.new(block) + block ||= lambda { |val| val } + new(block) end - def initialize(bind_substitution) - @bind_substitution = bind_substitution + def initialize(value_transformation) + @value_transformation = value_transformation end INSTANCE = create def scope(association, connection) - klass = association.klass - reflection = association.reflection - scope = klass.unscoped - owner = association.owner - alias_tracker = AliasTracker.empty connection + klass = association.klass + reflection = association.reflection + scope = klass.unscoped + owner = association.owner + alias_tracker = AliasTracker.create connection, association.klass.table_name, klass.type_caster + chain_head, chain_tail = get_chain(reflection, association, alias_tracker) scope.extending! Array(reflection.options[:extend]) - add_constraints(scope, owner, klass, reflection, alias_tracker) + add_constraints(scope, owner, klass, reflection, chain_head, chain_tail) end def join_type @@ -61,132 +49,114 @@ module ActiveRecord binds end - private + protected - def construct_tables(chain, klass, refl, alias_tracker) - chain.map do |reflection| - alias_tracker.aliased_table_for( - table_name_for(reflection, klass, refl), - table_alias_for(reflection, refl, reflection != refl) - ) - end - end - - def table_alias_for(reflection, refl, join = false) - name = "#{reflection.plural_name}_#{alias_suffix(refl)}" - name << "_join" if join - name - end + attr_reader :value_transformation + private def join(table, constraint) table.create_join(table, table.create_on(constraint), join_type) end - def column_for(table_name, column_name, alias_tracker) - columns = alias_tracker.connection.schema_cache.columns_hash(table_name) - columns[column_name] - end - - def bind_value(scope, column, value, alias_tracker) - @bind_substitution.bind_value scope, column, value, alias_tracker - end - - def bind(scope, table_name, column_name, value, tracker) - column = column_for table_name, column_name, tracker - bind_value scope, column, value, tracker - end - - def last_chain_scope(scope, table, reflection, owner, tracker, assoc_klass) - join_keys = reflection.join_keys(assoc_klass) + def last_chain_scope(scope, table, reflection, owner, association_klass) + join_keys = reflection.join_keys(association_klass) key = join_keys.key foreign_key = join_keys.foreign_key - bind_val = bind scope, table.table_name, key.to_s, owner[foreign_key], tracker - scope = scope.where(table[key].eq(bind_val)) + value = transform_value(owner[foreign_key]) + scope = scope.where(table.name => { key => value }) if reflection.type - value = owner.class.base_class.name - bind_val = bind scope, table.table_name, reflection.type, value, tracker - scope = scope.where(table[reflection.type].eq(bind_val)) - else - scope + polymorphic_type = transform_value(owner.class.base_class.name) + scope = scope.where(table.name => { reflection.type => polymorphic_type }) end + + scope end - def next_chain_scope(scope, table, reflection, tracker, assoc_klass, foreign_table, next_reflection) - join_keys = reflection.join_keys(assoc_klass) + def transform_value(value) + value_transformation.call(value) + end + + def next_chain_scope(scope, table, reflection, association_klass, foreign_table, next_reflection) + join_keys = reflection.join_keys(association_klass) key = join_keys.key foreign_key = join_keys.foreign_key constraint = table[key].eq(foreign_table[foreign_key]) if reflection.type - value = next_reflection.klass.base_class.name - bind_val = bind scope, table.table_name, reflection.type, value, tracker - scope = scope.where(table[reflection.type].eq(bind_val)) + value = transform_value(next_reflection.klass.base_class.name) + scope = scope.where(table.name => { reflection.type => value }) end scope = scope.joins(join(foreign_table, constraint)) end - def add_constraints(scope, owner, assoc_klass, refl, tracker) - chain = refl.chain - scope_chain = refl.scope_chain + class ReflectionProxy < SimpleDelegator # :nodoc: + attr_accessor :next + attr_reader :alias_name - tables = construct_tables(chain, assoc_klass, refl, tracker) + def initialize(reflection, alias_name) + super(reflection) + @alias_name = alias_name + end - owner_reflection = chain.last - table = tables.last - scope = last_chain_scope(scope, table, owner_reflection, owner, tracker, assoc_klass) + def all_includes; nil; end + end - chain.each_with_index do |reflection, i| - table, foreign_table = tables.shift, tables.first + def get_chain(reflection, association, tracker) + name = reflection.name + runtime_reflection = Reflection::RuntimeReflection.new(reflection, association) + previous_reflection = runtime_reflection + reflection.chain.drop(1).each do |refl| + alias_name = tracker.aliased_table_for(refl.table_name, refl.alias_candidate(name)) + proxy = ReflectionProxy.new(refl, alias_name) + previous_reflection.next = proxy + previous_reflection = proxy + end + [runtime_reflection, previous_reflection] + end - unless reflection == chain.last - next_reflection = chain[i + 1] - scope = next_chain_scope(scope, table, reflection, tracker, assoc_klass, foreign_table, next_reflection) - end + def add_constraints(scope, owner, association_klass, refl, chain_head, chain_tail) + owner_reflection = chain_tail + table = owner_reflection.alias_name + scope = last_chain_scope(scope, table, owner_reflection, owner, association_klass) - is_first_chain = i == 0 - klass = is_first_chain ? assoc_klass : reflection.klass + reflection = chain_head + loop do + break unless reflection + table = reflection.alias_name + + unless reflection == chain_tail + next_reflection = reflection.next + foreign_table = next_reflection.alias_name + scope = next_chain_scope(scope, table, reflection, association_klass, foreign_table, next_reflection) + end # Exclude the scope of the association itself, because that # was already merged in the #scope method. - scope_chain[i].each do |scope_chain_item| - item = eval_scope(klass, scope_chain_item, owner) + reflection.constraints.each do |scope_chain_item| + item = eval_scope(reflection.klass, scope_chain_item, owner) if scope_chain_item == refl.scope - scope.merge! item.except(:where, :includes, :bind) + scope.merge! item.except(:where, :includes) end - if is_first_chain + reflection.all_includes do scope.includes! item.includes_values end - scope.where_values += item.where_values - scope.bind_values += item.bind_values + scope.where_clause += item.where_clause scope.order_values |= item.order_values end + + reflection = reflection.next end scope end - def alias_suffix(refl) - refl.name - end - - def table_name_for(reflection, klass, refl) - if reflection == refl - # If this is a polymorphic belongs_to, we want to get the klass from the - # association because it depends on the polymorphic_type attribute of - # the owner - klass.table_name - else - reflection.table_name - end - end - def eval_scope(klass, scope, owner) klass.unscoped.instance_exec(owner, &scope) end |