aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb9
-rw-r--r--activerecord/lib/active_record/connection_adapters/mysql2_adapter.rb4
-rw-r--r--activerecord/lib/active_record/connection_adapters/mysql_adapter.rb4
-rw-r--r--activerecord/lib/active_record/relation.rb13
-rw-r--r--activerecord/test/cases/relations_test.rb8
5 files changed, 34 insertions, 4 deletions
diff --git a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb
index 2ae655e68d..7543d35d3b 100644
--- a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb
+++ b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb
@@ -306,6 +306,15 @@ module ActiveRecord
end
end
+ # The default strategy for an UPDATE with joins is to use a subquery. This doesn't work
+ # on mysql (even when aliasing the tables), but mysql allows using JOIN directly in
+ # an UPDATE statement, so in the mysql adapters we redefine this to do that.
+ def join_to_update(update, select) #:nodoc:
+ subselect = select.clone
+ subselect.ast.cores.last.projections = [update.ast.key]
+ update.wheres = [update.ast.key.in(subselect)]
+ end
+
protected
# Returns an array of record hashes with the column names as keys and
# column values as values.
diff --git a/activerecord/lib/active_record/connection_adapters/mysql2_adapter.rb b/activerecord/lib/active_record/connection_adapters/mysql2_adapter.rb
index 18fdfa29ec..c01a64e354 100644
--- a/activerecord/lib/active_record/connection_adapters/mysql2_adapter.rb
+++ b/activerecord/lib/active_record/connection_adapters/mysql2_adapter.rb
@@ -577,6 +577,10 @@ module ActiveRecord
where_sql
end
+ def join_to_update(update, select) #:nodoc:
+ update.table select.ast.cores.last.source
+ end
+
protected
def quoted_columns_for_index(column_names, options = {})
length = options[:length] if options.is_a?(Hash)
diff --git a/activerecord/lib/active_record/connection_adapters/mysql_adapter.rb b/activerecord/lib/active_record/connection_adapters/mysql_adapter.rb
index 14b950dbb0..ea0970028c 100644
--- a/activerecord/lib/active_record/connection_adapters/mysql_adapter.rb
+++ b/activerecord/lib/active_record/connection_adapters/mysql_adapter.rb
@@ -491,6 +491,10 @@ module ActiveRecord
execute("RELEASE SAVEPOINT #{current_savepoint_name}")
end
+ def join_to_update(update, select) #:nodoc:
+ update.table select.ast.cores.last.source
+ end
+
# SCHEMA STATEMENTS ========================================
def structure_dump #:nodoc:
diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb
index 7e59eb4584..565ece1930 100644
--- a/activerecord/lib/active_record/relation.rb
+++ b/activerecord/lib/active_record/relation.rb
@@ -217,13 +217,18 @@ module ActiveRecord
where(conditions).apply_finder_options(options.slice(:limit, :order)).update_all(updates)
else
stmt = arel.compile_update(Arel.sql(@klass.send(:sanitize_sql_for_assignment, updates)))
+ stmt.key = table[primary_key]
- if limit = arel.limit
- stmt.take limit
+ if joins_values.any?
+ @klass.connection.join_to_update(stmt, arel)
+ else
+ if limit = arel.limit
+ stmt.take limit
+ end
+
+ stmt.order(*arel.orders)
end
- stmt.order(*arel.orders)
- stmt.key = table[primary_key]
@klass.connection.update stmt, 'SQL', bind_values
end
end
diff --git a/activerecord/test/cases/relations_test.rb b/activerecord/test/cases/relations_test.rb
index 821da91f0a..7bd9c44651 100644
--- a/activerecord/test/cases/relations_test.rb
+++ b/activerecord/test/cases/relations_test.rb
@@ -965,4 +965,12 @@ class RelationTest < ActiveRecord::TestCase
def test_ordering_with_extra_spaces
assert_equal authors(:david), Author.order('id DESC , name DESC').last
end
+
+ def test_update_all_with_joins
+ comments = Comment.joins(:post).where('posts.id' => posts(:welcome).id)
+ count = comments.count
+
+ assert_equal count, comments.update_all(:post_id => posts(:thinking).id)
+ assert_equal posts(:thinking), comments(:greetings).post
+ end
end