aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord/lib/active_record/statement_cache.rb
blob: 1b1736dcabee99247fd42411fb5043ccc21f7cbd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# frozen_string_literal: true

module ActiveRecord
  # Statement cache is used to cache a single statement in order to avoid creating the AST again.
  # Initializing the cache is done by passing the statement in the create block:
  #
  #   cache = StatementCache.create(Book.connection) do |params|
  #     Book.where(name: "my book").where("author_id > 3")
  #   end
  #
  # The cached statement is executed by using the
  # {connection.execute}[rdoc-ref:ConnectionAdapters::DatabaseStatements#execute] method:
  #
  #   cache.execute([], Book.connection)
  #
  # The relation returned by the block is cached, and for each
  # {execute}[rdoc-ref:ConnectionAdapters::DatabaseStatements#execute]
  # call the cached relation gets duped. Database is queried when +to_a+ is called on the relation.
  #
  # If you want to cache the statement without the values you can use the +bind+ method of the
  # block parameter.
  #
  #   cache = StatementCache.create(Book.connection) do |params|
  #     Book.where(name: params.bind)
  #   end
  #
  # And pass the bind values as the first argument of +execute+ call.
  #
  #   cache.execute(["my book"], Book.connection)
  class StatementCache # :nodoc:
    class Substitute; end # :nodoc:

    class Query # :nodoc:
      def initialize(sql)
        @sql = sql
      end

      def sql_for(binds, connection)
        @sql
      end
    end

    class PartialQuery < Query # :nodoc:
      def initialize(values)
        @values = values
        @indexes = values.each_with_index.find_all { |thing, i|
          Substitute === thing
        }.map(&:last)
      end

      def sql_for(binds, connection)
        val = @values.dup
        casted_binds = binds.map(&:value_for_database)
        @indexes.each { |i| val[i] = connection.quote(casted_binds.shift) }
        val.join
      end
    end

    class PartialQueryCollector
      def initialize
        @parts = []
        @binds = []
      end

      def <<(str)
        @parts << str
        self
      end

      def add_bind(obj)
        @binds << obj
        @parts << Substitute.new
        self
      end

      def value
        [@parts, @binds]
      end
    end

    def self.query(sql)
      Query.new(sql)
    end

    def self.partial_query(values)
      PartialQuery.new(values)
    end

    def self.partial_query_collector
      PartialQueryCollector.new
    end

    class Params # :nodoc:
      def bind; Substitute.new; end
    end

    class BindMap # :nodoc:
      def initialize(bound_attributes)
        @indexes = []
        @bound_attributes = bound_attributes

        bound_attributes.each_with_index do |attr, i|
          if Substitute === attr.value
            @indexes << i
          end
        end
      end

      def bind(values)
        bas = @bound_attributes.dup
        @indexes.each_with_index { |offset, i| bas[offset] = bas[offset].with_cast_value(values[i]) }
        bas
      end
    end

    def self.create(connection, block = Proc.new)
      relation = block.call Params.new
      query_builder, binds = connection.cacheable_query(self, relation.arel)
      bind_map = BindMap.new(binds)
      new(query_builder, bind_map, relation.klass)
    end

    def initialize(query_builder, bind_map, klass)
      @query_builder = query_builder
      @bind_map = bind_map
      @klass = klass
    end

    def execute(params, connection, &block)
      bind_values = bind_map.bind params

      sql = query_builder.sql_for bind_values, connection

      klass.find_by_sql(sql, bind_values, preparable: true, &block)
    end

    def self.unsupported_value?(value)
      case value
      when NilClass, Array, Range, Hash, Relation, Base then true
      end
    end

    private
      attr_reader :query_builder, :bind_map, :klass
  end
end