aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord/test/cases/adapters/postgresql/transaction_test.rb
blob: 8e952a94083d96e7b78e62246aa6c7d550e3b29c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# frozen_string_literal: true

require "cases/helper"
require "support/connection_helper"
require "concurrent/atomic/cyclic_barrier"

module ActiveRecord
  class PostgresqlTransactionTest < ActiveRecord::PostgreSQLTestCase
    self.use_transactional_tests = false

    class Sample < ActiveRecord::Base
      self.table_name = "samples"
    end

    setup do
      @abort, Thread.abort_on_exception = Thread.abort_on_exception, false
      Thread.report_on_exception, @original_report_on_exception = false, Thread.report_on_exception if Thread.respond_to?(:report_on_exception)

      @connection = ActiveRecord::Base.connection

      @connection.transaction do
        @connection.drop_table "samples", if_exists: true
        @connection.create_table("samples") do |t|
          t.integer "value"
        end
      end

      Sample.reset_column_information
    end

    teardown do
      @connection.drop_table "samples", if_exists: true

      Thread.abort_on_exception = @abort
      Thread.report_on_exception = @original_report_on_exception if Thread.respond_to?(:report_on_exception)
    end

    test "raises SerializationFailure when a serialization failure occurs" do
      assert_raises(ActiveRecord::SerializationFailure) do
        before = Concurrent::CyclicBarrier.new(2)
        after = Concurrent::CyclicBarrier.new(2)

        thread = Thread.new do
          with_warning_suppression do
            Sample.transaction isolation: :serializable do
              before.wait
              Sample.create value: Sample.sum(:value)
              after.wait
            end
          end
        end

        begin
          with_warning_suppression do
            Sample.transaction isolation: :serializable do
              before.wait
              Sample.create value: Sample.sum(:value)
              after.wait
            end
          end
        ensure
          thread.join
        end
      end
    end

    test "raises Deadlocked when a deadlock is encountered" do
      with_warning_suppression do
        assert_raises(ActiveRecord::Deadlocked) do
          barrier = Concurrent::CyclicBarrier.new(2)

          s1 = Sample.create value: 1
          s2 = Sample.create value: 2

          thread = Thread.new do
            Sample.transaction do
              s1.lock!
              barrier.wait
              s2.update value: 1
            end
          end

          begin
            Sample.transaction do
              s2.lock!
              barrier.wait
              s1.update value: 2
            end
          ensure
            thread.join
          end
        end
      end
    end

    test "raises LockWaitTimeout when lock wait timeout exceeded" do
      skip unless ActiveRecord::Base.connection.postgresql_version >= 90300
      assert_raises(ActiveRecord::LockWaitTimeout) do
        s = Sample.create!(value: 1)
        latch1 = Concurrent::CountDownLatch.new
        latch2 = Concurrent::CountDownLatch.new

        thread = Thread.new do
          Sample.transaction do
            Sample.lock.find(s.id)
            latch1.count_down
            latch2.wait
          end
        end

        begin
          Sample.transaction do
            latch1.wait
            Sample.connection.execute("SET lock_timeout = 1")
            Sample.lock.find(s.id)
          end
        ensure
          Sample.connection.execute("SET lock_timeout = DEFAULT")
          latch2.count_down
          thread.join
        end
      end
    end

    test "raises QueryCanceled when statement timeout exceeded" do
      assert_raises(ActiveRecord::QueryCanceled) do
        s = Sample.create!(value: 1)
        latch1 = Concurrent::CountDownLatch.new
        latch2 = Concurrent::CountDownLatch.new

        thread = Thread.new do
          Sample.transaction do
            Sample.lock.find(s.id)
            latch1.count_down
            latch2.wait
          end
        end

        begin
          Sample.transaction do
            latch1.wait
            Sample.connection.execute("SET statement_timeout = 1")
            Sample.lock.find(s.id)
          end
        ensure
          Sample.connection.execute("SET statement_timeout = DEFAULT")
          latch2.count_down
          thread.join
        end
      end
    end

    test "raises QueryCanceled when canceling statement due to user request" do
      assert_raises(ActiveRecord::QueryCanceled) do
        s = Sample.create!(value: 1)
        latch = Concurrent::CountDownLatch.new

        thread = Thread.new do
          Sample.transaction do
            Sample.lock.find(s.id)
            latch.count_down
            sleep(0.5)
            conn = Sample.connection
            pid = conn.query_value("SELECT pid FROM pg_stat_activity WHERE query LIKE '% FOR UPDATE'")
            conn.execute("SELECT pg_cancel_backend(#{pid})")
          end
        end

        begin
          Sample.transaction do
            latch.wait
            Sample.lock.find(s.id)
          end
        ensure
          thread.join
        end
      end
    end

    private

      def with_warning_suppression
        log_level = ActiveRecord::Base.connection.client_min_messages
        ActiveRecord::Base.connection.client_min_messages = "error"
        yield
      ensure
        ActiveRecord::Base.connection.client_min_messages = log_level
      end
  end
end