diff options
4 files changed, 84 insertions, 8 deletions
diff --git a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb index 174450eb00..f08040a1a7 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb @@ -59,7 +59,7 @@ module ActiveRecord # Executes insert +sql+ statement in the context of this connection using # +binds+ as the bind substitutes. +name+ is the logged along with # the executed +sql+ statement. - def exec_insert(sql, name, binds) + def exec_insert(sql, name, binds, pk = nil, sequence_name = nil) exec_query(sql, name, binds) end @@ -87,7 +87,7 @@ module ActiveRecord # passed in as +id_value+. def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = []) sql, binds = sql_for_insert(to_sql(arel, binds), pk, id_value, sequence_name, binds) - value = exec_insert(sql, name, binds) + value = exec_insert(sql, name, binds, pk, sequence_name) id_value || last_inserted_id(value) end diff --git a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb index 10a178e369..2509fe0490 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb @@ -406,6 +406,7 @@ module ActiveRecord initialize_type_map @local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"] + self.use_returning = true end # Clears the prepared statements cache. @@ -667,8 +668,11 @@ module ActiveRecord pk = primary_key(table_ref) if table_ref end - if pk + if pk && use_returning? select_value("#{sql} RETURNING #{quote_column_name(pk)}") + elsif pk + super + last_insert_id_value(sequence_name || default_sequence_name(table_ref, pk)) else super end @@ -783,11 +787,35 @@ module ActiveRecord pk = primary_key(table_ref) if table_ref end - sql = "#{sql} RETURNING #{quote_column_name(pk)}" if pk + if pk && use_returning? + sql = "#{sql} RETURNING #{quote_column_name(pk)}" + end [sql, binds] end + def exec_insert(sql, name, binds, pk = nil, sequence_name = nil) + val = exec_query(sql, name, binds) + if !use_returning? && pk + if sequence_name + last_insert_id_value(sequence_name) + else + table_ref = extract_table_ref_from_insert_sql(sql) + sequence_name = default_sequence_name(table_ref, pk) + return val unless sequence_name + last_insert_id(sequence_name) + end + else + val + end + end + + def last_inserted_id(result) + return result if result.kind_of?(Integer) + row = result.rows.first + row && row.first + end + # Executes an UPDATE query and returns the number of affected tuples. def update_sql(sql, name = nil) super.cmd_tuples @@ -1028,7 +1056,9 @@ module ActiveRecord # Returns the sequence name for a table's primary key or some other specified key. def default_sequence_name(table_name, pk = nil) #:nodoc: - serial_sequence(table_name, pk || 'id').split('.').last + result = serial_sequence(table_name, pk || 'id') + return nil unless result + result.split('.').last rescue ActiveRecord::StatementInvalid "#{table_name}_#{pk || 'id'}_seq" end @@ -1236,6 +1266,14 @@ module ActiveRecord end end + def use_returning=(val) + @use_returning = val + end + + def use_returning? + @use_returning + end + protected # Returns the version of the connected PostgreSQL server. def postgresql_version @@ -1365,8 +1403,12 @@ module ActiveRecord # Returns the current ID of a table's sequence. def last_insert_id(sequence_name) #:nodoc: + Integer(last_insert_id_value(sequence_name)) + end + + def last_insert_id_value(sequence_name) #:nodoc: r = exec_query("SELECT currval($1)", 'SQL', [[nil, sequence_name]]) - Integer(r.rows.first.first) + r.rows.first.first end # Executes a SELECT query and returns the results, performing any data type diff --git a/activerecord/test/cases/adapters/postgresql/postgresql_adapter_test.rb b/activerecord/test/cases/adapters/postgresql/postgresql_adapter_test.rb index a71d0bb848..19da6163b5 100644 --- a/activerecord/test/cases/adapters/postgresql/postgresql_adapter_test.rb +++ b/activerecord/test/cases/adapters/postgresql/postgresql_adapter_test.rb @@ -49,6 +49,15 @@ module ActiveRecord assert_equal expect, id end + def test_insert_sql_with_returning_disabled + @connection.use_returning = false + id = @connection.insert_sql("insert into postgresql_partitioned_table_parent (number) VALUES (1)") + expect = @connection.query('select max(id) from postgresql_partitioned_table_parent').first.first + assert_equal expect, id + ensure + @connection.use_returning = true + end + def test_serial_sequence assert_equal 'public.accounts_id_seq', @connection.serial_sequence('accounts', 'id') diff --git a/activerecord/test/schema/postgresql_specific_schema.rb b/activerecord/test/schema/postgresql_specific_schema.rb index 25b416a906..84228cdd0a 100644 --- a/activerecord/test/schema/postgresql_specific_schema.rb +++ b/activerecord/test/schema/postgresql_specific_schema.rb @@ -1,8 +1,8 @@ ActiveRecord::Schema.define do %w(postgresql_tsvectors postgresql_hstores postgresql_arrays postgresql_moneys postgresql_numbers postgresql_times postgresql_network_addresses postgresql_bit_strings - postgresql_oids postgresql_xml_data_type defaults geometrics postgresql_timestamp_with_zones).each do |table_name| - execute "DROP TABLE IF EXISTS #{quote_table_name table_name}" + postgresql_oids postgresql_xml_data_type defaults geometrics postgresql_timestamp_with_zones postgresql_partitioned_table postgresql_partitioned_table_parent).each do |table_name| + execute "DROP TABLE IF EXISTS #{quote_table_name table_name}" end execute 'DROP SEQUENCE IF EXISTS companies_nonstd_seq CASCADE' @@ -10,6 +10,8 @@ ActiveRecord::Schema.define do execute "ALTER TABLE companies ALTER COLUMN id SET DEFAULT nextval('companies_nonstd_seq')" execute 'DROP SEQUENCE IF EXISTS companies_id_seq' + execute 'DROP FUNCTION IF EXISTS partitioned_insert_trigger()' + %w(accounts_id_seq developers_id_seq projects_id_seq topics_id_seq customers_id_seq orders_id_seq).each do |seq_name| execute "SELECT setval('#{seq_name}', 100)" end @@ -125,6 +127,29 @@ _SQL ); _SQL + execute <<_SQL + CREATE TABLE postgresql_partitioned_table_parent ( + id SERIAL PRIMARY KEY, + number integer + ); + CREATE TABLE postgresql_partitioned_table ( ) + INHERITS (postgresql_partitioned_table_parent); + + CREATE OR REPLACE FUNCTION partitioned_insert_trigger() + RETURNS TRIGGER AS $$ + BEGIN + INSERT INTO postgresql_partitioned_table VALUES (NEW.*); + RETURN NULL; + END; + $$ + LANGUAGE plpgsql; + + CREATE TRIGGER insert_partitioning_trigger + BEFORE INSERT ON postgresql_partitioned_table_parent + FOR EACH ROW EXECUTE PROCEDURE partitioned_insert_trigger(); +_SQL + + begin execute <<_SQL CREATE TABLE postgresql_xml_data_type ( |