aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord
diff options
context:
space:
mode:
authorRyuta Kamizono <kamipo@gmail.com>2018-06-14 23:48:01 +0900
committerRyuta Kamizono <kamipo@gmail.com>2018-06-19 22:21:51 +0900
commit15e3e9cdccd8ab0d204ed9e74bf7916d6025d405 (patch)
tree2b5588e1f5936fabd7ee6b12a883b00799ce7f81 /activerecord
parent2be5ef6a1d5baebcb749c42a48628c741570fd73 (diff)
downloadrails-15e3e9cdccd8ab0d204ed9e74bf7916d6025d405.tar.gz
rails-15e3e9cdccd8ab0d204ed9e74bf7916d6025d405.tar.bz2
rails-15e3e9cdccd8ab0d204ed9e74bf7916d6025d405.zip
Ensure to calculate column aliases after all table aliases are constructed
Currently, column aliases which is used for eager loading are calculated before constructing all table aliases in FROM clause. `JoinDependency#join_constraints` constructs table aliases for `joins` first, and then always re-constructs table aliases for eager loading. If both `joins` and eager loading are given a same table association, the re-construction would cause the discrepancy between column aliases and table aliases. To avoid the discrepancy, the column aliases should be calculated after all table aliases are constructed. Fixes #30603.
Diffstat (limited to 'activerecord')
-rw-r--r--activerecord/lib/active_record/associations/join_dependency.rb78
-rw-r--r--activerecord/lib/active_record/associations/join_dependency/join_association.rb30
-rw-r--r--activerecord/lib/active_record/associations/join_dependency/join_part.rb7
-rw-r--r--activerecord/lib/active_record/relation.rb22
-rw-r--r--activerecord/lib/active_record/relation/finder_methods.rb4
-rw-r--r--activerecord/lib/active_record/relation/merger.rb8
-rw-r--r--activerecord/lib/active_record/relation/query_methods.rb16
-rw-r--r--activerecord/test/cases/associations/eager_test.rb46
-rw-r--r--activerecord/test/fixtures/memberships.yml7
-rw-r--r--activerecord/test/models/member.rb7
10 files changed, 135 insertions, 90 deletions
diff --git a/activerecord/lib/active_record/associations/join_dependency.rb b/activerecord/lib/active_record/associations/join_dependency.rb
index c1faa021aa..3b0b2ea050 100644
--- a/activerecord/lib/active_record/associations/join_dependency.rb
+++ b/activerecord/lib/active_record/associations/join_dependency.rb
@@ -67,42 +67,31 @@ module ActiveRecord
end
end
- def initialize(base, table, associations, alias_tracker)
- @alias_tracker = alias_tracker
+ def initialize(base, table, associations)
tree = self.class.make_tree associations
@join_root = JoinBase.new(base, table, build(tree, base))
- @join_root.children.each { |child| construct_tables! @join_root, child }
end
def reflections
join_root.drop(1).map!(&:reflection)
end
- def join_constraints(joins_to_add, join_type)
- joins = join_root.children.flat_map { |child|
- make_join_constraints(join_root, child, join_type)
- }
+ def join_constraints(joins_to_add, join_type, alias_tracker)
+ @alias_tracker = alias_tracker
+
+ construct_tables!(join_root)
+ joins = make_join_constraints(join_root, join_type)
joins.concat joins_to_add.flat_map { |oj|
+ construct_tables!(oj.join_root)
if join_root.match? oj.join_root
walk join_root, oj.join_root
else
- oj.join_root.children.flat_map { |child|
- make_join_constraints(oj.join_root, child, join_type, true)
- }
+ make_join_constraints(oj.join_root, join_type)
end
}
end
- def aliases
- @aliases ||= Aliases.new join_root.each_with_index.map { |join_part, i|
- columns = join_part.column_names.each_with_index.map { |column_name, j|
- Aliases::Column.new column_name, "t#{i}_r#{j}"
- }
- Aliases::Table.new(join_part, columns)
- }
- end
-
def instantiate(result_set, &block)
primary_key = aliases.column_alias(join_root, join_root.primary_key)
@@ -134,28 +123,42 @@ module ActiveRecord
parents.values
end
+ def apply_column_aliases(relation)
+ relation._select!(-> { aliases.columns })
+ end
+
protected
- attr_reader :alias_tracker, :base_klass, :join_root
+ attr_reader :join_root
private
+ attr_reader :alias_tracker
- def make_constraints(parent, child, tables, join_type)
- chain = child.reflection.chain
- foreign_table = parent.table
- foreign_klass = parent.base_klass
- child.join_constraints(foreign_table, foreign_klass, join_type, tables, chain)
+ def aliases
+ @aliases ||= Aliases.new join_root.each_with_index.map { |join_part, i|
+ columns = join_part.column_names.each_with_index.map { |column_name, j|
+ Aliases::Column.new column_name, "t#{i}_r#{j}"
+ }
+ Aliases::Table.new(join_part, columns)
+ }
end
- def make_outer_joins(parent, child)
- join_type = Arel::Nodes::OuterJoin
- make_join_constraints(parent, child, join_type, true)
+ def construct_tables!(join_root)
+ join_root.each_children do |parent, child|
+ child.tables = table_aliases_for(parent, child)
+ end
end
- def make_join_constraints(parent, child, join_type, aliasing = false)
- tables = aliasing ? table_aliases_for(parent, child) : child.tables
- joins = make_constraints(parent, child, tables, join_type)
+ def make_join_constraints(join_root, join_type)
+ join_root.children.flat_map do |child|
+ make_constraints(join_root, child, join_type)
+ end
+ end
- joins.concat child.children.flat_map { |c| make_join_constraints(child, c, join_type, aliasing) }
+ def make_constraints(parent, child, join_type = Arel::Nodes::OuterJoin)
+ foreign_table = parent.table
+ foreign_klass = parent.base_klass
+ joins = child.join_constraints(foreign_table, foreign_klass, join_type, alias_tracker)
+ joins.concat child.children.flat_map { |c| make_constraints(child, c, join_type) }
end
def table_aliases_for(parent, node)
@@ -168,11 +171,6 @@ module ActiveRecord
}
end
- def construct_tables!(parent, node)
- node.tables = table_aliases_for(parent, node)
- node.children.each { |child| construct_tables! node, child }
- end
-
def table_alias_for(reflection, parent, join)
name = "#{reflection.plural_name}_#{parent.table_name}"
join ? "#{name}_join" : name
@@ -183,8 +181,8 @@ module ActiveRecord
[left.children.find { |node2| node1.match? node2 }, node1]
}.partition(&:first)
- ojs = missing.flat_map { |_, n| make_outer_joins left, n }
- intersection.flat_map { |l, r| walk l, r }.concat ojs
+ joins = intersection.flat_map { |l, r| r.table = l.table; walk(l, r) }
+ joins.concat missing.flat_map { |_, n| make_constraints(left, n) }
end
def find_reflection(klass, name)
@@ -202,7 +200,7 @@ module ActiveRecord
raise EagerLoadPolymorphicError.new(reflection)
end
- JoinAssociation.new(reflection, build(right, reflection.klass), alias_tracker)
+ JoinAssociation.new(reflection, build(right, reflection.klass))
end
end
diff --git a/activerecord/lib/active_record/associations/join_dependency/join_association.rb b/activerecord/lib/active_record/associations/join_dependency/join_association.rb
index c36386ec7e..6e5e950e90 100644
--- a/activerecord/lib/active_record/associations/join_dependency/join_association.rb
+++ b/activerecord/lib/active_record/associations/join_dependency/join_association.rb
@@ -6,17 +6,14 @@ module ActiveRecord
module Associations
class JoinDependency # :nodoc:
class JoinAssociation < JoinPart # :nodoc:
- # The reflection of the association represented
- attr_reader :reflection
+ attr_reader :reflection, :tables
+ attr_accessor :table
- attr_accessor :tables
-
- def initialize(reflection, children, alias_tracker)
+ def initialize(reflection, children)
super(reflection.klass, children)
- @alias_tracker = alias_tracker
- @reflection = reflection
- @tables = nil
+ @reflection = reflection
+ @tables = nil
end
def match?(other)
@@ -24,14 +21,13 @@ module ActiveRecord
super && reflection == other.reflection
end
- def join_constraints(foreign_table, foreign_klass, join_type, tables, chain)
- joins = []
- tables = tables.reverse
+ def join_constraints(foreign_table, foreign_klass, join_type, alias_tracker)
+ joins = []
# The chain starts with the target table, but we want to end with it here (makes
# more sense in this context), so we reverse
- chain.reverse_each do |reflection|
- table = tables.shift
+ reflection.chain.reverse_each.with_index(1) do |reflection, i|
+ table = tables[-i]
klass = reflection.klass
constraint = reflection.build_join_constraint(table, foreign_table)
@@ -54,12 +50,10 @@ module ActiveRecord
joins
end
- def table
- tables.first
+ def tables=(tables)
+ @tables = tables
+ @table = tables.first
end
-
- private
- attr_reader :alias_tracker
end
end
end
diff --git a/activerecord/lib/active_record/associations/join_dependency/join_part.rb b/activerecord/lib/active_record/associations/join_dependency/join_part.rb
index 2181f308bf..3cabb21983 100644
--- a/activerecord/lib/active_record/associations/join_dependency/join_part.rb
+++ b/activerecord/lib/active_record/associations/join_dependency/join_part.rb
@@ -33,6 +33,13 @@ module ActiveRecord
children.each { |child| child.each(&block) }
end
+ def each_children(&block)
+ children.each do |child|
+ yield self, child
+ child.each_children(&block)
+ end
+ end
+
# An Arel::Table for the active_record
def table
raise NotImplementedError
diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb
index c911cbe5ec..7ab9bb2d2d 100644
--- a/activerecord/lib/active_record/relation.rb
+++ b/activerecord/lib/active_record/relation.rb
@@ -502,17 +502,16 @@ module ActiveRecord
# # => SELECT "users".* FROM "users" WHERE "users"."name" = 'Oscar'
def to_sql
@to_sql ||= begin
- relation = self
-
- if eager_loading?
- apply_join_dependency { |rel, _| relation = rel }
- end
-
- conn = klass.connection
- conn.unprepared_statement {
- conn.to_sql(relation.arel)
- }
- end
+ if eager_loading?
+ apply_join_dependency do |relation, join_dependency|
+ relation = join_dependency.apply_column_aliases(relation)
+ relation.to_sql
+ end
+ else
+ conn = klass.connection
+ conn.unprepared_statement { conn.to_sql(arel) }
+ end
+ end
end
# Returns a hash of where conditions.
@@ -622,6 +621,7 @@ module ActiveRecord
if ActiveRecord::NullRelation === relation
[]
else
+ relation = join_dependency.apply_column_aliases(relation)
rows = connection.select_all(relation.arel, "SQL")
join_dependency.instantiate(rows, &block)
end.freeze
diff --git a/activerecord/lib/active_record/relation/finder_methods.rb b/activerecord/lib/active_record/relation/finder_methods.rb
index 31fb8ce0e5..b9e7c52e88 100644
--- a/activerecord/lib/active_record/relation/finder_methods.rb
+++ b/activerecord/lib/active_record/relation/finder_methods.rb
@@ -373,9 +373,8 @@ module ActiveRecord
def construct_join_dependency
including = eager_load_values + includes_values
- joins = joins_values.select { |join| join.is_a?(Arel::Nodes::Join) }
ActiveRecord::Associations::JoinDependency.new(
- klass, table, including, alias_tracker(joins)
+ klass, table, including
)
end
@@ -392,7 +391,6 @@ module ActiveRecord
end
if block_given?
- relation._select!(join_dependency.aliases.columns)
yield relation, join_dependency
else
relation
diff --git a/activerecord/lib/active_record/relation/merger.rb b/activerecord/lib/active_record/relation/merger.rb
index 25510d4a57..9cbcf61b3e 100644
--- a/activerecord/lib/active_record/relation/merger.rb
+++ b/activerecord/lib/active_record/relation/merger.rb
@@ -117,13 +117,11 @@ module ActiveRecord
if other.klass == relation.klass
relation.joins!(*other.joins_values)
else
- alias_tracker = nil
joins_dependency = other.joins_values.map do |join|
case join
when Hash, Symbol, Array
- alias_tracker ||= other.alias_tracker
ActiveRecord::Associations::JoinDependency.new(
- other.klass, other.table, join, alias_tracker
+ other.klass, other.table, join
)
else
join
@@ -140,13 +138,11 @@ module ActiveRecord
if other.klass == relation.klass
relation.left_outer_joins!(*other.left_outer_joins_values)
else
- alias_tracker = nil
joins_dependency = other.left_outer_joins_values.map do |join|
case join
when Hash, Symbol, Array
- alias_tracker ||= other.alias_tracker
ActiveRecord::Associations::JoinDependency.new(
- other.klass, other.table, join, alias_tracker
+ other.klass, other.table, join
)
else
join
diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb
index db9101a168..5b4ba85316 100644
--- a/activerecord/lib/active_record/relation/query_methods.rb
+++ b/activerecord/lib/active_record/relation/query_methods.rb
@@ -1018,19 +1018,19 @@ module ActiveRecord
def build_join_query(manager, buckets, join_type, aliases)
buckets.default = []
- association_joins = buckets[:association_join]
- stashed_association_joins = buckets[:stashed_join]
- join_nodes = buckets[:join_node].uniq
- string_joins = buckets[:string_join].map(&:strip).uniq
+ association_joins = buckets[:association_join]
+ stashed_joins = buckets[:stashed_join]
+ join_nodes = buckets[:join_node].uniq
+ string_joins = buckets[:string_join].map(&:strip).uniq
join_list = join_nodes + convert_join_strings_to_ast(string_joins)
alias_tracker = alias_tracker(join_list, aliases)
join_dependency = ActiveRecord::Associations::JoinDependency.new(
- klass, table, association_joins, alias_tracker
+ klass, table, association_joins
)
- joins = join_dependency.join_constraints(stashed_association_joins, join_type)
+ joins = join_dependency.join_constraints(stashed_joins, join_type, alias_tracker)
joins.each { |join| manager.from(join) }
manager.join_sources.concat(join_list)
@@ -1056,11 +1056,13 @@ module ActiveRecord
end
def arel_columns(columns)
- columns.map do |field|
+ columns.flat_map do |field|
if (Symbol === field || String === field) && (klass.has_attribute?(field) || klass.attribute_alias?(field)) && !from_clause.value
arel_attribute(field)
elsif Symbol === field
connection.quote_table_name(field.to_s)
+ elsif Proc === field
+ field.call
else
field
end
diff --git a/activerecord/test/cases/associations/eager_test.rb b/activerecord/test/cases/associations/eager_test.rb
index f46be8734b..8be663e3dc 100644
--- a/activerecord/test/cases/associations/eager_test.rb
+++ b/activerecord/test/cases/associations/eager_test.rb
@@ -77,8 +77,50 @@ class EagerAssociationTest < ActiveRecord::TestCase
end
def test_loading_with_scope_including_joins
- assert_equal clubs(:boring_club), Member.preload(:general_club).find(1).general_club
- assert_equal clubs(:boring_club), Member.eager_load(:general_club).find(1).general_club
+ member = Member.first
+ assert_equal members(:groucho), member
+ assert_equal clubs(:boring_club), member.general_club
+
+ member = Member.preload(:general_club).first
+ assert_equal members(:groucho), member
+ assert_equal clubs(:boring_club), member.general_club
+
+ member = Member.eager_load(:general_club).first
+ assert_equal members(:groucho), member
+ assert_equal clubs(:boring_club), member.general_club
+ end
+
+ def test_loading_association_with_same_table_joins
+ super_memberships = [memberships(:super_membership_of_boring_club)]
+
+ member = Member.joins(:favourite_memberships).first
+ assert_equal members(:groucho), member
+ assert_equal super_memberships, member.super_memberships
+
+ member = Member.joins(:favourite_memberships).preload(:super_memberships).first
+ assert_equal members(:groucho), member
+ assert_equal super_memberships, member.super_memberships
+
+ member = Member.joins(:favourite_memberships).eager_load(:super_memberships).first
+ assert_equal members(:groucho), member
+ assert_equal super_memberships, member.super_memberships
+ end
+
+ def test_loading_association_with_intersection_joins
+ member = Member.joins(:current_membership).first
+ assert_equal members(:groucho), member
+ assert_equal clubs(:boring_club), member.club
+ assert_equal memberships(:membership_of_boring_club), member.current_membership
+
+ member = Member.joins(:current_membership).preload(:club, :current_membership).first
+ assert_equal members(:groucho), member
+ assert_equal clubs(:boring_club), member.club
+ assert_equal memberships(:membership_of_boring_club), member.current_membership
+
+ member = Member.joins(:current_membership).eager_load(:club, :current_membership).first
+ assert_equal members(:groucho), member
+ assert_equal clubs(:boring_club), member.club
+ assert_equal memberships(:membership_of_boring_club), member.current_membership
end
def test_with_ordering
diff --git a/activerecord/test/fixtures/memberships.yml b/activerecord/test/fixtures/memberships.yml
index a5d52bd438..f7ca227533 100644
--- a/activerecord/test/fixtures/memberships.yml
+++ b/activerecord/test/fixtures/memberships.yml
@@ -26,6 +26,13 @@ blarpy_winkup_crazy_club:
favourite: false
type: CurrentMembership
+super_membership_of_boring_club:
+ joined_on: <%= 3.weeks.ago.to_s(:db) %>
+ club: boring_club
+ member_id: 1
+ favourite: false
+ type: SuperMembership
+
selected_membership_of_boring_club:
joined_on: <%= 3.weeks.ago.to_s(:db) %>
club: boring_club
diff --git a/activerecord/test/models/member.rb b/activerecord/test/models/member.rb
index 4315ba1941..6e33ac0a6d 100644
--- a/activerecord/test/models/member.rb
+++ b/activerecord/test/models/member.rb
@@ -26,13 +26,14 @@ class Member < ActiveRecord::Base
has_one :club_category, through: :club, source: :category
has_one :general_club, -> { general }, through: :current_membership, source: :club
- has_many :current_memberships, -> { where favourite: true }
- has_many :clubs, through: :current_memberships
+ has_many :super_memberships
+ has_many :favourite_memberships, -> { where(favourite: true) }, class_name: "Membership"
+ has_many :clubs, through: :favourite_memberships
has_many :tenant_memberships
has_many :tenant_clubs, through: :tenant_memberships, class_name: "Club", source: :club
- has_one :club_through_many, through: :current_memberships, source: :club
+ has_one :club_through_many, through: :favourite_memberships, source: :club
belongs_to :admittable, polymorphic: true
has_one :premium_club, through: :admittable