From 4c7e50f9328aca4e294b41fce0832bf6ac4a939a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Bu=CC=88nemann?= Date: Sat, 19 Dec 2015 19:40:46 +0100 Subject: Implement CASE Conditional Expression --- lib/arel/nodes.rb | 3 ++ lib/arel/nodes/case.rb | 57 +++++++++++++++++++++++++++ lib/arel/predications.rb | 4 ++ lib/arel/visitors/depth_first.rb | 10 ++++- lib/arel/visitors/to_sql.rb | 29 ++++++++++++++ test/nodes/test_case.rb | 82 +++++++++++++++++++++++++++++++++++++++ test/visitors/test_depth_first.rb | 12 ++++++ test/visitors/test_to_sql.rb | 60 ++++++++++++++++++++++++++++ 8 files changed, 256 insertions(+), 1 deletion(-) create mode 100644 lib/arel/nodes/case.rb create mode 100644 test/nodes/test_case.rb diff --git a/lib/arel/nodes.rb b/lib/arel/nodes.rb index 0e66d2dd0c..89f0f563ac 100644 --- a/lib/arel/nodes.rb +++ b/lib/arel/nodes.rb @@ -47,6 +47,9 @@ require 'arel/nodes/named_function' # windows require 'arel/nodes/window' +# conditional expressions +require 'arel/nodes/case' + # joins require 'arel/nodes/full_outer_join' require 'arel/nodes/inner_join' diff --git a/lib/arel/nodes/case.rb b/lib/arel/nodes/case.rb new file mode 100644 index 0000000000..85f8851dbe --- /dev/null +++ b/lib/arel/nodes/case.rb @@ -0,0 +1,57 @@ +module Arel + module Nodes + class Case < Arel::Nodes::Node + include Arel::OrderPredications + include Arel::Predications + include Arel::AliasPredication + + attr_accessor :case, :conditions, :default + + def initialize expression = nil, default = nil + @case = expression + @conditions = [] + @default = default + end + + def when condition, expression = nil + @conditions << When.new(Nodes.build_quoted(condition), expression) + self + end + + def then expression + @conditions.last.right = Nodes.build_quoted(expression) + self + end + + def else expression + @default = Else.new Nodes.build_quoted(expression) + self + end + + def initialize_copy other + super + @case = @case.clone if @case + @conditions = @conditions.map { |x| x.clone } + @default = @default.clone if @default + end + + def hash + [@case, @conditions, @default].hash + end + + def eql? other + self.class == other.class && + self.case == other.case && + self.conditions == other.conditions && + self.default == other.default + end + alias :== :eql? + end + + class When < Binary # :nodoc: + end + + class Else < Unary # :nodoc: + end + end +end diff --git a/lib/arel/predications.rb b/lib/arel/predications.rb index 1d2b0de235..e9078e9c4b 100644 --- a/lib/arel/predications.rb +++ b/lib/arel/predications.rb @@ -198,6 +198,10 @@ Passing a range to `#not_in` is deprecated. Call `#not_between`, instead. grouping_all :lteq, others end + def when right + Nodes::Case.new(self).when quoted_node(right) + end + private def grouping_any method_id, others, *extras diff --git a/lib/arel/visitors/depth_first.rb b/lib/arel/visitors/depth_first.rb index 22704dd038..2f71455580 100644 --- a/lib/arel/visitors/depth_first.rb +++ b/lib/arel/visitors/depth_first.rb @@ -16,6 +16,7 @@ module Arel def unary o visit o.expr end + alias :visit_Arel_Nodes_Else :unary alias :visit_Arel_Nodes_Group :unary alias :visit_Arel_Nodes_Grouping :unary alias :visit_Arel_Nodes_Having :unary @@ -53,6 +54,12 @@ module Arel visit o.distinct end + def visit_Arel_Nodes_Case o + visit o.case + visit o.conditions + visit o.default + end + def nary o o.children.each { |child| visit child} end @@ -86,8 +93,9 @@ module Arel alias :visit_Arel_Nodes_Regexp :binary alias :visit_Arel_Nodes_RightOuterJoin :binary alias :visit_Arel_Nodes_TableAlias :binary - alias :visit_Arel_Nodes_Values :binary alias :visit_Arel_Nodes_Union :binary + alias :visit_Arel_Nodes_Values :binary + alias :visit_Arel_Nodes_When :binary def visit_Arel_Nodes_StringJoin o visit o.left diff --git a/lib/arel/visitors/to_sql.rb b/lib/arel/visitors/to_sql.rb index ce1fdf80ce..598bf2d984 100644 --- a/lib/arel/visitors/to_sql.rb +++ b/lib/arel/visitors/to_sql.rb @@ -708,6 +708,35 @@ module Arel visit o.right, collector end + def visit_Arel_Nodes_Case o, collector + collector << "CASE " + if o.case + visit o.case, collector + collector << " " + end + o.conditions.each do |condition| + visit condition, collector + collector << " " + end + if o.default + visit o.default, collector + collector << " " + end + collector << "END" + end + + def visit_Arel_Nodes_When o, collector + collector << "WHEN " + visit o.left, collector + collector << " THEN " + visit o.right, collector + end + + def visit_Arel_Nodes_Else o, collector + collector << "ELSE " + visit o.expr, collector + end + def visit_Arel_Nodes_UnqualifiedColumn o, collector collector << "#{quote_column_name o.name}" collector diff --git a/test/nodes/test_case.rb b/test/nodes/test_case.rb new file mode 100644 index 0000000000..a813ec7e69 --- /dev/null +++ b/test/nodes/test_case.rb @@ -0,0 +1,82 @@ +require 'helper' + +module Arel + module Nodes + describe 'Case' do + describe '#initialize' do + it 'sets case expression from first argument' do + node = Case.new 'foo' + + assert_equal 'foo', node.case + end + + it 'sets default case from second argument' do + node = Case.new nil, 'bar' + + assert_equal 'bar', node.default + end + end + + describe '#clone' do + it 'clones case, conditions and default' do + foo = Nodes.build_quoted 'foo' + + node = Case.new + node.case = foo + node.conditions = [When.new(foo, foo)] + node.default = foo + + dolly = node.clone + + assert_equal dolly.case, node.case + refute_same dolly.case, node.case + + assert_equal dolly.conditions, node.conditions + refute_same dolly.conditions, node.conditions + + assert_equal dolly.default, node.default + refute_same dolly.default, node.default + end + end + + describe 'equality' do + it 'is equal with equal ivars' do + foo = Nodes.build_quoted 'foo' + one = Nodes.build_quoted 1 + zero = Nodes.build_quoted 0 + + case1 = Case.new foo + case1.conditions = [When.new(foo, one)] + case1.default = Else.new zero + + case2 = Case.new foo + case2.conditions = [When.new(foo, one)] + case2.default = Else.new zero + + array = [case1, case2] + + assert_equal 1, array.uniq.size + end + + it 'is not equal with different ivars' do + foo = Nodes.build_quoted 'foo' + bar = Nodes.build_quoted 'bar' + one = Nodes.build_quoted 1 + zero = Nodes.build_quoted 0 + + case1 = Case.new foo + case1.conditions = [When.new(foo, one)] + case1.default = Else.new zero + + case2 = Case.new foo + case2.conditions = [When.new(bar, one)] + case2.default = Else.new zero + + array = [case1, case2] + + assert_equal 2, array.uniq.size + end + end + end + end +end diff --git a/test/visitors/test_depth_first.rb b/test/visitors/test_depth_first.rb index 3356759b7d..1a72789f83 100644 --- a/test/visitors/test_depth_first.rb +++ b/test/visitors/test_depth_first.rb @@ -34,6 +34,7 @@ module Arel Arel::Nodes::UnqualifiedColumn, Arel::Nodes::Top, Arel::Nodes::Limit, + Arel::Nodes::Else, ].each do |klass| define_method("test_#{klass.name.gsub('::', '_')}") do op = klass.new(:a) @@ -118,6 +119,7 @@ module Arel Arel::Nodes::As, Arel::Nodes::DeleteStatement, Arel::Nodes::JoinSource, + Arel::Nodes::When, ].each do |klass| define_method("test_#{klass.name.gsub('::', '_')}") do binary = klass.new(:a, :b) @@ -247,6 +249,16 @@ module Arel assert_equal [:a, :b, stmt.columns, :c, stmt], @collector.calls end + def test_case + node = Arel::Nodes::Case.new + node.case = :a + node.conditions << :b + node.default = :c + + @visitor.accept node + assert_equal [:a, :b, node.conditions, :c, node], @collector.calls + end + def test_node node = Nodes::Node.new @visitor.accept node diff --git a/test/visitors/test_to_sql.rb b/test/visitors/test_to_sql.rb index 7ae5d5b3af..ea58039529 100644 --- a/test/visitors/test_to_sql.rb +++ b/test/visitors/test_to_sql.rb @@ -607,6 +607,66 @@ module Arel end end end + + describe 'Nodes::Case' do + it 'supports simple case expressions' do + node = Arel::Nodes::Case.new(@table[:name]) + .when('foo').then(1) + .else(0) + + compile(node).must_be_like %{ + CASE "users"."name" WHEN 'foo' THEN 1 ELSE 0 END + } + end + + it 'supports extended case expressions' do + node = Arel::Nodes::Case.new + .when(@table[:name].in(%w(foo bar))).then(1) + .else(0) + + compile(node).must_be_like %{ + CASE WHEN "users"."name" IN ('foo', 'bar') THEN 1 ELSE 0 END + } + end + + it 'works without default branch' do + node = Arel::Nodes::Case.new(@table[:name]) + .when('foo').then(1) + + compile(node).must_be_like %{ + CASE "users"."name" WHEN 'foo' THEN 1 END + } + end + + it 'allows chaining multiple conditions' do + node = Arel::Nodes::Case.new(@table[:name]) + .when('foo').then(1) + .when('bar').then(2) + .else(0) + + compile(node).must_be_like %{ + CASE "users"."name" WHEN 'foo' THEN 1 WHEN 'bar' THEN 2 ELSE 0 END + } + end + + it 'supports #when with two arguments and no #then' do + node = Arel::Nodes::Case.new @table[:name] + + { foo: 1, bar: 0 }.reduce(node) { |node, pair| node.when *pair } + + compile(node).must_be_like %{ + CASE "users"."name" WHEN 'foo' THEN 1 WHEN 'bar' THEN 0 END + } + end + + it 'can be chained as a predicate' do + node = @table[:name].when('foo').then('bar').else('baz') + + compile(node).must_be_like %{ + CASE "users"."name" WHEN 'foo' THEN 'bar' ELSE 'baz' END + } + end + end end end end -- cgit v1.2.3