/*******************************************************************************
 * Copyright 2013 EMBL-EBI
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ******************************************************************************/
package net.sf.cram.ref;

import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.ref.WeakReference;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.cram.io.InputStreamUtils;
import htsjdk.samtools.cram.ref.CRAMReferenceSource;
import htsjdk.samtools.reference.FastaSequenceIndex;
import htsjdk.samtools.reference.ReferenceSequence;
import htsjdk.samtools.reference.ReferenceSequenceFile;
import htsjdk.samtools.reference.ReferenceSequenceFileFactory;
import htsjdk.samtools.util.IOUtil;
import htsjdk.samtools.util.Log;
import net.sf.cram.common.Utils;

/**
 * A central class for automated discovery of reference sequences. The algorithm
 * is expected similar to that of samtools:
 * <ul>
 * <li>Search in memory cache by sequence name.</li>
 * <li>Use local fasta file is supplied as a reference file and cache the found
 * sequence in memory.</li>
 * <li>Try REF_CACHE env variable.</li>
 * <li>Try all entries in REF_PATH. The default value is the EBI reference
 * service.</li>
 * <li>Try @SQ:UR as a URL for a fasta file with the fasta index next to
 * it.</li>
 * </ul>
 * 
 * @author vadim
 */
public class ReferenceSource implements CRAMReferenceSource {
	private static final int REF_BASES_TO_CHECK_FOR_SANITY = 1000;
	private static final Pattern chrPattern = Pattern.compile("chr.*", Pattern.CASE_INSENSITIVE);
	private static String REF_CACHE = System.getenv("REF_CACHE");
	private static String REF_PATH = System.getenv("REF_PATH");
	private static List<PathPattern> refPatterns = new ArrayList<PathPattern>();

	static {
		if (REF_PATH == null)
			REF_PATH = "http://www.ebi.ac.uk/ena/cram/md5/%s";

		if (REF_CACHE != null)
			refPatterns.add(new PathPattern(REF_CACHE));
		for (String s : REF_PATH.split("(?i)(?<!(http|ftp)):")) {
			refPatterns.add(new PathPattern(s));
		}

	}

	private static Log log = Log.getInstance(ReferenceSource.class);
	private ReferenceSequenceFile rsFile;
	private FastaSequenceIndex fastaSequenceIndex;
	private int downloadTriesBeforeFailing = 2;

	/*
	 * In-memory cache of ref bases by sequence name. Garbage collector will
	 * automatically clean it if memory is low.
	 */
	private Map<String, WeakReference<byte[]>> cacheW = new HashMap<String, WeakReference<byte[]>>();

	public ReferenceSource() {
	}

	public ReferenceSource(File file) {
		if (file != null) {
			rsFile = ReferenceSequenceFileFactory.getReferenceSequenceFile(file);

			File indexFile = new File(file.getAbsoluteFile() + ".fai");
			if (indexFile.exists())
				fastaSequenceIndex = new FastaSequenceIndex(indexFile);
		}
	}

	public ReferenceSource(ReferenceSequenceFile rsFile) {
		this.rsFile = rsFile;
	}

	public void clearCache() {
		cacheW.clear();
	}

	private byte[] findInCache(String name) {
		WeakReference<byte[]> r = cacheW.get(name);
		if (r != null) {
			byte[] bytes = r.get();
			if (bytes != null)
				return bytes;
		}
		return null;
	}

	@Override
	public synchronized byte[] getReferenceBases(SAMSequenceRecord record, boolean tryNameVariants) {
		byte[] bases = findBases(record, tryNameVariants);
		if (bases == null)
			return null;

		cacheW.put(record.getSequenceName(), new WeakReference<byte[]>(bases));

		String md5 = record.getAttribute(SAMSequenceRecord.MD5_TAG);
		if (md5 == null) {
			md5 = Utils.calculateMD5String(bases);
			record.setAttribute(SAMSequenceRecord.MD5_TAG, md5);
		}

		if (REF_CACHE != null)
			addToRefCache(md5, bases);

		return bases;
	}

	private static byte[] readBytesFromFile(File file, int offset, int len) throws IOException {
		long size = file.length();
		if (size < offset || len < 0) {
			log.warn(String.format("Ref request is out of range: %s, size=%d, offset=%d, len=%d",
					file.getAbsolutePath(), size, offset, len));
			return new byte[] {};
		}
		byte[] data = new byte[(int) Math.min(size - offset, len)];
		FileInputStream fis = new FileInputStream(file);
		DataInputStream dis = new DataInputStream(fis);
		dis.skip(offset);
		dis.readFully(data);
		dis.close();
		return data;
	}

	public synchronized ReferenceRegion getRegion(SAMSequenceRecord record, int start_1based, int endInclusive_1based)
			throws IOException {
		{ // check cache by sequence name:
			String name = record.getSequenceName();
			byte[] bases = findInCache(name);
			if (bases != null) {
				log.debug("Reference found in memory cache by name: " + name);
				return ReferenceRegion.copyRegion(bases, record.getSequenceIndex(), record.getSequenceName(),
						start_1based, endInclusive_1based);
			}
		}

		String md5 = record.getAttribute(SAMSequenceRecord.MD5_TAG);
		{ // check cache by md5:
			if (md5 != null) {
				byte[] bases = findInCache(md5);
				if (bases != null) {
					log.debug("Reference found in memory cache by md5: " + md5);
					return ReferenceRegion.copyRegion(bases, record.getSequenceIndex(), record.getSequenceName(),
							start_1based, endInclusive_1based);
				}
			}
		}

		byte[] bases = null;
		if (REF_CACHE != null) {
			PathPattern pathPattern = new PathPattern(REF_CACHE);
			File file = new File(pathPattern.format(md5));
			if (file.exists()) {
				bases = readBytesFromFile(file, start_1based - 1, endInclusive_1based - start_1based + 1);
				return new ReferenceRegion(bases, record.getSequenceIndex(), record.getSequenceName(), start_1based);
			}
		}

		{ // try to fetch sequence by md5:
			if (md5 != null)
				try {
					bases = findBasesByMD5(md5);
				} catch (Exception e) {
					if (e instanceof RuntimeException)
						throw (RuntimeException) e;
					throw new RuntimeException(e);
				}
			if (bases != null) {
				cacheW.put(record.getSequenceName(), new WeakReference<byte[]>(bases));

				if (REF_CACHE != null)
					addToRefCache(md5, bases);
				return ReferenceRegion.copyRegion(bases, record.getSequenceIndex(), record.getSequenceName(),
						start_1based, endInclusive_1based);
			}
		}

		return null;
	}

	protected byte[] findBases(SAMSequenceRecord record, boolean tryNameVariants) {
		{ // check cache by sequence name:
			String name = record.getSequenceName();
			byte[] bases = findInCache(name);
			if (bases != null) {
				log.debug("Reference found in memory cache by name: " + name);
				return bases;
			}
		}

		String md5 = record.getAttribute(SAMSequenceRecord.MD5_TAG);
		{ // check cache by md5:
			if (md5 != null) {
				byte[] bases = findInCache(md5);
				if (bases != null) {
					log.debug("Reference found in memory cache by md5: " + md5);
					return bases;
				}
			}
		}

		byte[] bases;

		{ // try to fetch sequence by name:
			bases = findBasesByName(record.getSequenceName(), tryNameVariants);
			if (bases != null) {
				Utils.upperCase(bases);
				return bases;
			}
		}

		{ // try to fetch sequence by md5:
			if (md5 != null)
				try {
					bases = findBasesByMD5(md5);
				} catch (Exception e) {
					if (e instanceof RuntimeException)
						throw (RuntimeException) e;
					throw new RuntimeException(e);
				}
			if (bases != null) {
				return bases;
			}
		}

		{ // try @SQ:UR file location
			if (record.getAttribute(SAMSequenceRecord.URI_TAG) != null) {
				ReferenceSequenceFromSeekable s = ReferenceSequenceFromSeekable
						.fromString(record.getAttribute(SAMSequenceRecord.URI_TAG));
				bases = s.getSubsequenceAt(record.getSequenceName(), 1, record.getSequenceLength());
				Utils.upperCase(bases);
				return bases;
			}
		}
		return null;
	}

	protected byte[] findBasesByName(String name, boolean tryVariants) {
		if (rsFile == null || !rsFile.isIndexed())
			return null;

		ReferenceSequence sequence = null;
		if (fastaSequenceIndex != null)
			if (fastaSequenceIndex.hasIndexEntry(name))
				sequence = rsFile.getSequence(name);
			else
				sequence = null;

		if (sequence != null)
			return sequence.getBases();

		if (tryVariants) {
			for (String variant : getVariants(name)) {
				try {
					sequence = rsFile.getSequence(variant);
				} catch (Exception e) {
					log.info("Sequence not found: " + variant);
				}
				if (sequence != null) {
					log.debug("Reference found in memory cache for name %s by variant %s", name, variant);
					return sequence.getBases();
				}
			}
		}
		return null;
	}

	/**
	 * @param path
	 * @return true if the path is a valid URL, false otherwise.
	 */
	private static boolean isURL(String path) {
		try {
			URL url = new URL(path);
			return true;
		} catch (MalformedURLException e) {
			return false;
		}
	}

	private byte[] loadFromPath(String path, String md5) throws IOException {
		if (isURL(path)) {
			URL url = new URL(path);
			for (int i = 0; i < downloadTriesBeforeFailing; i++) {
				InputStream is = url.openStream();
				if (is == null)
					return null;

				if (REF_CACHE != null) {
					String localPath = addToRefCache(md5, is);
					File file = new File(localPath);
					if (file.length() > Integer.MAX_VALUE)
						throw new RuntimeException("The reference sequence is too long: " + md5);

					return readBytesFromFile(file, 0, (int) file.length());
				}
				byte[] data = InputStreamUtils.readFully(is);
				is.close();

				if (confirmMD5(md5, data)) {
					// sanitize, Internet is a wild place:
					if (Utils.isValidSequence(data, REF_BASES_TO_CHECK_FOR_SANITY))
						return data;
					else {
						// reject, it looks like garbage
						log.error("Downloaded sequence looks suspicous, rejected: " + url.toExternalForm());
						break;
					}
				}
			}
		} else {
			File file = new File(path);
			if (file.exists()) {
				if (file.length() > Integer.MAX_VALUE)
					throw new RuntimeException("The reference sequence is too long: " + md5);

				byte[] data = readBytesFromFile(file, 0, (int) file.length());

				if (confirmMD5(md5, data))
					return data;
				else
					throw new RuntimeException("MD5 mismatch for cached file: " + file.getAbsolutePath());
			}
		}
		return null;
	}

	protected byte[] findBasesByMD5(String md5) throws MalformedURLException, IOException {
		for (PathPattern p : refPatterns) {
			String path = p.format(md5);
			byte[] data = loadFromPath(path, md5);
			if (data == null)
				continue;
			log.debug("Reference found at the location ", path);
			return data;
		}

		return null;
	}

	private static void addToRefCache(String md5, byte[] data) {
		File cachedFile = new File(new PathPattern(REF_CACHE).format(md5));
		if (!cachedFile.exists()) {
			log.debug(String.format("Adding to REF_CACHE: md5=%s, length=%d", md5, data.length));
			cachedFile.getParentFile().mkdirs();
			File tmpFile;
			try {
				tmpFile = File.createTempFile(md5, ".tmp", cachedFile.getParentFile());
				FileOutputStream fos = new FileOutputStream(tmpFile);
				fos.write(data);
				fos.close();
				if (!cachedFile.exists())
					tmpFile.renameTo(cachedFile);
				else
					tmpFile.delete();
			} catch (IOException e) {
				throw new RuntimeException(e);
			}
		}
	}

	private static String addToRefCache(String md5, InputStream stream) {
		String localPath = new PathPattern(REF_CACHE).format(md5);
		File cachedFile = new File(localPath);
		if (!cachedFile.exists()) {
			log.info(String.format("Adding to REF_CACHE sequence md5=%s", md5));
			cachedFile.getParentFile().mkdirs();
			File tmpFile;
			try {
				tmpFile = File.createTempFile(md5, ".tmp", cachedFile.getParentFile());
				FileOutputStream fos = new FileOutputStream(tmpFile);
				IOUtil.copyStream(stream, fos);
				fos.close();
				if (!cachedFile.exists())
					tmpFile.renameTo(cachedFile);
				else
					tmpFile.delete();
			} catch (IOException e) {
				throw new RuntimeException(e);
			}
		}
		return localPath;
	}

	private boolean confirmMD5(String md5, byte[] data) {
		String downloadedMD5 = null;
		downloadedMD5 = Utils.calculateMD5String(data);
		if (md5.equals(downloadedMD5)) {
			return true;
		} else {
			String message = String.format("Downloaded sequence is corrupt: requested md5=%s, received md5=%s", md5,
					downloadedMD5);
			log.error(message);
			return false;
		}
	}

	protected List<String> getVariants(String name) {
		List<String> variants = new ArrayList<String>();

		if (name.equals("M"))
			variants.add("MT");

		if (name.equals("MT"))
			variants.add("M");

		boolean chrPatternMatch = chrPattern.matcher(name).matches();
		if (chrPatternMatch)
			variants.add(name.substring(3));
		else
			variants.add("chr" + name);

		if ("chrM".equals(name)) {
			// chrM case:
			variants.add("MT");
		}
		return variants;
	}

	public int getDownloadTriesBeforeFailing() {
		return downloadTriesBeforeFailing;
	}

	public void setDownloadTriesBeforeFailing(int downloadTriesBeforeFailing) {
		this.downloadTriesBeforeFailing = downloadTriesBeforeFailing;
	}
}