/* * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. * A copy of the License is located at * * http://aws.amazon.com/apache2.0 * * or in the "license" file accompanying this file. This file is distributed * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing * permissions and limitations under the License. */ package com.amazonaws.xray.sql.postgres; import com.amazonaws.xray.AWSXRay; import com.amazonaws.xray.entities.Namespace; import com.amazonaws.xray.entities.Subsegment; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.net.URI; import java.net.URISyntaxException; import java.sql.CallableStatement; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.Statement; import java.util.HashMap; import java.util.Map; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.tomcat.jdbc.pool.ConnectionPool; import org.apache.tomcat.jdbc.pool.JdbcInterceptor; import org.apache.tomcat.jdbc.pool.PooledConnection; /* * Inspired by: http://grepcode.com/file/repo1.maven.org/maven2/org.apache.tomcat/tomcat-jdbc/8.0.24/org/apache/tomcat/jdbc/pool/interceptor/AbstractQueryReport.java#AbstractQueryReport */ public class TracingInterceptor extends JdbcInterceptor { protected static final String CREATE_STATEMENT = "createStatement"; protected static final int CREATE_STATEMENT_INDEX = 0; protected static final String PREPARE_STATEMENT = "prepareStatement"; protected static final int PREPARE_STATEMENT_INDEX = 1; protected static final String PREPARE_CALL = "prepareCall"; protected static final int PREPARE_CALL_INDEX = 2; protected static final String[] STATEMENT_TYPES = {CREATE_STATEMENT, PREPARE_STATEMENT, PREPARE_CALL}; protected static final int STATEMENT_TYPE_COUNT = STATEMENT_TYPES.length; protected static final String EXECUTE = "execute"; protected static final String EXECUTE_QUERY = "executeQuery"; protected static final String EXECUTE_UPDATE = "executeUpdate"; protected static final String EXECUTE_BATCH = "executeBatch"; protected static final String[] EXECUTE_TYPES = {EXECUTE, EXECUTE_QUERY, EXECUTE_UPDATE, EXECUTE_BATCH}; /** * @deprecated For internal use only. */ @SuppressWarnings("checkstyle:ConstantName") @Deprecated protected static final Constructor<?>[] constructors = new Constructor[STATEMENT_TYPE_COUNT]; private static final Log logger = LogFactory.getLog(TracingInterceptor.class); private static final String DEFAULT_DATABASE_NAME = "database"; /** * Creates a constructor for a proxy class, if one doesn't already exist * * @param index the index of the constructor * @param clazz the interface that the proxy will implement * @return returns a constructor used to create new instances * @throws NoSuchMethodException */ protected Constructor<?> getConstructor(int index, Class<?> clazz) throws NoSuchMethodException { if (constructors[index] == null) { Class<?> proxyClass = Proxy.getProxyClass(TracingInterceptor.class.getClassLoader(), new Class[] {clazz}); constructors[index] = proxyClass.getConstructor(new Class[] {InvocationHandler.class}); } return constructors[index]; } public Object createStatement(Object proxy, Method method, Object[] args, Object statementObject) { try { String name = method.getName(); String sql = null; Constructor<?> constructor = null; Map<String, Object> additionalParams = new HashMap<>(); if (compare(CREATE_STATEMENT, name)) { //createStatement constructor = getConstructor(CREATE_STATEMENT_INDEX, Statement.class); } else if (compare(PREPARE_STATEMENT, name)) { additionalParams.put("preparation", "statement"); sql = (String) args[0]; constructor = getConstructor(PREPARE_STATEMENT_INDEX, PreparedStatement.class); } else if (compare(PREPARE_CALL, name)) { additionalParams.put("preparation", "call"); sql = (String) args[0]; constructor = getConstructor(PREPARE_CALL_INDEX, CallableStatement.class); } else { //do nothing, might be a future unsupported method //so we better bail out and let the system continue return statementObject; } Statement statement = ((Statement) statementObject); Connection connection = statement.getConnection(); DatabaseMetaData metadata = connection.getMetaData(); // parse cname for subsegment name additionalParams.put("url", metadata.getURL()); additionalParams.put("user", metadata.getUserName()); additionalParams.put("driver_version", metadata.getDriverVersion()); additionalParams.put("database_type", metadata.getDatabaseProductName()); additionalParams.put("database_version", metadata.getDatabaseProductVersion()); String hostname = DEFAULT_DATABASE_NAME; try { URI normalizedUri = new URI(new URI(metadata.getURL()).getSchemeSpecificPart()); hostname = connection.getCatalog() + "@" + normalizedUri.getHost(); } catch (URISyntaxException e) { logger.warn("Unable to parse database URI. Falling back to default '" + DEFAULT_DATABASE_NAME + "' for subsegment name.", e); } logger.debug("Instantiating new statement proxy."); return constructor.newInstance(new TracingStatementProxy(statementObject, sql, hostname, additionalParams)); } catch (SQLException | InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) { logger.warn("Unable to create statement proxy for tracing.", e); } return statementObject; } protected class TracingStatementProxy implements InvocationHandler { protected boolean closed = false; protected Object delegate; protected final String query; protected final String hostname; protected Map<String, Object> additionalParams; public TracingStatementProxy(Object parent, String query, String hostname, Map<String, Object> additionalParams) { this.delegate = parent; this.query = query; this.hostname = hostname; this.additionalParams = additionalParams; } @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { //get the name of the method for comparison final String name = method.getName(); //was close invoked? boolean close = compare(JdbcInterceptor.CLOSE_VAL, name); //allow close to be called multiple times if (close && closed) { return null; } //are we calling isClosed? if (compare(JdbcInterceptor.ISCLOSED_VAL, name)) { return Boolean.valueOf(closed); } //if we are calling anything else, bail out if (closed) { throw new SQLException("Statement closed."); } //check to see if we are about to execute a query final boolean process = isExecute(method); Object result = null; Subsegment subsegment = null; if (process) { subsegment = AWSXRay.beginSubsegment(hostname); } try { if (process && null != subsegment) { subsegment.putAllSql(additionalParams); subsegment.setNamespace(Namespace.REMOTE.toString()); } result = method.invoke(delegate, args); //execute the query } catch (Throwable t) { if (null != subsegment) { subsegment.addException(t); } if (t instanceof InvocationTargetException && t.getCause() != null) { throw t.getCause(); } else { throw t; } } finally { if (process && null != subsegment) { AWSXRay.endSubsegment(); } } //perform close cleanup if (close) { closed = true; delegate = null; } return result; } } /** * {@inheritDoc} */ @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { if (compare(CLOSE_VAL, method)) { return super.invoke(proxy, method, args); } else { boolean process = isStatement(method); if (process) { Object statement = super.invoke(proxy, method, args); return createStatement(proxy, method, args, statement); } else { return super.invoke(proxy, method, args); } } } private boolean isStatement(Method method) { return isMemberOf(STATEMENT_TYPES, method); } private boolean isExecute(Method method) { return isMemberOf(EXECUTE_TYPES, method); } protected boolean isMemberOf(String[] names, Method method) { boolean member = false; final String name = method.getName(); for (int i = 0; !member && i < names.length; i++) { member = compare(names[i], name); } return member; } @Override public void reset(ConnectionPool parent, PooledConnection con) { //do nothing } }