From e98a877d6c0a5798ac9011e9d1b5331ae66257bb Mon Sep 17 00:00:00 2001 From: Aaron Patterson Date: Tue, 8 Oct 2013 15:13:12 -0700 Subject: transform the association input so we can avoid type checks later. We should consider moving the input munging outside the class instantiation --- .../active_record/associations/join_dependency.rb | 48 ++++++++++++++-------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/activerecord/lib/active_record/associations/join_dependency.rb b/activerecord/lib/active_record/associations/join_dependency.rb index 74cbb95137..1d9db51c34 100644 --- a/activerecord/lib/active_record/associations/join_dependency.rb +++ b/activerecord/lib/active_record/associations/join_dependency.rb @@ -6,6 +6,30 @@ module ActiveRecord attr_reader :join_parts, :reflections, :alias_tracker, :base_klass + def self.make_tree(associations) + hash = {} + walk_tree associations, hash + hash + end + + def self.walk_tree(associations, hash) + case associations + when Symbol, String + hash[associations.to_sym] ||= {} + when Array + associations.each do |assoc| + walk_tree assoc, hash + end + when Hash + associations.each do |k,v| + cache = hash[k] ||= {} + walk_tree v, cache + end + else + raise ConfigurationError, associations.inspect + end + end + # base is the base class on which operation is taking place. # associations is the list of associations which are joined using hash, symbol or array. # joins is the list of all string join commnads and arel nodes. @@ -34,13 +58,14 @@ module ActiveRecord @reflections = [] @alias_tracker = AliasTracker.new(base.connection, joins) @alias_tracker.aliased_name_for(base.table_name) # Updates the count for base.table_name to 1 - build(associations, join_parts.last, Arel::InnerJoin) + tree = self.class.make_tree associations + build tree, join_parts.last, Arel::InnerJoin end def graft(*associations) associations.each do |association| - join_associations.detect {|a| association == a} || - build(association.reflection.name, find_parent_part(association.parent) || join_base, association.join_type) + join_associations.detect { |a| association == a } || + find_or_build_scalar(association.reflection.name, find_parent_part(association.parent) || join_base, association.join_type) end self end @@ -150,20 +175,9 @@ module ActiveRecord end def build(associations, parent, join_type) - case associations - when Symbol, String - find_or_build_scalar associations, parent, join_type - when Array - associations.each do |association| - build(association, parent, join_type) - end - when Hash - associations.each do |left, right| - join_association = find_or_build_scalar left, parent, join_type - build(right, join_association, join_type) - end - else - raise ConfigurationError, associations.inspect + associations.each do |left, right| + join_association = find_or_build_scalar left, parent, join_type + build right, join_association, join_type end end -- cgit v1.2.3