/*******************************************************************************
 * Copyright 2014 Observational Health Data Sciences and Informatics
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ******************************************************************************/
package org.ohdsi.databases;

import java.sql.BatchUpdateException;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.management.RuntimeErrorException;

import org.ohdsi.medlineXmlToDatabase.Abbreviator;
import org.ohdsi.medlineXmlToDatabase.MedlineCitationAnalyser.VariableType;
import org.ohdsi.utilities.StringUtilities;
import org.ohdsi.utilities.files.Row;

/**
 * Wrapper around java.sql.connection to handle any database work that is platform-specific.
 * 
 * @author MSCHUEMI
 * 
 */
public class ConnectionWrapper {
	private Connection	connection;
	private DbType		dbType;
	private boolean		batchMode	= false;
	private Statement	statement;
	
	public ConnectionWrapper(String server, String domain, String user, String password, DbType dbType) {
		this.connection = DBConnector.connect(server, domain, user, password, dbType);
		this.dbType = dbType;
	}
	
	public void setBatchMode(boolean batchMode) {
		try {
			if (this.batchMode && !batchMode) { // turn off batchmode
				this.batchMode = false;
				statement.executeBatch();
				connection.setAutoCommit(true);
			} else {
				this.batchMode = true;
				connection.setAutoCommit(false);
				statement = connection.createStatement();
			}
		} catch (SQLException e) {
			System.err.println("Error: " + e.getMessage());
			e.printStackTrace();
			e = e.getNextException();
			if (e != null) {
				System.err.println("Error: " + e.getMessage());
				e.printStackTrace();
			}
			throw new RuntimeException("Error executing batch data");
		}
	}
	
	/**
	 * Switch the database to use.
	 * 
	 * @param database
	 */
	public void use(String database) {
		if (database == null)
			return;
		if (dbType.equals(DbType.ORACLE))
			execute("ALTER SESSION SET current_schema = " + database);
		else if (dbType.equals(DbType.POSTGRESQL))
			execute("SET search_path TO " + database);
		else
			execute("USE " + database);
	}
	
	public void createDatabase(String database) {
		execute("CREATE SCHEMA " + database);
	}
	
	/**
	 * Execute the given SQL statement.
	 * 
	 * @param sql
	 */
	public void execute(String sql) {
		try {
			if (sql.length() == 0)
				return;
			if (batchMode)
				statement.addBatch(sql);
			else {
				Statement statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
				statement.execute(sql);
				statement.close();
			}
		} catch (SQLException e) {
			System.err.println(sql);
			e.printStackTrace();
			e = e.getNextException();
			if (e != null) {
				System.err.println("Error: " + e.getMessage());
				e.printStackTrace();
			}
			throw new RuntimeException("Error inserting data");
		}
	}
	
	public void insertIntoTable(String table, Map<String, String> field2Value) {
		List<String> fields = new ArrayList<String>(field2Value.keySet());
		
		StringBuilder sql = new StringBuilder();
		sql.append("INSERT INTO ");
		sql.append(Abbreviator.abbreviate(table));
		sql.append(" (");
		boolean first = true;
		for (String field : fields) {
			if (first)
				first = false;
			else
				sql.append(",");
			sql.append(Abbreviator.abbreviate(field));
		}
		
		if (dbType.equals(DbType.MYSQL)) { // MySQL uses double quotes, escape using backslash
			sql.append(") VALUES (\"");
			first = true;
			for (String field : fields) {
				if (first)
					first = false;
				else
					sql.append("\",\"");
				sql.append(field2Value.get(field).replaceAll("\\\\", "\\\\\\\\").replaceAll("\"", "\\\\\""));
			}
			sql.append("\");");
		} else if (dbType.equals(DbType.MSSQL) || dbType.equals(DbType.POSTGRESQL)) { // MSSQL uses single quotes, escape by doubling
			sql.append(") VALUES ('");
			first = true;
			for (String field : fields) {
				if (first)
					first = false;
				else
					sql.append("','");
				sql.append(field2Value.get(field).replaceAll("'", "''"));
			}
			sql.append("')");
		}
		execute(sql.toString());
	}
	
	public void insertIntoTable(String tableName, List<Row> rows, boolean emptyStringToNull) {
		List<String> columns = rows.get(0).getFieldNames();
		String sql = "INSERT INTO " + tableName;
		sql = sql + " (" + StringUtilities.join(columns, ",") + ")";
		sql = sql + " VALUES (?";
		for (int i = 1; i < columns.size(); i++)
			sql = sql + ",?";
		sql = sql + ")";
		try {
			connection.setAutoCommit(false);
			PreparedStatement statement = connection.prepareStatement(sql);
			for (Row row : rows) {
				for (int i = 0; i < columns.size(); i++) {
					String value = row.get(columns.get(i));
					if (value == null)
						System.out.println(row.toString());
					if (value.length() == 0 && emptyStringToNull)
						value = null;
					if (dbType.equals(DbType.POSTGRESQL)) // PostgreSQL does not allow unspecified types
						statement.setObject(i + 1, value, Types.OTHER);
					else if (dbType.equals(DbType.ORACLE)) {
						statement.setString(i + 1, value);
					} else
						statement.setString(i + 1, value);
				}
				statement.addBatch();
			}
			statement.executeBatch();
			connection.commit();
			statement.close();
			connection.setAutoCommit(true);
			connection.clearWarnings();
		} catch (SQLException e) {
			e.printStackTrace();
			if (e instanceof BatchUpdateException) {
				System.err.println(((BatchUpdateException) e).getNextException().getMessage());
			}
		}
	}
	
	public void createTable(String table, List<String> fields, List<String> types, List<String> primaryKey) {
		StringBuilder sql = new StringBuilder();
		sql.append("CREATE TABLE " + table + " (\n");
		boolean first = true;
		for (int i = 0; i < fields.size(); i++) {
			if (first)
				first = false;
			else
				sql.append(",\n");
			sql.append("  " + fields.get(i) + " " + types.get(i));
		}
		if (primaryKey != null && primaryKey.size() != 0)
			sql.append(",\n  PRIMARY KEY (" + StringUtilities.join(primaryKey, ",") + ")\n");
		sql.append(");\n\n");
		execute(Abbreviator.abbreviate(sql.toString()));
	}
	
	public void createTableUsingVariableTypes(String table, List<String> fields, List<VariableType> variableTypes, List<String> primaryKey) {
		List<String> types = new ArrayList<String>(variableTypes.size());
		for (VariableType variableType : variableTypes) {
			if (dbType.equals(DbType.MYSQL)) {
				if (variableType.isNumeric)
					types.add("INT");
				else if (variableType.maxLength > 255)
					types.add("TEXT");
				else
					types.add("VARCHAR(255)");
			} else if (dbType.equals(DbType.MSSQL)) {
				if (variableType.isNumeric) {
					if (variableType.maxLength < 10)
						types.add("INT");
					else
						types.add("BIGINT");
				} else if (variableType.maxLength > 255)
					types.add("VARCHAR(MAX)");
				else
					types.add("VARCHAR(255)");
			} else if (dbType.equals(DbType.POSTGRESQL)) {
				if (variableType.isNumeric) {
					if (variableType.maxLength < 10)
						types.add("INT");
					else
						types.add("BIGINT");
				} else if (variableType.maxLength > 255)
					types.add("TEXT");
				else
					types.add("VARCHAR(255)");
			} else
				throw new RuntimeException("Unknown datasource type " + dbType);
		}
		
		createTable(table, fields, types, primaryKey);
	}
	
	public void close() {
		try {
			connection.close();
		} catch (SQLException e) {
			e.printStackTrace();
		}
	}
	
	public class QueryResult implements Iterable<Row> {
		private String				sql;
		
		private List<DBRowIterator>	iterators	= new ArrayList<DBRowIterator>();
		
		public QueryResult(String sql) {
			this.sql = sql;
		}
		
		@Override
		public Iterator<Row> iterator() {
			DBRowIterator iterator = new DBRowIterator(sql);
			iterators.add(iterator);
			return iterator;
		}
		
		public void close() {
			for (DBRowIterator iterator : iterators) {
				iterator.close();
			}
		}
	}
	
	public List<String> getTableNames(String database) {
		List<String> names = new ArrayList<String>();
		String query = null;
		if (dbType.equals(DbType.MYSQL)) {
			if (database == null)
				query = "SHOW TABLES";
			else
				query = "SHOW TABLES IN " + database;
		} else if (dbType.equals(DbType.MSSQL)) {
			query = "SELECT name FROM " + database + ".sys.tables ";
		} else if (dbType.equals(DbType.ORACLE)) {
			query = "SELECT table_name FROM all_tables WHERE owner='" + database.toUpperCase() + "'";
		} else if (dbType.equals(DbType.POSTGRESQL)) {
			query = "SELECT table_name FROM information_schema.tables WHERE table_schema = '" + database + "'";
		}
		
		for (Row row : query(query))
			names.add(row.get(row.getFieldNames().get(0)));
		return names;
	}
	
	public List<String> getFieldNames(String table) {
		List<String> names = new ArrayList<String>();
		if (dbType.equals(DbType.MSSQL)) {
			for (Row row : query("SELECT name FROM syscolumns WHERE id=OBJECT_ID('" + table + "')"))
				names.add(row.get("name"));
		} else if (dbType.equals(DbType.MYSQL))
			for (Row row : query("SHOW COLUMNS FROM " + table))
				names.add(row.get("COLUMN_NAME"));
		else if (dbType.equals(DbType.POSTGRESQL))
			for (Row row : query("SELECT column_name FROM information_schema.columns WHERE table_name='" + table.toLowerCase() + "'"))
				names.add(row.get("column_name"));
		else
			throw new RuntimeException("DB type not supported");
		
		return names;
	}
	
	public List<FieldInfo> getFieldInfo(String table) {
		List<FieldInfo> fieldInfos = new ArrayList<FieldInfo>();
		try {
			DatabaseMetaData metaData = connection.getMetaData();
			ResultSet resultSet = metaData.getColumns(null, null, table, null);
			while (resultSet.next()) {
				FieldInfo fieldInfo = new FieldInfo();
				fieldInfo.name = resultSet.getString("COLUMN_NAME");
				fieldInfo.type = resultSet.getInt("DATA_TYPE");
				fieldInfo.length = resultSet.getInt("COLUMN_SIZE");
				fieldInfos.add(fieldInfo);
			}
		} catch (SQLException e) {
			throw (new RuntimeException(e));
		}
		return fieldInfos;
	}
	
	public class FieldInfo {
		public int		type;
		public String	name;
		public int		length;
	}
	
	private QueryResult query(String sql) {
		return new QueryResult(sql);
	}
	
	private class DBRowIterator implements Iterator<Row> {
		
		private ResultSet	resultSet;
		
		private boolean		hasNext;
		
		private Set<String>	columnNames	= new HashSet<String>();
		
		public DBRowIterator(String sql) {
			try {
				sql.trim();
				if (sql.endsWith(";"))
					sql = sql.substring(0, sql.length() - 1);
				Statement statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
				resultSet = statement.executeQuery(sql.toString());
				hasNext = resultSet.next();
			} catch (SQLException e) {
				System.err.println(sql.toString());
				System.err.println(e.getMessage());
				throw new RuntimeException(e);
			}
		}
		
		public void close() {
			if (resultSet != null) {
				try {
					resultSet.close();
				} catch (SQLException e) {
					e.printStackTrace();
				}
				resultSet = null;
				hasNext = false;
			}
		}
		
		@Override
		public boolean hasNext() {
			return hasNext;
		}
		
		@Override
		public Row next() {
			try {
				Row row = new Row();
				ResultSetMetaData metaData;
				metaData = resultSet.getMetaData();
				columnNames.clear();
				
				for (int i = 1; i < metaData.getColumnCount() + 1; i++) {
					String columnName = metaData.getColumnName(i);
					if (columnNames.add(columnName)) {
						String value = resultSet.getString(i);
						if (value == null)
							value = "";
						
						row.add(columnName, value.replace(" 00:00:00", ""));
					}
				}
				hasNext = resultSet.next();
				if (!hasNext) {
					resultSet.close();
					resultSet = null;
				}
				return row;
			} catch (SQLException e) {
				e.printStackTrace();
				throw new RuntimeException(e);
			}
		}
		
		@Override
		public void remove() {
		}
	}
	
	public void setDateFormat() {
		try {
			Statement statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_UPDATABLE);
			if (dbType.equals(DbType.POSTGRESQL))
				statement.execute("SET datestyle = \"ISO, MDY\"");
			else
				statement.execute("SET DateFormat DMY;");
		} catch (SQLException e) {
			e.printStackTrace();
		}
	}
	
	public void dropTableIfExists(String table) {
		if (dbType.equals(DbType.ORACLE) || dbType.equals(DbType.POSTGRESQL)) {
			try {
				Statement statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
				statement.execute("TRUNCATE TABLE " + table);
				statement.execute("DROP TABLE " + table);
				statement.close();
			} catch (Exception e) {
				// do nothing
			}
		} else if (dbType.equals(DbType.MSSQL)) {
			execute("IF OBJECT_ID('" + table + "', 'U') IS NOT NULL DROP TABLE " + table + ";");
		} else {
			execute("DROP TABLE " + table + " IF EXISTS");
		}
	}
}