diff options
Diffstat (limited to 'activerecord/lib')
-rw-r--r-- | activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb | 112 |
1 files changed, 62 insertions, 50 deletions
diff --git a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb index 39990db8fc..66cbe3a58b 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb @@ -24,7 +24,6 @@ module ActiveRecord username = config[:username].to_s password = config[:password].to_s - schema_order = config[:schema_order] encoding = config[:encoding] min_messages = config[:min_messages] @@ -38,7 +37,7 @@ module ActiveRecord PGconn.connect(host, port, "", "", database, username, password), logger ) - pga.execute("SET search_path TO #{schema_order}") if schema_order + pga.schema_search_path = config[:schema_search_path] || config[:schema_order] pga.execute("SET client_encoding TO '#{encoding}'") if encoding pga.execute("SET client_min_messages TO '#{min_messages}'") if min_messages @@ -57,7 +56,7 @@ module ActiveRecord # * <tt>:username</tt> -- Defaults to nothing # * <tt>:password</tt> -- Defaults to nothing # * <tt>:database</tt> -- The name of the database. No default, must be provided. - # * <tt>:schema_order</tt> -- An optional schema order string that is using in a SET search_path TO <schema_order> call on connection. + # * <tt>:schema_search_path</tt> -- An optional schema search path for the connection given as a string of comma-separated schema names. This is backward-compatible with the :schema_order option. # * <tt>:encoding</tt> -- An optional client encoding that is using in a SET client_encoding TO <encoding> call on connection. # * <tt>:min_messages</tt> -- An optional client min messages that is using in a SET client_min_messages TO <min_messages> call on connection. class PostgreSQLAdapter < AbstractAdapter @@ -71,9 +70,8 @@ module ActiveRecord end def columns(table_name, name = nil) - table_structure(table_name).inject([]) do |columns, field| - columns << Column.new(field[0], field[2], field[1]) - columns + column_definitions(table_name).collect do |name, type, default| + Column.new(name, default_value(default), translate_field_type(type)) end end @@ -110,14 +108,32 @@ module ActiveRecord end def quote_column_name(name) - return "\"#{name}\"" + %("#{name}") end def adapter_name() 'PostgreSQL' end + + # Set the schema search path to a string of comma-separated schema names. + # Names beginning with $ are quoted (e.g. $user => '$user') + # See http://www.postgresql.org/docs/8.0/interactive/ddl-schemas.html + def schema_search_path=(schema_csv) + if schema_csv + execute "SET search_path TO #{schema_csv}" + @schema_search_path = nil + end + end + + def schema_search_path + @schema_search_path ||= query('SHOW search_path')[0][0] + end + + private + BYTEA_COLUMN_TYPE_OID = 17 + def last_insert_id(table, column = "id") sequence_name = "#{table}_#{column || 'id'}_seq" @connection.exec("SELECT currval('#{sequence_name}')")[0][0].to_i @@ -133,7 +149,7 @@ module ActiveRecord hashed_row = {} row.each_index do |cel_index| column = row[cel_index] - if res.type(cel_index) == 17 # type oid for bytea + if res.type(cel_index) == BYTEA_COLUMN_TYPE_OID column = unescape_bytea(column) end hashed_row[fields[cel_index]] = column @@ -156,53 +172,49 @@ module ActiveRecord s.gsub(/\\([0-9][0-9][0-9])/) { $1.oct.chr }.gsub(/\\\\/) { '\\' } unless s.nil? end - def split_table_schema(table_name) - schema_split = table_name.split('.') - schema_name = "public" - if schema_split.length > 1 - schema_name = schema_split.first.strip - table_name = schema_split.last.strip - end - return [schema_name, table_name] + # Query a table's column names, default values, and types. + # + # The underlying query is roughly: + # SELECT column.name, column.type, default.value + # FROM column LEFT JOIN default + # ON column.table_id = default.table_id + # AND column.num = default.column_num + # WHERE column.table_id = get_table_id('table_name') + # AND column.num > 0 + # AND NOT column.is_dropped + # ORDER BY column.num + # + # If the table name is not prefixed with a schema, the database will + # take the first match from the schema search path. + # + # Query implementation notes: + # - format_type includes the column size constraint, e.g. varchar(50) + # - ::regclass is a function that gives the id for a table name + def column_definitions(table_name) + query <<-end_sql + SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc + FROM pg_attribute a LEFT JOIN pg_attrdef d + ON a.attrelid = d.adrelid AND a.attnum = d.adnum + WHERE a.attrelid = '#{table_name}'::regclass + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum + end_sql end - def table_structure(table_name) - database_name = @connection.db - schema_name, table_name = split_table_schema(table_name) - - # Grab a list of all the default values for the columns. - sql = "SELECT column_name, column_default, character_maximum_length, data_type " - sql << " FROM information_schema.columns " - sql << " WHERE table_catalog = '#{database_name}' " - sql << " AND table_schema = '#{schema_name}' " - sql << " AND table_name = '#{table_name}'" - sql << " ORDER BY ordinal_position" - - query(sql).collect do |row| - field = row[0] - type = type_as_string(row[3], row[2]) - default = default_value(row[1]) - length = row[2] - - [field, type, default, length] + # Translate PostgreSQL-specific types into simplified SQL types. + # These are special cases; standard types are handled by + # ConnectionAdapters::Column#simplified_type. + def translate_field_type(field_type) + # Match the beginning of field_type since it may have a size constraint on the end. + case field_type + when /^timestamp/i then 'datetime' + when /^real|^money/i then 'float' + when /^interval/i then 'string' + when /^bytea/i then 'binary' + else field_type # Pass through standard types. end end - def type_as_string(field_type, field_length) - type = case field_type - when 'numeric', 'real', 'money' then 'float' - when 'character varying', 'interval' then 'string' - when 'timestamp without time zone' then 'datetime' - when 'timestamp with time zone' then 'datetime' - when 'bytea' then 'binary' - else field_type - end - - size = field_length.nil? ? "" : "(#{field_length})" - - return type + size - end - def default_value(value) # Boolean types return "t" if value =~ /true/i |