#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2014 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html


"""
Python wrapper for Latent Dirichlet Allocation (LDA) from MALLET, the Java topic modelling
toolkit [1]_.

This module allows both LDA model estimation from a training corpus and inference of topic
distribution on new, unseen documents, using an (optimized version of) collapsed
gibbs sampling from MALLET.

MALLET's LDA training requires O(#corpus_words) of memory, keeping the entire corpus in RAM.
If you find yourself running out of memory, either decrease the `workers` constructor
parameter, or use `LdaModel` which needs only O(1) memory.

The wrapped model can NOT be updated with new documents for online training -- use gensim's `LdaModel` for that.

Example:

>>> model = gensim.models.LdaMallet('/Users/kofola/mallet-2.0.7/bin/mallet', corpus=my_corpus, num_topics=20, id2word=dictionary)
>>> print model[my_vector]  # print LDA topics of a document

.. [1] http://mallet.cs.umass.edu/

"""


import logging
import random
import tempfile
import os
from subprocess import call

import numpy

from gensim import utils

logger = logging.getLogger('gensim.models.ldamallet')


def read_doctopics(fname, eps=1e-6):
    """
    Yield document topic vectors from MALLET's "doc-topics" format, as sparse gensim vectors.

    """
    with utils.smart_open(fname) as fin:
        next(fin)  # skip the header line
        for lineno, line in enumerate(fin):
            parts = line.split()[2:]  # skip "doc" and "source" columns
            if len(parts) % 2 != 0:
                raise RuntimeError("invalid doc topics format at line %i in %s" % (lineno + 1, fname))
            doc = [(int(id), float(weight)) for id, weight in zip(parts[::2], parts[1::2]) if abs(float(weight)) > eps]
            # explicitly normalize probs to sum up to 1.0, just to be sure...
            weights = float(sum([weight for _, weight in doc]))
            yield [] if weights == 0 else sorted((id, 1.0 * weight / weights) for id, weight in doc)



class LdaMallet(utils.SaveLoad):
    """
    Class for LDA training using MALLET. Communication between MALLET and Python
    takes place by passing around data files on disk and calling Java with subprocess.call().

    """
    def __init__(self, mallet_path, corpus=None, num_topics=100, id2word=None, workers=4, prefix=None,
                 optimize_interval=0, iterations=1000):
        """
        `mallet_path` is path to the mallet executable, e.g. `/home/kofola/mallet-2.0.7/bin/mallet`.
        `corpus` is a gensim corpus, aka a stream of sparse document vectors.
        `id2word` is a mapping between tokens ids and token.
        `workers` is the number of threads, for parallel training.
        `prefix` is the string prefix under which all data files will be stored; default: system temp + random filename prefix.
        `optimize_interval` optimize hyperparameters every N iterations (sometimes leads to Java exception; 0 to switch off hyperparameter optimization).
        `iterations` is the number of sampling iterations.

        """
        self.mallet_path = mallet_path
        self.id2word = id2word
        if self.id2word is None:
            logger.warning("no word id mapping provided; initializing from corpus, assuming identity")
            self.id2word = utils.dict_from_corpus(corpus)
            self.num_terms = len(self.id2word)
        else:
            self.num_terms = 0 if not self.id2word else 1 + max(self.id2word.keys())
        if self.num_terms == 0:
            raise ValueError("cannot compute LDA over an empty collection (no terms)")
        self.num_topics = num_topics
        if prefix is None:
            rand_prefix = hex(random.randint(0, 0xffffff))[2:] + '_'
            prefix = os.path.join(tempfile.gettempdir(), rand_prefix)
        self.prefix = prefix
        self.workers = workers
        self.optimize_interval = optimize_interval
        self.iterations = iterations

        if corpus is not None:
            self.train(corpus)

    def finferencer(self):
        return self.prefix + 'inferencer.mallet'

    def ftopickeys(self):
        return self.prefix + 'topickeys.txt'

    def fstate(self):
        return self.prefix + 'state.mallet.gz'

    def fdoctopics(self):
        return self.prefix + 'doctopics.txt'

    def fcorpustxt(self):
        return self.prefix + 'corpus.txt'

    def fcorpusmallet(self):
        return self.prefix + 'corpus.mallet'

    def fwordweights(self):
        return self.prefix + 'wordweights.txt'

    def convert_input(self, corpus, infer=False):
        """
        Serialize documents (lists of unicode tokens) to a temporary text file,
        then convert that text file to MALLET format `outfile`.

        """
        logger.info("serializing temporary corpus to %s" % self.fcorpustxt())
        # write out the corpus in a file format that MALLET understands: one document per line:
        # document id[SPACE]label (not used)[SPACE]whitespace delimited utf8-encoded tokens
        with utils.smart_open(self.fcorpustxt(), 'wb') as fout:
            for docno, doc in enumerate(corpus):
                if self.id2word:
                    tokens = sum(([self.id2word[tokenid]] * int(cnt) for tokenid, cnt in doc), [])
                else:
                    tokens = sum(([str(tokenid)] * int(cnt) for tokenid, cnt in doc), [])
                fout.write(utils.to_utf8("%s 0 %s\n" % (docno, ' '.join(tokens))))

        # convert the text file above into MALLET's internal format
        cmd = self.mallet_path + " import-file --preserve-case --keep-sequence --remove-stopwords --token-regex '\S+' --input %s --output %s"
        if infer:
            cmd += ' --use-pipe-from ' + self.fcorpusmallet()
            cmd = cmd % (self.fcorpustxt(), self.fcorpusmallet() + '.infer')
        else:
            cmd = cmd % (self.fcorpustxt(), self.fcorpusmallet())
        logger.info("converting temporary corpus to MALLET format with %s" % cmd)
        call(cmd, shell=True)


    def train(self, corpus):
        self.convert_input(corpus, infer=False)
        cmd = self.mallet_path + " train-topics --input %s --num-topics %s --optimize-interval %s "\
            "--num-threads %s --output-state %s --output-doc-topics %s --output-topic-keys %s "\
            "--num-iterations %s --inferencer-filename %s"
        cmd = cmd % (self.fcorpusmallet(), self.num_topics, self.optimize_interval, self.workers,
            self.fstate(), self.fdoctopics(), self.ftopickeys(), self.iterations, self.finferencer())
        # NOTE "--keep-sequence-bigrams" / "--use-ngrams true" poorer results + runs out of memory
        logger.info("training MALLET LDA with %s" % cmd)
        call(cmd, shell=True)
        self.word_topics = self.load_word_topics()


    def __getitem__(self, bow, iterations=100):
        is_corpus, corpus = utils.is_corpus(bow)
        if not is_corpus:
            # query is a single document => make a corpus out of it
            bow = [bow]

        self.convert_input(bow, infer=True)
        cmd = self.mallet_path + " infer-topics --input %s --inferencer %s --output-doc-topics %s --num-iterations %s"
        cmd = cmd % (self.fcorpusmallet() + '.infer', self.finferencer(), self.fdoctopics() + '.infer', iterations)
        logger.info("inferring topics with MALLET LDA '%s'" % cmd)
        retval = call(cmd, shell=True)
        if retval != 0:
            raise RuntimeError("MALLET failed with error %s on return" % retval)
        result = list(read_doctopics(self.fdoctopics() + '.infer'))
        return result if is_corpus else result[0]


    def load_word_topics(self):
        logger.info("loading assigned topics from %s" % self.fstate())
        wordtopics = numpy.zeros((self.num_topics, self.num_terms), dtype=numpy.float32)
        with utils.smart_open(self.fstate()) as fin:
            _ = next(fin)  # header
            self.alpha = numpy.array([float(val) for val in next(fin).split()[2:]])
            assert len(self.alpha) == self.num_topics, "mismatch between MALLET vs. requested topics"
            _ = next(fin)  # beta
            for lineno, line in enumerate(fin):
                line = utils.to_unicode(line)
                doc, source, pos, typeindex, token, topic = line.split()
                tokenid = self.id2word.token2id[token] if hasattr(self.id2word, 'token2id') else int(token)
                wordtopics[int(topic), tokenid] += 1
        logger.info("loaded assigned topics for %i tokens" % wordtopics.sum())
        self.wordtopics = wordtopics
        self.print_topics(15)


    def print_topics(self, num_topics=10, num_words=10):
        return self.show_topics(num_topics, num_words, log=True)


    def show_topics(self, num_topics=10, num_words=10, log=False, formatted=True):
        """
        Print the `num_words` most probable words for `num_topics` number of topics.
        Set `num_topics=-1` to print all topics.

        Set `formatted=True` to return the topics as a list of strings, or `False` as lists of (weight, word) pairs.

        """
        if num_topics < 0 or num_topics >= self.num_topics:
            num_topics = self.num_topics
            chosen_topics = range(num_topics)
        else:
            num_topics = min(num_topics, self.num_topics)
            sort_alpha = self.alpha + 0.0001 * numpy.random.rand(len(self.alpha)) # add a little random jitter, to randomize results around the same alpha
            sorted_topics = list(numpy.argsort(sort_alpha))
            chosen_topics = sorted_topics[ : num_topics//2] + sorted_topics[-num_topics//2 : ]
        shown = []
        for i in chosen_topics:
            if formatted:
                topic = self.print_topic(i, topn=num_words)
            else:
                topic = self.show_topic(i, topn=num_words)
            shown.append(topic)
            if log:
                logger.info("topic #%i (%.3f): %s" % (i, self.alpha[i], topic))
        return shown


    def show_topic(self, topicid, topn=10):
        topic = self.wordtopics[topicid]
        topic = topic / topic.sum() # normalize to probability dist
        bestn = numpy.argsort(topic)[::-1][:topn]
        beststr = [(topic[id], self.id2word[id]) for id in bestn]
        return beststr


    def print_topic(self, topicid, topn=10):
        return ' + '.join(['%.3f*%s' % v for v in self.show_topic(topicid, topn)])