diff options
Diffstat (limited to 'activerecord/lib/active_record/associations/association_scope.rb')
-rw-r--r-- | activerecord/lib/active_record/associations/association_scope.rb | 87 |
1 files changed, 56 insertions, 31 deletions
diff --git a/activerecord/lib/active_record/associations/association_scope.rb b/activerecord/lib/active_record/associations/association_scope.rb index 17f056e764..27fd9e35db 100644 --- a/activerecord/lib/active_record/associations/association_scope.rb +++ b/activerecord/lib/active_record/associations/association_scope.rb @@ -1,52 +1,77 @@ module ActiveRecord module Associations class AssociationScope #:nodoc: - include JoinHelper + INSTANCE = new - attr_reader :association, :alias_tracker + def self.scope(association, connection) + INSTANCE.scope association, connection + end - delegate :klass, :owner, :reflection, :interpolate, :to => :association - delegate :chain, :scope_chain, :options, :source_options, :active_record, :to => :reflection + def scope(association, connection) + klass = association.klass + reflection = association.reflection + scope = klass.unscoped + owner = association.owner + alias_tracker = AliasTracker.empty connection - def initialize(association) - @association = association - @alias_tracker = AliasTracker.new klass.connection + scope.extending! Array(reflection.options[:extend]) + add_constraints(scope, owner, klass, reflection, alias_tracker) end - def scope - scope = klass.unscoped - scope.extending! Array(options[:extend]) - add_constraints(scope) + def join_type + Arel::Nodes::InnerJoin end private - def column_for(table_name, column_name) + def construct_tables(chain, klass, refl, alias_tracker) + chain.map do |reflection| + alias_tracker.aliased_table_for( + table_name_for(reflection, klass, refl), + table_alias_for(reflection, refl, reflection != refl) + ) + end + end + + def table_alias_for(reflection, refl, join = false) + name = "#{reflection.plural_name}_#{alias_suffix(refl)}" + name << "_join" if join + name + end + + def join(table, constraint) + table.create_join(table, table.create_on(constraint), join_type) + end + + def column_for(table_name, column_name, alias_tracker) columns = alias_tracker.connection.schema_cache.columns_hash(table_name) columns[column_name] end - def bind_value(scope, column, value) + def bind_value(scope, column, value, alias_tracker) substitute = alias_tracker.connection.substitute_at( column, scope.bind_values.length) scope.bind_values += [[column, value]] substitute end - def bind(scope, table_name, column_name, value) - column = column_for table_name, column_name - bind_value scope, column, value + def bind(scope, table_name, column_name, value, tracker) + column = column_for table_name, column_name, tracker + bind_value scope, column, value, tracker end - def add_constraints(scope) - tables = construct_tables + def add_constraints(scope, owner, assoc_klass, refl, tracker) + chain = refl.chain + scope_chain = refl.scope_chain + + tables = construct_tables(chain, assoc_klass, refl, tracker) chain.each_with_index do |reflection, i| table, foreign_table = tables.shift, tables.first if reflection.source_macro == :belongs_to if reflection.options[:polymorphic] - key = reflection.association_primary_key(self.klass) + key = reflection.association_primary_key(assoc_klass) else key = reflection.association_primary_key end @@ -58,12 +83,12 @@ module ActiveRecord end if reflection == chain.last - bind_val = bind scope, table.table_name, key.to_s, owner[foreign_key] + bind_val = bind scope, table.table_name, key.to_s, owner[foreign_key], tracker scope = scope.where(table[key].eq(bind_val)) if reflection.type value = owner.class.base_class.name - bind_val = bind scope, table.table_name, reflection.type.to_s, value + bind_val = bind scope, table.table_name, reflection.type.to_s, value, tracker scope = scope.where(table[reflection.type].eq(bind_val)) end else @@ -71,7 +96,7 @@ module ActiveRecord if reflection.type value = chain[i + 1].klass.base_class.name - bind_val = bind scope, table.table_name, reflection.type.to_s, value + bind_val = bind scope, table.table_name, reflection.type.to_s, value, tracker scope = scope.where(table[reflection.type].eq(bind_val)) end @@ -79,14 +104,14 @@ module ActiveRecord end is_first_chain = i == 0 - klass = is_first_chain ? self.klass : reflection.klass + klass = is_first_chain ? assoc_klass : reflection.klass # Exclude the scope of the association itself, because that # was already merged in the #scope method. scope_chain[i].each do |scope_chain_item| - item = eval_scope(klass, scope_chain_item) + item = eval_scope(klass, scope_chain_item, owner) - if scope_chain_item == self.reflection.scope + if scope_chain_item == refl.scope scope.merge! item.except(:where, :includes, :bind) end @@ -102,22 +127,22 @@ module ActiveRecord scope end - def alias_suffix - reflection.name + def alias_suffix(refl) + refl.name end - def table_name_for(reflection) - if reflection == self.reflection + def table_name_for(reflection, klass, refl) + if reflection == refl # If this is a polymorphic belongs_to, we want to get the klass from the # association because it depends on the polymorphic_type attribute of # the owner klass.table_name else - super + reflection.table_name end end - def eval_scope(klass, scope) + def eval_scope(klass, scope, owner) if scope.is_a?(Relation) scope else |