# frozen_string_literal: true module ActiveRecord class Relation class WhereClause # :nodoc: delegate :any?, :empty?, to: :predicates def initialize(predicates) @predicates = predicates end def +(other) WhereClause.new( predicates + other.predicates, ) end def -(other) WhereClause.new( predicates - other.predicates, ) end def merge(other) WhereClause.new( predicates_unreferenced_by(other) + other.predicates, ) end def except(*columns) WhereClause.new(except_predicates(columns)) end def or(other) left = self - other common = self - left right = other - common if left.empty? || right.empty? common else or_clause = WhereClause.new( [left.ast.or(right.ast)], ) common + or_clause end end def to_h(table_name = nil) equalities = equalities(predicates) if table_name equalities = equalities.select do |node| node.left.relation.name == table_name end end equalities.map { |node| name = node.left.name.to_s value = extract_node_value(node.right) [name, value] }.to_h end def ast Arel::Nodes::And.new(predicates_with_wrapped_sql_literals) end def ==(other) other.is_a?(WhereClause) && predicates == other.predicates end def invert(as = :nand) if predicates.size == 1 inverted_predicates = [ invert_predicate(predicates.first) ] elsif as == :nor inverted_predicates = predicates.map { |node| invert_predicate(node) } else inverted_predicates = [ Arel::Nodes::Not.new(ast) ] end WhereClause.new(inverted_predicates) end def self.empty @empty ||= new([]) end protected attr_reader :predicates def referenced_columns @referenced_columns ||= begin equality_nodes = predicates.select { |n| equality_node?(n) } Set.new(equality_nodes, &:left) end end private def equalities(predicates) equalities = [] predicates.each do |node| case node when Arel::Nodes::Equality equalities << node when Arel::Nodes::And equalities.concat equalities(node.children) end end equalities end def predicates_unreferenced_by(other) predicates.reject do |n| equality_node?(n) && other.referenced_columns.include?(n.left) end end def equality_node?(node) node.respond_to?(:operator) && node.operator == :== end def invert_predicate(node) case node when NilClass raise ArgumentError, "Invalid argument for .where.not(), got nil." when Arel::Nodes::In Arel::Nodes::NotIn.new(node.left, node.right) when Arel::Nodes::IsNotDistinctFrom Arel::Nodes::IsDistinctFrom.new(node.left, node.right) when Arel::Nodes::IsDistinctFrom Arel::Nodes::IsNotDistinctFrom.new(node.left, node.right) when Arel::Nodes::Equality Arel::Nodes::NotEqual.new(node.left, node.right) when String Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new(node)) else Arel::Nodes::Not.new(node) end end def except_predicates(columns) predicates.reject do |node| Arel.fetch_attribute(node) { |attr| columns.include?(attr.name.to_s) } end end def predicates_with_wrapped_sql_literals non_empty_predicates.map do |node| case node when Arel::Nodes::SqlLiteral, ::String wrap_sql_literal(node) else node end end end ARRAY_WITH_EMPTY_STRING = [""] def non_empty_predicates predicates - ARRAY_WITH_EMPTY_STRING end def wrap_sql_literal(node) if ::String === node node = Arel.sql(node) end Arel::Nodes::Grouping.new(node) end def extract_node_value(node) case node when Array node.map { |v| extract_node_value(v) } when Arel::Nodes::Casted, Arel::Nodes::Quoted node.val when Arel::Nodes::BindParam value = node.value if value.respond_to?(:value_before_type_cast) value.value_before_type_cast else value end end end end end end