aboutsummaryrefslogblamecommitdiffstats
path: root/activerecord/lib/active_record/associations/association_scope.rb
blob: 48288496c4dcbf956fc2d71d863905ab5ceaf943 (plain) (tree)
1
2
3
4
5
6
7


                                   


                                              
 




                             

                                                        















                                                             
                                        



                                              
                                                     
                                                          
 
                                                           
                                                                              

         

                              

         
                                            










                                                                

             
             

         

             
                                                      
                                 
                                                    


           



                                                                        

                                                                  


                            

                                                                      

         


                                                                 

         
                                                                                    



                                                     
                                                                                         



                                                       
                                                                                     





                                                                     
                                                                                                             







                                                              
                                                                                     





                                                                     














                                               
                          




















                                            

                                           













                                                                   
           


                                            


                                                   


                                                                 
 
                                                                          
                                                   
                                     

                                                                                             
 
                                    
                                                
                                                           
 

                                          

                                                                                                                    
             
 

                                                                     

                                                                         
 
                                             
                                                                
               
 
                     
                                                  
               
 
                                                   
                                                  
                                                   





             
                                         
                                                   
         


       
module ActiveRecord
  module Associations
    class AssociationScope #:nodoc:
      def self.scope(association, connection)
        INSTANCE.scope association, connection
      end

      class BindSubstitution
        def initialize(block)
          @block = block
        end

        def bind_value(scope, column, value, connection)
          substitute = connection.substitute_at(column)
          scope.bind_values += [[column, @block.call(value)]]
          substitute
        end
      end

      def self.create(&block)
        block = block ? block : lambda { |val| val }
        new BindSubstitution.new(block)
      end

      def initialize(bind_substitution)
        @bind_substitution = bind_substitution
      end

      INSTANCE = create

      def scope(association, connection)
        klass         = association.klass
        reflection    = association.reflection
        scope         = klass.unscoped
        owner         = association.owner
        alias_tracker = AliasTracker.empty connection
        chain         = get_chain(reflection, association)

        scope.extending! Array(reflection.options[:extend])
        add_constraints(scope, owner, klass, reflection, alias_tracker, chain)
      end

      def join_type
        Arel::Nodes::InnerJoin
      end

      def self.get_bind_values(owner, chain)
        binds = []
        last_reflection = chain.last

        binds << last_reflection.join_id_for(owner)
        if last_reflection.type
          binds << owner.class.base_class.name
        end

        chain.each_cons(2).each do |reflection, next_reflection|
          if reflection.type
            binds << next_reflection.klass.base_class.name
          end
        end
        binds
      end

      private

      def construct_tables(chain, alias_tracker, name)
        chain.map do |reflection|
          reflection.alias_name(name, alias_tracker)
        end
      end

      def join(table, constraint)
        table.create_join(table, table.create_on(constraint), join_type)
      end

      def column_for(table_name, column_name, connection)
        columns = connection.schema_cache.columns_hash(table_name)
        columns[column_name]
      end

      def bind_value(scope, column, value, connection)
        @bind_substitution.bind_value scope, column, value, connection
      end

      def bind(scope, table_name, column_name, value, connection)
        column   = column_for table_name, column_name, connection
        bind_value scope, column, value, connection
      end

      def last_chain_scope(scope, table, reflection, owner, connection, assoc_klass)
        join_keys = reflection.join_keys(assoc_klass)
        key = join_keys.key
        foreign_key = join_keys.foreign_key

        bind_val = bind scope, table.table_name, key.to_s, owner[foreign_key], connection
        scope    = scope.where(table[key].eq(bind_val))

        if reflection.type
          value    = owner.class.base_class.name
          bind_val = bind scope, table.table_name, reflection.type, value, connection
          scope    = scope.where(table[reflection.type].eq(bind_val))
        else
          scope
        end
      end

      def next_chain_scope(scope, table, reflection, connection, assoc_klass, foreign_table, next_reflection)
        join_keys = reflection.join_keys(assoc_klass)
        key = join_keys.key
        foreign_key = join_keys.foreign_key

        constraint = table[key].eq(foreign_table[foreign_key])

        if reflection.type
          value    = next_reflection.klass.base_class.name
          bind_val = bind scope, table.table_name, reflection.type, value, connection
          scope    = scope.where(table[reflection.type].eq(bind_val))
        end

        scope = scope.joins(join(foreign_table, constraint))
      end

      class RuntimeReflection
        def initialize(reflection, association)
          @reflection = reflection
          @association = association
        end

        def klass
          @association.klass
        end

        def scope
          @reflection.scope
        end

        def table_name
          klass.table_name
        end

        def plural_name
          @reflection.plural_name
        end

        def join_keys(assoc_klass)
          @reflection.join_keys(assoc_klass)
        end

        def type
          @reflection.type
        end

        def constraints
          @reflection.constraints
        end

        def source_type_info
          @reflection.source_type_info
        end

        def alias_name(name, alias_tracker)
          @alias ||= begin
            alias_name = "#{plural_name}_#{name}_join"
            table_name = klass.table_name
            alias_tracker.aliased_table_for(table_name, alias_name)
          end
        end
      end

      class ReflectionProxy < SimpleDelegator
        def alias_name(name, alias_tracker)
          @alias ||= begin
            alias_name = "#{plural_name}_#{name}"
            alias_tracker.aliased_table_for(table_name, alias_name)
          end
        end
      end

      def get_chain(reflection, association)
        chain = reflection.chain.map { |reflection|
          ReflectionProxy.new(reflection)
        }
        chain[0] = RuntimeReflection.new(reflection, association)
        chain
      end

      def add_constraints(scope, owner, assoc_klass, refl, tracker, chain)
        construct_tables(chain, tracker, refl.name)
        owner_reflection = chain.last
        table = owner_reflection.alias_name(refl.name, tracker)
        scope = last_chain_scope(scope, table, owner_reflection, owner, tracker, assoc_klass)

        # chain.first always == refl
        chain.each_with_index do |reflection, i|
          table = reflection.alias_name(refl.name, tracker)

          unless reflection == chain.last
            next_reflection = chain[i + 1]
            foreign_table = next_reflection.alias_name(refl.name, tracker)
            scope = next_chain_scope(scope, table, reflection, tracker, assoc_klass, foreign_table, next_reflection)
          end

          # Exclude the scope of the association itself, because that
          # was already merged in the #scope method.
          reflection.constraints.each do |scope_chain_item|
            item  = eval_scope(reflection.klass, scope_chain_item, owner)

            if scope_chain_item == refl.scope
              scope.merge! item.except(:where, :includes, :bind)
            end

            if i == 0
              scope.includes! item.includes_values
            end

            scope.where_values += item.where_values
            scope.bind_values  += item.bind_values
            scope.order_values |= item.order_values
          end
        end

        scope
      end

      def eval_scope(klass, scope, owner)
        klass.unscoped.instance_exec(owner, &scope)
      end
    end
  end
end