aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord/lib/active_record/associations/association_scope.rb
diff options
context:
space:
mode:
Diffstat (limited to 'activerecord/lib/active_record/associations/association_scope.rb')
-rw-r--r--activerecord/lib/active_record/associations/association_scope.rb211
1 files changed, 138 insertions, 73 deletions
diff --git a/activerecord/lib/active_record/associations/association_scope.rb b/activerecord/lib/active_record/associations/association_scope.rb
index 1303822868..dcbd57e61d 100644
--- a/activerecord/lib/active_record/associations/association_scope.rb
+++ b/activerecord/lib/active_record/associations/association_scope.rb
@@ -1,113 +1,182 @@
module ActiveRecord
module Associations
class AssociationScope #:nodoc:
- include JoinHelper
+ def self.scope(association, connection)
+ INSTANCE.scope association, connection
+ end
+
+ class BindSubstitution
+ def initialize(block)
+ @block = block
+ end
+
+ def bind_value(scope, column, value, alias_tracker)
+ substitute = alias_tracker.connection.substitute_at(column)
+ scope.bind_values += [[column, @block.call(value)]]
+ substitute
+ end
+ end
- attr_reader :association, :alias_tracker
+ def self.create(&block)
+ block = block ? block : lambda { |val| val }
+ new BindSubstitution.new(block)
+ end
+
+ def initialize(bind_substitution)
+ @bind_substitution = bind_substitution
+ end
- delegate :klass, :owner, :reflection, :interpolate, :to => :association
- delegate :chain, :scope_chain, :options, :source_options, :active_record, :to => :reflection
+ INSTANCE = create
+
+ def scope(association, connection)
+ klass = association.klass
+ reflection = association.reflection
+ scope = klass.unscoped
+ owner = association.owner
+ alias_tracker = AliasTracker.empty connection
+
+ scope.extending! Array(reflection.options[:extend])
+ add_constraints(scope, owner, klass, reflection, alias_tracker)
+ end
- def initialize(association)
- @association = association
- @alias_tracker = AliasTracker.new klass.connection
+ def join_type
+ Arel::Nodes::InnerJoin
end
- def scope
- scope = klass.unscoped
- scope.merge! eval_scope(klass, reflection.scope) if reflection.scope
- add_constraints(scope)
+ def self.get_bind_values(owner, chain)
+ binds = []
+ last_reflection = chain.last
+
+ binds << last_reflection.join_id_for(owner)
+ if last_reflection.type
+ binds << owner.class.base_class.name
+ end
+
+ chain.each_cons(2).each do |reflection, next_reflection|
+ if reflection.type
+ binds << next_reflection.klass.base_class.name
+ end
+ end
+ binds
end
private
- def column_for(table_name, column_name)
- columns = alias_tracker.connection.schema_cache.columns_hash[table_name]
+ def construct_tables(chain, klass, refl, alias_tracker)
+ chain.map do |reflection|
+ alias_tracker.aliased_table_for(
+ table_name_for(reflection, klass, refl),
+ table_alias_for(reflection, refl, reflection != refl)
+ )
+ end
+ end
+
+ def table_alias_for(reflection, refl, join = false)
+ name = "#{reflection.plural_name}_#{alias_suffix(refl)}"
+ name << "_join" if join
+ name
+ end
+
+ def join(table, constraint)
+ table.create_join(table, table.create_on(constraint), join_type)
+ end
+
+ def column_for(table_name, column_name, alias_tracker)
+ columns = alias_tracker.connection.schema_cache.columns_hash(table_name)
columns[column_name]
end
- def bind_value(scope, column, value)
- substitute = alias_tracker.connection.substitute_at(
- column, scope.bind_values.length)
- scope.bind_values += [[column, value]]
- substitute
+ def bind_value(scope, column, value, alias_tracker)
+ @bind_substitution.bind_value scope, column, value, alias_tracker
end
- def bind(scope, table_name, column_name, value)
- column = column_for table_name, column_name
- bind_value scope, column, value
+ def bind(scope, table_name, column_name, value, tracker)
+ column = column_for table_name, column_name, tracker
+ bind_value scope, column, value, tracker
end
- def add_constraints(scope)
- tables = construct_tables
+ def last_chain_scope(scope, table, reflection, owner, tracker, assoc_klass)
+ join_keys = reflection.join_keys(assoc_klass)
+ key = join_keys.key
+ foreign_key = join_keys.foreign_key
- chain.each_with_index do |reflection, i|
- table, foreign_table = tables.shift, tables.first
+ bind_val = bind scope, table.table_name, key.to_s, owner[foreign_key], tracker
+ scope = scope.where(table[key].eq(bind_val))
- if reflection.source_macro == :has_and_belongs_to_many
- join_table = tables.shift
+ if reflection.type
+ value = owner.class.base_class.name
+ bind_val = bind scope, table.table_name, reflection.type, value, tracker
+ scope = scope.where(table[reflection.type].eq(bind_val))
+ else
+ scope
+ end
+ end
- scope = scope.joins(join(
- join_table,
- table[reflection.association_primary_key].
- eq(join_table[reflection.association_foreign_key])
- ))
+ def next_chain_scope(scope, table, reflection, tracker, assoc_klass, foreign_table, next_reflection)
+ join_keys = reflection.join_keys(assoc_klass)
+ key = join_keys.key
+ foreign_key = join_keys.foreign_key
- table, foreign_table = join_table, tables.first
- end
+ constraint = table[key].eq(foreign_table[foreign_key])
- if reflection.source_macro == :belongs_to
- if reflection.options[:polymorphic]
- key = reflection.association_primary_key(klass)
- else
- key = reflection.association_primary_key
- end
+ if reflection.type
+ value = next_reflection.klass.base_class.name
+ bind_val = bind scope, table.table_name, reflection.type, value, tracker
+ scope = scope.where(table[reflection.type].eq(bind_val))
+ end
- foreign_key = reflection.foreign_key
- else
- key = reflection.foreign_key
- foreign_key = reflection.active_record_primary_key
- end
+ scope = scope.joins(join(foreign_table, constraint))
+ end
- if reflection == chain.last
- bind_val = bind scope, table.table_name, key.to_s, owner[foreign_key]
- scope = scope.where(table[key].eq(bind_val))
+ def add_constraints(scope, owner, assoc_klass, refl, tracker)
+ chain = refl.chain
+ scope_chain = refl.scope_chain
- if reflection.type
- value = owner.class.base_class.name
- bind_val = bind scope, table.table_name, reflection.type.to_s, value
- scope = scope.where(table[reflection.type].eq(bind_val))
- end
- else
- constraint = table[key].eq(foreign_table[foreign_key])
+ tables = construct_tables(chain, assoc_klass, refl, tracker)
- if reflection.type
- type = chain[i + 1].klass.base_class.name
- constraint = constraint.and(table[reflection.type].eq(type))
- end
+ owner_reflection = chain.last
+ table = tables.last
+ scope = last_chain_scope(scope, table, owner_reflection, owner, tracker, assoc_klass)
+
+ chain.each_with_index do |reflection, i|
+ table, foreign_table = tables.shift, tables.first
- scope = scope.joins(join(foreign_table, constraint))
+ unless reflection == chain.last
+ next_reflection = chain[i + 1]
+ scope = next_chain_scope(scope, table, reflection, tracker, assoc_klass, foreign_table, next_reflection)
end
+ is_first_chain = i == 0
+ klass = is_first_chain ? assoc_klass : reflection.klass
+
# Exclude the scope of the association itself, because that
# was already merged in the #scope method.
- (scope_chain[i] - [self.reflection.scope]).each do |scope_chain_item|
- item = eval_scope(reflection.klass, scope_chain_item)
+ scope_chain[i].each do |scope_chain_item|
+ item = eval_scope(klass, scope_chain_item, owner)
+
+ if scope_chain_item == refl.scope
+ scope.merge! item.except(:where, :includes, :bind)
+ end
+
+ if is_first_chain
+ scope.includes! item.includes_values
+ end
- scope.includes! item.includes_values
scope.where_values += item.where_values
+ scope.bind_values += item.bind_values
+ scope.order_values |= item.order_values
end
end
scope
end
- def alias_suffix
- reflection.name
+ def alias_suffix(refl)
+ refl.name
end
- def table_name_for(reflection)
- if reflection == self.reflection
+ def table_name_for(reflection, klass, refl)
+ if reflection == refl
# If this is a polymorphic belongs_to, we want to get the klass from the
# association because it depends on the polymorphic_type attribute of
# the owner
@@ -117,12 +186,8 @@ module ActiveRecord
end
end
- def eval_scope(klass, scope)
- if scope.is_a?(Relation)
- scope
- else
- klass.unscoped.instance_exec(owner, &scope)
- end
+ def eval_scope(klass, scope, owner)
+ klass.unscoped.instance_exec(owner, &scope)
end
end
end