aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord/lib/active_record/connection_adapters/postgresql/quoting.rb
blob: 07b66de36651a142ef5705823522abc30b1c06a8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# frozen_string_literal: true

module ActiveRecord
  module ConnectionAdapters
    module PostgreSQL
      module Quoting
        # Escapes binary strings for bytea input to the database.
        def escape_bytea(value)
          @connection.escape_bytea(value) if value
        end

        # Unescapes bytea output from a database to the binary string it represents.
        # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
        # on escaped binary output from database drive.
        def unescape_bytea(value)
          @connection.unescape_bytea(value) if value
        end

        # Quotes strings for use in SQL input.
        def quote_string(s) #:nodoc:
          @connection.escape(s)
        end

        # Checks the following cases:
        #
        # - table_name
        # - "table.name"
        # - schema_name.table_name
        # - schema_name."table.name"
        # - "schema.name".table_name
        # - "schema.name"."table.name"
        def quote_table_name(name) # :nodoc:
          self.class.quoted_table_names[name] ||= Utils.extract_schema_qualified_name(name.to_s).quoted.freeze
        end

        # Quotes schema names for use in SQL queries.
        def quote_schema_name(name)
          PG::Connection.quote_ident(name)
        end

        def quote_table_name_for_assignment(table, attr)
          quote_column_name(attr)
        end

        # Quotes column names for use in SQL queries.
        def quote_column_name(name) # :nodoc:
          self.class.quoted_column_names[name] ||= PG::Connection.quote_ident(super).freeze
        end

        # Quote date/time values for use in SQL input.
        def quoted_date(value) #:nodoc:
          if value.year <= 0
            bce_year = format("%04d", -value.year + 1)
            super.sub(/^-?\d+/, bce_year) + " BC"
          else
            super
          end
        end

        def quoted_binary(value) # :nodoc:
          "'#{escape_bytea(value.to_s)}'"
        end

        def quote_default_expression(value, column) # :nodoc:
          if value.is_a?(Proc)
            value.call
          elsif column.type == :uuid && value.is_a?(String) && /\(\)/.match?(value)
            value # Does not quote function default values for UUID columns
          elsif column.respond_to?(:array?)
            value = type_cast_from_column(column, value)
            quote(value)
          else
            super
          end
        end

        def lookup_cast_type_from_column(column) # :nodoc:
          type_map.lookup(column.oid, column.fmod, column.sql_type)
        end

        def column_name_matcher
          COLUMN_NAME
        end

        def column_name_with_order_matcher
          COLUMN_NAME_WITH_ORDER
        end

        COLUMN_NAME = /
          \A
          (
            (?:
              # "table_name"."column_name"::type_name | function(one or no argument)::type_name
              ((?:\w+\.|"\w+"\.)?(?:\w+|"\w+")(?:::\w+)?) | \w+\((?:|\g<2>)\)(?:::\w+)?
            )
            (?:(?:\s+AS)?\s+(?:\w+|"\w+"))?
          )
          (?:\s*,\s*\g<1>)*
          \z
        /ix

        COLUMN_NAME_WITH_ORDER = /
          \A
          (
            (?:
              # "table_name"."column_name"::type_name | function(one or no argument)::type_name
              ((?:\w+\.|"\w+"\.)?(?:\w+|"\w+")(?:::\w+)?) | \w+\((?:|\g<2>)\)(?:::\w+)?
            )
            (?:\s+ASC|\s+DESC)?
            (?:\s+NULLS\s+(?:FIRST|LAST))?
          )
          (?:\s*,\s*\g<1>)*
          \z
        /ix

        private_constant :COLUMN_NAME, :COLUMN_NAME_WITH_ORDER

        private
          def lookup_cast_type(sql_type)
            super(query_value("SELECT #{quote(sql_type)}::regtype::oid", "SCHEMA").to_i)
          end

          def _quote(value)
            case value
            when OID::Xml::Data
              "xml '#{quote_string(value.to_s)}'"
            when OID::Bit::Data
              if value.binary?
                "B'#{value}'"
              elsif value.hex?
                "X'#{value}'"
              end
            when Numeric
              if value.finite?
                super
              else
                "'#{value}'"
              end
            when OID::Array::Data
              _quote(encode_array(value))
            when Range
              _quote(encode_range(value))
            else
              super
            end
          end

          def _type_cast(value)
            case value
            when Type::Binary::Data
              # Return a bind param hash with format as binary.
              # See https://deveiate.org/code/pg/PG/Connection.html#method-i-exec_prepared-doc
              # for more information
              { value: value.to_s, format: 1 }
            when OID::Xml::Data, OID::Bit::Data
              value.to_s
            when OID::Array::Data
              encode_array(value)
            when Range
              encode_range(value)
            else
              super
            end
          end

          def encode_array(array_data)
            encoder = array_data.encoder
            values = type_cast_array(array_data.values)

            result = encoder.encode(values)
            if encoding = determine_encoding_of_strings_in_array(values)
              result.force_encoding(encoding)
            end
            result
          end

          def encode_range(range)
            "[#{type_cast_range_value(range.begin)},#{type_cast_range_value(range.end)}#{range.exclude_end? ? ')' : ']'}"
          end

          def determine_encoding_of_strings_in_array(value)
            case value
            when ::Array then determine_encoding_of_strings_in_array(value.first)
            when ::String then value.encoding
            end
          end

          def type_cast_array(values)
            case values
            when ::Array then values.map { |item| type_cast_array(item) }
            else _type_cast(values)
            end
          end

          def type_cast_range_value(value)
            infinity?(value) ? "" : type_cast(value)
          end

          def infinity?(value)
            value.respond_to?(:infinite?) && value.infinite?
          end
      end
    end
  end
end