diff options
Diffstat (limited to 'activerecord/lib/active_record/associations/has_many_through_association.rb')
-rw-r--r-- | activerecord/lib/active_record/associations/has_many_through_association.rb | 69 |
1 files changed, 43 insertions, 26 deletions
diff --git a/activerecord/lib/active_record/associations/has_many_through_association.rb b/activerecord/lib/active_record/associations/has_many_through_association.rb index 59929b8c4e..0d384950fe 100644 --- a/activerecord/lib/active_record/associations/has_many_through_association.rb +++ b/activerecord/lib/active_record/associations/has_many_through_association.rb @@ -21,20 +21,6 @@ module ActiveRecord super end - def concat_records(records) - ensure_not_nested - - records = super(records, true) - - if owner.new_record? && records - records.flatten.each do |record| - build_through_record(record) - end - end - - records - end - def insert_record(record, validate = true, raise = false) ensure_not_nested @@ -48,6 +34,20 @@ module ActiveRecord end private + def concat_records(records) + ensure_not_nested + + records = super(records, true) + + if owner.new_record? && records + records.flatten.each do |record| + build_through_record(record) + end + end + + records + end + # The through record (built with build_record) is temporarily cached # so that it may be reused if insert_record is subsequently called. # @@ -57,21 +57,14 @@ module ActiveRecord @through_records[record.object_id] ||= begin ensure_mutable - through_record = through_association.build(*options_for_through_record) - through_record.send("#{source_reflection.name}=", record) + attributes = through_scope_attributes + attributes[source_reflection.name] = record + attributes[source_reflection.foreign_type] = options[:source_type] if options[:source_type] - if options[:source_type] - through_record.send("#{source_reflection.foreign_type}=", options[:source_type]) - end - - through_record + through_association.build(attributes) end end - def options_for_through_record - [through_scope_attributes] - end - def through_scope_attributes scope.where_values_hash(through_association.reflection.name.to_s). except!(through_association.reflection.foreign_key, @@ -90,7 +83,7 @@ module ActiveRecord def build_record(attributes) ensure_not_nested - record = super(attributes) + record = super inverse = source_reflection.inverse_of if inverse @@ -161,6 +154,30 @@ module ActiveRecord else update_counter(-count) end + + count + end + + def difference(a, b) + distribution = distribution(b) + + a.reject { |record| mark_occurrence(distribution, record) } + end + + def intersection(a, b) + distribution = distribution(b) + + a.select { |record| mark_occurrence(distribution, record) } + end + + def mark_occurrence(distribution, record) + distribution[record] > 0 && distribution[record] -= 1 + end + + def distribution(array) + array.each_with_object(Hash.new(0)) do |record, distribution| + distribution[record] += 1 + end end def through_records_for(record) |