diff options
Diffstat (limited to 'activerecord/lib/active_record/associations/has_and_belongs_to_many_association.rb')
-rw-r--r-- | activerecord/lib/active_record/associations/has_and_belongs_to_many_association.rb | 46 |
1 files changed, 36 insertions, 10 deletions
diff --git a/activerecord/lib/active_record/associations/has_and_belongs_to_many_association.rb b/activerecord/lib/active_record/associations/has_and_belongs_to_many_association.rb index 3f90d61e2d..378fc79949 100644 --- a/activerecord/lib/active_record/associations/has_and_belongs_to_many_association.rb +++ b/activerecord/lib/active_record/associations/has_and_belongs_to_many_association.rb @@ -38,15 +38,42 @@ module ActiveRecord self end - def find(association_id = nil, &block) - if block_given? || @options[:finder_sql] - load_collection - @collection.find(&block) + def find_first + load_collection.first + end + + def find(*args) + # Return an Array if multiple ids are given. + expects_array = args.first.kind_of?(Array) + + ids = args.flatten.compact.uniq + + # If no block is given, raise RecordNotFound. + if ids.empty? + raise RecordNotFound, "Couldn't find #{@association_class.name} without an ID#{conditions}" + + # If using a custom finder_sql, scan the entire collection. + elsif @options[:finder_sql] + if ids.size == 1 + id = ids.first + record = load_collection.detect { |record| id == record.id } + expects_array? ? [record] : record + else + load_collection.select { |record| ids.include?(record.id) } + end + + # Otherwise, construct a query. else - if loaded? - find_all { |record| record.id == association_id.to_i }.first + ids_list = ids.map { |id| @owner.send(:quote, id) }.join(',') + records = find_all_records(@finder_sql.sub(/ORDER BY/, "AND j.#{@association_foreign_key} IN (#{ids_list}) ORDER BY")) + if records.size == ids.size + if ids.size == 1 and !expects_array + records.first + else + records + end else - find_all_records(@finder_sql.sub(/ORDER BY/, "AND j.#{@association_foreign_key} = #{@owner.send(:quote, association_id)} ORDER BY")).first + raise RecordNotFound, "Couldn't find #{@association_class.name} with ID in (#{ids_list})" end end end @@ -70,10 +97,9 @@ module ActiveRecord records = @association_class.find_by_sql(sql) @options[:uniq] ? uniq(records) : records end - + def count_records - load_collection - @collection.size + load_collection.size end def insert_record(record) |