aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord
diff options
context:
space:
mode:
Diffstat (limited to 'activerecord')
-rw-r--r--activerecord/lib/active_record/associations.rb37
-rw-r--r--activerecord/test/cases/associations/belongs_to_associations_test.rb37
2 files changed, 74 insertions, 0 deletions
diff --git a/activerecord/lib/active_record/associations.rb b/activerecord/lib/active_record/associations.rb
index 0d9171d876..4711522b00 100644
--- a/activerecord/lib/active_record/associations.rb
+++ b/activerecord/lib/active_record/associations.rb
@@ -1223,12 +1223,14 @@ module ActiveRecord
if reflection.options[:polymorphic]
association_accessor_methods(reflection, BelongsToPolymorphicAssociation)
+ association_foreign_type_setter_method(reflection)
else
association_accessor_methods(reflection, BelongsToAssociation)
association_constructor_method(:build, reflection, BelongsToAssociation)
association_constructor_method(:create, reflection, BelongsToAssociation)
end
+ association_foreign_key_setter_method(reflection)
add_counter_cache_callbacks(reflection) if options[:counter_cache]
add_touch_callbacks(reflection, options[:touch]) if options[:touch]
@@ -1555,6 +1557,41 @@ module ActiveRecord
end
end
+ def association_foreign_key_setter_method(reflection)
+ setters = reflect_on_all_associations(:belongs_to).map do |belongs_to_reflection|
+ if belongs_to_reflection.primary_key_name == reflection.primary_key_name
+ "association_instance_set(:#{belongs_to_reflection.name}, nil);"
+ end
+ end.compact.join
+
+ class_eval <<-FILE, __FILE__, __LINE__ + 1
+ def #{reflection.primary_key_name}=(new_id)
+ write_attribute :#{reflection.primary_key_name}, new_id
+ if #{reflection.primary_key_name}_changed?
+ #{ setters }
+ end
+ end
+ FILE
+ end
+
+ def association_foreign_type_setter_method(reflection)
+ setters = reflect_on_all_associations(:belongs_to).map do |belongs_to_reflection|
+ if belongs_to_reflection.options[:foreign_type] == reflection.options[:foreign_type]
+ "association_instance_set(:#{belongs_to_reflection.name}, nil);"
+ end
+ end.compact.join
+
+ field = reflection.options[:foreign_type]
+ class_eval <<-FILE, __FILE__, __LINE__ + 1
+ def #{field}=(new_id)
+ write_attribute :#{field}, new_id
+ if #{field}_changed?
+ #{ setters }
+ end
+ end
+ FILE
+ end
+
def add_counter_cache_callbacks(reflection)
cache_column = reflection.counter_cache_column
diff --git a/activerecord/test/cases/associations/belongs_to_associations_test.rb b/activerecord/test/cases/associations/belongs_to_associations_test.rb
index 1b0c00bd5a..1820f95261 100644
--- a/activerecord/test/cases/associations/belongs_to_associations_test.rb
+++ b/activerecord/test/cases/associations/belongs_to_associations_test.rb
@@ -486,4 +486,41 @@ class BelongsToAssociationsTest < ActiveRecord::TestCase
new_firm = accounts(:signals37).build_firm(:name => 'Apple')
assert_equal new_firm.name, "Apple"
end
+
+ def test_reassigning_the_parent_id_updates_the_object
+ original_parent = Firm.create! :name => "original"
+ updated_parent = Firm.create! :name => "updated"
+
+ client = Client.new("client_of" => original_parent.id)
+ assert_equal original_parent, client.firm
+ assert_equal original_parent, client.firm_with_condition
+ assert_equal original_parent, client.firm_with_other_name
+
+ client.client_of = updated_parent.id
+ assert_equal updated_parent, client.firm
+ assert_equal updated_parent, client.firm_with_condition
+ assert_equal updated_parent, client.firm_with_other_name
+ end
+
+ def test_polymorphic_reassignment_of_associated_id_updates_the_object
+ member1 = Member.create!
+ member2 = Member.create!
+
+ sponsor = Sponsor.new("sponsorable_type" => "Member", "sponsorable_id" => member1.id)
+ assert_equal member1, sponsor.sponsorable
+
+ sponsor.sponsorable_id = member2.id
+ assert_equal member2, sponsor.sponsorable
+ end
+
+ def test_polymorphic_reassignment_of_associated_type_updates_the_object
+ member1 = Member.create!
+
+ sponsor = Sponsor.new("sponsorable_type" => "Member", "sponsorable_id" => member1.id)
+ assert_equal member1, sponsor.sponsorable
+
+ sponsor.sponsorable_type = "Firm"
+ assert_not_equal member1, sponsor.sponsorable
+ end
+
end