diff options
author | Pratik Naik <pratiknaik@gmail.com> | 2009-12-30 13:28:26 +0530 |
---|---|---|
committer | Pratik Naik <pratiknaik@gmail.com> | 2009-12-30 19:29:26 +0530 |
commit | 7aabaac0f5ba108f917af2c65a79511694393b85 (patch) | |
tree | 5a9a1ed9620140af878844157c9b6dc2a4359eae /activerecord/lib/active_record/relation | |
parent | 97204fc0bc52af8fb6714e6f9fcd414567e0fc1a (diff) | |
download | rails-7aabaac0f5ba108f917af2c65a79511694393b85.tar.gz rails-7aabaac0f5ba108f917af2c65a79511694393b85.tar.bz2 rails-7aabaac0f5ba108f917af2c65a79511694393b85.zip |
Organize Relation methods into separate modules
Diffstat (limited to 'activerecord/lib/active_record/relation')
3 files changed, 429 insertions, 0 deletions
diff --git a/activerecord/lib/active_record/relation/calculation_methods.rb b/activerecord/lib/active_record/relation/calculation_methods.rb new file mode 100644 index 0000000000..a925a99464 --- /dev/null +++ b/activerecord/lib/active_record/relation/calculation_methods.rb @@ -0,0 +1,177 @@ +module ActiveRecord + module CalculationMethods + + def count(*args) + calculate(:count, *construct_count_options_from_args(*args)) + end + + def average(column_name) + calculate(:average, column_name) + end + + def minimum(column_name) + calculate(:minimum, column_name) + end + + def maximum(column_name) + calculate(:maximum, column_name) + end + + def sum(column_name) + calculate(:sum, column_name) + end + + def calculate(operation, column_name, options = {}) + operation = operation.to_s.downcase + + if operation == "count" + joins = @relation.joins(relation) + if joins.present? && joins =~ /LEFT OUTER/i + distinct = true + column_name = @klass.primary_key if column_name == :all + end + + distinct = nil if column_name.to_s =~ /\s*DISTINCT\s+/i + distinct ||= options[:distinct] + else + distinct = nil + end + + distinct = options[:distinct] || distinct + column_name = :all if column_name.blank? && operation == "count" + + if @relation.send(:groupings).any? + return execute_grouped_calculation(operation, column_name) + else + return execute_simple_calculation(operation, column_name, distinct) + end + rescue ThrowResult + 0 + end + + private + + def execute_simple_calculation(operation, column_name, distinct) #:nodoc: + column = if @klass.column_names.include?(column_name.to_s) + Arel::Attribute.new(@klass.arel_table, column_name) + else + Arel::SqlLiteral.new(column_name == :all ? "*" : column_name.to_s) + end + + relation = select(operation == 'count' ? column.count(distinct) : column.send(operation)) + type_cast_calculated_value(@klass.connection.select_value(relation.to_sql), column_for(column_name), operation) + end + + def execute_grouped_calculation(operation, column_name) #:nodoc: + group_attr = @relation.send(:groupings).first.value + association = @klass.reflect_on_association(group_attr.to_sym) + associated = association && association.macro == :belongs_to # only count belongs_to associations + group_field = associated ? association.primary_key_name : group_attr + group_alias = column_alias_for(group_field) + group_column = column_for(group_field) + + group = @klass.connection.adapter_name == 'FrontBase' ? group_alias : group_field + + aggregate_alias = column_alias_for(operation, column_name) + + select_statement = if operation == 'count' && column_name == :all + "COUNT(*) AS count_all" + else + Arel::Attribute.new(@klass.arel_table, column_name).send(operation).as(aggregate_alias).to_sql + end + + select_statement << ", #{group_field} AS #{group_alias}" + + relation = select(select_statement).group(group) + + calculated_data = @klass.connection.select_all(relation.to_sql) + + if association + key_ids = calculated_data.collect { |row| row[group_alias] } + key_records = association.klass.base_class.find(key_ids) + key_records = key_records.inject({}) { |hsh, r| hsh.merge(r.id => r) } + end + + calculated_data.inject(ActiveSupport::OrderedHash.new) do |all, row| + key = type_cast_calculated_value(row[group_alias], group_column) + key = key_records[key] if associated + value = row[aggregate_alias] + all[key] = type_cast_calculated_value(value, column_for(column_name), operation) + all + end + end + + def construct_count_options_from_args(*args) + options = {} + column_name = :all + + # Handles count(), count(:column), count(:distinct => true), count(:column, :distinct => true) + # TODO : relation.projections only works when .select() was last in the chain. Fix it! + case args.size + when 0 + select = get_projection_name_from_chained_relations + column_name = select if select !~ /(,|\*)/ + when 1 + if args[0].is_a?(Hash) + select = get_projection_name_from_chained_relations + column_name = select if select !~ /(,|\*)/ + options = args[0] + else + column_name = args[0] + end + when 2 + column_name, options = args + else + raise ArgumentError, "Unexpected parameters passed to count(): #{args.inspect}" + end + + [column_name || :all, options] + end + + # Converts the given keys to the value that the database adapter returns as + # a usable column name: + # + # column_alias_for("users.id") # => "users_id" + # column_alias_for("sum(id)") # => "sum_id" + # column_alias_for("count(distinct users.id)") # => "count_distinct_users_id" + # column_alias_for("count(*)") # => "count_all" + # column_alias_for("count", "id") # => "count_id" + def column_alias_for(*keys) + table_name = keys.join(' ') + table_name.downcase! + table_name.gsub!(/\*/, 'all') + table_name.gsub!(/\W+/, ' ') + table_name.strip! + table_name.gsub!(/ +/, '_') + + @klass.connection.table_alias_for(table_name) + end + + def column_for(field) + field_name = field.to_s.split('.').last + @klass.columns.detect { |c| c.name.to_s == field_name } + end + + def type_cast_calculated_value(value, column, operation = nil) + case operation + when 'count' then value.to_i + when 'sum' then type_cast_using_column(value || '0', column) + when 'average' then value && (value.is_a?(Fixnum) ? value.to_f : value).to_d + else type_cast_using_column(value, column) + end + end + + def type_cast_using_column(value, column) + column ? column.type_cast(value) : value + end + + def get_projection_name_from_chained_relations(relation = @relation) + if relation.respond_to?(:projections) && relation.projections.present? + relation.send(:select_clauses).join(', ') + elsif relation.respond_to?(:relation) + get_projection_name_from_chained_relations(relation.relation) + end + end + + end +end diff --git a/activerecord/lib/active_record/relation/finder_methods.rb b/activerecord/lib/active_record/relation/finder_methods.rb new file mode 100644 index 0000000000..7a1d6fc538 --- /dev/null +++ b/activerecord/lib/active_record/relation/finder_methods.rb @@ -0,0 +1,120 @@ +module ActiveRecord + module FinderMethods + + def find(*ids, &block) + return to_a.find(&block) if block_given? + + expects_array = ids.first.kind_of?(Array) + return ids.first if expects_array && ids.first.empty? + + ids = ids.flatten.compact.uniq + + case ids.size + when 0 + raise RecordNotFound, "Couldn't find #{@klass.name} without an ID" + when 1 + result = find_one(ids.first) + expects_array ? [ result ] : result + else + find_some(ids) + end + end + + def exists?(id = nil) + relation = select("#{@klass.quoted_table_name}.#{@klass.primary_key}").limit(1) + relation = relation.where(@klass.primary_key => id) if id + relation.first ? true : false + end + + def first + if loaded? + @records.first + else + @first ||= limit(1).to_a[0] + end + end + + def last + if loaded? + @records.last + else + @last ||= reverse_order.limit(1).to_a[0] + end + end + + protected + + def find_by_attributes(match, attributes, *args) + conditions = attributes.inject({}) {|h, a| h[a] = args[attributes.index(a)]; h} + result = where(conditions).send(match.finder) + + if match.bang? && result.blank? + raise RecordNotFound, "Couldn't find #{@klass.name} with #{conditions.to_a.collect {|p| p.join(' = ')}.join(', ')}" + else + result + end + end + + def find_or_instantiator_by_attributes(match, attributes, *args) + guard_protected_attributes = false + + if args[0].is_a?(Hash) + guard_protected_attributes = true + attributes_for_create = args[0].with_indifferent_access + conditions = attributes_for_create.slice(*attributes).symbolize_keys + else + attributes_for_create = conditions = attributes.inject({}) {|h, a| h[a] = args[attributes.index(a)]; h} + end + + record = where(conditions).first + + unless record + record = @klass.new { |r| r.send(:attributes=, attributes_for_create, guard_protected_attributes) } + yield(record) if block_given? + record.save if match.instantiator == :create + end + + record + end + + def find_one(id) + record = where(@klass.primary_key => id).first + + unless record + conditions = where_clause(', ') + conditions = " [WHERE #{conditions}]" if conditions.present? + raise RecordNotFound, "Couldn't find #{@klass.name} with ID=#{id}#{conditions}" + end + + record + end + + def find_some(ids) + result = where(@klass.primary_key => ids).all + + expected_size = + if @relation.taken && ids.size > @relation.taken + @relation.taken + else + ids.size + end + + # 11 ids with limit 3, offset 9 should give 2 results. + if @relation.skipped && (ids.size - @relation.skipped < expected_size) + expected_size = ids.size - @relation.skipped + end + + if result.size == expected_size + result + else + conditions = where_clause(', ') + conditions = " [WHERE #{conditions}]" if conditions.present? + + error = "Couldn't find all #{@klass.name.pluralize} with IDs " + error << "(#{ids.join(", ")})#{conditions} (found #{result.size} results, but was looking for #{expected_size})" + raise RecordNotFound, error + end + end + + end +end diff --git a/activerecord/lib/active_record/relation/query_methods.rb b/activerecord/lib/active_record/relation/query_methods.rb new file mode 100644 index 0000000000..631c80da25 --- /dev/null +++ b/activerecord/lib/active_record/relation/query_methods.rb @@ -0,0 +1,132 @@ +module ActiveRecord + module QueryMethods + + def preload(*associations) + spawn.tap {|r| r.preload_associations += Array.wrap(associations) } + end + + def eager_load(*associations) + spawn.tap {|r| r.eager_load_associations += Array.wrap(associations) } + end + + def readonly(status = true) + spawn.tap {|r| r.readonly = status } + end + + def select(selects) + if selects.present? + relation = spawn(@relation.project(selects)) + relation.readonly = @relation.joins(relation).present? ? false : @readonly + relation + else + spawn + end + end + + def from(from) + from.present? ? spawn(@relation.from(from)) : spawn + end + + def having(*args) + return spawn if args.blank? + + if [String, Hash, Array].include?(args.first.class) + havings = @klass.send(:merge_conditions, args.size > 1 ? Array.wrap(args) : args.first) + else + havings = args.first + end + + spawn(@relation.having(havings)) + end + + def group(groups) + groups.present? ? spawn(@relation.group(groups)) : spawn + end + + def order(orders) + orders.present? ? spawn(@relation.order(orders)) : spawn + end + + def lock(locks = true) + case locks + when String + spawn(@relation.lock(locks)) + when TrueClass, NilClass + spawn(@relation.lock) + else + spawn + end + end + + def reverse_order + relation = spawn + relation.instance_variable_set(:@orders, nil) + + order_clause = @relation.send(:order_clauses).join(', ') + if order_clause.present? + relation.order(reverse_sql_order(order_clause)) + else + relation.order("#{@klass.table_name}.#{@klass.primary_key} DESC") + end + end + + def limit(limits) + limits.present? ? spawn(@relation.take(limits)) : spawn + end + + def offset(offsets) + offsets.present? ? spawn(@relation.skip(offsets)) : spawn + end + + def on(join) + spawn(@relation.on(join)) + end + + def joins(join, join_type = nil) + return spawn if join.blank? + + join_relation = case join + when String + @relation.join(join) + when Hash, Array, Symbol + if @klass.send(:array_of_strings?, join) + @relation.join(join.join(' ')) + else + @relation.join(@klass.send(:build_association_joins, join)) + end + else + @relation.join(join, join_type) + end + + spawn(join_relation).tap { |r| r.readonly = true } + end + + def where(*args) + return spawn if args.blank? + + if [String, Hash, Array].include?(args.first.class) + conditions = @klass.send(:merge_conditions, args.size > 1 ? Array.wrap(args) : args.first) + conditions = Arel::SqlLiteral.new(conditions) if conditions + else + conditions = args.first + end + + spawn(@relation.where(conditions)) + end + + private + + def reverse_sql_order(order_query) + order_query.to_s.split(/,/).each { |s| + if s.match(/\s(asc|ASC)$/) + s.gsub!(/\s(asc|ASC)$/, ' DESC') + elsif s.match(/\s(desc|DESC)$/) + s.gsub!(/\s(desc|DESC)$/, ' ASC') + else + s.concat(' DESC') + end + }.join(',') + end + + end +end |