aboutsummaryrefslogtreecommitdiffstats
path: root/activerecord/lib/active_record/enum.rb
diff options
context:
space:
mode:
Diffstat (limited to 'activerecord/lib/active_record/enum.rb')
-rw-r--r--activerecord/lib/active_record/enum.rb259
1 files changed, 259 insertions, 0 deletions
diff --git a/activerecord/lib/active_record/enum.rb b/activerecord/lib/active_record/enum.rb
new file mode 100644
index 0000000000..e6dba66a08
--- /dev/null
+++ b/activerecord/lib/active_record/enum.rb
@@ -0,0 +1,259 @@
+# frozen_string_literal: true
+
+require "active_support/core_ext/object/deep_dup"
+
+module ActiveRecord
+ # Declare an enum attribute where the values map to integers in the database,
+ # but can be queried by name. Example:
+ #
+ # class Conversation < ActiveRecord::Base
+ # enum status: [ :active, :archived ]
+ # end
+ #
+ # # conversation.update! status: 0
+ # conversation.active!
+ # conversation.active? # => true
+ # conversation.status # => "active"
+ #
+ # # conversation.update! status: 1
+ # conversation.archived!
+ # conversation.archived? # => true
+ # conversation.status # => "archived"
+ #
+ # # conversation.status = 1
+ # conversation.status = "archived"
+ #
+ # conversation.status = nil
+ # conversation.status.nil? # => true
+ # conversation.status # => nil
+ #
+ # Scopes based on the allowed values of the enum field will be provided
+ # as well. With the above example:
+ #
+ # Conversation.active
+ # Conversation.archived
+ #
+ # Of course, you can also query them directly if the scopes don't fit your
+ # needs:
+ #
+ # Conversation.where(status: [:active, :archived])
+ # Conversation.where.not(status: :active)
+ #
+ # You can set the default value from the database declaration, like:
+ #
+ # create_table :conversations do |t|
+ # t.column :status, :integer, default: 0
+ # end
+ #
+ # Good practice is to let the first declared status be the default.
+ #
+ # Finally, it's also possible to explicitly map the relation between attribute and
+ # database integer with a hash:
+ #
+ # class Conversation < ActiveRecord::Base
+ # enum status: { active: 0, archived: 1 }
+ # end
+ #
+ # Note that when an array is used, the implicit mapping from the values to database
+ # integers is derived from the order the values appear in the array. In the example,
+ # <tt>:active</tt> is mapped to +0+ as it's the first element, and <tt>:archived</tt>
+ # is mapped to +1+. In general, the +i+-th element is mapped to <tt>i-1</tt> in the
+ # database.
+ #
+ # Therefore, once a value is added to the enum array, its position in the array must
+ # be maintained, and new values should only be added to the end of the array. To
+ # remove unused values, the explicit hash syntax should be used.
+ #
+ # In rare circumstances you might need to access the mapping directly.
+ # The mappings are exposed through a class method with the pluralized attribute
+ # name, which return the mapping in a +HashWithIndifferentAccess+:
+ #
+ # Conversation.statuses[:active] # => 0
+ # Conversation.statuses["archived"] # => 1
+ #
+ # Use that class method when you need to know the ordinal value of an enum.
+ # For example, you can use that when manually building SQL strings:
+ #
+ # Conversation.where("status <> ?", Conversation.statuses[:archived])
+ #
+ # You can use the +:_prefix+ or +:_suffix+ options when you need to define
+ # multiple enums with same values. If the passed value is +true+, the methods
+ # are prefixed/suffixed with the name of the enum. It is also possible to
+ # supply a custom value:
+ #
+ # class Conversation < ActiveRecord::Base
+ # enum status: [:active, :archived], _suffix: true
+ # enum comments_status: [:active, :inactive], _prefix: :comments
+ # end
+ #
+ # With the above example, the bang and predicate methods along with the
+ # associated scopes are now prefixed and/or suffixed accordingly:
+ #
+ # conversation.active_status!
+ # conversation.archived_status? # => false
+ #
+ # conversation.comments_inactive!
+ # conversation.comments_active? # => false
+
+ module Enum
+ def self.extended(base) # :nodoc:
+ base.class_attribute(:defined_enums, instance_writer: false, default: {})
+ end
+
+ def inherited(base) # :nodoc:
+ base.defined_enums = defined_enums.deep_dup
+ super
+ end
+
+ class EnumType < Type::Value # :nodoc:
+ delegate :type, to: :subtype
+
+ def initialize(name, mapping, subtype)
+ @name = name
+ @mapping = mapping
+ @subtype = subtype
+ end
+
+ def cast(value)
+ return if value.blank?
+
+ if mapping.has_key?(value)
+ value.to_s
+ elsif mapping.has_value?(value)
+ mapping.key(value)
+ else
+ assert_valid_value(value)
+ end
+ end
+
+ def deserialize(value)
+ return if value.nil?
+ mapping.key(subtype.deserialize(value))
+ end
+
+ def serialize(value)
+ mapping.fetch(value, value)
+ end
+
+ def assert_valid_value(value)
+ unless value.blank? || mapping.has_key?(value) || mapping.has_value?(value)
+ raise ArgumentError, "'#{value}' is not a valid #{name}"
+ end
+ end
+
+ private
+ attr_reader :name, :mapping, :subtype
+ end
+
+ def enum(definitions)
+ klass = self
+ enum_prefix = definitions.delete(:_prefix)
+ enum_suffix = definitions.delete(:_suffix)
+ enum_scopes = definitions.delete(:_scopes)
+ definitions.each do |name, values|
+ assert_valid_enum_definition_values(values)
+ # statuses = { }
+ enum_values = ActiveSupport::HashWithIndifferentAccess.new
+ name = name.to_s
+
+ # def self.statuses() statuses end
+ detect_enum_conflict!(name, name.pluralize, true)
+ singleton_class.define_method(name.pluralize) { enum_values }
+ defined_enums[name] = enum_values
+
+ detect_enum_conflict!(name, name)
+ detect_enum_conflict!(name, "#{name}=")
+
+ attr = attribute_alias?(name) ? attribute_alias(name) : name
+ decorate_attribute_type(attr, :enum) do |subtype|
+ EnumType.new(attr, enum_values, subtype)
+ end
+
+ _enum_methods_module.module_eval do
+ pairs = values.respond_to?(:each_pair) ? values.each_pair : values.each_with_index
+ pairs.each do |label, value|
+ if enum_prefix == true
+ prefix = "#{name}_"
+ elsif enum_prefix
+ prefix = "#{enum_prefix}_"
+ end
+ if enum_suffix == true
+ suffix = "_#{name}"
+ elsif enum_suffix
+ suffix = "_#{enum_suffix}"
+ end
+
+ value_method_name = "#{prefix}#{label}#{suffix}"
+ enum_values[label] = value
+ label = label.to_s
+
+ # def active?() status == "active" end
+ klass.send(:detect_enum_conflict!, name, "#{value_method_name}?")
+ define_method("#{value_method_name}?") { self[attr] == label }
+
+ # def active!() update!(status: 0) end
+ klass.send(:detect_enum_conflict!, name, "#{value_method_name}!")
+ define_method("#{value_method_name}!") { update!(attr => value) }
+
+ # scope :active, -> { where(status: 0) }
+ if enum_scopes != false
+ klass.send(:detect_enum_conflict!, name, value_method_name, true)
+ klass.scope value_method_name, -> { where(attr => value) }
+ end
+ end
+ end
+ enum_values.freeze
+ end
+ end
+
+ private
+ def _enum_methods_module
+ @_enum_methods_module ||= begin
+ mod = Module.new
+ include mod
+ mod
+ end
+ end
+
+ def assert_valid_enum_definition_values(values)
+ unless values.is_a?(Hash) || values.all? { |v| v.is_a?(Symbol) } || values.all? { |v| v.is_a?(String) }
+ error_message = <<~MSG
+ Enum values #{values} must be either a hash, an array of symbols, or an array of strings.
+ MSG
+ raise ArgumentError, error_message
+ end
+
+ if values.is_a?(Hash) && values.keys.any?(&:blank?) || values.is_a?(Array) && values.any?(&:blank?)
+ raise ArgumentError, "Enum label name must not be blank."
+ end
+ end
+
+ ENUM_CONFLICT_MESSAGE = \
+ "You tried to define an enum named \"%{enum}\" on the model \"%{klass}\", but " \
+ "this will generate a %{type} method \"%{method}\", which is already defined " \
+ "by %{source}."
+ private_constant :ENUM_CONFLICT_MESSAGE
+
+ def detect_enum_conflict!(enum_name, method_name, klass_method = false)
+ if klass_method && dangerous_class_method?(method_name)
+ raise_conflict_error(enum_name, method_name, type: "class")
+ elsif klass_method && method_defined_within?(method_name, Relation)
+ raise_conflict_error(enum_name, method_name, type: "class", source: Relation.name)
+ elsif !klass_method && dangerous_attribute_method?(method_name)
+ raise_conflict_error(enum_name, method_name)
+ elsif !klass_method && method_defined_within?(method_name, _enum_methods_module, Module)
+ raise_conflict_error(enum_name, method_name, source: "another enum")
+ end
+ end
+
+ def raise_conflict_error(enum_name, method_name, type: "instance", source: "Active Record")
+ raise ArgumentError, ENUM_CONFLICT_MESSAGE % {
+ enum: enum_name,
+ klass: name,
+ type: type,
+ method: method_name,
+ source: source
+ }
+ end
+ end
+end