aboutsummaryrefslogtreecommitdiffstats
path: root/activesupport
diff options
context:
space:
mode:
Diffstat (limited to 'activesupport')
-rw-r--r--activesupport/CHANGELOG2
-rw-r--r--activesupport/lib/active_support/core_ext/enumerable.rb9
-rw-r--r--activesupport/test/core_ext/enumerable_test.rb8
3 files changed, 17 insertions, 2 deletions
diff --git a/activesupport/CHANGELOG b/activesupport/CHANGELOG
index 62450807c2..5d79c19385 100644
--- a/activesupport/CHANGELOG
+++ b/activesupport/CHANGELOG
@@ -1,5 +1,7 @@
*SVN*
+* Optional identity for Enumerable#sum defaults to zero. #5657 [gensym@mac.com]
+
* HashWithIndifferentAccess shouldn't confuse false and nil. #5601 [shugo@ruby-lang.org]
* Fixed HashWithIndifferentAccess#default #5586 [chris@seagul.co.uk]
diff --git a/activesupport/lib/active_support/core_ext/enumerable.rb b/activesupport/lib/active_support/core_ext/enumerable.rb
index 92304a23f6..59128007df 100644
--- a/activesupport/lib/active_support/core_ext/enumerable.rb
+++ b/activesupport/lib/active_support/core_ext/enumerable.rb
@@ -30,7 +30,14 @@ module Enumerable #:nodoc:
#
# Also calculates sums without the use of a block:
# [5, 15, 10].sum # => 30
- def sum(&block)
+ #
+ # The default identity (sum of an empty list) is zero.
+ # However, you can override this default:
+ #
+ # [].sum(Payment.new(0)) { |i| i.amount } # => Payment.new(0)
+ #
+ def sum(identity = 0, &block)
+ return identity unless size > 0
if block_given?
map(&block).sum
else
diff --git a/activesupport/test/core_ext/enumerable_test.rb b/activesupport/test/core_ext/enumerable_test.rb
index 3180755e5f..0590846b7b 100644
--- a/activesupport/test/core_ext/enumerable_test.rb
+++ b/activesupport/test/core_ext/enumerable_test.rb
@@ -44,7 +44,13 @@ class EnumerableTests < Test::Unit::TestCase
assert_raise(TypeError) { payments.sum(&:price) }
assert_equal 60, payments.sum { |p| p.price.to_i * 2 }
end
-
+
+ def test_empty_sums
+ assert_equal 0, [].sum
+ assert_equal 0, [].sum { |i| i }
+ assert_equal Payment.new(0), [].sum(Payment.new(0))
+ end
+
def test_index_by
payments = [ Payment.new(5), Payment.new(15), Payment.new(10) ]
assert_equal(