# frozen_string_literal: true module Arel # :nodoc: all module Visitors class UnsupportedVisitError < StandardError def initialize(object) super "Unsupported argument type: #{object.class.name}. Construct an Arel node instead." end end class ToSql < Arel::Visitors::Visitor def initialize(connection) super() @connection = connection end def compile(node, collector = Arel::Collectors::SQLString.new) accept(node, collector).value end private def visit_Arel_Nodes_DeleteStatement(o, collector) o = prepare_delete_statement(o) if has_join_sources?(o) collector << "DELETE " visit o.relation.left, collector collector << " FROM " else collector << "DELETE FROM " end collector = visit o.relation, collector collect_nodes_for o.wheres, collector, " WHERE ", " AND " collect_nodes_for o.orders, collector, " ORDER BY " maybe_visit o.limit, collector end def visit_Arel_Nodes_UpdateStatement(o, collector) o = prepare_update_statement(o) collector << "UPDATE " collector = visit o.relation, collector collect_nodes_for o.values, collector, " SET " collect_nodes_for o.wheres, collector, " WHERE ", " AND " collect_nodes_for o.orders, collector, " ORDER BY " maybe_visit o.limit, collector end def visit_Arel_Nodes_InsertStatement(o, collector) collector << "INSERT INTO " collector = visit o.relation, collector if o.columns.any? collector << " (#{o.columns.map { |x| quote_column_name x.name }.join ', '})" end if o.values maybe_visit o.values, collector elsif o.select maybe_visit o.select, collector else collector end end def visit_Arel_Nodes_Exists(o, collector) collector << "EXISTS (" collector = visit(o.expressions, collector) << ")" if o.alias collector << " AS " visit o.alias, collector else collector end end def visit_Arel_Nodes_Casted(o, collector) collector << quoted(o.val, o.attribute).to_s end def visit_Arel_Nodes_Quoted(o, collector) collector << quoted(o.expr, nil).to_s end def visit_Arel_Nodes_True(o, collector) collector << "TRUE" end def visit_Arel_Nodes_False(o, collector) collector << "FALSE" end def visit_Arel_Nodes_ValuesList(o, collector) collector << "VALUES " len = o.rows.length - 1 o.rows.each_with_index { |row, i| collector << "(" row_len = row.length - 1 row.each_with_index do |value, k| case value when Nodes::SqlLiteral, Nodes::BindParam collector = visit(value, collector) else collector << quote(value).to_s end collector << ", " unless k == row_len end collector << ")" collector << ", " unless i == len } collector end def visit_Arel_Nodes_SelectStatement(o, collector) if o.with collector = visit o.with, collector collector << " " end collector = o.cores.inject(collector) { |c, x| visit_Arel_Nodes_SelectCore(x, c) } unless o.orders.empty? collector << " ORDER BY " len = o.orders.length - 1 o.orders.each_with_index { |x, i| collector = visit(x, collector) collector << ", " unless len == i } end visit_Arel_Nodes_SelectOptions(o, collector) end def visit_Arel_Nodes_SelectOptions(o, collector) collector = maybe_visit o.limit, collector collector = maybe_visit o.offset, collector maybe_visit o.lock, collector end def visit_Arel_Nodes_SelectCore(o, collector) collector << "SELECT" collector = collect_optimizer_hints(o, collector) collector = maybe_visit o.set_quantifier, collector collect_nodes_for o.projections, collector, " " if o.source && !o.source.empty? collector << " FROM " collector = visit o.source, collector end collect_nodes_for o.wheres, collector, " WHERE ", " AND " collect_nodes_for o.groups, collector, " GROUP BY " collect_nodes_for o.havings, collector, " HAVING ", " AND " collect_nodes_for o.windows, collector, " WINDOW " maybe_visit o.comment, collector end def visit_Arel_Nodes_OptimizerHints(o, collector) hints = o.expr.map { |v| sanitize_as_sql_comment(v) }.join(" ") collector << "/*+ #{hints} */" end def visit_Arel_Nodes_Comment(o, collector) collector << o.values.map { |v| "/* #{sanitize_as_sql_comment(v)} */" }.join(" ") end def collect_nodes_for(nodes, collector, spacer, connector = ", ") unless nodes.empty? collector << spacer inject_join nodes, collector, connector end end def visit_Arel_Nodes_Bin(o, collector) visit o.expr, collector end def visit_Arel_Nodes_Distinct(o, collector) collector << "DISTINCT" end def visit_Arel_Nodes_DistinctOn(o, collector) raise NotImplementedError, "DISTINCT ON not implemented for this db" end def visit_Arel_Nodes_With(o, collector) collector << "WITH " inject_join o.children, collector, ", " end def visit_Arel_Nodes_WithRecursive(o, collector) collector << "WITH RECURSIVE " inject_join o.children, collector, ", " end def visit_Arel_Nodes_Union(o, collector) infix_value_with_paren(o, collector, " UNION ") end def visit_Arel_Nodes_UnionAll(o, collector) infix_value_with_paren(o, collector, " UNION ALL ") end def visit_Arel_Nodes_Intersect(o, collector) collector << "( " infix_value(o, collector, " INTERSECT ") << " )" end def visit_Arel_Nodes_Except(o, collector) collector << "( " infix_value(o, collector, " EXCEPT ") << " )" end def visit_Arel_Nodes_NamedWindow(o, collector) collector << quote_column_name(o.name) collector << " AS " visit_Arel_Nodes_Window o, collector end def visit_Arel_Nodes_Window(o, collector) collector << "(" collect_nodes_for o.partitions, collector, "PARTITION BY " if o.orders.any? collector << " " if o.partitions.any? collector << "ORDER BY " collector = inject_join o.orders, collector, ", " end if o.framing collector << " " if o.partitions.any? || o.orders.any? collector = visit o.framing, collector end collector << ")" end def visit_Arel_Nodes_Rows(o, collector) if o.expr collector << "ROWS " visit o.expr, collector else collector << "ROWS" end end def visit_Arel_Nodes_Range(o, collector) if o.expr collector << "RANGE " visit o.expr, collector else collector << "RANGE" end end def visit_Arel_Nodes_Preceding(o, collector) collector = if o.expr visit o.expr, collector else collector << "UNBOUNDED" end collector << " PRECEDING" end def visit_Arel_Nodes_Following(o, collector) collector = if o.expr visit o.expr, collector else collector << "UNBOUNDED" end collector << " FOLLOWING" end def visit_Arel_Nodes_CurrentRow(o, collector) collector << "CURRENT ROW" end def visit_Arel_Nodes_Over(o, collector) case o.right when nil visit(o.left, collector) << " OVER ()" when Arel::Nodes::SqlLiteral infix_value o, collector, " OVER " when String, Symbol visit(o.left, collector) << " OVER #{quote_column_name o.right.to_s}" else infix_value o, collector, " OVER " end end def visit_Arel_Nodes_Offset(o, collector) collector << "OFFSET " visit o.expr, collector end def visit_Arel_Nodes_Limit(o, collector) collector << "LIMIT " visit o.expr, collector end def visit_Arel_Nodes_Lock(o, collector) visit o.expr, collector end def visit_Arel_Nodes_Grouping(o, collector) if o.expr.is_a? Nodes::Grouping visit(o.expr, collector) else collector << "(" visit(o.expr, collector) << ")" end end def visit_Arel_SelectManager(o, collector) collector << "(" visit(o.ast, collector) << ")" end def visit_Arel_Nodes_Ascending(o, collector) visit(o.expr, collector) << " ASC" end def visit_Arel_Nodes_Descending(o, collector) visit(o.expr, collector) << " DESC" end def visit_Arel_Nodes_Group(o, collector) visit o.expr, collector end def visit_Arel_Nodes_NamedFunction(o, collector) collector << o.name collector << "(" collector << "DISTINCT " if o.distinct collector = inject_join(o.expressions, collector, ", ") << ")" if o.alias collector << " AS " visit o.alias, collector else collector end end def visit_Arel_Nodes_Extract(o, collector) collector << "EXTRACT(#{o.field.to_s.upcase} FROM " visit(o.expr, collector) << ")" end def visit_Arel_Nodes_Count(o, collector) aggregate "COUNT", o, collector end def visit_Arel_Nodes_Sum(o, collector) aggregate "SUM", o, collector end def visit_Arel_Nodes_Max(o, collector) aggregate "MAX", o, collector end def visit_Arel_Nodes_Min(o, collector) aggregate "MIN", o, collector end def visit_Arel_Nodes_Avg(o, collector) aggregate "AVG", o, collector end def visit_Arel_Nodes_TableAlias(o, collector) collector = visit o.relation, collector collector << " " collector << quote_table_name(o.name) end def visit_Arel_Nodes_Between(o, collector) collector = visit o.left, collector collector << " BETWEEN " visit o.right, collector end def visit_Arel_Nodes_GreaterThanOrEqual(o, collector) collector = visit o.left, collector collector << " >= " visit o.right, collector end def visit_Arel_Nodes_GreaterThan(o, collector) collector = visit o.left, collector collector << " > " visit o.right, collector end def visit_Arel_Nodes_LessThanOrEqual(o, collector) collector = visit o.left, collector collector << " <= " visit o.right, collector end def visit_Arel_Nodes_LessThan(o, collector) collector = visit o.left, collector collector << " < " visit o.right, collector end def visit_Arel_Nodes_Matches(o, collector) collector = visit o.left, collector collector << " LIKE " collector = visit o.right, collector if o.escape collector << " ESCAPE " visit o.escape, collector else collector end end def visit_Arel_Nodes_DoesNotMatch(o, collector) collector = visit o.left, collector collector << " NOT LIKE " collector = visit o.right, collector if o.escape collector << " ESCAPE " visit o.escape, collector else collector end end def visit_Arel_Nodes_JoinSource(o, collector) if o.left collector = visit o.left, collector end if o.right.any? collector << " " if o.left collector = inject_join o.right, collector, " " end collector end def visit_Arel_Nodes_Regexp(o, collector) raise NotImplementedError, "~ not implemented for this db" end def visit_Arel_Nodes_NotRegexp(o, collector) raise NotImplementedError, "!~ not implemented for this db" end def visit_Arel_Nodes_StringJoin(o, collector) visit o.left, collector end def visit_Arel_Nodes_FullOuterJoin(o, collector) collector << "FULL OUTER JOIN " collector = visit o.left, collector collector << " " visit o.right, collector end def visit_Arel_Nodes_OuterJoin(o, collector) collector << "LEFT OUTER JOIN " collector = visit o.left, collector collector << " " visit o.right, collector end def visit_Arel_Nodes_RightOuterJoin(o, collector) collector << "RIGHT OUTER JOIN " collector = visit o.left, collector collector << " " visit o.right, collector end def visit_Arel_Nodes_InnerJoin(o, collector) collector << "INNER JOIN " collector = visit o.left, collector if o.right collector << " " visit(o.right, collector) else collector end end def visit_Arel_Nodes_On(o, collector) collector << "ON " visit o.expr, collector end def visit_Arel_Nodes_Not(o, collector) collector << "NOT (" visit(o.expr, collector) << ")" end def visit_Arel_Table(o, collector) if o.table_alias collector << "#{quote_table_name o.name} #{quote_table_name o.table_alias}" else collector << quote_table_name(o.name) end end def visit_Arel_Nodes_In(o, collector) unless Array === o.right return collect_in_clause(o.left, o.right, collector) end unless o.right.empty? o.right.delete_if { |value| unboundable?(value) } end return collector << "1=0" if o.right.empty? in_clause_length = @connection.in_clause_length if !in_clause_length || o.right.length <= in_clause_length collect_in_clause(o.left, o.right, collector) else collector << "(" o.right.each_slice(in_clause_length).each_with_index do |right, i| collector << " OR " unless i == 0 collect_in_clause(o.left, right, collector) end collector << ")" end end def collect_in_clause(left, right, collector) collector = visit left, collector collector << " IN (" visit(right, collector) << ")" end def visit_Arel_Nodes_NotIn(o, collector) unless Array === o.right return collect_not_in_clause(o.left, o.right, collector) end unless o.right.empty? o.right.delete_if { |value| unboundable?(value) } end return collector << "1=1" if o.right.empty? in_clause_length = @connection.in_clause_length if !in_clause_length || o.right.length <= in_clause_length collect_not_in_clause(o.left, o.right, collector) else o.right.each_slice(in_clause_length).each_with_index do |right, i| collector << " AND " unless i == 0 collect_not_in_clause(o.left, right, collector) end collector end end def collect_not_in_clause(left, right, collector) collector = visit left, collector collector << " NOT IN (" visit(right, collector) << ")" end def visit_Arel_Nodes_And(o, collector) inject_join o.children, collector, " AND " end def visit_Arel_Nodes_Or(o, collector) collector = visit o.left, collector collector << " OR " visit o.right, collector end def visit_Arel_Nodes_Assignment(o, collector) case o.right when Arel::Nodes::Node, Arel::Attributes::Attribute collector = visit o.left, collector collector << " = " visit o.right, collector else collector = visit o.left, collector collector << " = " collector << quote(o.right).to_s end end def visit_Arel_Nodes_Equality(o, collector) right = o.right return collector << "1=0" if unboundable?(right) collector = visit o.left, collector if right.nil? collector << " IS NULL" else collector << " = " visit right, collector end end def visit_Arel_Nodes_IsNotDistinctFrom(o, collector) if o.right.nil? collector = visit o.left, collector collector << " IS NULL" else collector = is_distinct_from(o, collector) collector << " = 0" end end def visit_Arel_Nodes_IsDistinctFrom(o, collector) if o.right.nil? collector = visit o.left, collector collector << " IS NOT NULL" else collector = is_distinct_from(o, collector) collector << " = 1" end end def visit_Arel_Nodes_NotEqual(o, collector) right = o.right return collector << "1=1" if unboundable?(right) collector = visit o.left, collector if right.nil? collector << " IS NOT NULL" else collector << " != " visit right, collector end end def visit_Arel_Nodes_As(o, collector) collector = visit o.left, collector collector << " AS " visit o.right, collector end def visit_Arel_Nodes_Case(o, collector) collector << "CASE " if o.case visit o.case, collector collector << " " end o.conditions.each do |condition| visit condition, collector collector << " " end if o.default visit o.default, collector collector << " " end collector << "END" end def visit_Arel_Nodes_When(o, collector) collector << "WHEN " visit o.left, collector collector << " THEN " visit o.right, collector end def visit_Arel_Nodes_Else(o, collector) collector << "ELSE " visit o.expr, collector end def visit_Arel_Nodes_UnqualifiedColumn(o, collector) collector << "#{quote_column_name o.name}" collector end def visit_Arel_Attributes_Attribute(o, collector) join_name = o.relation.table_alias || o.relation.name collector << "#{quote_table_name join_name}.#{quote_column_name o.name}" end alias :visit_Arel_Attributes_Integer :visit_Arel_Attributes_Attribute alias :visit_Arel_Attributes_Float :visit_Arel_Attributes_Attribute alias :visit_Arel_Attributes_Decimal :visit_Arel_Attributes_Attribute alias :visit_Arel_Attributes_String :visit_Arel_Attributes_Attribute alias :visit_Arel_Attributes_Time :visit_Arel_Attributes_Attribute alias :visit_Arel_Attributes_Boolean :visit_Arel_Attributes_Attribute def literal(o, collector); collector << o.to_s; end def visit_Arel_Nodes_BindParam(o, collector) collector.add_bind(o.value) { "?" } end alias :visit_Arel_Nodes_SqlLiteral :literal alias :visit_Integer :literal def quoted(o, a) if a && a.able_to_type_cast? quote(a.type_cast_for_database(o)) else quote(o) end end def unsupported(o, collector) raise UnsupportedVisitError.new(o) end alias :visit_ActiveSupport_Multibyte_Chars :unsupported alias :visit_ActiveSupport_StringInquirer :unsupported alias :visit_BigDecimal :unsupported alias :visit_Class :unsupported alias :visit_Date :unsupported alias :visit_DateTime :unsupported alias :visit_FalseClass :unsupported alias :visit_Float :unsupported alias :visit_Hash :unsupported alias :visit_NilClass :unsupported alias :visit_String :unsupported alias :visit_Symbol :unsupported alias :visit_Time :unsupported alias :visit_TrueClass :unsupported def visit_Arel_Nodes_InfixOperation(o, collector) collector = visit o.left, collector collector << " #{o.operator} " visit o.right, collector end alias :visit_Arel_Nodes_Addition :visit_Arel_Nodes_InfixOperation alias :visit_Arel_Nodes_Subtraction :visit_Arel_Nodes_InfixOperation alias :visit_Arel_Nodes_Multiplication :visit_Arel_Nodes_InfixOperation alias :visit_Arel_Nodes_Division :visit_Arel_Nodes_InfixOperation def visit_Arel_Nodes_UnaryOperation(o, collector) collector << " #{o.operator} " visit o.expr, collector end def visit_Array(o, collector) inject_join o, collector, ", " end alias :visit_Set :visit_Array def quote(value) return value if Arel::Nodes::SqlLiteral === value @connection.quote value end def quote_table_name(name) return name if Arel::Nodes::SqlLiteral === name @connection.quote_table_name(name) end def quote_column_name(name) return name if Arel::Nodes::SqlLiteral === name @connection.quote_column_name(name) end def sanitize_as_sql_comment(value) return value if Arel::Nodes::SqlLiteral === value @connection.sanitize_as_sql_comment(value) end def collect_optimizer_hints(o, collector) maybe_visit o.optimizer_hints, collector end def maybe_visit(thing, collector) return collector unless thing collector << " " visit thing, collector end def inject_join(list, collector, join_str) len = list.length - 1 list.each_with_index.inject(collector) { |c, (x, i)| if i == len visit x, c else visit(x, c) << join_str end } end def unboundable?(value) value.respond_to?(:unboundable?) && value.unboundable? end def has_join_sources?(o) o.relation.is_a?(Nodes::JoinSource) && !o.relation.right.empty? end def has_limit_or_offset_or_orders?(o) o.limit || o.offset || !o.orders.empty? end # The default strategy for an UPDATE with joins is to use a subquery. This doesn't work # on MySQL (even when aliasing the tables), but MySQL allows using JOIN directly in # an UPDATE statement, so in the MySQL visitor we redefine this to do that. def prepare_update_statement(o) if o.key && (has_limit_or_offset_or_orders?(o) || has_join_sources?(o)) stmt = o.clone stmt.limit = nil stmt.offset = nil stmt.orders = [] stmt.wheres = [Nodes::In.new(o.key, [build_subselect(o.key, o)])] stmt.relation = o.relation.left if has_join_sources?(o) stmt else o end end alias :prepare_delete_statement :prepare_update_statement # FIXME: we should probably have a 2-pass visitor for this def build_subselect(key, o) stmt = Nodes::SelectStatement.new core = stmt.cores.first core.froms = o.relation core.wheres = o.wheres core.projections = [key] stmt.limit = o.limit stmt.offset = o.offset stmt.orders = o.orders stmt end def infix_value(o, collector, value) collector = visit o.left, collector collector << value visit o.right, collector end def infix_value_with_paren(o, collector, value, suppress_parens = false) collector << "( " unless suppress_parens collector = if o.left.class == o.class infix_value_with_paren(o.left, collector, value, true) else visit o.left, collector end collector << value collector = if o.right.class == o.class infix_value_with_paren(o.right, collector, value, true) else visit o.right, collector end collector << " )" unless suppress_parens collector end def aggregate(name, o, collector) collector << "#{name}(" if o.distinct collector << "DISTINCT " end collector = inject_join(o.expressions, collector, ", ") << ")" if o.alias collector << " AS " visit o.alias, collector else collector end end def is_distinct_from(o, collector) collector << "CASE WHEN " collector = visit o.left, collector collector << " = " collector = visit o.right, collector collector << " OR (" collector = visit o.left, collector collector << " IS NULL AND " collector = visit o.right, collector collector << " IS NULL)" collector << " THEN 0 ELSE 1 END" end end end end