aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord/lib/active_record/associations/join_dependency.rb
diff options
context:
space:
mode:
Diffstat (limited to 'activerecord/lib/active_record/associations/join_dependency.rb')
-rw-r--r--activerecord/lib/active_record/associations/join_dependency.rb233
1 files changed, 135 insertions, 98 deletions
diff --git a/activerecord/lib/active_record/associations/join_dependency.rb b/activerecord/lib/active_record/associations/join_dependency.rb
index 5aa17e5fbb..f12d5c49b5 100644
--- a/activerecord/lib/active_record/associations/join_dependency.rb
+++ b/activerecord/lib/active_record/associations/join_dependency.rb
@@ -1,11 +1,34 @@
module ActiveRecord
module Associations
class JoinDependency # :nodoc:
- autoload :JoinPart, 'active_record/associations/join_dependency/join_part'
autoload :JoinBase, 'active_record/associations/join_dependency/join_base'
autoload :JoinAssociation, 'active_record/associations/join_dependency/join_association'
- attr_reader :join_parts, :reflections, :alias_tracker, :base_klass
+ attr_reader :join_parts, :alias_tracker, :base_klass
+
+ def self.make_tree(associations)
+ hash = {}
+ walk_tree associations, hash
+ hash
+ end
+
+ def self.walk_tree(associations, hash)
+ case associations
+ when Symbol, String
+ hash[associations.to_sym] ||= {}
+ when Array
+ associations.each do |assoc|
+ walk_tree assoc, hash
+ end
+ when Hash
+ associations.each do |k,v|
+ cache = hash[k] ||= {}
+ walk_tree v, cache
+ end
+ else
+ raise ConfigurationError, associations.inspect
+ end
+ end
# base is the base class on which operation is taking place.
# associations is the list of associations which are joined using hash, symbol or array.
@@ -32,18 +55,23 @@ module ActiveRecord
@base_klass = base
@table_joins = joins
@join_parts = [JoinBase.new(base)]
- @associations = {}
- @reflections = []
@alias_tracker = AliasTracker.new(base.connection, joins)
@alias_tracker.aliased_name_for(base.table_name) # Updates the count for base.table_name to 1
- build(associations)
+ tree = self.class.make_tree associations
+ build tree, join_parts.last, Arel::InnerJoin
end
def graft(*associations)
- associations.each do |association|
- join_associations.detect {|a| association == a} ||
- build(association.reflection.name, association.find_parent_in(self) || join_base, association.join_type)
- end
+ join_assocs = join_associations
+ base = join_base
+
+ associations.reject { |association|
+ join_assocs.detect { |a| association == a }
+ }.each { |association|
+ join_part = find_parent_part(association.parent) || base
+ type = association.join_type
+ find_or_build_scalar association.reflection, join_part, type
+ }
self
end
@@ -51,8 +79,8 @@ module ActiveRecord
join_parts.drop 1
end
- def join_base
- join_parts.first
+ def reflections
+ join_associations.map(&:reflection)
end
def join_relation(relation)
@@ -70,101 +98,107 @@ module ActiveRecord
}.flatten
end
- def instantiate(rows)
+ def instantiate(result_set)
primary_key = join_base.aliased_primary_key
parents = {}
- records = rows.map { |model|
- primary_id = model[primary_key]
- parent = parents[primary_id] ||= join_base.instantiate(model)
- construct(parent, @associations, join_associations, model)
+ type_caster = result_set.column_type primary_key
+ assoc = associations
+
+ records = result_set.map { |row_hash|
+ primary_id = type_caster.type_cast row_hash[primary_key]
+ parent = parents[primary_id] ||= join_base.instantiate(row_hash)
+ construct(parent, assoc, join_associations, row_hash, result_set)
parent
}.uniq
- remove_duplicate_results!(base_klass, records, @associations)
+ remove_duplicate_results!(base_klass, records, assoc)
records
end
- protected
+ private
+
+ def associations
+ join_associations.each_with_object({}) do |assoc, tree|
+ cache_joined_association assoc, tree
+ end
+ end
+
+ def find_parent_part(parent)
+ join_parts.detect do |join_part|
+ case parent
+ when JoinBase
+ parent.base_klass == join_part.base_klass
+ else
+ parent == join_part
+ end
+ end
+ end
+
+ def join_base
+ join_parts.first
+ end
def remove_duplicate_results!(base, records, associations)
- case associations
- when Symbol, String
- reflection = base.reflections[associations]
+ associations.each_key do |name|
+ reflection = base.reflect_on_association(name)
remove_uniq_by_reflection(reflection, records)
- when Array
- associations.each do |association|
- remove_duplicate_results!(base, records, association)
- end
- when Hash
- associations.each_key do |name|
- reflection = base.reflections[name]
- remove_uniq_by_reflection(reflection, records)
-
- parent_records = []
- records.each do |record|
- if descendant = record.send(reflection.name)
- if reflection.collection?
- parent_records.concat descendant.target.uniq
- else
- parent_records << descendant
- end
+
+ parent_records = []
+ records.each do |record|
+ if descendant = record.send(reflection.name)
+ if reflection.collection?
+ parent_records.concat descendant.target.uniq
+ else
+ parent_records << descendant
end
end
+ end
- remove_duplicate_results!(reflection.klass, parent_records, associations[name]) unless parent_records.empty?
+ unless parent_records.empty?
+ remove_duplicate_results!(reflection.klass, parent_records, associations[name])
end
end
end
- def cache_joined_association(association)
+ def cache_joined_association(association, tree)
associations = []
parent = association.parent
while parent != join_base
associations.unshift(parent.reflection.name)
parent = parent.parent
end
- ref = @associations
- associations.each do |key|
- ref = ref[key]
+ ref = associations.inject(tree) do |cache,key|
+ cache[key]
end
ref[association.reflection.name] ||= {}
end
- def build(associations, parent = join_parts.last, join_type = Arel::InnerJoin)
- case associations
- when Symbol, String
- reflection = parent.reflections[associations.intern] or
- raise ConfigurationError, "Association named '#{ associations }' was not found on #{ parent.base_klass.name }; perhaps you misspelled it?"
- unless join_association = find_join_association(reflection, parent)
- @reflections << reflection
- join_association = build_join_association(reflection, parent)
- join_association.join_type = join_type
- @join_parts << join_association
- cache_joined_association(join_association)
- end
- join_association
- when Array
- associations.each do |association|
- build(association, parent, join_type)
- end
- when Hash
- associations.keys.sort_by { |a| a.to_s }.each do |name|
- join_association = build(name, parent, join_type)
- build(associations[name], join_association, join_type)
- end
- else
- raise ConfigurationError, associations.inspect
+ def find_reflection(klass, name)
+ klass.reflect_on_association(name.intern) or
+ raise ConfigurationError, "Association named '#{ name }' was not found on #{ klass.name }; perhaps you misspelled it?"
+ end
+
+ def build(associations, parent, join_type)
+ associations.each do |name, right|
+ reflection = find_reflection parent.base_klass, name
+ join_association = build_join_association reflection, parent, join_type
+ @join_parts << join_association
+ build right, join_association, join_type
end
end
- def find_join_association(name_or_reflection, parent)
- if String === name_or_reflection
- name_or_reflection = name_or_reflection.to_sym
+ def find_or_build_scalar(reflection, parent, join_type)
+ unless join_association = find_join_association(reflection, parent)
+ join_association = build_join_association(reflection, parent, join_type)
+ @join_parts << join_association
end
+ join_association
+ end
+ def find_join_association(reflection, parent)
join_associations.detect { |j|
- j.reflection == name_or_reflection && j.parent == parent
+ j.reflection == reflection && j.parent == parent
}
end
@@ -174,39 +208,42 @@ module ActiveRecord
end
end
- def build_join_association(reflection, parent)
- JoinAssociation.new(reflection, self, parent)
+ def build_join_association(reflection, parent, join_type)
+ reflection.check_validity!
+
+ if reflection.options[:polymorphic]
+ raise EagerLoadPolymorphicError.new(reflection)
+ end
+
+ JoinAssociation.new(reflection, join_parts.length, parent, join_type, alias_tracker)
end
- def construct(parent, associations, join_parts, row)
- case associations
- when Symbol, String
- name = associations.to_s
+ def construct(parent, associations, join_parts, row, rs)
+ associations.sort_by { |k,_| k.to_s }.each do |association_name, assoc|
+ association = construct_scalar(parent, association_name, join_parts, row, rs)
+ construct(association, assoc, join_parts, row, rs) if association
+ end
+ end
- join_part = join_parts.detect { |j|
- j.reflection.name.to_s == name &&
- j.parent_table_name == parent.class.table_name }
+ def construct_scalar(parent, associations, join_parts, row, rs)
+ name = associations.to_s
- raise(ConfigurationError, "No such association") unless join_part
+ join_part = join_parts.detect { |j|
+ j.reflection.name.to_s == name &&
+ j.parent_table_name == parent.class.table_name
+ }
- join_parts.delete(join_part)
- construct_association(parent, join_part, row)
- when Array
- associations.each do |association|
- construct(parent, association, join_parts, row)
- end
- when Hash
- associations.sort_by { |k,_| k.to_s }.each do |association_name, assoc|
- association = construct(parent, association_name, join_parts, row)
- construct(association, assoc, join_parts, row) if association
- end
- else
- raise ConfigurationError, associations.inspect
- end
+ raise(ConfigurationError, "No such association") unless join_part
+
+ join_parts.delete(join_part)
+ construct_association(parent, join_part, row, rs)
end
- def construct_association(record, join_part, row)
- return if record.id.to_s != join_part.parent.record_id(row).to_s
+ def construct_association(record, join_part, row, rs)
+ caster = rs.column_type(join_part.parent.aliased_primary_key)
+ row_id = caster.type_cast row[join_part.parent.aliased_primary_key]
+
+ return if record.id != row_id
macro = join_part.reflection.macro
if macro == :has_one
@@ -216,7 +253,7 @@ module ActiveRecord
else
association = join_part.instantiate(row) unless row[join_part.aliased_primary_key].nil?
case macro
- when :has_many, :has_and_belongs_to_many
+ when :has_many
other = record.association(join_part.reflection.name)
other.loaded!
other.target.push(association) if association