From b5302d5a820b078b6488104dd695a679e5a49623 Mon Sep 17 00:00:00 2001
From: Dmytro Shteflyuk <kpumuk@kpumuk.info>
Date: Thu, 15 Nov 2018 14:49:55 -0500
Subject: Arel: Implemented DB-aware NULL-safe comparison (#34451)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Arel: Implemented DB-aware NULL-safe comparison

* Fixed where clause inversion for NULL-safe comparison

* Renaming "null_safe_eq" to "is_not_distinct_from", "null_safe_not_eq" to "is_distinct_from"

[Dmytro Shteflyuk + Rafael Mendonça França]
---
 .../lib/active_record/relation/where_clause.rb     |  4 +++
 activerecord/lib/arel/nodes/equality.rb            |  7 +++++
 activerecord/lib/arel/predications.rb              |  8 ++++++
 activerecord/lib/arel/visitors/depth_first.rb      |  2 ++
 activerecord/lib/arel/visitors/dot.rb              |  2 ++
 activerecord/lib/arel/visitors/ibm_db.rb           |  6 ++++
 activerecord/lib/arel/visitors/mssql.rb            | 25 ++++++++++++++++
 activerecord/lib/arel/visitors/mysql.rb            | 11 ++++++++
 activerecord/lib/arel/visitors/oracle.rb           |  6 ++++
 activerecord/lib/arel/visitors/oracle12.rb         |  6 ++++
 activerecord/lib/arel/visitors/postgresql.rb       | 12 ++++++++
 activerecord/lib/arel/visitors/sqlite.rb           | 12 ++++++++
 activerecord/lib/arel/visitors/to_sql.rb           | 33 ++++++++++++++++++++++
 13 files changed, 134 insertions(+)

(limited to 'activerecord/lib')

diff --git a/activerecord/lib/active_record/relation/where_clause.rb b/activerecord/lib/active_record/relation/where_clause.rb
index a502713e56..e225628bae 100644
--- a/activerecord/lib/active_record/relation/where_clause.rb
+++ b/activerecord/lib/active_record/relation/where_clause.rb
@@ -125,6 +125,10 @@ module ActiveRecord
             raise ArgumentError, "Invalid argument for .where.not(), got nil."
           when Arel::Nodes::In
             Arel::Nodes::NotIn.new(node.left, node.right)
+          when Arel::Nodes::IsNotDistinctFrom
+            Arel::Nodes::IsDistinctFrom.new(node.left, node.right)
+          when Arel::Nodes::IsDistinctFrom
+            Arel::Nodes::IsNotDistinctFrom.new(node.left, node.right)
           when Arel::Nodes::Equality
             Arel::Nodes::NotEqual.new(node.left, node.right)
           when String
diff --git a/activerecord/lib/arel/nodes/equality.rb b/activerecord/lib/arel/nodes/equality.rb
index 2aa85a977e..551d56c2ff 100644
--- a/activerecord/lib/arel/nodes/equality.rb
+++ b/activerecord/lib/arel/nodes/equality.rb
@@ -7,5 +7,12 @@ module Arel # :nodoc: all
       alias :operand1 :left
       alias :operand2 :right
     end
+
+    %w{
+      IsDistinctFrom
+      IsNotDistinctFrom
+    }.each do |name|
+      const_set name, Class.new(Equality)
+    end
   end
 end
diff --git a/activerecord/lib/arel/predications.rb b/activerecord/lib/arel/predications.rb
index e83a6f162f..77502dd199 100644
--- a/activerecord/lib/arel/predications.rb
+++ b/activerecord/lib/arel/predications.rb
@@ -18,6 +18,14 @@ module Arel # :nodoc: all
       Nodes::Equality.new self, quoted_node(other)
     end
 
+    def is_not_distinct_from(other)
+      Nodes::IsNotDistinctFrom.new self, quoted_node(other)
+    end
+
+    def is_distinct_from(other)
+      Nodes::IsDistinctFrom.new self, quoted_node(other)
+    end
+
     def eq_any(others)
       grouping_any :eq, others
     end
diff --git a/activerecord/lib/arel/visitors/depth_first.rb b/activerecord/lib/arel/visitors/depth_first.rb
index 8f65d303ac..92d309453c 100644
--- a/activerecord/lib/arel/visitors/depth_first.rb
+++ b/activerecord/lib/arel/visitors/depth_first.rb
@@ -95,6 +95,8 @@ module Arel # :nodoc: all
         alias :visit_Arel_Nodes_NotEqual           :binary
         alias :visit_Arel_Nodes_NotIn              :binary
         alias :visit_Arel_Nodes_NotRegexp          :binary
+        alias :visit_Arel_Nodes_IsNotDistinctFrom  :binary
+        alias :visit_Arel_Nodes_IsDistinctFrom     :binary
         alias :visit_Arel_Nodes_Or                 :binary
         alias :visit_Arel_Nodes_OuterJoin          :binary
         alias :visit_Arel_Nodes_Regexp             :binary
diff --git a/activerecord/lib/arel/visitors/dot.rb b/activerecord/lib/arel/visitors/dot.rb
index 9054f0159b..6389c875cb 100644
--- a/activerecord/lib/arel/visitors/dot.rb
+++ b/activerecord/lib/arel/visitors/dot.rb
@@ -195,6 +195,8 @@ module Arel # :nodoc: all
         alias :visit_Arel_Nodes_JoinSource         :binary
         alias :visit_Arel_Nodes_LessThan           :binary
         alias :visit_Arel_Nodes_LessThanOrEqual    :binary
+        alias :visit_Arel_Nodes_IsNotDistinctFrom  :binary
+        alias :visit_Arel_Nodes_IsDistinctFrom     :binary
         alias :visit_Arel_Nodes_Matches            :binary
         alias :visit_Arel_Nodes_NotEqual           :binary
         alias :visit_Arel_Nodes_NotIn              :binary
diff --git a/activerecord/lib/arel/visitors/ibm_db.rb b/activerecord/lib/arel/visitors/ibm_db.rb
index 0a06aef60b..73166054da 100644
--- a/activerecord/lib/arel/visitors/ibm_db.rb
+++ b/activerecord/lib/arel/visitors/ibm_db.rb
@@ -10,6 +10,12 @@ module Arel # :nodoc: all
           collector = visit o.expr, collector
           collector << " ROWS ONLY"
         end
+
+        def is_distinct_from(o, collector)
+          collector << "DECODE("
+          collector = visit [o.left, o.right, 0, 1], collector
+          collector << ")"
+        end
     end
   end
 end
diff --git a/activerecord/lib/arel/visitors/mssql.rb b/activerecord/lib/arel/visitors/mssql.rb
index d564e19089..fdd864b40d 100644
--- a/activerecord/lib/arel/visitors/mssql.rb
+++ b/activerecord/lib/arel/visitors/mssql.rb
@@ -12,6 +12,31 @@ module Arel # :nodoc: all
 
       private
 
+        def visit_Arel_Nodes_IsNotDistinctFrom(o, collector)
+          right = o.right
+
+          if right.nil?
+            collector = visit o.left, collector
+            collector << " IS NULL"
+          else
+            collector << "EXISTS (VALUES ("
+            collector = visit o.left, collector
+            collector << ") INTERSECT VALUES ("
+            collector = visit right, collector
+            collector << "))"
+          end
+        end
+
+        def visit_Arel_Nodes_IsDistinctFrom(o, collector)
+          if o.right.nil?
+            collector = visit o.left, collector
+            collector << " IS NOT NULL"
+          else
+            collector << "NOT "
+            visit_Arel_Nodes_IsNotDistinctFrom o, collector
+          end
+        end
+
         def visit_Arel_Visitors_MSSQL_RowNumber(o, collector)
           collector << "ROW_NUMBER() OVER (ORDER BY "
           inject_join(o.children, collector, ", ") << ") as _row_num"
diff --git a/activerecord/lib/arel/visitors/mysql.rb b/activerecord/lib/arel/visitors/mysql.rb
index 4e7b2456aa..dd77cfdf66 100644
--- a/activerecord/lib/arel/visitors/mysql.rb
+++ b/activerecord/lib/arel/visitors/mysql.rb
@@ -37,6 +37,17 @@ module Arel # :nodoc: all
           collector
         end
 
+        def visit_Arel_Nodes_IsNotDistinctFrom(o, collector)
+          collector = visit o.left, collector
+          collector << " <=> "
+          visit o.right, collector
+        end
+
+        def visit_Arel_Nodes_IsDistinctFrom(o, collector)
+          collector << "NOT "
+          visit_Arel_Nodes_IsNotDistinctFrom o, collector
+        end
+
         # In the simple case, MySQL allows us to place JOINs directly into the UPDATE
         # query. However, this does not allow for LIMIT, OFFSET and ORDER. To support
         # these, we must use a subquery.
diff --git a/activerecord/lib/arel/visitors/oracle.rb b/activerecord/lib/arel/visitors/oracle.rb
index 30a1529d46..f96bf65ee5 100644
--- a/activerecord/lib/arel/visitors/oracle.rb
+++ b/activerecord/lib/arel/visitors/oracle.rb
@@ -148,6 +148,12 @@ module Arel # :nodoc: all
         def visit_Arel_Nodes_BindParam(o, collector)
           collector.add_bind(o.value) { |i| ":a#{i}" }
         end
+
+        def is_distinct_from(o, collector)
+          collector << "DECODE("
+          collector = visit [o.left, o.right, 0, 1], collector
+          collector << ")"
+        end
     end
   end
 end
diff --git a/activerecord/lib/arel/visitors/oracle12.rb b/activerecord/lib/arel/visitors/oracle12.rb
index 7061f06087..b092aa95e0 100644
--- a/activerecord/lib/arel/visitors/oracle12.rb
+++ b/activerecord/lib/arel/visitors/oracle12.rb
@@ -56,6 +56,12 @@ module Arel # :nodoc: all
         def visit_Arel_Nodes_BindParam(o, collector)
           collector.add_bind(o.value) { |i| ":a#{i}" }
         end
+
+        def is_distinct_from(o, collector)
+          collector << "DECODE("
+          collector = visit [o.left, o.right, 0, 1], collector
+          collector << ")"
+        end
     end
   end
 end
diff --git a/activerecord/lib/arel/visitors/postgresql.rb b/activerecord/lib/arel/visitors/postgresql.rb
index c5110fa89c..920776b4dc 100644
--- a/activerecord/lib/arel/visitors/postgresql.rb
+++ b/activerecord/lib/arel/visitors/postgresql.rb
@@ -77,6 +77,18 @@ module Arel # :nodoc: all
           grouping_parentheses o, collector
         end
 
+        def visit_Arel_Nodes_IsNotDistinctFrom(o, collector)
+          collector = visit o.left, collector
+          collector << " IS NOT DISTINCT FROM "
+          visit o.right, collector
+        end
+
+        def visit_Arel_Nodes_IsDistinctFrom(o, collector)
+          collector = visit o.left, collector
+          collector << " IS DISTINCT FROM "
+          visit o.right, collector
+        end
+
         # Used by Lateral visitor to enclose select queries in parentheses
         def grouping_parentheses(o, collector)
           if o.expr.is_a? Nodes::SelectStatement
diff --git a/activerecord/lib/arel/visitors/sqlite.rb b/activerecord/lib/arel/visitors/sqlite.rb
index cb1d2424ad..af6f7e856a 100644
--- a/activerecord/lib/arel/visitors/sqlite.rb
+++ b/activerecord/lib/arel/visitors/sqlite.rb
@@ -22,6 +22,18 @@ module Arel # :nodoc: all
         def visit_Arel_Nodes_False(o, collector)
           collector << "0"
         end
+
+        def visit_Arel_Nodes_IsNotDistinctFrom(o, collector)
+          collector = visit o.left, collector
+          collector << " IS "
+          visit o.right, collector
+        end
+
+        def visit_Arel_Nodes_IsDistinctFrom(o, collector)
+          collector = visit o.left, collector
+          collector << " IS NOT "
+          visit o.right, collector
+        end
     end
   end
 end
diff --git a/activerecord/lib/arel/visitors/to_sql.rb b/activerecord/lib/arel/visitors/to_sql.rb
index 7efd74dbc9..f9fe4404eb 100644
--- a/activerecord/lib/arel/visitors/to_sql.rb
+++ b/activerecord/lib/arel/visitors/to_sql.rb
@@ -641,6 +641,26 @@ module Arel # :nodoc: all
           end
         end
 
+        def visit_Arel_Nodes_IsNotDistinctFrom(o, collector)
+          if o.right.nil?
+            collector = visit o.left, collector
+            collector << " IS NULL"
+          else
+            collector = is_distinct_from(o, collector)
+            collector << " = 0"
+          end
+        end
+
+        def visit_Arel_Nodes_IsDistinctFrom(o, collector)
+          if o.right.nil?
+            collector = visit o.left, collector
+            collector << " IS NOT NULL"
+          else
+            collector = is_distinct_from(o, collector)
+            collector << " = 1"
+          end
+        end
+
         def visit_Arel_Nodes_NotEqual(o, collector)
           right = o.right
 
@@ -873,6 +893,19 @@ module Arel # :nodoc: all
             collector
           end
         end
+
+        def is_distinct_from(o, collector)
+          collector << "CASE WHEN "
+          collector = visit o.left, collector
+          collector << " = "
+          collector = visit o.right, collector
+          collector << " OR ("
+          collector = visit o.left, collector
+          collector << " IS NULL AND "
+          collector = visit o.right, collector
+          collector << " IS NULL)"
+          collector << " THEN 0 ELSE 1 END"
+        end
     end
   end
 end
-- 
cgit v1.2.3