diff options
-rw-r--r-- | activerecord/lib/active_record/connection_adapters/abstract/transaction.rb | 29 | ||||
-rw-r--r-- | activerecord/test/cases/transactions_test.rb | 12 |
2 files changed, 34 insertions, 7 deletions
diff --git a/activerecord/lib/active_record/connection_adapters/abstract/transaction.rb b/activerecord/lib/active_record/connection_adapters/abstract/transaction.rb index 2b8026dbf9..3ecef96b10 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract/transaction.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract/transaction.rb @@ -5,17 +5,36 @@ module ActiveRecord def initialize(connection) @connection = connection - @state = nil + @state = TransactionState.new + end + + def state + @state + end + end + + class TransactionState + + VALID_STATES = Set.new([:committed, :rolledback, nil]) + + def initialize(state = nil) + @state = state end def committed? - @state == :commit + @state == :committed end def rolledback? - @state == :rollback + @state == :rolledback end + def set_state(state) + if !VALID_STATES.include?(state) + raise ArgumentError, "Invalid transaction state: #{state}" + end + @state = state + end end class ClosedTransaction < Transaction #:nodoc: @@ -101,7 +120,7 @@ module ActiveRecord end def rollback_records - @state = :rollback + @state.set_state(:rolledback) records.uniq.each do |record| begin record.rolledback!(parent.closed?) @@ -112,7 +131,7 @@ module ActiveRecord end def commit_records - @state = :commit + @state.set_state(:committed) records.uniq.each do |record| begin record.committed! diff --git a/activerecord/test/cases/transactions_test.rb b/activerecord/test/cases/transactions_test.rb index 9d278480ef..546737b398 100644 --- a/activerecord/test/cases/transactions_test.rb +++ b/activerecord/test/cases/transactions_test.rb @@ -456,9 +456,13 @@ class TransactionTest < ActiveRecord::TestCase transaction = ActiveRecord::ConnectionAdapters::ClosedTransaction.new(connection).begin assert transaction.open? + assert !transaction.state.rolledback? + assert !transaction.state.committed? + transaction.perform_rollback - assert transaction.rolledback? + assert transaction.state.rolledback? + assert !transaction.state.committed? end def test_transactions_state_from_commit @@ -466,9 +470,13 @@ class TransactionTest < ActiveRecord::TestCase transaction = ActiveRecord::ConnectionAdapters::ClosedTransaction.new(connection).begin assert transaction.open? + assert !transaction.state.rolledback? + assert !transaction.state.committed? + transaction.perform_commit - assert transaction.committed? + assert !transaction.state.rolledback? + assert transaction.state.committed? end private |