aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord/lib/active_record/associations/association_scope.rb
diff options
context:
space:
mode:
Diffstat (limited to 'activerecord/lib/active_record/associations/association_scope.rb')
-rw-r--r--activerecord/lib/active_record/associations/association_scope.rb116
1 files changed, 65 insertions, 51 deletions
diff --git a/activerecord/lib/active_record/associations/association_scope.rb b/activerecord/lib/active_record/associations/association_scope.rb
index 519d4d8651..53f65920e1 100644
--- a/activerecord/lib/active_record/associations/association_scope.rb
+++ b/activerecord/lib/active_record/associations/association_scope.rb
@@ -10,9 +10,8 @@ module ActiveRecord
@block = block
end
- def bind_value(scope, column, value, alias_tracker)
- substitute = alias_tracker.connection.substitute_at(
- column, scope.bind_values.length)
+ def bind_value(scope, column, value, connection)
+ substitute = connection.substitute_at(column)
scope.bind_values += [[column, @block.call(value)]]
substitute
end
@@ -45,20 +44,20 @@ module ActiveRecord
end
def self.get_bind_values(owner, chain)
- bvs = []
- chain.each_with_index do |reflection, i|
- if reflection == chain.last
- bvs << reflection.join_id_for(owner)
- if reflection.type
- bvs << owner.class.base_class.name
- end
- else
- if reflection.type
- bvs << chain[i + 1].klass.base_class.name
- end
+ binds = []
+ last_reflection = chain.last
+
+ binds << last_reflection.join_id_for(owner)
+ if last_reflection.type
+ binds << owner.class.base_class.name
+ end
+
+ chain.each_cons(2).each do |reflection, next_reflection|
+ if reflection.type
+ binds << next_reflection.klass.base_class.name
end
end
- bvs
+ binds
end
private
@@ -67,7 +66,8 @@ module ActiveRecord
chain.map do |reflection|
alias_tracker.aliased_table_for(
table_name_for(reflection, klass, refl),
- table_alias_for(reflection, refl, reflection != refl)
+ table_alias_for(reflection, refl, reflection != refl),
+ type_caster: klass.type_caster,
)
end
end
@@ -82,52 +82,70 @@ module ActiveRecord
table.create_join(table, table.create_on(constraint), join_type)
end
- def column_for(table_name, column_name, alias_tracker)
- columns = alias_tracker.connection.schema_cache.columns_hash(table_name)
+ def column_for(table_name, column_name, connection)
+ columns = connection.schema_cache.columns_hash(table_name)
columns[column_name]
end
- def bind_value(scope, column, value, alias_tracker)
- @bind_substitution.bind_value scope, column, value, alias_tracker
+ def bind_value(scope, column, value, connection)
+ @bind_substitution.bind_value scope, column, value, connection
end
- def bind(scope, table_name, column_name, value, tracker)
- column = column_for table_name, column_name, tracker
- bind_value scope, column, value, tracker
+ def bind(scope, table_name, column_name, value, connection)
+ column = column_for table_name, column_name, connection
+ bind_value scope, column, value, connection
+ end
+
+ def last_chain_scope(scope, table, reflection, owner, connection, assoc_klass)
+ join_keys = reflection.join_keys(assoc_klass)
+ key = join_keys.key
+ foreign_key = join_keys.foreign_key
+
+ bind_val = bind scope, table.table_name, key.to_s, owner[foreign_key], connection
+ scope = scope.where(table[key].eq(bind_val))
+
+ if reflection.type
+ value = owner.class.base_class.name
+ bind_val = bind scope, table.table_name, reflection.type, value, connection
+ scope = scope.where(table[reflection.type].eq(bind_val))
+ else
+ scope
+ end
+ end
+
+ def next_chain_scope(scope, table, reflection, connection, assoc_klass, foreign_table, next_reflection)
+ join_keys = reflection.join_keys(assoc_klass)
+ key = join_keys.key
+ foreign_key = join_keys.foreign_key
+
+ constraint = table[key].eq(foreign_table[foreign_key])
+
+ if reflection.type
+ value = next_reflection.klass.base_class.name
+ bind_val = bind scope, table.table_name, reflection.type, value, connection
+ scope = scope.where(table[reflection.type].eq(bind_val))
+ end
+
+ scope = scope.joins(join(foreign_table, constraint))
end
def add_constraints(scope, owner, assoc_klass, refl, tracker)
chain = refl.chain
scope_chain = refl.scope_chain
+ connection = tracker.connection
tables = construct_tables(chain, assoc_klass, refl, tracker)
+ owner_reflection = chain.last
+ table = tables.last
+ scope = last_chain_scope(scope, table, owner_reflection, owner, connection, assoc_klass)
+
chain.each_with_index do |reflection, i|
table, foreign_table = tables.shift, tables.first
- join_keys = reflection.join_keys(assoc_klass)
- key = join_keys.key
- foreign_key = join_keys.foreign_key
-
- if reflection == chain.last
- bind_val = bind scope, table.table_name, key.to_s, owner[foreign_key], tracker
- scope = scope.where(table[key].eq(bind_val))
-
- if reflection.type
- value = owner.class.base_class.name
- bind_val = bind scope, table.table_name, reflection.type.to_s, value, tracker
- scope = scope.where(table[reflection.type].eq(bind_val))
- end
- else
- constraint = table[key].eq(foreign_table[foreign_key])
-
- if reflection.type
- value = chain[i + 1].klass.base_class.name
- bind_val = bind scope, table.table_name, reflection.type.to_s, value, tracker
- scope = scope.where(table[reflection.type].eq(bind_val))
- end
-
- scope = scope.joins(join(foreign_table, constraint))
+ unless reflection == chain.last
+ next_reflection = chain[i + 1]
+ scope = next_chain_scope(scope, table, reflection, connection, assoc_klass, foreign_table, next_reflection)
end
is_first_chain = i == 0
@@ -171,11 +189,7 @@ module ActiveRecord
end
def eval_scope(klass, scope, owner)
- if scope.is_a?(Relation)
- scope
- else
- klass.unscoped.instance_exec(owner, &scope)
- end
+ klass.unscoped.instance_exec(owner, &scope)
end
end
end