aboutsummaryrefslogblamecommitdiffstats
path: root/activerecord/lib/active_record/relation/where_clause.rb
blob: 8fae380b0aa68bd9f6fda5f92225ef8f69ec36d9 (plain) (tree)
1
2
3
4
5
6
7
8
9

                             

                   
                               
                                              
 
                                
                                



                        
                                        


         





                                        

                        
                                                               


         
                          
                                                   

         
                   



                              







                                      

         
                                
                                           





                                                  
                               
                                    

                                                


              



                                                                  

                                   
                                        

         








                                                                                
                                            

         
                    
                          
         

               
                               
 




                                                                        
           

             













                                                         
 



                                                                          
           
 


                                                             
 
                                  
                   



                                                                              



                                                                     





                                                                   

             
 
                                      
                                     
                                                                                  
             
           
 

                                                

                                                  
                                    
                     


               
 









                                              
           















                                                         


       
# frozen_string_literal: true

module ActiveRecord
  class Relation
    class WhereClause # :nodoc:
      delegate :any?, :empty?, to: :predicates

      def initialize(predicates)
        @predicates = predicates
      end

      def +(other)
        WhereClause.new(
          predicates + other.predicates,
        )
      end

      def -(other)
        WhereClause.new(
          predicates - other.predicates,
        )
      end

      def merge(other)
        WhereClause.new(
          predicates_unreferenced_by(other) + other.predicates,
        )
      end

      def except(*columns)
        WhereClause.new(except_predicates(columns))
      end

      def or(other)
        left = self - other
        common = self - left
        right = other - common

        if left.empty? || right.empty?
          common
        else
          or_clause = WhereClause.new(
            [left.ast.or(right.ast)],
          )
          common + or_clause
        end
      end

      def to_h(table_name = nil)
        equalities = equalities(predicates)
        if table_name
          equalities = equalities.select do |node|
            node.left.relation.name == table_name
          end
        end

        equalities.map { |node|
          name = node.left.name.to_s
          value = extract_node_value(node.right)
          [name, value]
        }.to_h
      end

      def ast
        Arel::Nodes::And.new(predicates_with_wrapped_sql_literals)
      end

      def ==(other)
        other.is_a?(WhereClause) &&
          predicates == other.predicates
      end

      def invert(as = :nand)
        if predicates.size == 1
          inverted_predicates = [ invert_predicate(predicates.first) ]
        elsif as == :nor
          inverted_predicates = predicates.map { |node| invert_predicate(node) }
        else
          inverted_predicates = [ Arel::Nodes::Not.new(ast) ]
        end

        WhereClause.new(inverted_predicates)
      end

      def self.empty
        @empty ||= new([])
      end

      protected
        attr_reader :predicates

        def referenced_columns
          @referenced_columns ||= begin
            equality_nodes = predicates.select { |n| equality_node?(n) }
            Set.new(equality_nodes, &:left)
          end
        end

      private
        def equalities(predicates)
          equalities = []

          predicates.each do |node|
            case node
            when Arel::Nodes::Equality
              equalities << node
            when Arel::Nodes::And
              equalities.concat equalities(node.children)
            end
          end

          equalities
        end

        def predicates_unreferenced_by(other)
          predicates.reject do |n|
            equality_node?(n) && other.referenced_columns.include?(n.left)
          end
        end

        def equality_node?(node)
          node.respond_to?(:operator) && node.operator == :==
        end

        def invert_predicate(node)
          case node
          when NilClass
            raise ArgumentError, "Invalid argument for .where.not(), got nil."
          when Arel::Nodes::In
            Arel::Nodes::NotIn.new(node.left, node.right)
          when Arel::Nodes::IsNotDistinctFrom
            Arel::Nodes::IsDistinctFrom.new(node.left, node.right)
          when Arel::Nodes::IsDistinctFrom
            Arel::Nodes::IsNotDistinctFrom.new(node.left, node.right)
          when Arel::Nodes::Equality
            Arel::Nodes::NotEqual.new(node.left, node.right)
          when String
            Arel::Nodes::Not.new(Arel::Nodes::SqlLiteral.new(node))
          else
            Arel::Nodes::Not.new(node)
          end
        end

        def except_predicates(columns)
          predicates.reject do |node|
            Arel.fetch_attribute(node) { |attr| columns.include?(attr.name.to_s) }
          end
        end

        def predicates_with_wrapped_sql_literals
          non_empty_predicates.map do |node|
            case node
            when Arel::Nodes::SqlLiteral, ::String
              wrap_sql_literal(node)
            else node
            end
          end
        end

        ARRAY_WITH_EMPTY_STRING = [""]
        def non_empty_predicates
          predicates - ARRAY_WITH_EMPTY_STRING
        end

        def wrap_sql_literal(node)
          if ::String === node
            node = Arel.sql(node)
          end
          Arel::Nodes::Grouping.new(node)
        end

        def extract_node_value(node)
          case node
          when Array
            node.map { |v| extract_node_value(v) }
          when Arel::Nodes::Casted, Arel::Nodes::Quoted
            node.val
          when Arel::Nodes::BindParam
            value = node.value
            if value.respond_to?(:value_before_type_cast)
              value.value_before_type_cast
            else
              value
            end
          end
        end
    end
  end
end