diff options
Diffstat (limited to 'activerecord/lib/active_record/relation/calculations.rb')
-rw-r--r-- | activerecord/lib/active_record/relation/calculations.rb | 320 |
1 files changed, 159 insertions, 161 deletions
diff --git a/activerecord/lib/active_record/relation/calculations.rb b/activerecord/lib/active_record/relation/calculations.rb index 54c9af4898..e4676f79a5 100644 --- a/activerecord/lib/active_record/relation/calculations.rb +++ b/activerecord/lib/active_record/relation/calculations.rb @@ -37,7 +37,11 @@ module ActiveRecord # Note: not all valid {Relation#select}[rdoc-ref:QueryMethods#select] expressions are valid #count expressions. The specifics differ # between databases. In invalid cases, an error from the database is thrown. def count(column_name = nil) - calculate(:count, column_name) + if block_given? + to_a.count { |*block_args| yield(*block_args) } + else + calculate(:count, column_name) + end end # Calculates the average value on a given column. Returns +nil+ if there's @@ -89,7 +93,7 @@ module ActiveRecord # # There are two basic forms of output: # - # * Single aggregate value: The single value is type cast to Fixnum for COUNT, Float + # * Single aggregate value: The single value is type cast to Integer for COUNT, Float # for AVG, and the given column's type for everything else. # # * Grouped values: This returns an ordered hash of the values and groups them. It @@ -108,12 +112,11 @@ module ActiveRecord # ... # end def calculate(operation, column_name) - if column_name.is_a?(Symbol) && attribute_alias?(column_name) - column_name = attribute_alias(column_name) - end - if has_include?(column_name) - construct_relation_for_association_calculations.calculate(operation, column_name) + relation = construct_relation_for_association_calculations + relation = relation.distinct if operation.to_s.downcase == "count" + + relation.calculate(operation, column_name) else perform_calculation(operation, column_name) end @@ -156,7 +159,7 @@ module ActiveRecord # def pluck(*column_names) if loaded? && (column_names.map(&:to_s) - @klass.attribute_names - @klass.attribute_aliases.keys).empty? - return @records.pluck(*column_names) + return records.pluck(*column_names) end if has_include?(column_names.first) @@ -181,201 +184,196 @@ module ActiveRecord private - def has_include?(column_name) - eager_loading? || (includes_values.present? && column_name && column_name != :all) - end - - def perform_calculation(operation, column_name) - operation = operation.to_s.downcase + def has_include?(column_name) + eager_loading? || (includes_values.present? && column_name && column_name != :all) + end - # If #count is used with #distinct (i.e. `relation.distinct.count`) it is - # considered distinct. - distinct = self.distinct_value + def perform_calculation(operation, column_name) + operation = operation.to_s.downcase - if operation == "count" - column_name ||= select_for_count + # If #count is used with #distinct (i.e. `relation.distinct.count`) it is + # considered distinct. + distinct = self.distinct_value - unless arel.ast.grep(Arel::Nodes::OuterJoin).empty? - distinct = true + if operation == "count" + column_name ||= select_for_count + column_name = primary_key if column_name == :all && distinct + distinct = nil if column_name =~ /\s*DISTINCT[\s(]+/i end - column_name = primary_key if column_name == :all && distinct - distinct = nil if column_name =~ /\s*DISTINCT[\s(]+/i - end - - if group_values.any? - execute_grouped_calculation(operation, column_name, distinct) - else - execute_simple_calculation(operation, column_name, distinct) + if group_values.any? + execute_grouped_calculation(operation, column_name, distinct) + else + execute_simple_calculation(operation, column_name, distinct) + end end - end - def aggregate_column(column_name) - return column_name if Arel::Expressions === column_name + def aggregate_column(column_name) + return column_name if Arel::Expressions === column_name - if @klass.column_names.include?(column_name.to_s) - Arel::Attribute.new(@klass.unscoped.table, column_name) - else - Arel.sql(column_name == :all ? "*" : column_name.to_s) + if @klass.has_attribute?(column_name.to_s) || @klass.attribute_alias?(column_name.to_s) + @klass.arel_attribute(column_name) + else + Arel.sql(column_name == :all ? "*" : column_name.to_s) + end end - end - - def operation_over_aggregate_column(column, operation, distinct) - operation == 'count' ? column.count(distinct) : column.send(operation) - end - - def execute_simple_calculation(operation, column_name, distinct) #:nodoc: - # PostgreSQL doesn't like ORDER BY when there are no GROUP BY - relation = unscope(:order) - column_alias = column_name + def operation_over_aggregate_column(column, operation, distinct) + operation == "count" ? column.count(distinct) : column.send(operation) + end - if operation == "count" && (relation.limit_value || relation.offset_value) - # Shortcut when limit is zero. - return 0 if relation.limit_value == 0 + def execute_simple_calculation(operation, column_name, distinct) #:nodoc: + # PostgreSQL doesn't like ORDER BY when there are no GROUP BY + relation = unscope(:order) - query_builder = build_count_subquery(relation, column_name, distinct) - else - column = aggregate_column(column_name) + column_alias = column_name - select_value = operation_over_aggregate_column(column, operation, distinct) + if operation == "count" && (relation.limit_value || relation.offset_value) + # Shortcut when limit is zero. + return 0 if relation.limit_value == 0 - column_alias = select_value.alias - column_alias ||= @klass.connection.column_name_for_operation(operation, select_value) - relation.select_values = [select_value] + query_builder = build_count_subquery(relation, column_name, distinct) + else + column = aggregate_column(column_name) - query_builder = relation.arel - end + select_value = operation_over_aggregate_column(column, operation, distinct) - result = @klass.connection.select_all(query_builder, nil, bound_attributes) - row = result.first - value = row && row.values.first - column = result.column_types.fetch(column_alias) do - type_for(column_name) - end + column_alias = select_value.alias + column_alias ||= @klass.connection.column_name_for_operation(operation, select_value) + relation.select_values = [select_value] - type_cast_calculated_value(value, column, operation) - end + query_builder = relation.arel + end - def execute_grouped_calculation(operation, column_name, distinct) #:nodoc: - group_attrs = group_values + result = @klass.connection.select_all(query_builder, nil, bound_attributes) + row = result.first + value = row && row.values.first + type = result.column_types.fetch(column_alias) do + type_for(column_name) + end - if group_attrs.first.respond_to?(:to_sym) - association = @klass._reflect_on_association(group_attrs.first) - associated = group_attrs.size == 1 && association && association.belongs_to? # only count belongs_to associations - group_fields = Array(associated ? association.foreign_key : group_attrs) - else - group_fields = group_attrs + type_cast_calculated_value(value, type, operation) end - group_fields = arel_columns(group_fields) - group_aliases = group_fields.map { |field| column_alias_for(field) } - group_columns = group_aliases.zip(group_fields) + def execute_grouped_calculation(operation, column_name, distinct) #:nodoc: + group_attrs = group_values - if operation == 'count' && column_name == :all - aggregate_alias = 'count_all' - else - aggregate_alias = column_alias_for([operation, column_name].join(' ')) - end - - select_values = [ - operation_over_aggregate_column( - aggregate_column(column_name), - operation, - distinct).as(aggregate_alias) - ] - select_values += select_values unless having_clause.empty? - - select_values.concat group_columns.map { |aliaz, field| - if field.respond_to?(:as) - field.as(aliaz) + if group_attrs.first.respond_to?(:to_sym) + association = @klass._reflect_on_association(group_attrs.first) + associated = group_attrs.size == 1 && association && association.belongs_to? # only count belongs_to associations + group_fields = Array(associated ? association.foreign_key : group_attrs) else - "#{field} AS #{aliaz}" + group_fields = group_attrs end - } - - relation = except(:group) - relation.group_values = group_fields - relation.select_values = select_values + group_fields = arel_columns(group_fields) - calculated_data = @klass.connection.select_all(relation, nil, relation.bound_attributes) + group_aliases = group_fields.map { |field| column_alias_for(field) } + group_columns = group_aliases.zip(group_fields) - if association - key_ids = calculated_data.collect { |row| row[group_aliases.first] } - key_records = association.klass.base_class.where(association.klass.base_class.primary_key => key_ids) - key_records = Hash[key_records.map { |r| [r.id, r] }] - end + if operation == "count" && column_name == :all + aggregate_alias = "count_all" + else + aggregate_alias = column_alias_for([operation, column_name].join(" ")) + end - Hash[calculated_data.map do |row| - key = group_columns.map { |aliaz, col_name| - column = calculated_data.column_types.fetch(aliaz) do - type_for(col_name) + select_values = [ + operation_over_aggregate_column( + aggregate_column(column_name), + operation, + distinct).as(aggregate_alias) + ] + select_values += select_values unless having_clause.empty? + + select_values.concat group_columns.map { |aliaz, field| + if field.respond_to?(:as) + field.as(aliaz) + else + "#{field} AS #{aliaz}" end - type_cast_calculated_value(row[aliaz], column) } - key = key.first if key.size == 1 - key = key_records[key] if associated - column_type = calculated_data.column_types.fetch(aggregate_alias) { type_for(column_name) } - [key, type_cast_calculated_value(row[aggregate_alias], column_type, operation)] - end] - end + relation = except(:group) + relation.group_values = group_fields + relation.select_values = select_values - # Converts the given keys to the value that the database adapter returns as - # a usable column name: - # - # column_alias_for("users.id") # => "users_id" - # column_alias_for("sum(id)") # => "sum_id" - # column_alias_for("count(distinct users.id)") # => "count_distinct_users_id" - # column_alias_for("count(*)") # => "count_all" - def column_alias_for(keys) - if keys.respond_to? :name - keys = "#{keys.relation.name}.#{keys.name}" + calculated_data = @klass.connection.select_all(relation, nil, relation.bound_attributes) + + if association + key_ids = calculated_data.collect { |row| row[group_aliases.first] } + key_records = association.klass.base_class.where(association.klass.base_class.primary_key => key_ids) + key_records = Hash[key_records.map { |r| [r.id, r] }] + end + + Hash[calculated_data.map do |row| + key = group_columns.map { |aliaz, col_name| + type = type_for(col_name) do + calculated_data.column_types.fetch(aliaz, Type.default_value) + end + type_cast_calculated_value(row[aliaz], type) + } + key = key.first if key.size == 1 + key = key_records[key] if associated + + type = calculated_data.column_types.fetch(aggregate_alias) { type_for(column_name) } + [key, type_cast_calculated_value(row[aggregate_alias], type, operation)] + end] end - table_name = keys.to_s.downcase - table_name.gsub!(/\*/, 'all') - table_name.gsub!(/\W+/, ' ') - table_name.strip! - table_name.gsub!(/ +/, '_') + # Converts the given keys to the value that the database adapter returns as + # a usable column name: + # + # column_alias_for("users.id") # => "users_id" + # column_alias_for("sum(id)") # => "sum_id" + # column_alias_for("count(distinct users.id)") # => "count_distinct_users_id" + # column_alias_for("count(*)") # => "count_all" + def column_alias_for(keys) + if keys.respond_to? :name + keys = "#{keys.relation.name}.#{keys.name}" + end - @klass.connection.table_alias_for(table_name) - end + table_name = keys.to_s.downcase + table_name.gsub!(/\*/, "all") + table_name.gsub!(/\W+/, " ") + table_name.strip! + table_name.gsub!(/ +/, "_") - def type_for(field) - field_name = field.respond_to?(:name) ? field.name.to_s : field.to_s.split('.').last - @klass.type_for_attribute(field_name) - end + @klass.connection.table_alias_for(table_name) + end - def type_cast_calculated_value(value, type, operation = nil) - case operation - when 'count' then value.to_i - when 'sum' then type.deserialize(value || 0) - when 'average' then value.respond_to?(:to_d) ? value.to_d : value + def type_for(field, &block) + field_name = field.respond_to?(:name) ? field.name.to_s : field.to_s.split(".").last + @klass.type_for_attribute(field_name, &block) + end + + def type_cast_calculated_value(value, type, operation = nil) + case operation + when "count" then value.to_i + when "sum" then type.deserialize(value || 0) + when "average" then value.respond_to?(:to_d) ? value.to_d : value else type.deserialize(value) + end end - end - def select_for_count - if select_values.present? - return select_values.first if select_values.one? - select_values.join(", ") - else - :all + def select_for_count + if select_values.present? + return select_values.first if select_values.one? + select_values.join(", ") + else + :all + end end - end - def build_count_subquery(relation, column_name, distinct) - column_alias = Arel.sql('count_column') - subquery_alias = Arel.sql('subquery_for_count') + def build_count_subquery(relation, column_name, distinct) + column_alias = Arel.sql("count_column") + subquery_alias = Arel.sql("subquery_for_count") - aliased_column = aggregate_column(column_name == :all ? 1 : column_name).as(column_alias) - relation.select_values = [aliased_column] - subquery = relation.arel.as(subquery_alias) + aliased_column = aggregate_column(column_name == :all ? 1 : column_name).as(column_alias) + relation.select_values = [aliased_column] + subquery = relation.arel.as(subquery_alias) - sm = Arel::SelectManager.new relation.engine - select_value = operation_over_aggregate_column(column_alias, 'count', distinct) - sm.project(select_value).from(subquery) - end + sm = Arel::SelectManager.new relation.engine + select_value = operation_over_aggregate_column(column_alias, "count", distinct) + sm.project(select_value).from(subquery) + end end end |