diff options
author | Jon Leighton <j@jonathanleighton.com> | 2012-04-12 17:55:39 +0100 |
---|---|---|
committer | Jon Leighton <j@jonathanleighton.com> | 2012-04-12 17:57:54 +0100 |
commit | 8c2c60511beaad05a218e73c4918ab89fb1804f0 (patch) | |
tree | 761212dd7e1a8661f4ab30b814f2d3c4b2e86078 | |
parent | 5c51cd0b2f7ac28825bdeb1f2f49f4647be12e52 (diff) | |
download | rails-8c2c60511beaad05a218e73c4918ab89fb1804f0.tar.gz rails-8c2c60511beaad05a218e73c4918ab89fb1804f0.tar.bz2 rails-8c2c60511beaad05a218e73c4918ab89fb1804f0.zip |
Add bang versions of relation query methods.
The main reason for this is that I want to separate the code that does
the mutating from the code that does the cloning.
-rw-r--r-- | activerecord/CHANGELOG.md | 7 | ||||
-rw-r--r-- | activerecord/lib/active_record/relation/query_methods.rb | 219 | ||||
-rw-r--r-- | activerecord/test/cases/relation_test.rb | 52 |
3 files changed, 194 insertions, 84 deletions
diff --git a/activerecord/CHANGELOG.md b/activerecord/CHANGELOG.md index 26f6093bc2..85cb3e0e20 100644 --- a/activerecord/CHANGELOG.md +++ b/activerecord/CHANGELOG.md @@ -1,5 +1,12 @@ ## Rails 4.0.0 (unreleased) ## +* Added bang methods for mutating `ActiveRecord::Relation` objects. + For example, while `foo.where(:bar)` will return a new object + leaving `foo` unchanged, `foo.where!(:bar)` will mutate the foo + object + + *Jon Leighton* + * Added `#find_by` and `#find_by!` to mirror the functionality provided by dynamic finders in a way that allows dynamic input more easily: diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb index d737b34115..10492165fb 100644 --- a/activerecord/lib/active_record/relation/query_methods.rb +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -13,29 +13,32 @@ module ActiveRecord :uniq_value, :references_values def includes(*args) - args.reject! {|a| a.blank? } + args.empty? ? self : clone.includes!(*args) + end - return self if args.empty? + def includes!(*args) + args.reject! {|a| a.blank? } - relation = clone - relation.includes_values = (relation.includes_values + args).flatten.uniq - relation + self.includes_values = (includes_values + args).flatten.uniq + self end def eager_load(*args) - return self if args.blank? + args.blank? ? self : clone.eager_load!(*args) + end - relation = clone - relation.eager_load_values += args - relation + def eager_load!(*args) + self.eager_load_values += args + self end def preload(*args) - return self if args.blank? + args.blank? ? self : clone.preload!(*args) + end - relation = clone - relation.preload_values += args - relation + def preload!(*args) + self.preload_values += args + self end # Used to indicate that an association is referenced by an SQL string, and should @@ -49,11 +52,12 @@ module ActiveRecord # User.includes(:posts).where("posts.name = 'foo'").references(:posts) # # => Query now knows the string references posts, so adds a JOIN def references(*args) - return self if args.blank? + args.blank? ? self : clone.references!(*args) + end - relation = clone - relation.references_values = (references_values + args.flatten.map(&:to_s)).uniq - relation + def references!(*args) + self.references_values = (references_values + args.flatten.map(&:to_s)).uniq + self end # Works in two unique ways. @@ -87,34 +91,45 @@ module ActiveRecord # => ActiveModel::MissingAttributeError: missing attribute: other_field def select(value = Proc.new) if block_given? - to_a.select {|*block_args| value.call(*block_args) } + to_a.select { |*block_args| value.call(*block_args) } + else + clone.select!(value) + end + end + + def select!(value = Proc.new) + if block_given? + # TODO: test + to_a.select! { |*block_args| value.call(*block_args) } else - relation = clone - relation.select_values += Array.wrap(value) - relation + self.select_values += Array.wrap(value) + self end end def group(*args) - return self if args.blank? + args.blank? ? self : clone.group!(*args) + end - relation = clone - relation.group_values += args.flatten - relation + def group!(*args) + self.group_values += args.flatten + self end def order(*args) - return self if args.blank? + args.blank? ? self : clone.order!(*args) + end + def order!(*args) args = args.flatten + references = args.reject { |arg| Arel::Node === arg } .map { |arg| arg =~ /^([a-zA-Z]\w*)\.(\w+)/ && $1 } .compact + references!(references) if references.any? - relation = clone - relation = relation.references(references) if references.any? - relation.order_values += args - relation + self.order_values += args + self end # Replaces any existing order defined on the relation with the specified order. @@ -128,72 +143,88 @@ module ActiveRecord # generates a query with 'ORDER BY id ASC, name ASC'. # def reorder(*args) - return self if args.blank? + args.blank? ? self : clone.reorder!(*args) + end - relation = clone - relation.reordering_value = true - relation.order_values = args.flatten - relation + def reorder!(*args) + self.reordering_value = true + self.order_values = args.flatten + self end def joins(*args) - return self if args.compact.blank? - - relation = clone + args.compact.blank? ? self : clone.joins!(*args) + end + def joins!(*args) args.flatten! - relation.joins_values += args - relation + self.joins_values += args + self end def bind(value) - relation = clone - relation.bind_values += [value] - relation + clone.bind!(value) + end + + def bind!(value) + self.bind_values += [value] + self end def where(opts, *rest) - return self if opts.blank? + opts.blank? ? self : clone.where!(opts, *rest) + end + + def where!(opts, *rest) + references!(PredicateBuilder.references(opts)) if Hash === opts - relation = clone - relation = relation.references(PredicateBuilder.references(opts)) if Hash === opts - relation.where_values += build_where(opts, rest) - relation + self.where_values += build_where(opts, rest) + self end def having(opts, *rest) - return self if opts.blank? + opts.blank? ? self : clone.having!(opts, *rest) + end + + def having!(opts, *rest) + references!(PredicateBuilder.references(opts)) if Hash === opts - relation = clone - relation = relation.references(PredicateBuilder.references(opts)) if Hash === opts - relation.having_values += build_where(opts, rest) - relation + self.having_values += build_where(opts, rest) + self end def limit(value) - relation = clone - relation.limit_value = value - relation + clone.limit!(value) + end + + def limit!(value) + self.limit_value = value + self end def offset(value) - relation = clone - relation.offset_value = value - relation + clone.offset!(value) + end + + def offset!(value) + self.offset_value = value + self end def lock(locks = true) - relation = clone + clone.lock!(locks) + end + def lock!(locks = true) case locks when String, TrueClass, NilClass - relation.lock_value = locks || true + self.lock_value = locks || true else - relation.lock_value = false + self.lock_value = false end - relation + self end # Returns a chainable relation with zero records, specifically an @@ -230,21 +261,30 @@ module ActiveRecord end def readonly(value = true) - relation = clone - relation.readonly_value = value - relation + clone.readonly!(value) + end + + def readonly!(value = true) + self.readonly_value = value + self end def create_with(value) - relation = clone - relation.create_with_value = value ? create_with_value.merge(value) : {} - relation + clone.create_with!(value) + end + + def create_with!(value) + self.create_with_value = value ? create_with_value.merge(value) : {} + self end def from(value) - relation = clone - relation.from_value = value - relation + clone.from!(value) + end + + def from!(value) + self.from_value = value + self end # Specifies whether the records should be unique or not. For example: @@ -258,9 +298,12 @@ module ActiveRecord # User.select(:name).uniq.uniq(false) # # => You can also remove the uniqueness def uniq(value = true) - relation = clone - relation.uniq_value = value - relation + clone.uniq!(value) + end + + def uniq!(value = true) + self.uniq_value = value + self end # Used to extend a scope with additional methods, either through @@ -299,20 +342,28 @@ module ActiveRecord # # pagination code goes here # end # end - def extending(*modules) - modules << Module.new(&Proc.new) if block_given? + def extending(*modules, &block) + if modules.any? || block + clone.extending!(*modules, &block) + else + self + end + end - return self if modules.empty? + def extending!(*modules, &block) + modules << Module.new(&block) if block_given? - relation = clone - relation.send(:apply_modules, modules.flatten) - relation + self.send(:apply_modules, modules.flatten) + self end def reverse_order - relation = clone - relation.reverse_order_value = !relation.reverse_order_value - relation + clone.reverse_order! + end + + def reverse_order! + self.reverse_order_value = !reverse_order_value + self end def arel diff --git a/activerecord/test/cases/relation_test.rb b/activerecord/test/cases/relation_test.rb index ac6dee3c6a..4b18c37d27 100644 --- a/activerecord/test/cases/relation_test.rb +++ b/activerecord/test/cases/relation_test.rb @@ -155,4 +155,56 @@ module ActiveRecord assert_equal ['foo'], relation.references_values end end + + class RelationMutationTest < ActiveSupport::TestCase + def relation + @relation ||= Relation.new :a, :b + end + + (Relation::ASSOCIATION_METHODS + Relation::MULTI_VALUE_METHODS - [:references]).each do |method| + test "##{method}!" do + assert relation.public_send("#{method}!", :foo).equal?(relation) + assert_equal [:foo], relation.public_send("#{method}_values") + end + end + + test '#references!' do + assert relation.references!(:foo).equal?(relation) + assert relation.references_values.include?('foo') + end + + (Relation::SINGLE_VALUE_METHODS - [:lock, :reordering, :reverse_order]).each do |method| + test "##{method}!" do + assert relation.public_send("#{method}!", :foo).equal?(relation) + assert_equal :foo, relation.public_send("#{method}_value") + end + end + + test '#lock!' do + assert relation.lock!('foo').equal?(relation) + assert_equal 'foo', relation.lock_value + end + + test '#reorder!' do + relation = self.relation.order('foo') + + assert relation.reorder!('bar').equal?(relation) + assert_equal ['bar'], relation.order_values + assert relation.reordering_value + end + + test 'reverse_order!' do + assert relation.reverse_order!.equal?(relation) + assert relation.reverse_order_value + relation.reverse_order! + assert !relation.reverse_order_value + end + + test 'extending!' do + mod = Module.new + + assert relation.extending!(mod).equal?(relation) + assert relation.is_a?(mod) + end + end end |