/**
*  This file is part of FNLP (formerly FudanNLP).
*  
*  FNLP is free software: you can redistribute it and/or modify
*  it under the terms of the GNU Lesser General Public License as published by
*  the Free Software Foundation, either version 3 of the License, or
*  (at your option) any later version.
*  
*  FNLP is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU Lesser General Public License for more details.
*  
*  You should have received a copy of the GNU General Public License
*  along with FudanNLP.  If not, see <http://www.gnu.org/licenses/>.
*  
*  Copyright 2009-2014 www.fnlp.org. All rights reserved. 
*/

package org.fnlp.nlp.similarity.train;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.util.Date;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.fnlp.data.reader.Reader;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.nlp.similarity.Cluster;
import org.fnlp.util.MyArrays;
import org.fnlp.util.MyCollection;
import org.fnlp.util.MyHashSparseArrays;

import gnu.trove.iterator.TIntFloatIterator;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.iterator.TIntObjectIterator;
import gnu.trove.iterator.hash.TObjectHashIterator;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.hash.TIntHashSet;
import gnu.trove.set.hash.TLinkedHashSet;
/**
 * Brown 词聚类算法,单线程版
 * @author xpqiu
 *
 */
public class WordCluster implements Serializable{

	
	private static final long serialVersionUID = 1632709924496094832L;
	private static float ENERGY = 0.999f;
	public int slotsize = 50;	
	int lastid;

	LabelAlphabet alpahbet = new LabelAlphabet();

	TIntObjectHashMap<TIntHashSet> leftnodes = new TIntObjectHashMap<TIntHashSet>();
	TIntObjectHashMap<TIntHashSet> rightnodes = new TIntObjectHashMap<TIntHashSet>();
	TIntObjectHashMap<Cluster> clusters = new TIntObjectHashMap<Cluster>();

	/**
	 * 父节点
	 */
	TIntIntHashMap heads = new TIntIntHashMap(200,0.5f,-1,-1);

	TIntHashSet slots = new TIntHashSet();

	/**
	 * 有向边
	 */
	TIntObjectHashMap<TIntFloatHashMap> pcc = new TIntObjectHashMap<TIntFloatHashMap>();
	/**
	 * 无向边
	 */
	TIntObjectHashMap<TIntFloatHashMap> wcc = new TIntObjectHashMap<TIntFloatHashMap>();

	TIntFloatHashMap wordProb = new TIntFloatHashMap();

	public int totalword;
	/**
	 * 是否持续合并到一个类
	 */
	private boolean meger = true;

	public WordCluster(){

	}

	/**
	 * 读文件,并统计每个字的字频
	 */
	public void read(Reader reader) {
		totalword = 0;
		while (reader.hasNext()) {
			String content = (String) reader.next().getData();
			int prechar = -1;
			wordProb.adjustOrPutValue(prechar, 1, 1);
			totalword += content.length()+2;
			for (int i = 0; i < content.length()+1; i++) {
				int idx;
				if(i<content.length()){
					String c = String.valueOf(content.charAt(i));
					idx = alpahbet.lookupIndex(c);					
				}
				else{
					idx = -2;					
				}
				wordProb.adjustOrPutValue(idx, 1, 1);


				TIntFloatHashMap map = pcc.get(prechar);
				if(map==null){
					map = new TIntFloatHashMap();
					pcc.put(prechar, map);
				}				
				map.adjustOrPutValue(idx, 1, 1);

				TIntHashSet left = leftnodes.get(idx);
				if(left==null){
					left = new TIntHashSet();
					leftnodes.put(idx, left);

				}
				left.add(prechar);

				TIntHashSet right = rightnodes.get(prechar);
				if(right==null){
					right = new TIntHashSet();
					rightnodes.put(prechar, right );
				}
				right.add(idx);		
				prechar = idx;
			}
		}
		lastid = alpahbet.size();
		
		System.out.println("[总个数:]\t" + totalword);
		int size  = alpahbet.size();
		System.out.println("[字典大小:]\t" + size);

		statisticProb();

	}

	/**
	 * 一次性统计概率,节约时间
	 */
	private void statisticProb() {
		System.out.println("统计概率");
		TIntFloatIterator it = wordProb.iterator();
		while(it.hasNext()){
			it.advance();
			float v = it.value()/totalword;
			it.setValue(v);
			int key = it.key();
			if(key<0)
				continue;
			Cluster cluster = new Cluster(key,v,alpahbet.lookupString(key));
			clusters.put(key, cluster);
		}

		TIntObjectIterator<TIntFloatHashMap> it1 = pcc.iterator();
		while(it1.hasNext()){
			it1.advance();
			TIntFloatHashMap map = it1.value();
			TIntFloatIterator it2 = map.iterator();
			while(it2.hasNext()){
				it2.advance();
				it2.setValue(it2.value()/totalword);
			}
		}

	}


	/**
	 * total graph weight
	 * 
	 * @param c1
	 * @param c2
	 * @param b 
	 * @return
	 */
	private float weight(int c1, int c2) {
		float w;
		float pc1 = wordProb.get(c1);
		float pc2 = wordProb.get(c2);
		if (c1==c2) {
			float pcc = getProb(c1,c1);
			w =  clacW(pcc,pc1,pc2);
		} else {
			float pcc1 = getProb(c1, c2);			
			float p1= clacW(pcc1,pc1,pc2);			

			float pcc2 = getProb(c2, c1);			
			float p2 = clacW(pcc2,pc2,pc1);			
			w =  p1 + p2;
		}
		setweight(c1, c2, w);
		return w;
	}


	/**
	 * 计算c1,c2合并后与k的权重
	 * @param c1
	 * @param c2
	 * @param k
	 * @return
	 */
	private float weight(int c1, int c2, int k) {
		float w;
		float pc1 = wordProb.get(c1);
		float pc2 = wordProb.get(c2);
		float pck = wordProb.get(k);
		//新类的概率
		float pc = pc1+pc2;

		if (c1==k) {			
			float pcc1 = getProb(c1,c1);
			float pcc2 = getProb(c2,c2);
			float pcc3 = getProb(c1,c2);
			float pcc4 = getProb(c2,c1);
			float pcc = pcc1 + pcc2 + pcc3 + pcc4;
			w = clacW(pcc,pc,pc);			

		} else {

			float pcc1 = getProb(c1,k);
			float pcc2 = getProb(c2,k);

			float pcc12 = pcc1 + pcc2;			
			float p1 = clacW(pcc12,pc,pck);

			float pcc3 = getProb(k,c1);
			float pcc4 = getProb(k,c2);			
			float pcc34 = pcc3 + pcc4;			
			float p2 = clacW(pcc34,pck,pc);
			w =  p1 + p2;
		}
		return w;
	}

	private float clacW(float pcc, float pc1, float pc2) {
		float p= 0;
		if(pcc!=0f)
			p =pcc *  (float) (Math.log(pcc) - Math.log(pc1) - Math.log(pc2));
		//		if(Float.isInfinite(p)||Float.isNaN(p))
		//			return p;		
		return p;
	}

	private float getProb(int c1, int c2) {
		float p;
		TIntFloatHashMap map = pcc.get(c1);
		if(map == null){
			p = 0f;
		}else{
			p = pcc.get(c1).get(c2);						
		}
		return p;
	}


	/**
	 * merge clusters
	 */
	public void mergeCluster() {
		int maxc1 = -1;
		int maxc2 = -1;
		float maxL = Float.NEGATIVE_INFINITY;
		TIntIterator it1 = slots.iterator();		
		while(it1.hasNext()){
			int i = it1.next();
			TIntIterator it2 = slots.iterator();
			//			System.out.print(i+": ");
			while(it2.hasNext()){
				int j= it2.next();

				if(i>=j)
					continue;
				//				System.out.print(j+" ");
				float L = calcL(i, j);
				//				System.out.print(L+" ");
				if (L > maxL) {
					maxL = L;
					maxc1 = i;
					maxc2 = j;
				}
			}
			//			System.out.println();
		}
		//		if(maxL == Float.NEGATIVE_INFINITY )
		//			return;

		merge(maxc1,maxc2);
	}
	
	/**
	 * 合并c1和c2
	 * @param c1
	 * @param c2
	 */

	protected void merge(int c1, int c2) {
		int newid = lastid++;
		heads.put(c1, newid);
		heads.put(c2, newid);
		TIntFloatHashMap newpcc = new TIntFloatHashMap();
		TIntFloatHashMap inewpcc = new TIntFloatHashMap();
		TIntFloatHashMap newwcc = new TIntFloatHashMap();
		float pc1 = wordProb.get(c1);
		float pc2 = wordProb.get(c2);		
		//新类的概率
		float pc = pc1+pc2;

		float w;
		{
			float pcc1 = getProb(c1,c1);
			float pcc2 = getProb(c2,c2);
			float pcc3 = getProb(c1,c2);
			float pcc4 = getProb(c2,c1);
			float pcc = pcc1 + pcc2 + pcc3 + pcc4;
			if(pcc!=0.0f)
				newpcc.put(newid, pcc);
			w = clacW(pcc,pc,pc);
			if(w!=0.0f)
				newwcc.put(newid, w);
		}
		TIntIterator it = slots.iterator();
		while(it.hasNext()){
			int k = it.next();

			float pck = wordProb.get(k);			
			if (c1==k||c2==k) {			
				continue;
			} else {				
				float pcc1 = getProb(c1,k);
				float pcc2 = getProb(c2,k);
				float pcc12 = pcc1 + pcc2;
				if(pcc12!=0.0f)
					newpcc.put(newid, pcc12);
				float p1 = clacW(pcc12,pc,pck);

				float pcc3 = getProb(k,c1);
				float pcc4 = getProb(k,c2);			
				float pcc34 = pcc3 + pcc4;
				if(pcc34!=0.0f)
					inewpcc.put(k, pcc34);	
				float p2 = clacW(pcc34,pck,pc);
				w =  p1 + p2;
				if(w!=0.0f)
					newwcc.put(newid, w);
			}
		}

		//更新slots
		slots.remove(c1);
		slots.remove(c2);
		slots.add(newid);
		pcc.put(newid, newpcc);
		pcc.remove(c1);
		pcc.remove(c2);
		TIntFloatIterator it2 = inewpcc.iterator();
		while(it2.hasNext()){
			it2.advance();
			TIntFloatHashMap pmap = pcc.get(it2.key());
			//						if(pmap==null){
			//							pmap = new TIntFloatHashMap();
			//							pcc.put(it2.key(), pmap);
			//						}
			pmap.put(newid, it2.value());
			pmap.remove(c1);
			pmap.remove(c2);
		}


		//
		//newid 永远大于 it3.key;
		wcc.put(newid, new TIntFloatHashMap());
		wcc.remove(c1);
		wcc.remove(c2);
		TIntFloatIterator it3 = newwcc.iterator();
		while(it3.hasNext()){
			it3.advance();
			TIntFloatHashMap pmap = wcc.get(it3.key());
			pmap.put(newid, it3.value());
			pmap.remove(c1);
			pmap.remove(c2);
		}

		wordProb.remove(c1);
		wordProb.remove(c2);
		wordProb.put(newid, pc);

		//修改cluster
		Cluster cluster = new Cluster(newid, clusters.get(c1),clusters.get(c2),pc);
		clusters.put(newid, cluster);
		System.out.println("合并:"+cluster.rep);
		
	}

	/**
	 * calculate the value L
	 * 
	 * @param c1
	 * @param c2
	 * @return
	 */
	public float calcL(int c1, int c2) {
		float L = 0;

		TIntIterator it = slots.iterator();
		while(it.hasNext()){
			int k = it.next();
			if(k==c2)
				continue;
			L += weight(c1,c2,k);
		}

		it = slots.iterator();
		while(it.hasNext()){
			int k = it.next();
			L -= getweight(c1,k);
			L -= getweight(c2, k);
		}
		return L;

	}



	private void setweight(int c1, int c2, float w) {
		if(w==0.0f)
			return;
		int max,min;
		if(c1<=c2){
			max = c2;
			min = c1;
		}else{
			max = c1;
			min = c2;
		}
		TIntFloatHashMap map2 = wcc.get(min);
		if(map2==null){
			map2 = new TIntFloatHashMap();
			wcc.put(min, map2);
		}
		map2.put(max, w);
	}

	private float getweight(int c1, int c2) {
		int max,min;
		if(c1<=c2){
			max = c2;
			min = c1;
		}else{
			max = c1;
			min = c2;
		}
		float w;
		TIntFloatHashMap map2 = wcc.get(min);
		if(map2==null){
			w = 0;
		}else
			w = map2.get(max);
		return w;
	}

	/**
	 * start clustering
	 */
	public Cluster startClustering() {



//		int[] idx = MyCollection.sort(wordProb);
		wordProb.remove(-1);
		wordProb.remove(-2);

		int[] idx = MyHashSparseArrays.trim(wordProb, ENERGY);

		int mergeCount  = idx.length;
		int remainCount  = idx.length;
		
		System.out.println("[待合并个数:]\t" +mergeCount );
		System.out.println("[总个数:]\t" + totalword);
		
		int round;
		for (round = 0; round< Math.min(slotsize,mergeCount); round++) {
			slots.add(idx[round]);
			System.out.println(round + "\t" + alpahbet.lookupString(idx[round]) + "\t" + slots.size());

		}
		TIntIterator it1 = slots.iterator();

		while(it1.hasNext()){
			int i = it1.next();
			TIntIterator it2 = slots.iterator();
			while(it2.hasNext()){
				int j= it2.next();
				if(i>j)
					continue;
				weight(i, j);
			}
		}
		
		while (slots.size()>1) {
			if(round < mergeCount)
				System.out.println(round + "\t" + alpahbet.lookupString(idx[round]) + "\tSize:\t" +slots.size());
			else
				System.out.println(round + "\t" + "\tSize:\t" +slots.size());
			System.out.println("[待合并个数:]\t" + remainCount-- );
			long starttime = System.currentTimeMillis();
			mergeCluster();
			long endtime = System.currentTimeMillis();
			System.out.println("\tTime:\t" + (endtime-starttime)/1000.0);
			if(round < mergeCount){
				int id = idx[round];
				slots.add(id);
				TIntIterator it = slots.iterator();
				while(it.hasNext()){
					int j= it.next();
					weight(j, id);
				}
			}else{
				if(!meger )
					return null;
			}
			try {
				saveTxt("../tmp/res-"+round);
			} catch (Exception e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
			round++;
		}

		return clusters.get(slots.toArray()[0]);
		

	}

	public String toString(){
		StringBuilder sb = new StringBuilder();

		TIntObjectHashMap<TLinkedHashSet<String>> sets = new TIntObjectHashMap<TLinkedHashSet<String>>();

		for(int i=0;i<alpahbet.size();i++){
			int head = getHead(i);
			TLinkedHashSet<String> s = sets.get(head);
			if(s==null){
				s = new TLinkedHashSet();
				sets.put(head, s);
			}
			s.add(alpahbet.lookupString(i));
		}

		TIntObjectIterator<TLinkedHashSet<String>> it = sets.iterator();
		while(it.hasNext()){
			it.advance();
			if(it.value().size()<2)
				continue;
			sb.append(wordProb.get(it.key()));
			sb.append(" ");
			TObjectHashIterator<String> itt = it.value().iterator();
			while(itt.hasNext()){
				String ss = itt.next();
				sb.append(ss);
				sb.append(" ");
			}
			sb.append("\n");
		}

		return sb.toString();

	}

	private int getHead(int i) {
		int h = heads.get(i);
		if(h==-1)
			return i;
		else
			return getHead(h);
	}

	/**
	 * 将模型存储到文件
	 * @param file
	 * @throws IOException
	 */
	public void saveModel(String file) throws IOException {
		File f = new File(file);
		File path = f.getParentFile();
		if(!path.exists()){
			path.mkdirs();
		}
		ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream(
				new BufferedOutputStream(new FileOutputStream(file))));
		out.writeObject(this);
		out.close();
	}

	public static  WordCluster loadFrom(String file) throws IOException,
	ClassNotFoundException {
		ObjectInputStream in = new ObjectInputStream(new GZIPInputStream(
				new BufferedInputStream(new FileInputStream(file))));
		WordCluster cl = (WordCluster) in.readObject();
		in.close();
		return cl;
	}

	/**
	 * 将结果保存到文件
	 * @param file
	 * @throws Exception
	 */
	public void saveTxt(String file) throws Exception {
		FileOutputStream fos = new FileOutputStream(file);
		BufferedWriter bout = new BufferedWriter(new OutputStreamWriter(
				fos, "UTF8"));
		bout.write(this.toString());
		bout.close();

	}

	/**
	 * @param args
	 * @throws Exception 
	 */
	public static void main(String[] args) throws Exception {

		/**
		 * 分析命令参数
		 */
		Options opt = new Options();

		opt.addOption("path", true, "保存路径");
		opt.addOption("res", true, "评测结果保存路径");
		opt.addOption("slot", true, "槽大小");

		BasicParser parser = new BasicParser();
		CommandLine cl;
		try {
			cl = parser.parse(opt, args);
		} catch (Exception e) {
			System.err.println("Parameters format error");
			return;
		}

		int slotsize = Integer.parseInt(cl.getOptionValue("slot", "50"));
		System.out.println("槽大小:"+slotsize);

		String file = cl.getOptionValue("path", "./tmp/news.allsites.txt");
		System.out.println("数据路径:"+file);

		String resfile = cl.getOptionValue("res", "./tmp/res.txt");
		System.out.println("测试结果:"+resfile);


		SougouCA sca = new SougouCA(file);

		WordCluster wc = new WordCluster();
		wc.slotsize = slotsize;
		wc.read(sca);

		wc.startClustering();
		wc.saveModel(resfile+".m");
		wc.saveTxt(resfile);		
		wc = WordCluster.loadFrom(resfile+".m");
		wc.saveTxt(resfile+"1");
		System.out.println(new Date().toString());
		System.out.println("Done");
	}
}