From 5d6c8d5e9d1c6544f8db8639e3a53a8d7682eeb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 7 May 2011 17:33:40 +0200 Subject: Revert "Revert the merge because tests did not pass." This reverts commit 886818d2bab40585c0cea763002ffc16917dd0b3. --- activerecord/lib/active_record/base.rb | 4 ++++ activerecord/lib/active_record/identity_map.rb | 10 +++++----- 2 files changed, 9 insertions(+), 5 deletions(-) (limited to 'activerecord/lib') diff --git a/activerecord/lib/active_record/base.rb b/activerecord/lib/active_record/base.rb index 6149865f80..1fe867495d 100644 --- a/activerecord/lib/active_record/base.rb +++ b/activerecord/lib/active_record/base.rb @@ -830,6 +830,10 @@ module ActiveRecord #:nodoc: @symbolized_base_class ||= base_class.to_s.to_sym end + def symbolized_sti_name + @symbolized_sti_name ||= sti_name ? sti_name.to_sym : symbolized_base_class + end + # Returns the base AR subclass that this class descends from. If A # extends AR::Base, A.base_class will return A. If B descends from A # through some arbitrarily deep hierarchy, B.base_class will return A. diff --git a/activerecord/lib/active_record/identity_map.rb b/activerecord/lib/active_record/identity_map.rb index 9eb47ad99f..f88ead9ca0 100644 --- a/activerecord/lib/active_record/identity_map.rb +++ b/activerecord/lib/active_record/identity_map.rb @@ -49,7 +49,7 @@ module ActiveRecord end def get(klass, primary_key) - record = repository[klass.symbolized_base_class][primary_key] + record = repository[klass.symbolized_sti_name][primary_key] if record.is_a?(klass) ActiveSupport::Notifications.instrument("identity.active_record", @@ -64,15 +64,15 @@ module ActiveRecord end def add(record) - repository[record.class.symbolized_base_class][record.id] = record + repository[record.class.symbolized_sti_name][record.id] = record end def remove(record) - repository[record.class.symbolized_base_class].delete(record.id) + repository[record.class.symbolized_sti_name].delete(record.id) end - def remove_by_id(symbolized_base_class, id) - repository[symbolized_base_class].delete(id) + def remove_by_id(symbolized_sti_name, id) + repository[symbolized_sti_name].delete(id) end def clear -- cgit v1.2.3