From caa178c178468f7adcc5f4c597280e6362d9e175 Mon Sep 17 00:00:00 2001
From: Sean Griffin <sean@seantheprogrammer.com>
Date: Wed, 31 Aug 2016 08:54:38 -0400
Subject: Ensure that inverse associations are set before running callbacks

If a parent association was accessed in an `after_find` or
`after_initialize` callback, it would always end up loading the
association, and then immediately overwriting the association we just
loaded. If this occurred in a way that the parent's `current_scope` was
set to eager load the child, this would result in an infinite loop and
eventually overflow the stack.

For records that are created with `.new`, we have a mechanism to
perform an action before the callbacks are run. I've introduced the same
code path for records created with `instantiate`, and updated all code
which sets inverse instances on newly loaded associations to use this
block instead.

Fixes #26320.
---
 activerecord/CHANGELOG.md                          |  7 ++++++
 .../lib/active_record/association_relation.rb      |  5 +++-
 .../associations/collection_association.rb         |  6 ++---
 .../active_record/associations/join_dependency.rb  |  8 ++++---
 .../associations/join_dependency/join_part.rb      |  4 ++--
 .../associations/preloader/association.rb          | 20 +++++++++++++---
 .../preloader/collection_association.rb            |  1 -
 .../associations/preloader/singular_association.rb |  1 -
 activerecord/lib/active_record/core.rb             |  2 ++
 activerecord/lib/active_record/persistence.rb      |  4 ++--
 activerecord/lib/active_record/querying.rb         |  4 ++--
 activerecord/lib/active_record/relation.rb         |  8 +++----
 activerecord/lib/active_record/statement_cache.rb  |  4 ++--
 .../associations/inverse_associations_test.rb      | 27 ++++++++++++++++++++++
 14 files changed, 77 insertions(+), 24 deletions(-)

diff --git a/activerecord/CHANGELOG.md b/activerecord/CHANGELOG.md
index 7838fe7167..f62d29b5bf 100644
--- a/activerecord/CHANGELOG.md
+++ b/activerecord/CHANGELOG.md
@@ -1,3 +1,10 @@
+*   Inverse association instances will now be set before `after_find` or
+    `after_initialize` callbacks are run.
+
+    Fixes #26320.
+
+    *Sean Griffin*
+
 *   Remove unnecessarily association load when a `belongs_to` association has already been
     loaded then the foreign key is changed directly and the record saved.
 
diff --git a/activerecord/lib/active_record/association_relation.rb b/activerecord/lib/active_record/association_relation.rb
index 2da2d968b9..de2d03cd0b 100644
--- a/activerecord/lib/active_record/association_relation.rb
+++ b/activerecord/lib/active_record/association_relation.rb
@@ -29,7 +29,10 @@ module ActiveRecord
     private
 
       def exec_queries
-        super.each { |r| @association.set_inverse_instance r }
+        super do |r|
+          @association.set_inverse_instance r
+          yield r if block_given?
+        end
       end
   end
 end
diff --git a/activerecord/lib/active_record/associations/collection_association.rb b/activerecord/lib/active_record/associations/collection_association.rb
index 0f51b35164..0c911a5396 100644
--- a/activerecord/lib/active_record/associations/collection_association.rb
+++ b/activerecord/lib/active_record/associations/collection_association.rb
@@ -390,9 +390,9 @@ module ActiveRecord
           end
 
           binds = AssociationScope.get_bind_values(owner, reflection.chain)
-          records = sc.execute(binds, klass, conn)
-          records.each { |record| set_inverse_instance(record) }
-          records
+          sc.execute(binds, klass, conn) do |record|
+            set_inverse_instance(record)
+          end
         end
 
         # We have some records loaded from the database (persisted) and some that are
diff --git a/activerecord/lib/active_record/associations/join_dependency.rb b/activerecord/lib/active_record/associations/join_dependency.rb
index 62acad0eda..02f0721bed 100644
--- a/activerecord/lib/active_record/associations/join_dependency.rb
+++ b/activerecord/lib/active_record/associations/join_dependency.rb
@@ -286,17 +286,19 @@ module ActiveRecord
         end
 
         def construct_model(record, node, row, model_cache, id, aliases)
-          model = model_cache[node][id] ||= node.instantiate(row,
-                                                             aliases.column_aliases(node))
           other = record.association(node.reflection.name)
 
+          model = model_cache[node][id] ||=
+            node.instantiate(row, aliases.column_aliases(node)) do |m|
+              other.set_inverse_instance(m)
+            end
+
           if node.reflection.collection?
             other.target.push(model)
           else
             other.target = model
           end
 
-          other.set_inverse_instance(model)
           model
         end
     end
diff --git a/activerecord/lib/active_record/associations/join_dependency/join_part.rb b/activerecord/lib/active_record/associations/join_dependency/join_part.rb
index 551087f822..61cec5403a 100644
--- a/activerecord/lib/active_record/associations/join_dependency/join_part.rb
+++ b/activerecord/lib/active_record/associations/join_dependency/join_part.rb
@@ -62,8 +62,8 @@ module ActiveRecord
           hash
         end
 
-        def instantiate(row, aliases)
-          base_klass.instantiate(extract_record(row, aliases))
+        def instantiate(row, aliases, &block)
+          base_klass.instantiate(extract_record(row, aliases), &block)
         end
       end
     end
diff --git a/activerecord/lib/active_record/associations/preloader/association.rb b/activerecord/lib/active_record/associations/preloader/association.rb
index 4bb627f399..07407700cd 100644
--- a/activerecord/lib/active_record/associations/preloader/association.rb
+++ b/activerecord/lib/active_record/associations/preloader/association.rb
@@ -62,7 +62,12 @@ module ActiveRecord
         private
 
           def associated_records_by_owner(preloader)
-            records = load_records
+            records = load_records do |record|
+              owner = owners_by_key[convert_key(record[association_key_name])]
+              association = owner.association(reflection.name)
+              association.set_inverse_instance(record)
+            end
+
             owners.each_with_object({}) do |owner, result|
               result[owner] = records[convert_key(owner[owner_key_name])] || []
             end
@@ -79,6 +84,15 @@ module ActiveRecord
             @owner_keys
           end
 
+          def owners_by_key
+            unless defined?(@owners_by_key)
+              @owners_by_key = owners.each_with_object({}) do |owner, h|
+                h[convert_key(owner[owner_key_name])] = owner
+              end
+            end
+            @owners_by_key
+          end
+
           def key_conversion_required?
             @key_conversion_required ||= association_key_type != owner_key_type
           end
@@ -99,13 +113,13 @@ module ActiveRecord
             @model.type_for_attribute(owner_key_name.to_s).type
           end
 
-          def load_records
+          def load_records(&block)
             return {} if owner_keys.empty?
             # Some databases impose a limit on the number of ids in a list (in Oracle it's 1000)
             # Make several smaller queries if necessary or make one query if the adapter supports it
             slices  = owner_keys.each_slice(klass.connection.in_clause_length || owner_keys.size)
             @preloaded_records = slices.flat_map do |slice|
-              records_for(slice)
+              records_for(slice).load(&block)
             end
             @preloaded_records.group_by do |record|
               convert_key(record[association_key_name])
diff --git a/activerecord/lib/active_record/associations/preloader/collection_association.rb b/activerecord/lib/active_record/associations/preloader/collection_association.rb
index 24b8e01029..26690bf16d 100644
--- a/activerecord/lib/active_record/associations/preloader/collection_association.rb
+++ b/activerecord/lib/active_record/associations/preloader/collection_association.rb
@@ -9,7 +9,6 @@ module ActiveRecord
               association = owner.association(reflection.name)
               association.loaded!
               association.target.concat(records)
-              records.each { |record| association.set_inverse_instance(record) }
             end
           end
       end
diff --git a/activerecord/lib/active_record/associations/preloader/singular_association.rb b/activerecord/lib/active_record/associations/preloader/singular_association.rb
index 0888d383a6..5c5828262e 100644
--- a/activerecord/lib/active_record/associations/preloader/singular_association.rb
+++ b/activerecord/lib/active_record/associations/preloader/singular_association.rb
@@ -10,7 +10,6 @@ module ActiveRecord
 
               association = owner.association(reflection.name)
               association.target = record
-              association.set_inverse_instance(record) if record
             end
           end
       end
diff --git a/activerecord/lib/active_record/core.rb b/activerecord/lib/active_record/core.rb
index aef4761be4..2725c85446 100644
--- a/activerecord/lib/active_record/core.rb
+++ b/activerecord/lib/active_record/core.rb
@@ -366,6 +366,8 @@ module ActiveRecord
 
       self.class.define_attribute_methods
 
+      yield self if block_given?
+
       _run_find_callbacks
       _run_initialize_callbacks
 
diff --git a/activerecord/lib/active_record/persistence.rb b/activerecord/lib/active_record/persistence.rb
index a6615f3774..a04ef2e263 100644
--- a/activerecord/lib/active_record/persistence.rb
+++ b/activerecord/lib/active_record/persistence.rb
@@ -63,10 +63,10 @@ module ActiveRecord
       #
       # See <tt>ActiveRecord::Inheritance#discriminate_class_for_record</tt> to see
       # how this "single-table" inheritance mapping is implemented.
-      def instantiate(attributes, column_types = {})
+      def instantiate(attributes, column_types = {}, &block)
         klass = discriminate_class_for_record(attributes)
         attributes = klass.attributes_builder.build_from_database(attributes, column_types)
-        klass.allocate.init_with("attributes" => attributes, "new_record" => false)
+        klass.allocate.init_with("attributes" => attributes, "new_record" => false, &block)
       end
 
       private
diff --git a/activerecord/lib/active_record/querying.rb b/activerecord/lib/active_record/querying.rb
index dd7d650207..36689f6559 100644
--- a/activerecord/lib/active_record/querying.rb
+++ b/activerecord/lib/active_record/querying.rb
@@ -35,7 +35,7 @@ module ActiveRecord
     #
     #   Post.find_by_sql ["SELECT title FROM posts WHERE author = ? AND created > ?", author_id, start_date]
     #   Post.find_by_sql ["SELECT body FROM comments WHERE author = :user_id OR approved_by = :user_id", { :user_id => user_id }]
-    def find_by_sql(sql, binds = [], preparable: nil)
+    def find_by_sql(sql, binds = [], preparable: nil, &block)
       result_set = connection.select_all(sanitize_sql(sql), "#{name} Load", binds, preparable: preparable)
       column_types = result_set.column_types.dup
       columns_hash.each_key { |k| column_types.delete k }
@@ -47,7 +47,7 @@ module ActiveRecord
       }
 
       message_bus.instrument("instantiation.active_record", payload) do
-        result_set.map { |record| instantiate(record, column_types) }
+        result_set.map { |record| instantiate(record, column_types, &block) }
       end
     end
 
diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb
index d7de1032b6..6d571cf026 100644
--- a/activerecord/lib/active_record/relation.rb
+++ b/activerecord/lib/active_record/relation.rb
@@ -562,8 +562,8 @@ module ActiveRecord
     # return value is the relation itself, not the records.
     #
     #   Post.where(published: true).load # => #<ActiveRecord::Relation>
-    def load
-      exec_queries unless loaded?
+    def load(&block)
+      exec_queries(&block) unless loaded?
 
       self
     end
@@ -678,8 +678,8 @@ module ActiveRecord
 
     private
 
-      def exec_queries
-        @records = eager_loading? ? find_with_associations.freeze : @klass.find_by_sql(arel, bound_attributes).freeze
+      def exec_queries(&block)
+        @records = eager_loading? ? find_with_associations.freeze : @klass.find_by_sql(arel, bound_attributes, &block).freeze
 
         preload = preload_values
         preload +=  includes_values unless eager_loading?
diff --git a/activerecord/lib/active_record/statement_cache.rb b/activerecord/lib/active_record/statement_cache.rb
index fd67032235..d19bb96ede 100644
--- a/activerecord/lib/active_record/statement_cache.rb
+++ b/activerecord/lib/active_record/statement_cache.rb
@@ -99,12 +99,12 @@ module ActiveRecord
       @bind_map      = bind_map
     end
 
-    def execute(params, klass, connection)
+    def execute(params, klass, connection, &block)
       bind_values = bind_map.bind params
 
       sql = query_builder.sql_for bind_values, connection
 
-      klass.find_by_sql(sql, bind_values, preparable: true)
+      klass.find_by_sql(sql, bind_values, preparable: true, &block)
     end
     alias :call :execute
   end
diff --git a/activerecord/test/cases/associations/inverse_associations_test.rb b/activerecord/test/cases/associations/inverse_associations_test.rb
index 0b23cea420..6fe6ee6783 100644
--- a/activerecord/test/cases/associations/inverse_associations_test.rb
+++ b/activerecord/test/cases/associations/inverse_associations_test.rb
@@ -494,6 +494,33 @@ class InverseHasManyTests < ActiveRecord::TestCase
 
     assert !man.persisted?
   end
+
+  def test_inverse_instance_should_be_set_before_find_callbacks_are_run
+    reset_callbacks(Interest, :find) do
+      Interest.after_find { raise unless association(:man).loaded? && man.present? }
+
+      assert Man.first.interests.reload.any?
+      assert Man.includes(:interests).first.interests.any?
+      assert Man.joins(:interests).includes(:interests).first.interests.any?
+    end
+  end
+
+  def test_inverse_instance_should_be_set_before_initialize_callbacks_are_run
+    reset_callbacks(Interest, :initialize) do
+      Interest.after_initialize { raise unless association(:man).loaded? && man.present? }
+
+      assert Man.first.interests.reload.any?
+      assert Man.includes(:interests).first.interests.any?
+      assert Man.joins(:interests).includes(:interests).first.interests.any?
+    end
+  end
+
+  def reset_callbacks(target, type)
+    old_callbacks = target.send(:get_callbacks, type).deep_dup
+    yield
+  ensure
+    target.send(:set_callbacks, type, old_callbacks) if old_callbacks
+  end
 end
 
 class InverseBelongsToTests < ActiveRecord::TestCase
-- 
cgit v1.2.3