aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord/lib/active_record/associations/alias_tracker.rb
diff options
context:
space:
mode:
Diffstat (limited to 'activerecord/lib/active_record/associations/alias_tracker.rb')
-rw-r--r--activerecord/lib/active_record/associations/alias_tracker.rb91
1 files changed, 53 insertions, 38 deletions
diff --git a/activerecord/lib/active_record/associations/alias_tracker.rb b/activerecord/lib/active_record/associations/alias_tracker.rb
index 0c23029981..2b7e4f28c5 100644
--- a/activerecord/lib/active_record/associations/alias_tracker.rb
+++ b/activerecord/lib/active_record/associations/alias_tracker.rb
@@ -5,71 +5,86 @@ module ActiveRecord
# Keeps track of table aliases for ActiveRecord::Associations::ClassMethods::JoinDependency and
# ActiveRecord::Associations::ThroughAssociationScope
class AliasTracker # :nodoc:
- attr_reader :aliases, :table_joins, :connection
+ attr_reader :aliases
- # table_joins is an array of arel joins which might conflict with the aliases we assign here
- def initialize(connection = Base.connection, table_joins = [])
- @aliases = Hash.new { |h,k| h[k] = initial_count_for(k) }
- @table_joins = table_joins
- @connection = connection
+ def self.create(connection, initial_table, type_caster)
+ aliases = Hash.new(0)
+ aliases[initial_table] = 1
+ new connection, aliases, type_caster
end
- def aliased_table_for(table_name, aliased_name = nil)
- table_alias = aliased_name_for(table_name, aliased_name)
-
- if table_alias == table_name
- Arel::Table.new(table_name)
+ def self.create_with_joins(connection, initial_table, joins, type_caster)
+ if joins.empty?
+ create(connection, initial_table, type_caster)
else
- Arel::Table.new(table_name).alias(table_alias)
+ aliases = Hash.new { |h, k|
+ h[k] = initial_count_for(connection, k, joins)
+ }
+ aliases[initial_table] = 1
+ new connection, aliases, type_caster
end
end
- def aliased_name_for(table_name, aliased_name = nil)
- aliased_name ||= table_name
+ def self.initial_count_for(connection, name, table_joins)
+ # quoted_name should be downcased as some database adapters (Oracle) return quoted name in uppercase
+ quoted_name = connection.quote_table_name(name).downcase
+ counts = table_joins.map do |join|
+ if join.is_a?(Arel::Nodes::StringJoin)
+ # Table names + table aliases
+ join.left.downcase.scan(
+ /join(?:\s+\w+)?\s+(\S+\s+)?#{quoted_name}\son/
+ ).size
+ elsif join.respond_to? :left
+ join.left.table_name == name ? 1 : 0
+ else
+ # this branch is reached by two tests:
+ #
+ # activerecord/test/cases/associations/cascaded_eager_loading_test.rb:37
+ # with :posts
+ #
+ # activerecord/test/cases/associations/eager_test.rb:1133
+ # with :comments
+ #
+ 0
+ end
+ end
+
+ counts.sum
+ end
+
+ # table_joins is an array of arel joins which might conflict with the aliases we assign here
+ def initialize(connection, aliases, type_caster)
+ @aliases = aliases
+ @connection = connection
+ @type_caster = type_caster
+ end
+
+ def aliased_table_for(table_name, aliased_name)
if aliases[table_name].zero?
# If it's zero, we can have our table_name
aliases[table_name] = 1
- table_name
+ Arel::Table.new(table_name, type_caster: @type_caster)
else
# Otherwise, we need to use an alias
- aliased_name = connection.table_alias_for(aliased_name)
+ aliased_name = @connection.table_alias_for(aliased_name)
# Update the count
aliases[aliased_name] += 1
- if aliases[aliased_name] > 1
+ table_alias = if aliases[aliased_name] > 1
"#{truncate(aliased_name)}_#{aliases[aliased_name]}"
else
aliased_name
end
+ Arel::Table.new(table_name, type_caster: @type_caster).alias(table_alias)
end
end
private
- def initial_count_for(name)
- return 0 if Arel::Table === table_joins
-
- # quoted_name should be downcased as some database adapters (Oracle) return quoted name in uppercase
- quoted_name = connection.quote_table_name(name).downcase
-
- counts = table_joins.map do |join|
- if join.is_a?(Arel::Nodes::StringJoin)
- # Table names + table aliases
- join.left.downcase.scan(
- /join(?:\s+\w+)?\s+(\S+\s+)?#{quoted_name}\son/
- ).size
- else
- join.left.table_name == name ? 1 : 0
- end
- end
-
- counts.sum
- end
-
def truncate(name)
- name.slice(0, connection.table_alias_length - 2)
+ name.slice(0, @connection.table_alias_length - 2)
end
end
end