# -*- coding: utf-8 -*- """ redshift_psql.py is a collection of functions used to load data from s3 into redshift Example Instantiation: rspg = RedshiftPostrges("config.yaml", "stream_name", "pg_auth_file", run_local=True) """ import socket import time from datetime import datetime import boto import staticconf from dateutil.parser import parse as parsedate from staticconf import read_string from staticconf import YamlConfiguration import psycopg2 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT from psycopg2.extensions import QueryCanceledError from sherlock.common.aws import get_aws_creds ADD_SCHEMA_PATH = "SET search_path TO '$user', public, %(schema_path)s" DEFAULT_NAMESPACE = "public" # Copied from http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/ from select import select from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE def wait_select_inter(conn): while True: try: state = conn.poll() if state == POLL_OK: break elif state == POLL_READ: select([conn.fileno()], [], []) elif state == POLL_WRITE: select([], [conn.fileno()], []) else: raise conn.OperationalError( "bad state from poll: %s" % state) except KeyboardInterrupt: conn.cancel() # the loop will be broken by a server error continue def get_namespaced_tablename(tablename, schemaname=None): if schemaname is None: rs_schema = get_redshift_schema() else: # note we do lower for backward compatability rs_schema = schemaname.lower() if rs_schema == DEFAULT_NAMESPACE: return tablename return rs_schema + "." + tablename def get_redshift_schema(): # note we do lower for backward compatability return read_string('redshift_schema', DEFAULT_NAMESPACE).lower() class RedshiftPostgres(object): """ This class simplifies running queries on redshift. The current purpose is for creating tables, and copying data into them from S3. However, it can be used for general SQL commands. Constructor Args: logdir -- the directory where the logs go logstrm -- a PipelineStreamLogger to record starts, completes and failed sql commands psql_auth_file -- the file from which we get a username and password for a redshift account run_local -- whether to run locally or not """ # this should give 1 hour for a sql command to complete SECONDS_BEFORE_SENDING_PROBE = 1 SECONDS_BETWEEN_SENDING_PROBE = 60 RETRIES_BEFORE_QUIT = 60 def __init__(self, logstrm, psql_auth_file, run_local=False): self.run_local = run_local self.host = staticconf.read_string('redshift_host') self.port = staticconf.read_int('redshift_port') private_dict = YamlConfiguration(psql_auth_file) self.user = private_dict['redshift_user'] self.password = private_dict['redshift_password'] self.log_stream = logstrm self._aws_key = '' self._aws_secret = '' self._aws_token = '' self._aws_token_expiry = datetime.utcnow() self._whitelist = ['select', 'create', 'insert', 'update'] self._set_aws_auth() psycopg2.extensions.set_wait_callback(wait_select_inter) def _set_aws_auth(self): """ _set_aws_auth gets key, secret, token and expiration either from a file or from a temporary instance and sets them """ cred_tuple = get_aws_creds(self.run_local) self._aws_key = cred_tuple.access_key_id self._aws_secret = cred_tuple.secret_access_key self._aws_token = cred_tuple.token self._aws_token_expiry = parsedate(cred_tuple.expiration) def get_boto_config(self): boto_dict = {} for section in boto.config.sections(): boto_dict[section] = {} for option in boto.config.options(section): if option != 'aws_secret_access_key': boto_dict[section][option] = boto.config.get(section, option) else: boto_dict[section][option] = "xxxxxxxxxxxxxxxx" return boto_dict def get_connection(self, database): """ gets a connection to the a psql database Args: self.password -- the password to the database Returns: a connection object """ # additional logging to help with connection issues boto_config_dict = self.get_boto_config() self.log_stream.write_msg('boto', extra_msg=boto_config_dict) log_template = "getting connection with host {0} port {1} user {2} db {3}" log_msg = log_template.format(self.host, self.port, self.user, database) self.log_stream.write_msg('starting', extra_msg=log_msg) conn = psycopg2.connect( host=self.host, port=self.port, user=self.user, password=self.password, database=database, sslmode='require') self.log_stream.write_msg('finished', extra_msg=log_msg) fd = conn.fileno() sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, self.SECONDS_BEFORE_SENDING_PROBE) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, self.SECONDS_BETWEEN_SENDING_PROBE) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, self.RETRIES_BEFORE_QUIT) return conn def cleanse_sql(self, command): """ cleanses a psql command of any auth information Args: command -- the psql command Returns: the cleansed command """ cmd_list = command.split() first_word = cmd_list[0] if first_word.lower() in self._whitelist: return " ".join(cmd_list) return cmd_list[0] + " cleansed " def run_sql_ex(self, sql, database, log_msg, s3_needed=False, params=None, output=False, time_est_secs=10, need_commit=True, schema=DEFAULT_NAMESPACE): """ run_sql takes a command and executes using the connection found in get_conection. Args: sql -- the postgres command to run database -- the database on which the command is to be run log_msg -- a shortened message for what command we're running s3_needed -- if the sql command requires s3 input this = True, otherwise it's false. For example, a simple query of a table would have s3_needed=False, while COPY from S3 would have s3_needed=True. params -- if there are any parameters for the command output -- if the command Returns rows (e.g., a SELECT command)) time_est_secs -- how long the user estimates the command to run in seconds. This is used to decide whether to get new certs or not need_commit -- False if command does not need to be committed (ex: vacuum) schema -- the schema in the database on which the command is run anything other than the default namespace must have the schemaname added in the search_path. This ephemeral so must be done on a per-session basis, and since we close the cursor and connecion after each query we'll check every time. Returns: if there's a return value, it is the results of the query """ start_time = time.time() self.log_stream.write_msg('starting', extra_msg=log_msg) exception = None if s3_needed: try: if self._aws_token is None: sql = sql.replace(';token=%s', '') sql = sql % (self._aws_key, self._aws_secret) else: sql = sql % (self._aws_key, self._aws_secret, self._aws_token) except TypeError as type_error: self.log_stream.write_msg( 'error', error_msg=repr(type_error), extra_msg=sql ) raise try: result = dict() with self.get_connection(database) as conn: if need_commit is False: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cur = conn.cursor() if schema != DEFAULT_NAMESPACE: schema_params = {'schema_path': schema} cur.execute(ADD_SCHEMA_PATH, schema_params) if params: cur.execute(sql, params) else: cur.execute(sql) if output: rows = cur.fetchall() result['status'] = cur.statusmessage cur.close() self.log_stream.write_msg( 'finished', job_start_secs=start_time, extra_msg=log_msg ) if output: result['output'] = rows return result except QueryCanceledError as cmd_exception: exception = cmd_exception raise KeyboardInterrupt except Exception as cmd_exception: exception = cmd_exception raise finally: self.log_stream.write_msg( 'error', job_start_secs=start_time, error_msg=repr(exception), extra_msg=log_msg ) def run_sql(self, sql, database, log_msg, s3_needed=False, params=None, output=False, time_est_secs=10, need_commit=True, schema=DEFAULT_NAMESPACE): result = self.run_sql_ex( sql, database, log_msg, s3_needed, params, output, time_est_secs, need_commit, schema=schema) if result is False: return False return result['output'] if output is True else True