From caaf40d5358ae8a2b31949c2af2d94be1be73976 Mon Sep 17 00:00:00 2001 From: David Heinemeier Hansson Date: Sat, 24 Sep 2005 19:50:57 +0000 Subject: Added AbstractAdapter#select_value and AbstractAdapter#select_values as convenience methods for selecting single values, instead of hashes, of the first column in a SELECT #2283 git-svn-id: http://svn-commit.rubyonrails.org/rails/trunk@2323 5ecf4fe2-1ee6-0310-87b1-e25e094e27de --- activerecord/lib/active_record/base.rb | 8 +------- .../active_record/connection_adapters/abstract_adapter.rb | 13 +++++++++++++ activerecord/lib/active_record/transactions.rb | 4 +--- activerecord/test/finder_test.rb | 13 +++++++++++++ 4 files changed, 28 insertions(+), 10 deletions(-) (limited to 'activerecord') diff --git a/activerecord/lib/active_record/base.rb b/activerecord/lib/active_record/base.rb index 73c69a91d1..03123758ef 100755 --- a/activerecord/lib/active_record/base.rb +++ b/activerecord/lib/active_record/base.rb @@ -464,13 +464,7 @@ module ActiveRecord #:nodoc: # Product.count_by_sql "SELECT COUNT(*) FROM sales s, customers c WHERE s.customer_id = c.id" def count_by_sql(sql) sql = sanitize_conditions(sql) - rows = connection.select_one(sql, "#{name} Count") - - if !rows.nil? and count = rows.values.first - count.to_i - else - 0 - end + connection.select_value(sql, "#{name} Count").to_i end # Increments the specified counter by one. So DiscussionBoard.increment_counter("post_count", diff --git a/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb b/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb index a1a71afb5b..fed51fea18 100755 --- a/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb @@ -275,6 +275,19 @@ module ActiveRecord # Returns a record hash with the column names as a keys and fields as values. def select_one(sql, name = nil) end + # Returns a single value from a record + def select_value(sql, name = nil) + result = select_one(sql, name) + result.nil? ? nil : result.values.first + end + + # Returns an array of the values of the first column in a select: + # select_values("SELECT id FROM companies LIMIT 3") => [1,2,3] + def select_values(sql, name = nil) + result = select_all(sql, name) + result.map{ |v| v.values.first } + end + # Returns an array of table names for the current database. # def tables(name = nil) end diff --git a/activerecord/lib/active_record/transactions.rb b/activerecord/lib/active_record/transactions.rb index de2937a96d..e9a612495c 100644 --- a/activerecord/lib/active_record/transactions.rb +++ b/activerecord/lib/active_record/transactions.rb @@ -81,9 +81,7 @@ module ActiveRecord # Tribute: Object-level transactions are implemented by Transaction::Simple by Austin Ziegler. module ClassMethods def transaction(*objects, &block) - previous_handler = trap('TERM') do - raise TransactionError, "Transaction aborted" - end + previous_handler = trap('TERM') { raise TransactionError, "Transaction aborted" } lock_mutex begin diff --git a/activerecord/test/finder_test.rb b/activerecord/test/finder_test.rb index 2f23503ce1..8fc89f3808 100644 --- a/activerecord/test/finder_test.rb +++ b/activerecord/test/finder_test.rb @@ -311,6 +311,19 @@ class FinderTest < Test::Unit::TestCase assert developer_names.include?('Jamis') end + def test_select_value + assert_equal "37signals", Company.connection.select_value("SELECT name FROM companies WHERE id = 1") + assert_nil Company.connection.select_value("SELECT name FROM companies WHERE id = -1") + # make sure we didn't break count... + assert_equal 0, Company.count_by_sql("SELECT COUNT(*) FROM companies WHERE name = 'Halliburton'") + assert_equal 1, Company.count_by_sql("SELECT COUNT(*) FROM companies WHERE name = '37signals'") + end + + def test_select_values + assert_equal ["1","2","3"], Company.connection.select_values("SELECT id FROM companies ORDER BY id LIMIT 3") + assert_equal ["37signals","Summit","Microsoft"], Company.connection.select_values("SELECT name FROM companies ORDER BY id LIMIT 3") + end + protected def bind(statement, *vars) if vars.first.is_a?(Hash) -- cgit v1.2.3