# sqlalchemy_teradata/dialect.py # Copyright (C) 2015-2016 by Teradata # <see AUTHORS file> # # This module is part of sqlalchemy-teradata and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php from sqlalchemy.engine import default from sqlalchemy import pool, String, Numeric from sqlalchemy.sql import select, and_, or_ from sqlalchemy_teradata.compiler import TeradataCompiler, TeradataDDLCompiler, TeradataTypeCompiler from sqlalchemy_teradata.base import TeradataIdentifierPreparer, TeradataExecutionContext from sqlalchemy.sql.expression import text, table, column, asc from sqlalchemy import Table, Column, Index import sqlalchemy.types as sqltypes import sqlalchemy_teradata.types as tdtypes from itertools import groupby # ischema names is used for reflecting columns (see get_columns in the dialect) ischema_names = { None: sqltypes.NullType, 'cf': tdtypes.CHAR, 'cv': tdtypes.VARCHAR, 'uf': sqltypes.NCHAR, 'uv': sqltypes.NVARCHAR, 'co': tdtypes.CLOB, 'n' : tdtypes.NUMERIC, 'd' : tdtypes.DECIMAL, 'i' : sqltypes.INTEGER, 'i1': tdtypes.BYTEINT, 'i2': sqltypes.SMALLINT, 'i8': sqltypes.BIGINT, 'f' : sqltypes.FLOAT, 'da': sqltypes.DATE, 'ts': tdtypes.TIMESTAMP, 'sz': tdtypes.TIMESTAMP, #Added timestamp with timezone 'at': tdtypes.TIME, 'tz': tdtypes.TIMESTAMP, #Added time with timezone #Expreimental - Binary 'bf': sqltypes.BINARY, 'bv': sqltypes.VARBINARY, 'bo': sqltypes.BLOB } #TODO: add the interval types and blob stringtypes=[ t for t in ischema_names if issubclass(ischema_names[t],sqltypes.String)] class TeradataDialect(default.DefaultDialect): name = 'teradata' driver = 'teradata' default_paramstyle = 'qmark' poolclass = pool.SingletonThreadPool statement_compiler = TeradataCompiler ddl_compiler = TeradataDDLCompiler type_compiler = TeradataTypeCompiler preparer = TeradataIdentifierPreparer execution_ctx_cls = TeradataExecutionContext supports_native_boolean = False supports_native_decimal = True supports_unicode_statements = True supports_unicode_binds = True postfetch_lastrowid = False implicit_returning = False preexecute_autoincrement_sequences = False construct_arguments = [ (Table, { "post_create": None, "postfixes": None }), (Index, { "order_by": None, "loading": None }), (Column, { "compress": None, "identity": None }) ] def __init__(self, **kwargs): super(TeradataDialect, self).__init__(**kwargs) def create_connect_args(self, url): if url is not None: params = super(TeradataDialect, self).create_connect_args(url)[1] cargs = ("Teradata", params['host'], params['username'], params['password']) cparams = {p:params[p] for p in params if p not in\ ['host', 'username', 'password']} return (cargs, cparams) @classmethod def dbapi(cls): """ Hook to the dbapi2.0 implementation's module""" from teradata import tdodbc return tdodbc def normalize_name(self, name, **kw): if name is not None: return name.strip().lower() return name def has_table(self, connection, table_name, schema=None): if schema is None: schema=self.default_schema_name stmt = select([column('tablename')], from_obj=[text('dbc.tablesvx')]).where( and_(text('DatabaseName=:schema'), text('TableName=:table_name'))) res = connection.execute(stmt, schema=schema, table_name=table_name).fetchone() return res is not None def _resolve_type(self, t, **kw): """ Resolve types for String, Numeric, Date/Time, etc. columns """ t = self.normalize_name(t) if t in ischema_names: #print(t,ischema_names[t]) t = ischema_names[t] if issubclass(t, sqltypes.String): return t(length=kw['length']/2 if kw['chartype']=='UNICODE' else kw['length'],\ charset=kw['chartype']) elif issubclass(t, sqltypes.Numeric): return t(precision=kw['prec'], scale=kw['scale']) elif issubclass(t, sqltypes.Time) or issubclass(t, sqltypes.DateTime): #Timezone tz=kw['fmt'][-1]=='Z' #Precision prec = kw['fmt'] #For some timestamps and dates, there is no precision, or indicatd in scale prec = prec[prec.index('(') + 1: prec.index(')')] if '(' in prec else 0 prec = kw['scale'] if prec=='F' else int(prec) #prec = int(prec[prec.index('(') + 1: prec.index(')')]) if '(' in prec else 0 return t(precision=prec,timezone=tz) elif issubclass(t, sqltypes.Interval): return t(day_precision=kw['prec'],second_precision=kw['scale']) else: return t() # For types like Integer, ByteInt return ischema_names[None] def _get_column_info(self, row): """ Resolves the column information for get_columns given a row. """ chartype = { 0: None, 1: 'LATIN', 2: 'UNICODE', 3: 'KANJISJIS', 4: 'GRAPHIC'} #Handle unspecified characterset and disregard chartypes specified for non-character types (e.g. binary, json) typ = self._resolve_type(row['columntype'],\ length=int(row['columnlength'] or 0),\ chartype=chartype[row['chartype'] if row['chartype'] in stringtypes else 0],\ prec=int(row['decimaltotaldigits'] or 0),\ scale=int(row['decimalfractionaldigits'] or 0),\ fmt=row['columnformat']) autoinc = row['idcoltype'] in ('GA', 'GD') return { 'name': self.normalize_name(row['columnname']), 'type': typ, 'nullable': row['nullable'] == u'Y', 'default': row['defaultvalue'], 'attrs': { 'columnformat':row['columnformat']}, 'autoincrement': autoinc } def get_columns(self, connection, table_name, schema=None, **kw): helpView=False if schema is None: schema = self.default_schema_name if int(self.server_version_info.split('.')[0])<16: dbc_columninfo='dbc.ColumnsV' #Check if the object us a view stmt = select([column('tablekind')],\ from_obj=[text('dbc.tablesV')]).where(\ and_(text('DatabaseName=:schema'),\ text('TableName=:table_name'),\ text("tablekind='V'"))) res = connection.execute(stmt, schema=schema, table_name=table_name).rowcount helpView = (res==1) else: dbc_columninfo='dbc.ColumnsQV' stmt = select([column('columnname'), column('columntype'),\ column('columnlength'), column('chartype'),\ column('decimaltotaldigits'), column('decimalfractionaldigits'),\ column('columnformat'),\ column('nullable'), column('defaultvalue'), column('idcoltype')],\ from_obj=[text(dbc_columninfo)]).where(\ and_(text('DatabaseName=:schema'),\ text('TableName=:table_name'))) res = connection.execute(stmt, schema=schema, table_name=table_name).fetchall() #If this is a view in pre-16 version, get types for individual columns if helpView: res=[self._get_column_help(connection, schema,table_name,r['columnname']) for r in res] return [self._get_column_info(row) for row in res] def _get_default_schema_name(self, connection): return self.normalize_name( connection.execute('select database').scalar()) def _get_column_help(self, connection, schema,table_name,column_name): stmt='help column '+schema+'.'+table_name+'.'+column_name res = connection.execute(stmt).fetchall()[0] return {'columnname':res['Column Name'], 'columntype':res['Type'], 'columnlength':res['Max Length'], 'chartype':res['Char Type'], 'decimaltotaldigits':res['Decimal Total Digits'], 'decimalfractionaldigits':res['Decimal Fractional Digits'], 'columnformat':res['Format'], 'nullable':res['Nullable'], 'defaultvalue':None, 'idcoltype':res['IdCol Type'] } def get_table_names(self, connection, schema=None, **kw): if schema is None: schema = self.default_schema_name stmt = select([column('tablename')], from_obj=[text('dbc.TablesVX')]).where( and_(text('DatabaseName = :schema'), or_(text('tablekind=\'T\''), text('tablekind=\'O\'')))) res = connection.execute(stmt, schema=schema).fetchall() return [self.normalize_name(name['tablename']) for name in res] def get_schema_names(self, connection, **kw): stmt = select([column('username')], from_obj=[text('dbc.UsersV')], order_by=[text('username')]) res = connection.execute(stmt).fetchall() return [self.normalize_name(name['username']) for name in res] def get_view_definition(self, connection, view_name, schema=None, **kw): if schema is None: schema = self.default_schema_name res = connection.execute('show table {}.{}'.format(schema, view_name)).scalar() return self.normalize_name(res) def get_view_names(self, connection, schema=None, **kw): if schema is None: schema = self.default_schema_name stmt = select([column('tablename')], from_obj=[text('dbc.TablesVX')]).where( and_(text('DatabaseName = :schema'), text('tablekind=\'V\''))) res = connection.execute(stmt, schema=schema).fetchall() return [self.normalize_name(name['tablename']) for name in res] def get_pk_constraint(self, connection, table_name, schema=None, **kw): """ Override TODO: Check if we need PRIMARY Indices or PRIMARY KEY Indices TODO: Check for border cases (No PK Indices) """ if schema is None: schema = self.default_schema_name stmt = select([column('ColumnName'), column('IndexName')], from_obj=[text('dbc.Indices')]).where( and_(text('DatabaseName = :schema'), text('TableName=:table'), text('IndexType=:indextype')) ).order_by(asc(column('IndexNumber'))) # K for Primary Key res = connection.execute(stmt, schema=schema, table=table_name, indextype='K').fetchall() index_columns = list() index_name = None for index_column in res: index_columns.append(self.normalize_name(index_column['ColumnName'])) index_name = self.normalize_name(index_column['IndexName']) # There should be just one IndexName return { "constrained_columns": index_columns, "name": index_name } def get_unique_constraints(self, connection, table_name, schema=None, **kw): """ Overrides base class method """ if schema is None: schema = self.default_schema_name stmt = select([column('ColumnName'), column('IndexName')], from_obj=[text('dbc.Indices')]) \ .where(and_(text('DatabaseName = :schema'), text('TableName=:table'), text('IndexType=:indextype'))) \ .order_by(asc(column('IndexName'))) # U for Unique res = connection.execute(stmt, schema=schema, table=table_name, indextype='U').fetchall() def grouper(fk_row): return { 'name': self.normalize_name(fk_row['IndexName']), } unique_constraints = list() for constraint_info, constraint_cols in groupby(res, grouper): unique_constraint = { 'name': self.normalize_name(constraint_info['name']), 'column_names': list() } for constraint_col in constraint_cols: unique_constraint['column_names'].append(self.normalize_name(constraint_col['ColumnName'])) unique_constraints.append(unique_constraint) return unique_constraints def get_foreign_keys(self, connection, table_name, schema=None, **kw): """ Overrides base class method """ if schema is None: schema = self.default_schema_name stmt = select([column('IndexID'), column('IndexName'), column('ChildKeyColumn'), column('ParentDB'), column('ParentTable'), column('ParentKeyColumn')], from_obj=[text('DBC.All_RI_ChildrenV')]) \ .where(and_(text('ChildTable = :table'), text('ChildDB = :schema'))) \ .order_by(asc(column('IndexID'))) res = connection.execute(stmt, schema=schema, table=table_name).fetchall() def grouper(fk_row): return { 'name': fk_row.IndexName or fk_row.IndexID, #ID if IndexName is None 'schema': fk_row.ParentDB, 'table': fk_row.ParentTable } # TODO: Check if there's a better way fk_dicts = list() for constraint_info, constraint_cols in groupby(res, grouper): fk_dict = { 'name': constraint_info['name'], 'constrained_columns': list(), 'referred_table': constraint_info['table'], 'referred_schema': constraint_info['schema'], 'referred_columns': list() } for constraint_col in constraint_cols: fk_dict['constrained_columns'].append(self.normalize_name(constraint_col['ChildKeyColumn'])) fk_dict['referred_columns'].append(self.normalize_name(constraint_col['ParentKeyColumn'])) fk_dicts.append(fk_dict) return fk_dicts def get_indexes(self, connection, table_name, schema=None, **kw): """ Overrides base class method """ if schema is None: schema = self.default_schema_name stmt = select(["*"], from_obj=[text('dbc.Indices')]) \ .where(and_(text('DatabaseName = :schema'), text('TableName=:table'))) \ .order_by(asc(column('IndexName'))) res = connection.execute(stmt, schema=schema, table=table_name).fetchall() def grouper(fk_row): return { 'name': fk_row.IndexName or fk_row.IndexNumber, # If IndexName is None TODO: Check what to do 'unique': True if fk_row.UniqueFlag == 'Y' else False } # TODO: Check if there's a better way indices = list() for index_info, index_cols in groupby(res, grouper): index_dict = { 'name': index_info['name'], 'column_names': list(), 'unique': index_info['unique'] } for index_col in index_cols: index_dict['column_names'].append(self.normalize_name(index_col['ColumnName'])) indices.append(index_dict) return indices def get_transaction_mode(self, connection, **kw): """ Returns the transaction mode set for the current session. T = TDBS A = ANSI """ stmt = select([text('transaction_mode')],\ from_obj=[text('dbc.sessioninfov')]).\ where(text('sessionno=SESSION')) res = connection.execute(stmt).scalar() return res def _get_server_version_info(self, connection, **kw): """ Returns the Teradata Database software version. """ stmt = select([text('InfoData')],\ from_obj=[text('dbc.dbcinfov')]).\ where(text('InfoKey=\'VERSION\'')) res = connection.execute(stmt).scalar() return res def conn_supports_autocommit(self, connection, **kw): """ Returns True if autocommit is used for this connection (underlying Teradata session) else False """ return self.get_transaction_mode(connection) == 'T' dialect = TeradataDialect