aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--activerecord/CHANGELOG.md9
-rw-r--r--activerecord/lib/active_record/base.rb1
-rw-r--r--activerecord/lib/active_record/inheritance.rb26
-rw-r--r--activerecord/lib/active_record/reflection.rb2
-rw-r--r--activerecord/test/cases/associations/belongs_to_associations_test.rb28
-rw-r--r--activerecord/test/cases/associations/has_many_associations_test.rb28
-rw-r--r--activerecord/test/cases/associations/has_one_associations_test.rb30
-rw-r--r--activerecord/test/cases/forbidden_attributes_protection_test.rb15
-rw-r--r--activerecord/test/cases/inheritance_test.rb23
-rw-r--r--activerecord/test/models/author.rb1
-rw-r--r--activerecord/test/models/company.rb1
11 files changed, 163 insertions, 1 deletions
diff --git a/activerecord/CHANGELOG.md b/activerecord/CHANGELOG.md
index c90ce121a7..130d0f05d2 100644
--- a/activerecord/CHANGELOG.md
+++ b/activerecord/CHANGELOG.md
@@ -1,5 +1,14 @@
## Rails 4.0.0 (unreleased) ##
+* Add STI support to init and building associations.
+ Allows you to do BaseClass.new(:type => "SubClass") as well as
+ parent.children.build(:type => "SubClass") or parent.build_child
+ to initialize an STI subclass. Ensures that the class name is a
+ valid class and that it is in the ancestors of the super class
+ that the association is expecting.
+
+ *Jason Rush*
+
* Observers was extracted from Active Record as `rails-observers` gem.
*Rafael Mendonça França*
diff --git a/activerecord/lib/active_record/base.rb b/activerecord/lib/active_record/base.rb
index 965fe3f33a..aab832c2f7 100644
--- a/activerecord/lib/active_record/base.rb
+++ b/activerecord/lib/active_record/base.rb
@@ -13,6 +13,7 @@ require 'active_support/core_ext/string/behavior'
require 'active_support/core_ext/kernel/singleton_class'
require 'active_support/core_ext/module/introspection'
require 'active_support/core_ext/object/duplicable'
+require 'active_support/core_ext/class/subclasses'
require 'arel'
require 'active_record/errors'
require 'active_record/log_subscriber'
diff --git a/activerecord/lib/active_record/inheritance.rb b/activerecord/lib/active_record/inheritance.rb
index a448fa1f5c..850911ebe7 100644
--- a/activerecord/lib/active_record/inheritance.rb
+++ b/activerecord/lib/active_record/inheritance.rb
@@ -9,6 +9,19 @@ module ActiveRecord
end
module ClassMethods
+ # Determines if one of the attributes passed in is the inheritance column,
+ # and if the inheritance column is attr accessible, it initializes an
+ # instance of the given subclass instead of the base class
+ def new(*args, &block)
+ if (attrs = args.first).is_a?(Hash)
+ if subclass = subclass_from_attrs(attrs)
+ return subclass.new(*args, &block)
+ end
+ end
+ # Delegate to the original .new
+ super
+ end
+
# True if this isn't a concrete subclass needing a STI type condition.
def descends_from_active_record?
if self == Base
@@ -145,6 +158,19 @@ module ActiveRecord
sti_column.in(sti_names)
end
+
+ # Detect the subclass from the inheritance column of attrs. If the inheritance column value
+ # is not self or a valid subclass, raises ActiveRecord::SubclassNotFound
+ # If this is a StrongParameters hash, and access to inheritance_column is not permitted,
+ # this will ignore the inheritance column and return nil
+ def subclass_from_attrs(attrs)
+ subclass_name = attrs.with_indifferent_access[inheritance_column]
+ return nil if subclass_name.blank? || subclass_name == self.name
+ unless subclass = subclasses.detect { |sub| sub.name == subclass_name }
+ raise ActiveRecord::SubclassNotFound.new("Invalid single-table inheritance type: #{subclass_name} is not a subclass of #{name}")
+ end
+ subclass
+ end
end
private
diff --git a/activerecord/lib/active_record/reflection.rb b/activerecord/lib/active_record/reflection.rb
index 0103de4cbd..bcfcb061f2 100644
--- a/activerecord/lib/active_record/reflection.rb
+++ b/activerecord/lib/active_record/reflection.rb
@@ -179,7 +179,7 @@ module ActiveRecord
@collection = [:has_many, :has_and_belongs_to_many].include?(macro)
end
- # Returns a new, unsaved instance of the associated class. +options+ will
+ # Returns a new, unsaved instance of the associated class. +attributes+ will
# be passed to the class's constructor.
def build_association(attributes, &block)
klass.new(attributes, &block)
diff --git a/activerecord/test/cases/associations/belongs_to_associations_test.rb b/activerecord/test/cases/associations/belongs_to_associations_test.rb
index 5f7825783b..49d6c31c9a 100644
--- a/activerecord/test/cases/associations/belongs_to_associations_test.rb
+++ b/activerecord/test/cases/associations/belongs_to_associations_test.rb
@@ -109,6 +109,34 @@ class BelongsToAssociationsTest < ActiveRecord::TestCase
assert_equal apple.id, citibank.firm_id
end
+ def test_building_the_belonging_object_with_implicit_sti_base_class
+ account = Account.new
+ company = account.build_firm
+ assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
+ end
+
+ def test_building_the_belonging_object_with_explicit_sti_base_class
+ account = Account.new
+ company = account.build_firm(:type => "Company")
+ assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
+ end
+
+ def test_building_the_belonging_object_with_sti_subclass
+ account = Account.new
+ company = account.build_firm(:type => "Firm")
+ assert(company.kind_of?(Firm), "Expected #{company.class} to be a Firm")
+ end
+
+ def test_building_the_belonging_object_with_an_invalid_type
+ account = Account.new
+ assert_raise(ActiveRecord::SubclassNotFound) { account.build_firm(:type => "InvalidType") }
+ end
+
+ def test_building_the_belonging_object_with_an_unrelated_type
+ account = Account.new
+ assert_raise(ActiveRecord::SubclassNotFound) { account.build_firm(:type => "Account") }
+ end
+
def test_building_the_belonging_object_with_primary_key
client = Client.create(:name => "Primary key client")
apple = client.build_firm_with_primary_key("name" => "Apple")
diff --git a/activerecord/test/cases/associations/has_many_associations_test.rb b/activerecord/test/cases/associations/has_many_associations_test.rb
index e5022d49f1..eee0cf03aa 100644
--- a/activerecord/test/cases/associations/has_many_associations_test.rb
+++ b/activerecord/test/cases/associations/has_many_associations_test.rb
@@ -144,6 +144,34 @@ class HasManyAssociationsTest < ActiveRecord::TestCase
assert_equal 'defaulty', bulb.name
end
+ def test_building_the_associated_object_with_implicit_sti_base_class
+ firm = DependentFirm.new
+ company = firm.companies.build
+ assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
+ end
+
+ def test_building_the_associated_object_with_explicit_sti_base_class
+ firm = DependentFirm.new
+ company = firm.companies.build(:type => "Company")
+ assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
+ end
+
+ def test_building_the_associated_object_with_sti_subclass
+ firm = DependentFirm.new
+ company = firm.companies.build(:type => "Client")
+ assert(company.kind_of?(Client), "Expected #{company.class} to be a Client")
+ end
+
+ def test_building_the_associated_object_with_an_invalid_type
+ firm = DependentFirm.new
+ assert_raise(ActiveRecord::SubclassNotFound) { firm.companies.build(:type => "Invalid") }
+ end
+
+ def test_building_the_associated_object_with_an_unrelated_type
+ firm = DependentFirm.new
+ assert_raise(ActiveRecord::SubclassNotFound) { firm.companies.build(:type => "Account") }
+ end
+
def test_association_keys_bypass_attribute_protection
car = Car.create(:name => 'honda')
diff --git a/activerecord/test/cases/associations/has_one_associations_test.rb b/activerecord/test/cases/associations/has_one_associations_test.rb
index ea1cfa0805..5c4f6bbb32 100644
--- a/activerecord/test/cases/associations/has_one_associations_test.rb
+++ b/activerecord/test/cases/associations/has_one_associations_test.rb
@@ -6,6 +6,8 @@ require 'models/ship'
require 'models/pirate'
require 'models/car'
require 'models/bulb'
+require 'models/author'
+require 'models/post'
class HasOneAssociationsTest < ActiveRecord::TestCase
self.use_transactional_fixtures = false unless supports_savepoints?
@@ -212,6 +214,34 @@ class HasOneAssociationsTest < ActiveRecord::TestCase
}
end
+ def test_building_the_associated_object_with_implicit_sti_base_class
+ firm = DependentFirm.new
+ company = firm.build_company
+ assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
+ end
+
+ def test_building_the_associated_object_with_explicit_sti_base_class
+ firm = DependentFirm.new
+ company = firm.build_company(:type => "Company")
+ assert(company.kind_of?(Company), "Expected #{company.class} to be a Company")
+ end
+
+ def test_building_the_associated_object_with_sti_subclass
+ firm = DependentFirm.new
+ company = firm.build_company(:type => "Client")
+ assert(company.kind_of?(Client), "Expected #{company.class} to be a Client")
+ end
+
+ def test_building_the_associated_object_with_an_invalid_type
+ firm = DependentFirm.new
+ assert_raise(ActiveRecord::SubclassNotFound) { firm.build_company(:type => "Invalid") }
+ end
+
+ def test_building_the_associated_object_with_an_unrelated_type
+ firm = DependentFirm.new
+ assert_raise(ActiveRecord::SubclassNotFound) { firm.build_company(:type => "Account") }
+ end
+
def test_build_and_create_should_not_happen_within_scope
pirate = pirates(:blackbeard)
scoped_count = pirate.association(:foo_bulb).scope.where_values.count
diff --git a/activerecord/test/cases/forbidden_attributes_protection_test.rb b/activerecord/test/cases/forbidden_attributes_protection_test.rb
index 9a2172f41e..490b599fb6 100644
--- a/activerecord/test/cases/forbidden_attributes_protection_test.rb
+++ b/activerecord/test/cases/forbidden_attributes_protection_test.rb
@@ -1,6 +1,7 @@
require 'cases/helper'
require 'active_support/core_ext/hash/indifferent_access'
require 'models/person'
+require 'models/company'
class ProtectedParams < ActiveSupport::HashWithIndifferentAccess
attr_accessor :permitted
@@ -40,6 +41,20 @@ class ForbiddenAttributesProtectionTest < ActiveRecord::TestCase
assert_equal 'm', person.gender
end
+ def test_forbidden_attributes_cannot_be_used_for_sti_inheritance_column
+ params = ProtectedParams.new(type: 'Client')
+ assert_raises(ActiveModel::ForbiddenAttributesError) do
+ Company.new(params)
+ end
+ end
+
+ def test_permitted_attributes_can_be_used_for_sti_inheritance_column
+ params = ProtectedParams.new(type: 'Client')
+ params.permit!
+ person = Company.new(params)
+ assert_equal person.class, Client
+ end
+
def test_regular_hash_should_still_be_used_for_mass_assignment
person = Person.new(first_name: 'Guille', gender: 'm')
diff --git a/activerecord/test/cases/inheritance_test.rb b/activerecord/test/cases/inheritance_test.rb
index aab7aa51dd..2466a764f6 100644
--- a/activerecord/test/cases/inheritance_test.rb
+++ b/activerecord/test/cases/inheritance_test.rb
@@ -156,6 +156,29 @@ class InheritanceTest < ActiveRecord::TestCase
assert_kind_of Cabbage, savoy
end
+ def test_inheritance_new_with_default_class
+ company = Company.new
+ assert_equal company.class, Company
+ end
+
+ def test_inheritance_new_with_base_class
+ company = Company.new(:type => 'Company')
+ assert_equal company.class, Company
+ end
+
+ def test_inheritance_new_with_subclass
+ firm = Company.new(:type => 'Firm')
+ assert_equal firm.class, Firm
+ end
+
+ def test_new_with_invalid_type
+ assert_raise(ActiveRecord::SubclassNotFound) { Company.new(:type => 'InvalidType') }
+ end
+
+ def test_new_with_unrelated_type
+ assert_raise(ActiveRecord::SubclassNotFound) { Company.new(:type => 'Account') }
+ end
+
def test_inheritance_condition
assert_equal 10, Company.count
assert_equal 2, Firm.count
diff --git a/activerecord/test/models/author.rb b/activerecord/test/models/author.rb
index 77f4a2ec87..6935cfb0ea 100644
--- a/activerecord/test/models/author.rb
+++ b/activerecord/test/models/author.rb
@@ -1,5 +1,6 @@
class Author < ActiveRecord::Base
has_many :posts
+ has_one :post
has_many :very_special_comments, :through => :posts
has_many :posts_with_comments, -> { includes(:comments) }, :class_name => "Post"
has_many :popular_grouped_posts, -> { includes(:comments).group("type").having("SUM(comments_count) > 1").select("type") }, :class_name => "Post"
diff --git a/activerecord/test/models/company.rb b/activerecord/test/models/company.rb
index 17b17724e8..3ca8f69646 100644
--- a/activerecord/test/models/company.rb
+++ b/activerecord/test/models/company.rb
@@ -111,6 +111,7 @@ end
class DependentFirm < Company
has_one :account, :foreign_key => "firm_id", :dependent => :nullify
has_many :companies, :foreign_key => 'client_of', :dependent => :nullify
+ has_one :company, :foreign_key => 'client_of', :dependent => :nullify
end
class RestrictedFirm < Company