Logo Search packages:      
Sourcecode: sqlalchemy version File versions  Download package

oracle.py

# oracle.py
# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php


import sys, StringIO, string, re

import sqlalchemy.util as util
import sqlalchemy.sql as sql
import sqlalchemy.engine as engine
import sqlalchemy.engine.default as default
import sqlalchemy.schema as schema
import sqlalchemy.ansisql as ansisql
import sqlalchemy.types as sqltypes
import sqlalchemy.exceptions as exceptions

try:
    import cx_Oracle
except:
    cx_Oracle = None
        
class OracleNumeric(sqltypes.Numeric):
    def get_col_spec(self):
        return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
class OracleInteger(sqltypes.Integer):
    def get_col_spec(self):
        return "INTEGER"
class OracleSmallInteger(sqltypes.Smallinteger):
    def get_col_spec(self):
        return "SMALLINT"
class OracleDateTime(sqltypes.DateTime):
    def get_col_spec(self):
        return "DATE"
# Note:
# Oracle DATE == DATETIME
# Oracle does not allow milliseconds in DATE
# Oracle does not support TIME columns

# only if cx_oracle contains TIMESTAMP
class OracleTimestamp(sqltypes.DateTime):
    def get_col_spec(self):
        return "TIMESTAMP"
    def get_dbapi_type(self, dialect):
        return dialect.TIMESTAMP
        
class OracleText(sqltypes.TEXT):
    def get_col_spec(self):
        return "CLOB"
class OracleString(sqltypes.String):
    def get_col_spec(self):
        return "VARCHAR(%(length)s)" % {'length' : self.length}
class OracleChar(sqltypes.CHAR):
    def get_col_spec(self):
        return "CHAR(%(length)s)" % {'length' : self.length}
class OracleBinary(sqltypes.Binary):
    def get_col_spec(self):
        return "BLOB"
class OracleBoolean(sqltypes.Boolean):
    def get_col_spec(self):
        return "SMALLINT"
    def convert_result_value(self, value, dialect):
        if value is None:
            return None
        return value and True or False
    def convert_bind_param(self, value, dialect):
        if value is True:
            return 1
        elif value is False:
            return 0
        elif value is None:
            return None
        else:
            return value and True or False      

        
colspecs = {
    sqltypes.Integer : OracleInteger,
    sqltypes.Smallinteger : OracleSmallInteger,
    sqltypes.Numeric : OracleNumeric,
    sqltypes.Float : OracleNumeric,
    sqltypes.DateTime : OracleDateTime,
    sqltypes.Date : OracleDateTime,
    sqltypes.String : OracleString,
    sqltypes.Binary : OracleBinary,
    sqltypes.Boolean : OracleBoolean,
    sqltypes.TEXT : OracleText,
    sqltypes.TIMESTAMP : OracleTimestamp,
    sqltypes.CHAR: OracleChar,
}

ischema_names = {
    'VARCHAR2' : OracleString,
    'DATE' : OracleDateTime,
    'DATETIME' : OracleDateTime,
    'NUMBER' : OracleNumeric,
    'BLOB' : OracleBinary,
    'CLOB' : OracleText,
    'TIMESTAMP' : OracleTimestamp
}

constraintSQL = """SELECT
  ac.constraint_name,
  ac.constraint_type,
  LOWER(loc.column_name) AS local_column,
  LOWER(rem.table_name) AS remote_table,
  LOWER(rem.column_name) AS remote_column
FROM all_constraints ac,
  all_cons_columns loc,
  all_cons_columns rem
WHERE ac.table_name = :table_name
AND ac.constraint_type IN ('R','P')
AND ac.owner = :owner
AND ac.owner = loc.owner
AND ac.constraint_name = loc.constraint_name
AND ac.r_owner = rem.owner(+)
AND ac.r_constraint_name = rem.constraint_name(+)
-- order multiple primary keys correctly
ORDER BY ac.constraint_name, loc.position, rem.position"""


def descriptor():
    return {'name':'oracle',
    'description':'Oracle',
    'arguments':[
        ('dsn', 'Data Source Name', None),
        ('user', 'Username', None),
        ('password', 'Password', None)
    ]}

class OracleExecutionContext(default.DefaultExecutionContext):
    def pre_exec(self, engine, proxy, compiled, parameters):
        super(OracleExecutionContext, self).pre_exec(engine, proxy, compiled, parameters)
        if self.dialect.auto_setinputsizes:
                self.set_input_sizes(proxy(), parameters)
        
class OracleDialect(ansisql.ANSIDialect):
    def __init__(self, use_ansi=True, auto_setinputsizes=False, module=None, threaded=True, **kwargs):
        self.use_ansi = use_ansi
        self.threaded = threaded
        if module is None:
            self.module = cx_Oracle
        else:
            self.module = module
        self.supports_timestamp = hasattr(self.module, 'TIMESTAMP' )
        self.auto_setinputsizes = auto_setinputsizes
        ansisql.ANSIDialect.__init__(self, **kwargs)

    def dbapi(self):
        return self.module

    def create_connect_args(self, url):
        if url.database:
            # if we have a database, then we have a remote host
            port = url.port
            if port:
                port = int(port)
            else:
                port = 1521
            dsn = self.module.makedsn(url.host,port,url.database)
        else:
            # we have a local tnsname
            dsn = url.host
        opts = dict(
            user=url.username,
            password=url.password,
            dsn = dsn,
            threaded = self.threaded
            )
        opts.update(url.query)
        return ([], opts)
        
    def type_descriptor(self, typeobj):
        return sqltypes.adapt_type(typeobj, colspecs)

    def oid_column_name(self):
        return "rowid"

    def create_execution_context(self):
        return OracleExecutionContext(self)

    def compiler(self, statement, bindparams, **kwargs):
        return OracleCompiler(self, statement, bindparams, **kwargs)
    def schemagenerator(self, *args, **kwargs):
        return OracleSchemaGenerator(*args, **kwargs)
    def schemadropper(self, *args, **kwargs):
        return OracleSchemaDropper(*args, **kwargs)
    def defaultrunner(self, engine, proxy):
        return OracleDefaultRunner(engine, proxy)


    def has_table(self, connection, table_name):
        cursor = connection.execute("""select table_name from all_tables where table_name=:name""", {'name':table_name.upper()})
        return bool( cursor.fetchone() is not None )

    def has_sequence(self, connection, sequence_name):
        cursor = connection.execute("""select sequence_name from all_sequences where sequence_name=:name""", {'name':sequence_name.upper()})
        return bool( cursor.fetchone() is not None )
        
    def reflecttable(self, connection, table):
        preparer = self.identifier_preparer
        if not preparer.should_quote(table):
            name = table.name.upper()
        else:
            name = table.name
        c = connection.execute ("select distinct OWNER from ALL_TAB_COLUMNS where TABLE_NAME = :table_name", {'table_name':name})
        rows = c.fetchall()
        if not rows :
            raise exceptions.NoSuchTableError(table.name)
        else:
            if table.owner is not None:
                if table.owner.upper() in [r[0] for r in rows]:
                    owner = table.owner.upper()
                else:
                    raise exceptions.AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name))
            else:
                if len(rows)==1:
                    owner = rows[0][0]
                else:
                    raise exceptions.AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name)

        c = connection.execute ("select COLUMN_NAME, DATA_TYPE, DATA_LENGTH, DATA_PRECISION, DATA_SCALE, NULLABLE, DATA_DEFAULT from ALL_TAB_COLUMNS where TABLE_NAME = :table_name and OWNER = :owner", {'table_name':name, 'owner':owner})
        
        while True:
            row = c.fetchone()
            if row is None:
                break
            found_table = True

            #print "ROW:" , row
            (name, coltype, length, precision, scale, nullable, default) = (row[0], row[1], row[2], row[3], row[4], row[5]=='Y', row[6])

            # INTEGER if the scale is 0 and precision is null
            # NUMBER if the scale and precision are both null
            # NUMBER(9,2) if the precision is 9 and the scale is 2
            # NUMBER(3) if the precision is 3 and scale is 0
            #length is ignored except for CHAR and VARCHAR2
            if coltype=='NUMBER' :
                if precision is None and scale is None:
                    coltype = OracleNumeric
                elif precision is None and scale == 0  :
                    coltype = OracleInteger
                else :
                    coltype = OracleNumeric(precision, scale)
            elif coltype=='CHAR' or coltype=='VARCHAR2':
                coltype = ischema_names.get(coltype, OracleString)(length)
            else:
                coltype = re.sub(r'\(\d+\)', '', coltype)
                try:
                    coltype = ischema_names[coltype]
                except KeyError:
                    raise exceptions.AssertionError("Cant get coltype for type '%s'" % coltype)
               
            colargs = []
            if default is not None:
                colargs.append(schema.PassiveDefault(sql.text(default)))
          
            # if name comes back as all upper, assume its case folded 
            if (name.upper() == name): 
                name = name.lower()
            
            table.append_column(schema.Column(name, coltype, nullable=nullable, *colargs))

       
        c = connection.execute(constraintSQL, {'table_name' : table.name.upper(), 'owner' : owner})
        fks = {}
        while True:
            row = c.fetchone()
            if row is None:
                break
            #print "ROW:" , row                
            (cons_name, cons_type, local_column, remote_table, remote_column) = row
            if cons_type == 'P':
                table.primary_key.add(table.c[local_column])
            elif cons_type == 'R':
                try:
                    fk = fks[cons_name]
                except KeyError:
                   fk = ([], [])
                   fks[cons_name] = fk
                refspec = ".".join([remote_table, remote_column])
                schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection)
                if local_column not in fk[0]:
                    fk[0].append(local_column)
                if refspec not in fk[1]:
                    fk[1].append(refspec)

        for name, value in fks.iteritems():
            table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))

    def do_executemany(self, c, statement, parameters, context=None):
        rowcount = 0
        for param in parameters:
            c.execute(statement, param)
            rowcount += c.rowcount
        if context is not None:
            context._rowcount = rowcount

00300 class OracleCompiler(ansisql.ANSICompiler):
    """oracle compiler modifies the lexical structure of Select statements to work under 
    non-ANSI configured Oracle databases, if the use_ansi flag is False."""
    
00304     def default_from(self):
        """called when a SELECT statement has no froms, and no FROM clause is to be appended.  
        gives Oracle a chance to tack on a "FROM DUAL" to the string output. """
        return " FROM DUAL"

    def apply_function_parens(self, func):
        return len(func.clauses) > 0

    def visit_join(self, join):
        if self.dialect.use_ansi:
            return ansisql.ANSICompiler.visit_join(self, join)
        
        self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
        self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
        self.strings[join] = self.froms[join]

        if join.isouter:
            # if outer join, push on the right side table as the current "outertable"
            self._outertable = join.right

            # now re-visit the onclause, which will be used as a where clause
            # (the first visit occured via the Join object itself right before it called visit_join())
            join.onclause.accept_visitor(self)

            self._outertable = None

        self.visit_compound(self.wheres[join])
       
00332     def visit_alias(self, alias):
      """oracle doesnt like 'FROM table AS alias'.  is the AS standard SQL??"""
        self.froms[alias] = self.get_from_text(alias.original) + " " + alias.name
        self.strings[alias] = self.get_str(alias.original)
 
    def visit_column(self, column):
        ansisql.ANSICompiler.visit_column(self, column)
        if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable:
            self.strings[column] = self.strings[column] + "(+)"
       
00342     def visit_insert(self, insert):
        """inserts are required to have the primary keys be explicitly present.
         mapper will by default not put them in the insert statement to comply
         with autoincrement fields that require they not be present.  so, 
         put them all in for all primary key columns."""
        for c in insert.table.primary_key:
            if not self.parameters.has_key(c.key):
                self.parameters[c.key] = None
        return ansisql.ANSICompiler.visit_insert(self, insert)

00352     def _TODO_visit_compound_select(self, select):
        """need to determine how to get LIMIT/OFFSET into a UNION for oracle"""
        if getattr(select, '_oracle_visit', False):
            # cancel out the compiled order_by on the select
            if hasattr(select, "order_by_clause"):
                self.strings[select.order_by_clause] = ""
            ansisql.ANSICompiler.visit_compound_select(self, select)
            return
            
        if select.limit is not None or select.offset is not None:
            select._oracle_visit = True
            # to use ROW_NUMBER(), an ORDER BY is required. 
            orderby = self.strings[select.order_by_clause]
            if not orderby:
                orderby = select.oid_column
                orderby.accept_visitor(self)
                orderby = self.strings[orderby]
            class SelectVisitor(sql.ClauseVisitor):
                def visit_select(self, select):
                    select.append_column(sql.column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
            select.accept_visitor(SelectVisitor())
            limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
            if select.offset is not None:
                limitselect.append_whereclause("ora_rn>%d" % select.offset)
                if select.limit is not None:
                    limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
            else:
                limitselect.append_whereclause("ora_rn<=%d" % select.limit)
            limitselect.accept_visitor(self)
            self.strings[select] = self.strings[limitselect]
            self.froms[select] = self.froms[limitselect]
        else:
            ansisql.ANSICompiler.visit_compound_select(self, select)
        
00386     def visit_select(self, select):
        """looks for LIMIT and OFFSET in a select statement, and if so tries to wrap it in a 
        subquery with row_number() criterion."""
        # TODO: put a real copy-container on Select and copy, or somehow make this
        # not modify the Select statement
        if getattr(select, '_oracle_visit', False):
            # cancel out the compiled order_by on the select
            if hasattr(select, "order_by_clause"):
                self.strings[select.order_by_clause] = ""
            ansisql.ANSICompiler.visit_select(self, select)
            return

        if select.limit is not None or select.offset is not None:
            select._oracle_visit = True
            # to use ROW_NUMBER(), an ORDER BY is required. 
            orderby = self.strings[select.order_by_clause]
            if not orderby:
                orderby = select.oid_column
                orderby.accept_visitor(self)
                orderby = self.strings[orderby]
            select.append_column(sql.column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn"))
            limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
            if select.offset is not None:
                limitselect.append_whereclause("ora_rn>%d" % select.offset)
                if select.limit is not None:
                    limitselect.append_whereclause("ora_rn<=%d" % (select.limit + select.offset))
            else:
                limitselect.append_whereclause("ora_rn<=%d" % select.limit)
            limitselect.accept_visitor(self)
            self.strings[select] = self.strings[limitselect]
            self.froms[select] = self.froms[limitselect]
        else:
            ansisql.ANSICompiler.visit_select(self, select)
            
    def limit_clause(self, select):
        return ""

    def for_update_clause(self, select):
        if select.for_update=="nowait":
            return " FOR UPDATE NOWAIT"
        else:
            return super(OracleCompiler, self).for_update_clause(select)

class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
    def get_column_specification(self, column, **kwargs):
        colspec = self.preparer.format_column(column)
        colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
        default = self.get_column_default_string(column)
        if default is not None:
            colspec += " DEFAULT " + default

        if not column.nullable:
            colspec += " NOT NULL"
        return colspec

    def visit_sequence(self, sequence):
        if not self.engine.dialect.has_sequence(self.connection, sequence.name):
            self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence))
            self.execute()

class OracleSchemaDropper(ansisql.ANSISchemaDropper):
    def visit_sequence(self, sequence):
        if self.engine.dialect.has_sequence(self.connection, sequence.name):
            self.append("DROP SEQUENCE %s" % sequence.name)
            self.execute()

class OracleDefaultRunner(ansisql.ANSIDefaultRunner):
    def exec_default_sql(self, default):
        c = sql.select([default.arg], from_obj=["DUAL"], engine=self.engine).compile()
        return self.proxy(str(c), c.get_params()).fetchone()[0]
    
    def visit_sequence(self, seq):
        return self.proxy("SELECT " + seq.name + ".nextval FROM DUAL").fetchone()[0]

dialect = OracleDialect

Generated by  Doxygen 1.6.0   Back to index