From 4c7e50f9328aca4e294b41fce0832bf6ac4a939a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Bu=CC=88nemann?= Date: Sat, 19 Dec 2015 19:40:46 +0100 Subject: Implement CASE Conditional Expression --- lib/arel/nodes.rb | 3 +++ lib/arel/nodes/case.rb | 57 ++++++++++++++++++++++++++++++++++++++++ lib/arel/predications.rb | 4 +++ lib/arel/visitors/depth_first.rb | 10 ++++++- lib/arel/visitors/to_sql.rb | 29 ++++++++++++++++++++ 5 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 lib/arel/nodes/case.rb (limited to 'lib') diff --git a/lib/arel/nodes.rb b/lib/arel/nodes.rb index 0e66d2dd0c..89f0f563ac 100644 --- a/lib/arel/nodes.rb +++ b/lib/arel/nodes.rb @@ -47,6 +47,9 @@ require 'arel/nodes/named_function' # windows require 'arel/nodes/window' +# conditional expressions +require 'arel/nodes/case' + # joins require 'arel/nodes/full_outer_join' require 'arel/nodes/inner_join' diff --git a/lib/arel/nodes/case.rb b/lib/arel/nodes/case.rb new file mode 100644 index 0000000000..85f8851dbe --- /dev/null +++ b/lib/arel/nodes/case.rb @@ -0,0 +1,57 @@ +module Arel + module Nodes + class Case < Arel::Nodes::Node + include Arel::OrderPredications + include Arel::Predications + include Arel::AliasPredication + + attr_accessor :case, :conditions, :default + + def initialize expression = nil, default = nil + @case = expression + @conditions = [] + @default = default + end + + def when condition, expression = nil + @conditions << When.new(Nodes.build_quoted(condition), expression) + self + end + + def then expression + @conditions.last.right = Nodes.build_quoted(expression) + self + end + + def else expression + @default = Else.new Nodes.build_quoted(expression) + self + end + + def initialize_copy other + super + @case = @case.clone if @case + @conditions = @conditions.map { |x| x.clone } + @default = @default.clone if @default + end + + def hash + [@case, @conditions, @default].hash + end + + def eql? other + self.class == other.class && + self.case == other.case && + self.conditions == other.conditions && + self.default == other.default + end + alias :== :eql? + end + + class When < Binary # :nodoc: + end + + class Else < Unary # :nodoc: + end + end +end diff --git a/lib/arel/predications.rb b/lib/arel/predications.rb index 1d2b0de235..e9078e9c4b 100644 --- a/lib/arel/predications.rb +++ b/lib/arel/predications.rb @@ -198,6 +198,10 @@ Passing a range to `#not_in` is deprecated. Call `#not_between`, instead. grouping_all :lteq, others end + def when right + Nodes::Case.new(self).when quoted_node(right) + end + private def grouping_any method_id, others, *extras diff --git a/lib/arel/visitors/depth_first.rb b/lib/arel/visitors/depth_first.rb index 22704dd038..2f71455580 100644 --- a/lib/arel/visitors/depth_first.rb +++ b/lib/arel/visitors/depth_first.rb @@ -16,6 +16,7 @@ module Arel def unary o visit o.expr end + alias :visit_Arel_Nodes_Else :unary alias :visit_Arel_Nodes_Group :unary alias :visit_Arel_Nodes_Grouping :unary alias :visit_Arel_Nodes_Having :unary @@ -53,6 +54,12 @@ module Arel visit o.distinct end + def visit_Arel_Nodes_Case o + visit o.case + visit o.conditions + visit o.default + end + def nary o o.children.each { |child| visit child} end @@ -86,8 +93,9 @@ module Arel alias :visit_Arel_Nodes_Regexp :binary alias :visit_Arel_Nodes_RightOuterJoin :binary alias :visit_Arel_Nodes_TableAlias :binary - alias :visit_Arel_Nodes_Values :binary alias :visit_Arel_Nodes_Union :binary + alias :visit_Arel_Nodes_Values :binary + alias :visit_Arel_Nodes_When :binary def visit_Arel_Nodes_StringJoin o visit o.left diff --git a/lib/arel/visitors/to_sql.rb b/lib/arel/visitors/to_sql.rb index ce1fdf80ce..598bf2d984 100644 --- a/lib/arel/visitors/to_sql.rb +++ b/lib/arel/visitors/to_sql.rb @@ -708,6 +708,35 @@ module Arel visit o.right, collector end + def visit_Arel_Nodes_Case o, collector + collector << "CASE " + if o.case + visit o.case, collector + collector << " " + end + o.conditions.each do |condition| + visit condition, collector + collector << " " + end + if o.default + visit o.default, collector + collector << " " + end + collector << "END" + end + + def visit_Arel_Nodes_When o, collector + collector << "WHEN " + visit o.left, collector + collector << " THEN " + visit o.right, collector + end + + def visit_Arel_Nodes_Else o, collector + collector << "ELSE " + visit o.expr, collector + end + def visit_Arel_Nodes_UnqualifiedColumn o, collector collector << "#{quote_column_name o.name}" collector -- cgit v1.2.3