From 08633bae5e4f05e913ec5d5d2483bfd6c07c7375 Mon Sep 17 00:00:00 2001 From: Pratik Naik Date: Tue, 29 Dec 2009 04:28:43 +0530 Subject: Migrate all the calculation methods to Relation --- activerecord/lib/active_record/associations.rb | 11 +- .../associations/association_collection.rb | 2 +- activerecord/lib/active_record/base.rb | 4 + activerecord/lib/active_record/calculations.rb | 200 ++++++--------------- activerecord/lib/active_record/relation.rb | 7 +- .../lib/active_record/relational_calculations.rb | 147 +++++++++++++-- .../associations/inner_join_association_test.rb | 2 + activerecord/test/cases/calculations_test.rb | 18 +- 8 files changed, 213 insertions(+), 178 deletions(-) (limited to 'activerecord') diff --git a/activerecord/lib/active_record/associations.rb b/activerecord/lib/active_record/associations.rb index 0b248a6c55..f0bad6c3ba 100755 --- a/activerecord/lib/active_record/associations.rb +++ b/activerecord/lib/active_record/associations.rb @@ -1466,11 +1466,10 @@ module ActiveRecord end def find_with_associations(options = {}, join_dependency = nil) - catch :invalid_query do - join_dependency ||= JoinDependency.new(self, merge_includes(scope(:find, :include), options[:include]), options[:joins]) - rows = select_all_rows(options, join_dependency) - return join_dependency.instantiate(rows) - end + join_dependency ||= JoinDependency.new(self, merge_includes(scope(:find, :include), options[:include]), options[:joins]) + rows = select_all_rows(options, join_dependency) + join_dependency.instantiate(rows) + rescue ThrowResult [] end @@ -1733,7 +1732,7 @@ module ActiveRecord def construct_arel_limited_ids_condition(options, join_dependency) if (ids_array = select_limited_ids_array(options, join_dependency)).empty? - throw :invalid_query + raise ThrowResult else Arel::Predicates::In.new( Arel::SqlLiteral.new("#{connection.quote_table_name table_name}.#{primary_key}"), diff --git a/activerecord/lib/active_record/associations/association_collection.rb b/activerecord/lib/active_record/associations/association_collection.rb index 56b2a90138..b2b3a9789c 100644 --- a/activerecord/lib/active_record/associations/association_collection.rb +++ b/activerecord/lib/active_record/associations/association_collection.rb @@ -177,7 +177,7 @@ module ActiveRecord if @reflection.options[:counter_sql] @reflection.klass.count_by_sql(@counter_sql) else - column_name, options = @reflection.klass.send(:construct_count_options_from_args, *args) + column_name, options = @reflection.klass.scoped.send(:construct_count_options_from_args, *args) if @reflection.options[:uniq] # This is needed because 'SELECT count(DISTINCT *)..' is not valid SQL. column_name = "#{@reflection.quoted_table_name}.#{@reflection.klass.primary_key}" if column_name == :all diff --git a/activerecord/lib/active_record/base.rb b/activerecord/lib/active_record/base.rb index 767109474d..53f0a920a3 100755 --- a/activerecord/lib/active_record/base.rb +++ b/activerecord/lib/active_record/base.rb @@ -69,6 +69,10 @@ module ActiveRecord #:nodoc: class StatementInvalid < ActiveRecordError end + # Raised when SQL statement is invalid and the application gets a blank result. + class ThrowResult < ActiveRecordError + end + # Parent class for all specific exceptions which wrap database driver exceptions # provides access to the original exception also. class WrappedDatabaseException < StatementInvalid diff --git a/activerecord/lib/active_record/calculations.rb b/activerecord/lib/active_record/calculations.rb index 59811c40a8..d51d9f2159 100644 --- a/activerecord/lib/active_record/calculations.rb +++ b/activerecord/lib/active_record/calculations.rb @@ -44,7 +44,26 @@ module ActiveRecord # # Note: Person.count(:all) will not work because it will use :all as the condition. Use Person.count instead. def count(*args) - calculate(:count, *construct_count_options_from_args(*args)) + case args.size + when 0 + construct_calculation_arel.count + when 1 + if args[0].is_a?(Hash) + options = args[0] + distinct = options.has_key?(:distinct) ? options.delete(:distinct) : false + construct_calculation_arel(options).count(options[:select], :distinct => distinct) + else + construct_calculation_arel.count(args[0]) + end + when 2 + column_name, options = args + distinct = options.has_key?(:distinct) ? options.delete(:distinct) : false + construct_calculation_arel(options).count(column_name, :distinct => distinct) + else + raise ArgumentError, "Unexpected parameters passed to count(): #{args.inspect}" + end + rescue ThrowResult + 0 end # Calculates the average value on a given column. The value is returned as @@ -122,168 +141,63 @@ module ActiveRecord # Person.minimum(:age, :having => 'min(age) > 17', :group => :last_name) # Selects the minimum age for any family without any minors # Person.sum("2 * age") def calculate(operation, column_name, options = {}) - validate_calculation_options(operation, options) - operation = operation.to_s.downcase - - scope = scope(:find) + construct_calculation_arel(options).calculate(operation, column_name, options.slice(:distinct)) + rescue ThrowResult + 0 + end - merged_includes = merge_includes(scope ? scope[:include] : [], options[:include]) + private + def validate_calculation_options(options = {}) + options.assert_valid_keys(CALCULATIONS_OPTIONS) + end - if operation == "count" - if merged_includes.any? - distinct = true - column_name = options[:select] || primary_key - end + def construct_calculation_arel(options = {}) + validate_calculation_options(options) + options = options.except(:distinct) - distinct = nil if column_name.to_s =~ /\s*DISTINCT\s+/i - distinct ||= options[:distinct] - else - distinct = nil - end + scope = scope(:find) + includes = merge_includes(scope ? scope[:include] : [], options[:include]) - catch :invalid_query do - relation = if merged_includes.any? - join_dependency = ActiveRecord::Associations::ClassMethods::JoinDependency.new(self, merged_includes, construct_join(options[:joins], scope)) - construct_finder_arel_with_included_associations(options, join_dependency) + if includes.any? + join_dependency = ActiveRecord::Associations::ClassMethods::JoinDependency.new(self, includes, construct_join(options[:joins], scope)) + construct_calculation_arel_with_included_associations(options, join_dependency) else - relation = arel_table(options[:from]). + arel_table. joins(construct_join(options[:joins], scope)). + from((scope && scope[:from]) || options[:from]). where(construct_conditions(options[:conditions], scope)). order(options[:order]). limit(options[:limit]). - offset(options[:offset]) - end - if options[:group] - return execute_grouped_calculation(operation, column_name, options, relation) - else - return execute_simple_calculation(operation, column_name, options.merge(:distinct => distinct), relation) + offset(options[:offset]). + group(options[:group]). + having(options[:having]). + select(options[:select] || (scope && scope[:select]) || default_select(options[:joins] || (scope && scope[:joins]))) end end - 0 - end - - def execute_simple_calculation(operation, column_name, options, relation) #:nodoc: - column = if column_names.include?(column_name.to_s) - Arel::Attribute.new(arel_table(options[:from] || table_name), - options[:select] || column_name) - else - Arel::SqlLiteral.new(options[:select] || - (column_name == :all ? "*" : column_name.to_s)) - end - - relation = relation.select(operation == 'count' ? column.count(options[:distinct]) : column.send(operation)) - - type_cast_calculated_value(connection.select_value(relation.to_sql), column_for(column_name), operation) - end - def execute_grouped_calculation(operation, column_name, options, relation) #:nodoc: - group_attr = options[:group].to_s - association = reflect_on_association(group_attr.to_sym) - associated = association && association.macro == :belongs_to # only count belongs_to associations - group_field = associated ? association.primary_key_name : group_attr - group_alias = column_alias_for(group_field) - group_column = column_for group_field + def construct_calculation_arel_with_included_associations(options, join_dependency) + scope = scope(:find) - options[:group] = connection.adapter_name == 'FrontBase' ? group_alias : group_field + relation = arel_table - aggregate_alias = column_alias_for(operation, column_name) - - options[:select] = (operation == 'count' && column_name == :all) ? - "COUNT(*) AS count_all" : - Arel::Attribute.new(arel_table, column_name).send(operation).as(aggregate_alias).to_sql - - options[:select] << ", #{group_field} AS #{group_alias}" - - relation = relation.select(options[:select]).group(options[:group]).having(options[:having]) - - calculated_data = connection.select_all(relation.to_sql) - - if association - key_ids = calculated_data.collect { |row| row[group_alias] } - key_records = association.klass.base_class.find(key_ids) - key_records = key_records.inject({}) { |hsh, r| hsh.merge(r.id => r) } - end - - calculated_data.inject(ActiveSupport::OrderedHash.new) do |all, row| - key = type_cast_calculated_value(row[group_alias], group_column) - key = key_records[key] if associated - value = row[aggregate_alias] - all[key] = type_cast_calculated_value(value, column_for(column_name), operation) - all - end - end - - protected - def construct_count_options_from_args(*args) - options = {} - column_name = :all - - # We need to handle - # count() - # count(:column_name=:all) - # count(options={}) - # count(column_name=:all, options={}) - # selects specified by scopes - case args.size - when 0 - column_name = scope(:find)[:select] if scope(:find) - when 1 - if args[0].is_a?(Hash) - column_name = scope(:find)[:select] if scope(:find) - options = args[0] - else - column_name = args[0] - end - when 2 - column_name, options = args - else - raise ArgumentError, "Unexpected parameters passed to count(): #{args.inspect}" + for association in join_dependency.join_associations + relation = association.join_relation(relation) end - [column_name || :all, options] - end - - private - def validate_calculation_options(operation, options = {}) - options.assert_valid_keys(CALCULATIONS_OPTIONS) - end + relation = relation.joins(construct_join(options[:joins], scope)). + select(column_aliases(join_dependency)). + group(options[:group]). + having(options[:having]). + order(options[:order]). + where(construct_conditions(options[:conditions], scope)). + from((scope && scope[:from]) || options[:from]) - # Converts the given keys to the value that the database adapter returns as - # a usable column name: - # - # column_alias_for("users.id") # => "users_id" - # column_alias_for("sum(id)") # => "sum_id" - # column_alias_for("count(distinct users.id)") # => "count_distinct_users_id" - # column_alias_for("count(*)") # => "count_all" - # column_alias_for("count", "id") # => "count_id" - def column_alias_for(*keys) - table_name = keys.join(' ') - table_name.downcase! - table_name.gsub!(/\*/, 'all') - table_name.gsub!(/\W+/, ' ') - table_name.strip! - table_name.gsub!(/ +/, '_') + relation = relation.where(construct_arel_limited_ids_condition(options, join_dependency)) if !using_limitable_reflections?(join_dependency.reflections) && ((scope && scope[:limit]) || options[:limit]) + relation = relation.limit(construct_limit(options[:limit], scope)) if using_limitable_reflections?(join_dependency.reflections) - connection.table_alias_for(table_name) + relation end - def column_for(field) - field_name = field.to_s.split('.').last - columns.detect { |c| c.name.to_s == field_name } - end - - def type_cast_calculated_value(value, column, operation = nil) - case operation - when 'count' then value.to_i - when 'sum' then type_cast_using_column(value || '0', column) - when 'average' then value && (value.is_a?(Fixnum) ? value.to_f : value).to_d - else type_cast_using_column(value, column) - end - end - - def type_cast_using_column(value, column) - column ? column.type_cast(value) : value - end end end end diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb index cb743358b2..e495aa80db 100644 --- a/activerecord/lib/active_record/relation.rb +++ b/activerecord/lib/active_record/relation.rb @@ -149,8 +149,8 @@ module ActiveRecord return @records if loaded? @records = if @eager_load_associations.any? - catch :invalid_query do - return @klass.send(:find_with_associations, { + begin + @klass.send(:find_with_associations, { :select => @relation.send(:select_clauses).join(', '), :joins => @relation.joins(relation), :group => @relation.send(:group_clauses).join(', '), @@ -161,8 +161,9 @@ module ActiveRecord :from => (@relation.send(:from_clauses) if @relation.send(:sources).present?) }, ActiveRecord::Associations::ClassMethods::JoinDependency.new(@klass, @eager_load_associations, nil)) + rescue ThrowResult + [] end - [] else @klass.find_by_sql(@relation.to_sql) end diff --git a/activerecord/lib/active_record/relational_calculations.rb b/activerecord/lib/active_record/relational_calculations.rb index 10eb992167..d77624c7bf 100644 --- a/activerecord/lib/active_record/relational_calculations.rb +++ b/activerecord/lib/active_record/relational_calculations.rb @@ -2,39 +2,119 @@ module ActiveRecord module RelationalCalculations def count(*args) - column_name, options = construct_count_options_from_args(*args) - distinct = options[:distinct] ? true : false + calculate(:count, *construct_count_options_from_args(*args)) + end + + def average(column_name) + calculate(:average, column_name) + end + + def minimum(column_name) + calculate(:minimum, column_name) + end + + def maximum(column_name) + calculate(:maximum, column_name) + end + + def sum(column_name) + calculate(:sum, column_name) + end + + def calculate(operation, column_name, options = {}) + operation = operation.to_s.downcase + + if operation == "count" + joins = @relation.joins(relation) + if joins.present? && joins =~ /LEFT OUTER/i + distinct = true + column_name = @klass.primary_key if column_name == :all + end + distinct = nil if column_name.to_s =~ /\s*DISTINCT\s+/i + distinct ||= options[:distinct] + else + distinct = nil + end + + distinct = options[:distinct] || distinct + column_name = :all if column_name.blank? && operation == "count" + + if @relation.send(:groupings).any? + return execute_grouped_calculation(operation, column_name) + else + return execute_simple_calculation(operation, column_name, distinct) + end + rescue ThrowResult + 0 + end + + private + + def execute_simple_calculation(operation, column_name, distinct) #:nodoc: column = if @klass.column_names.include?(column_name.to_s) - Arel::Attribute.new(@relation.table, column_name) + Arel::Attribute.new(@klass.arel_table, column_name) else Arel::SqlLiteral.new(column_name == :all ? "*" : column_name.to_s) end - relation = select(column.count(distinct)) - @klass.connection.select_value(relation.to_sql).to_i + relation = select(operation == 'count' ? column.count(distinct) : column.send(operation)) + type_cast_calculated_value(@klass.connection.select_value(relation.to_sql), column_for(column_name), operation) end - private + def execute_grouped_calculation(operation, column_name) #:nodoc: + group_attr = @relation.send(:groupings).first.value + association = @klass.reflect_on_association(group_attr.to_sym) + associated = association && association.macro == :belongs_to # only count belongs_to associations + group_field = associated ? association.primary_key_name : group_attr + group_alias = column_alias_for(group_field) + group_column = column_for(group_field) + + group = @klass.connection.adapter_name == 'FrontBase' ? group_alias : group_field + + aggregate_alias = column_alias_for(operation, column_name) + + select_statement = if operation == 'count' && column_name == :all + "COUNT(*) AS count_all" + else + Arel::Attribute.new(@klass.arel_table, column_name).send(operation).as(aggregate_alias).to_sql + end + + select_statement << ", #{group_field} AS #{group_alias}" + + relation = select(select_statement).group(group) + + calculated_data = @klass.connection.select_all(relation.to_sql) + + if association + key_ids = calculated_data.collect { |row| row[group_alias] } + key_records = association.klass.base_class.find(key_ids) + key_records = key_records.inject({}) { |hsh, r| hsh.merge(r.id => r) } + end + + calculated_data.inject(ActiveSupport::OrderedHash.new) do |all, row| + key = type_cast_calculated_value(row[group_alias], group_column) + key = key_records[key] if associated + value = row[aggregate_alias] + all[key] = type_cast_calculated_value(value, column_for(column_name), operation) + all + end + end def construct_count_options_from_args(*args) options = {} column_name = :all - # We need to handle - # count() - # count(:column_name=:all) - # count(options={}) - # count(column_name=:all, options={}) - # selects specified by scopes - + # Handles count(), count(:column), count(:distinct => true), count(:column, :distinct => true) # TODO : relation.projections only works when .select() was last in the chain. Fix it! case args.size when 0 - column_name = @relation.send(:select_clauses).join(', ') if @relation.respond_to?(:projections) && @relation.projections.present? + select = @relation.send(:select_clauses).join(', ') if @relation.respond_to?(:projections) && @relation.projections.present? + column_name = select if select !~ /(,|\*)/ when 1 if args[0].is_a?(Hash) - column_name = @relation.send(:select_clauses).join(', ') if @relation.respond_to?(:projections) && @relation.projections.present? + select = @relation.send(:select_clauses).join(', ') if @relation.respond_to?(:projections) && @relation.projections.present? + column_name = select if select !~ /(,|\*)/ options = args[0] else column_name = args[0] @@ -48,5 +128,42 @@ module ActiveRecord [column_name || :all, options] end + # Converts the given keys to the value that the database adapter returns as + # a usable column name: + # + # column_alias_for("users.id") # => "users_id" + # column_alias_for("sum(id)") # => "sum_id" + # column_alias_for("count(distinct users.id)") # => "count_distinct_users_id" + # column_alias_for("count(*)") # => "count_all" + # column_alias_for("count", "id") # => "count_id" + def column_alias_for(*keys) + table_name = keys.join(' ') + table_name.downcase! + table_name.gsub!(/\*/, 'all') + table_name.gsub!(/\W+/, ' ') + table_name.strip! + table_name.gsub!(/ +/, '_') + + @klass.connection.table_alias_for(table_name) + end + + def column_for(field) + field_name = field.to_s.split('.').last + @klass.columns.detect { |c| c.name.to_s == field_name } + end + + def type_cast_calculated_value(value, column, operation = nil) + case operation + when 'count' then value.to_i + when 'sum' then type_cast_using_column(value || '0', column) + when 'average' then value && (value.is_a?(Fixnum) ? value.to_f : value).to_d + else type_cast_using_column(value, column) + end + end + + def type_cast_using_column(value, column) + column ? column.type_cast(value) : value + end + end end diff --git a/activerecord/test/cases/associations/inner_join_association_test.rb b/activerecord/test/cases/associations/inner_join_association_test.rb index 18a1cd3cd0..0315604106 100644 --- a/activerecord/test/cases/associations/inner_join_association_test.rb +++ b/activerecord/test/cases/associations/inner_join_association_test.rb @@ -81,6 +81,8 @@ class InnerJoinAssociationTest < ActiveRecord::TestCase end def test_calculate_honors_implicit_inner_joins + Author.calculate(:count, 'authors.id', :joins => :posts) + return real_count = Author.scoped.to_a.sum{|a| a.posts.count } assert_equal real_count, Author.calculate(:count, 'authors.id', :joins => :posts), "plain inner join count should match the number of referenced posts records" end diff --git a/activerecord/test/cases/calculations_test.rb b/activerecord/test/cases/calculations_test.rb index 7e4c4a0913..bd2d471fc7 100644 --- a/activerecord/test/cases/calculations_test.rb +++ b/activerecord/test/cases/calculations_test.rb @@ -29,8 +29,8 @@ class CalculationsTest < ActiveRecord::TestCase end def test_type_cast_calculated_value_should_convert_db_averages_of_fixnum_class_to_decimal - assert_equal 0, NumericData.send(:type_cast_calculated_value, 0, nil, 'avg') - assert_equal 53.0, NumericData.send(:type_cast_calculated_value, 53, nil, 'avg') + assert_equal 0, NumericData.scoped.send(:type_cast_calculated_value, 0, nil, 'avg') + assert_equal 53.0, NumericData.scoped.send(:type_cast_calculated_value, 53, nil, 'avg') end def test_should_get_maximum_of_field @@ -248,17 +248,15 @@ class CalculationsTest < ActiveRecord::TestCase def test_should_reject_invalid_options assert_nothing_raised do - [:count, :sum].each do |func| - # empty options are valid - Company.send(:validate_calculation_options, func) - # these options are valid for all calculations - [:select, :conditions, :joins, :order, :group, :having, :distinct].each do |opt| - Company.send(:validate_calculation_options, func, opt => true) - end + # empty options are valid + Company.send(:validate_calculation_options) + # these options are valid for all calculations + [:select, :conditions, :joins, :order, :group, :having, :distinct].each do |opt| + Company.send(:validate_calculation_options, opt => true) end # :include is only valid on :count - Company.send(:validate_calculation_options, :count, :include => true) + Company.send(:validate_calculation_options, :include => true) end assert_raise(ArgumentError) { Company.send(:validate_calculation_options, :sum, :foo => :bar) } -- cgit v1.2.3