#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking transformation algorithms (the models package).
"""


import logging
import numbers
import os
import unittest
import copy

import numpy as np
from numpy.testing import assert_allclose

from gensim.corpora import mmcorpus, Dictionary
from gensim.models import ldamodel, ldamulticore
from gensim import matutils, utils
from gensim.test import basetmtests
from gensim.test.utils import datapath, get_tmpfile, common_texts

GITHUB_ACTIONS_WINDOWS = os.environ.get('RUNNER_OS') == 'Windows'

dictionary = Dictionary(common_texts)
corpus = [dictionary.doc2bow(text) for text in common_texts]


def test_random_state():
    testcases = [np.random.seed(0), None, np.random.RandomState(0), 0]
    for testcase in testcases:
        assert isinstance(utils.get_random_state(testcase), np.random.RandomState)


class TestLdaModel(unittest.TestCase, basetmtests.TestBaseTopicModel):
    def setUp(self):
        self.corpus = mmcorpus.MmCorpus(datapath('testcorpus.mm'))
        self.class_ = ldamodel.LdaModel
        self.model = self.class_(corpus, id2word=dictionary, num_topics=2, passes=100)

    def test_sync_state(self):
        model2 = self.class_(corpus=self.corpus, id2word=dictionary, num_topics=2, passes=1)
        model2.state = copy.deepcopy(self.model.state)
        model2.sync_state()

        assert_allclose(self.model.get_term_topics(2), model2.get_term_topics(2), rtol=1e-5)
        assert_allclose(self.model.get_topics(), model2.get_topics(), rtol=1e-5)

        # properly continues training on the new state
        self.model.random_state = np.random.RandomState(0)
        model2.random_state = np.random.RandomState(0)
        self.model.passes = 1
        model2.passes = 1
        self.model.update(self.corpus)
        model2.update(self.corpus)

        assert_allclose(self.model.get_term_topics(2), model2.get_term_topics(2), rtol=1e-5)
        assert_allclose(self.model.get_topics(), model2.get_topics(), rtol=1e-5)

    def test_transform(self):
        passed = False
        # sometimes, LDA training gets stuck at a local minimum
        # in that case try re-training the model from scratch, hoping for a
        # better random initialization
        for i in range(25):  # restart at most 5 times
            # create the transformation model
            model = self.class_(id2word=dictionary, num_topics=2, passes=100)
            model.update(self.corpus)

            # transform one document
            doc = list(corpus)[0]
            transformed = model[doc]

            vec = matutils.sparse2full(transformed, 2)  # convert to dense vector, for easier equality tests
            expected = [0.13, 0.87]
            # must contain the same values, up to re-ordering
            passed = np.allclose(sorted(vec), sorted(expected), atol=1e-1)
            if passed:
                break
            logging.warning(
                "LDA failed to converge on attempt %i (got %s, expected %s)", i, sorted(vec), sorted(expected)
            )
        self.assertTrue(passed)

    def test_alpha_auto(self):
        model1 = self.class_(corpus, id2word=dictionary, alpha='symmetric', passes=10)
        modelauto = self.class_(corpus, id2word=dictionary, alpha='auto', passes=10)

        # did we learn something?
        self.assertFalse(all(np.equal(model1.alpha, modelauto.alpha)))

    def test_alpha(self):
        kwargs = dict(
            id2word=dictionary,
            num_topics=2,
            alpha=None
        )
        expected_shape = (2,)

        # should not raise anything
        self.class_(**kwargs)

        kwargs['alpha'] = 'symmetric'
        model = self.class_(**kwargs)
        self.assertEqual(model.alpha.shape, expected_shape)
        assert_allclose(model.alpha, np.array([0.5, 0.5]))

        kwargs['alpha'] = 'asymmetric'
        model = self.class_(**kwargs)
        self.assertEqual(model.alpha.shape, expected_shape)
        assert_allclose(model.alpha, [0.630602, 0.369398], rtol=1e-5)

        kwargs['alpha'] = 0.3
        model = self.class_(**kwargs)
        self.assertEqual(model.alpha.shape, expected_shape)
        assert_allclose(model.alpha, np.array([0.3, 0.3]))

        kwargs['alpha'] = 3
        model = self.class_(**kwargs)
        self.assertEqual(model.alpha.shape, expected_shape)
        assert_allclose(model.alpha, np.array([3, 3]))

        kwargs['alpha'] = [0.3, 0.3]
        model = self.class_(**kwargs)
        self.assertEqual(model.alpha.shape, expected_shape)
        assert_allclose(model.alpha, np.array([0.3, 0.3]))

        kwargs['alpha'] = np.array([0.3, 0.3])
        model = self.class_(**kwargs)
        self.assertEqual(model.alpha.shape, expected_shape)
        assert_allclose(model.alpha, np.array([0.3, 0.3]))

        # all should raise an exception for being wrong shape
        kwargs['alpha'] = [0.3, 0.3, 0.3]
        self.assertRaises(AssertionError, self.class_, **kwargs)

        kwargs['alpha'] = [[0.3], [0.3]]
        self.assertRaises(AssertionError, self.class_, **kwargs)

        kwargs['alpha'] = [0.3]
        self.assertRaises(AssertionError, self.class_, **kwargs)

        kwargs['alpha'] = "gensim is cool"
        self.assertRaises(ValueError, self.class_, **kwargs)

    def test_eta_auto(self):
        model1 = self.class_(corpus, id2word=dictionary, eta='symmetric', passes=10)
        modelauto = self.class_(corpus, id2word=dictionary, eta='auto', passes=10)

        # did we learn something?
        self.assertFalse(np.allclose(model1.eta, modelauto.eta))

    def test_eta(self):
        kwargs = dict(
            id2word=dictionary,
            num_topics=2,
            eta=None
        )
        num_terms = len(dictionary)
        expected_shape = (num_terms,)

        # should not raise anything
        model = self.class_(**kwargs)
        self.assertEqual(model.eta.shape, expected_shape)
        assert_allclose(model.eta, np.array([0.5] * num_terms))

        kwargs['eta'] = 'symmetric'
        model = self.class_(**kwargs)
        self.assertEqual(model.eta.shape, expected_shape)
        assert_allclose(model.eta, np.array([0.5] * num_terms))

        kwargs['eta'] = 0.3
        model = self.class_(**kwargs)
        self.assertEqual(model.eta.shape, expected_shape)
        assert_allclose(model.eta, np.array([0.3] * num_terms))

        kwargs['eta'] = 3
        model = self.class_(**kwargs)
        self.assertEqual(model.eta.shape, expected_shape)
        assert_allclose(model.eta, np.array([3] * num_terms))

        kwargs['eta'] = [0.3] * num_terms
        model = self.class_(**kwargs)
        self.assertEqual(model.eta.shape, expected_shape)
        assert_allclose(model.eta, np.array([0.3] * num_terms))

        kwargs['eta'] = np.array([0.3] * num_terms)
        model = self.class_(**kwargs)
        self.assertEqual(model.eta.shape, expected_shape)
        assert_allclose(model.eta, np.array([0.3] * num_terms))

        # should be ok with num_topics x num_terms
        testeta = np.array([[0.5] * len(dictionary)] * 2)
        kwargs['eta'] = testeta
        self.class_(**kwargs)

        # all should raise an exception for being wrong shape
        kwargs['eta'] = testeta.reshape(tuple(reversed(testeta.shape)))
        self.assertRaises(AssertionError, self.class_, **kwargs)

        kwargs['eta'] = [0.3]
        self.assertRaises(AssertionError, self.class_, **kwargs)

        kwargs['eta'] = [0.3] * (num_terms + 1)
        self.assertRaises(AssertionError, self.class_, **kwargs)

        kwargs['eta'] = "gensim is cool"
        self.assertRaises(ValueError, self.class_, **kwargs)

        kwargs['eta'] = "asymmetric"
        self.assertRaises(ValueError, self.class_, **kwargs)

    def test_top_topics(self):
        top_topics = self.model.top_topics(self.corpus)

        for topic, score in top_topics:
            self.assertTrue(isinstance(topic, list))
            self.assertTrue(isinstance(score, float))

            for v, k in topic:
                self.assertTrue(isinstance(k, str))
                self.assertTrue(np.issubdtype(v, np.floating))

    def test_get_topic_terms(self):
        topic_terms = self.model.get_topic_terms(1)

        for k, v in topic_terms:
            self.assertTrue(isinstance(k, numbers.Integral))
            self.assertTrue(np.issubdtype(v, np.floating))

    @unittest.skipIf(GITHUB_ACTIONS_WINDOWS, 'see <https://github.com/RaRe-Technologies/gensim/pull/2836>')
    def test_get_document_topics(self):

        model = self.class_(
            self.corpus, id2word=dictionary, num_topics=2, passes=100, random_state=np.random.seed(0)
        )

        doc_topics = model.get_document_topics(self.corpus)

        for topic in doc_topics:
            self.assertTrue(isinstance(topic, list))
            for k, v in topic:
                self.assertTrue(isinstance(k, numbers.Integral))
                self.assertTrue(np.issubdtype(v, np.floating))

        # Test case to use the get_document_topic function for the corpus
        all_topics = model.get_document_topics(self.corpus, per_word_topics=True)

        self.assertEqual(model.state.numdocs, len(corpus))

        for topic in all_topics:
            self.assertTrue(isinstance(topic, tuple))
            for k, v in topic[0]:  # list of doc_topics
                self.assertTrue(isinstance(k, numbers.Integral))
                self.assertTrue(np.issubdtype(v, np.floating))

            for w, topic_list in topic[1]:  # list of word_topics
                self.assertTrue(isinstance(w, numbers.Integral))
                self.assertTrue(isinstance(topic_list, list))

            for w, phi_values in topic[2]:  # list of word_phis
                self.assertTrue(isinstance(w, numbers.Integral))
                self.assertTrue(isinstance(phi_values, list))

        # Test case to check the filtering effect of minimum_probability and minimum_phi_value
        doc_topic_count_na = 0
        word_phi_count_na = 0

        all_topics = model.get_document_topics(
            self.corpus, minimum_probability=0.8, minimum_phi_value=1.0, per_word_topics=True
        )

        self.assertEqual(model.state.numdocs, len(corpus))

        for topic in all_topics:
            self.assertTrue(isinstance(topic, tuple))
            for k, v in topic[0]:  # list of doc_topics
                self.assertTrue(isinstance(k, numbers.Integral))
                self.assertTrue(np.issubdtype(v, np.floating))
                if len(topic[0]) != 0:
                    doc_topic_count_na += 1

            for w, topic_list in topic[1]:  # list of word_topics
                self.assertTrue(isinstance(w, numbers.Integral))
                self.assertTrue(isinstance(topic_list, list))

            for w, phi_values in topic[2]:  # list of word_phis
                self.assertTrue(isinstance(w, numbers.Integral))
                self.assertTrue(isinstance(phi_values, list))
                if len(phi_values) != 0:
                    word_phi_count_na += 1

        self.assertTrue(model.state.numdocs > doc_topic_count_na)
        self.assertTrue(sum(len(i) for i in corpus) > word_phi_count_na)

        doc_topics, word_topics, word_phis = model.get_document_topics(self.corpus[1], per_word_topics=True)

        for k, v in doc_topics:
            self.assertTrue(isinstance(k, numbers.Integral))
            self.assertTrue(np.issubdtype(v, np.floating))

        for w, topic_list in word_topics:
            self.assertTrue(isinstance(w, numbers.Integral))
            self.assertTrue(isinstance(topic_list, list))

        for w, phi_values in word_phis:
            self.assertTrue(isinstance(w, numbers.Integral))
            self.assertTrue(isinstance(phi_values, list))

        # word_topics looks like this: ({word_id => [topic_id_most_probable, topic_id_second_most_probable, ...]).
        # we check one case in word_topics, i.e of the first word in the doc, and its likely topics.

        # FIXME: Fails on osx and win
        # expected_word = 0
        # self.assertEqual(word_topics[0][0], expected_word)
        # self.assertTrue(0 in word_topics[0][1])

    def test_term_topics(self):

        model = self.class_(
            self.corpus, id2word=dictionary, num_topics=2, passes=100, random_state=np.random.seed(0)
        )

        # check with word_type
        result = model.get_term_topics(2)
        for topic_no, probability in result:
            self.assertTrue(isinstance(topic_no, int))
            self.assertTrue(np.issubdtype(probability, np.floating))

        # checks if topic '1' is in the result list
        # FIXME: Fails on osx and win
        # self.assertTrue(1 in result[0])

        # if user has entered word instead, check with word
        result = model.get_term_topics(str(model.id2word[2]))
        for topic_no, probability in result:
            self.assertTrue(isinstance(topic_no, int))
            self.assertTrue(np.issubdtype(probability, np.floating))

        # checks if topic '1' is in the result list
        # FIXME: Fails on osx and win
        # self.assertTrue(1 in result[0])

    def test_passes(self):
        # long message includes the original error message with a custom one
        self.longMessage = True
        # construct what we expect when passes aren't involved
        test_rhots = list()
        model = self.class_(id2word=dictionary, chunksize=1, num_topics=2)

        def final_rhot(model):
            return pow(model.offset + (1 * model.num_updates) / model.chunksize, -model.decay)

        # generate 5 updates to test rhot on
        for x in range(5):
            model.update(self.corpus)
            test_rhots.append(final_rhot(model))

        for passes in [1, 5, 10, 50, 100]:
            model = self.class_(id2word=dictionary, chunksize=1, num_topics=2, passes=passes)
            self.assertEqual(final_rhot(model), 1.0)
            # make sure the rhot matches the test after each update
            for test_rhot in test_rhots:
                model.update(self.corpus)

                msg = ", ".join(str(x) for x in [passes, model.num_updates, model.state.numdocs])
                self.assertAlmostEqual(final_rhot(model), test_rhot, msg=msg)

            self.assertEqual(model.state.numdocs, len(corpus) * len(test_rhots))
            self.assertEqual(model.num_updates, len(corpus) * len(test_rhots))

    # def test_topic_seeding(self):
    #     for topic in range(2):
    #         passed = False
    #         for i in range(5):  # restart at most this many times, to mitigate LDA randomness
    #             # try seeding it both ways round, check you get the same
    #             # topics out but with which way round they are depending
    #             # on the way round they're seeded
    #             eta = np.ones((2, len(dictionary))) * 0.5
    #             system = dictionary.token2id[u'system']
    #             trees = dictionary.token2id[u'trees']

    #             # aggressively seed the word 'system', in one of the
    #             # two topics, 10 times higher than the other words
    #             eta[topic, system] *= 10.0

    #             model = self.class_(id2word=dictionary, num_topics=2, passes=200, eta=eta)
    #             model.update(self.corpus)

    #             topics = [{word: p for p, word in model.show_topic(j, topn=None)} for j in range(2)]

    #             # check that the word 'system' in the topic we seeded got a high weight,
    #             # and the word 'trees' (the main word in the other topic) a low weight --
    #             # and vice versa for the other topic (which we didn't seed with 'system')
    #             passed = (
    #                 (topics[topic][u'system'] > topics[topic][u'trees'])
    #                 and
    #                 (topics[1 - topic][u'system'] < topics[1 - topic][u'trees'])
    #             )
    #             if passed:
    #                 break
    #             logging.warning("LDA failed to converge on attempt %i (got %s)", i, topics)
    #         self.assertTrue(passed)

    def test_persistence(self):
        fname = get_tmpfile('gensim_models_lda.tst')
        model = self.model
        model.save(fname)
        model2 = self.class_.load(fname)
        self.assertEqual(model.num_topics, model2.num_topics)
        self.assertTrue(np.allclose(model.expElogbeta, model2.expElogbeta))
        tstvec = []
        self.assertTrue(np.allclose(model[tstvec], model2[tstvec]))  # try projecting an empty vector

    def test_model_compatibility_with_python_versions(self):
        fname_model_2_7 = datapath('ldamodel_python_2_7')
        model_2_7 = self.class_.load(fname_model_2_7)
        fname_model_3_5 = datapath('ldamodel_python_3_5')
        model_3_5 = self.class_.load(fname_model_3_5)
        self.assertEqual(model_2_7.num_topics, model_3_5.num_topics)
        self.assertTrue(np.allclose(model_2_7.expElogbeta, model_3_5.expElogbeta))
        tstvec = []
        self.assertTrue(np.allclose(model_2_7[tstvec], model_3_5[tstvec]))  # try projecting an empty vector
        id2word_2_7 = dict(model_2_7.id2word.iteritems())
        id2word_3_5 = dict(model_3_5.id2word.iteritems())
        self.assertEqual(set(id2word_2_7.keys()), set(id2word_3_5.keys()))

    def test_persistence_ignore(self):
        fname = get_tmpfile('gensim_models_lda_testPersistenceIgnore.tst')
        model = ldamodel.LdaModel(self.corpus, num_topics=2)
        model.save(fname, ignore='id2word')
        model2 = ldamodel.LdaModel.load(fname)
        self.assertTrue(model2.id2word is None)

        model.save(fname, ignore=['id2word'])
        model2 = ldamodel.LdaModel.load(fname)
        self.assertTrue(model2.id2word is None)

    def test_persistence_compressed(self):
        fname = get_tmpfile('gensim_models_lda.tst.gz')
        model = self.model
        model.save(fname)
        model2 = self.class_.load(fname, mmap=None)
        self.assertEqual(model.num_topics, model2.num_topics)
        self.assertTrue(np.allclose(model.expElogbeta, model2.expElogbeta))
        tstvec = []
        self.assertTrue(np.allclose(model[tstvec], model2[tstvec]))  # try projecting an empty vector

    def test_large_mmap(self):
        fname = get_tmpfile('gensim_models_lda.tst')
        model = self.model

        # simulate storing large arrays separately
        model.save(fname, sep_limit=0)

        # test loading the large model arrays with mmap
        model2 = self.class_.load(fname, mmap='r')
        self.assertEqual(model.num_topics, model2.num_topics)
        self.assertTrue(isinstance(model2.expElogbeta, np.memmap))
        self.assertTrue(np.allclose(model.expElogbeta, model2.expElogbeta))
        tstvec = []
        self.assertTrue(np.allclose(model[tstvec], model2[tstvec]))  # try projecting an empty vector

    def test_large_mmap_compressed(self):
        fname = get_tmpfile('gensim_models_lda.tst.gz')
        model = self.model

        # simulate storing large arrays separately
        model.save(fname, sep_limit=0)

        # test loading the large model arrays with mmap
        self.assertRaises(IOError, self.class_.load, fname, mmap='r')

    def test_random_state_backward_compatibility(self):
        # load a model saved using a pre-0.13.2 version of Gensim
        pre_0_13_2_fname = datapath('pre_0_13_2_model')
        model_pre_0_13_2 = self.class_.load(pre_0_13_2_fname)

        # set `num_topics` less than `model_pre_0_13_2.num_topics` so that `model_pre_0_13_2.random_state` is used
        model_topics = model_pre_0_13_2.print_topics(num_topics=2, num_words=3)

        for i in model_topics:
            self.assertTrue(isinstance(i[0], int))
            self.assertTrue(isinstance(i[1], str))

        # save back the loaded model using a post-0.13.2 version of Gensim
        post_0_13_2_fname = get_tmpfile('gensim_models_lda_post_0_13_2_model.tst')
        model_pre_0_13_2.save(post_0_13_2_fname)

        # load a model saved using a post-0.13.2 version of Gensim
        model_post_0_13_2 = self.class_.load(post_0_13_2_fname)
        model_topics_new = model_post_0_13_2.print_topics(num_topics=2, num_words=3)

        for i in model_topics_new:
            self.assertTrue(isinstance(i[0], int))
            self.assertTrue(isinstance(i[1], str))

    def test_dtype_backward_compatibility(self):
        lda_3_0_1_fname = datapath('lda_3_0_1_model')
        test_doc = [(0, 1), (1, 1), (2, 1)]
        expected_topics = [(0, 0.87005886977475178), (1, 0.12994113022524822)]

        # save model to use in test
        # self.model.save(lda_3_0_1_fname)

        # load a model saved using a 3.0.1 version of Gensim
        model = self.class_.load(lda_3_0_1_fname)

        # and test it on a predefined document
        topics = model[test_doc]
        self.assertTrue(np.allclose(expected_topics, topics))

# endclass TestLdaModel


class TestLdaMulticore(TestLdaModel):
    def setUp(self):
        self.corpus = mmcorpus.MmCorpus(datapath('testcorpus.mm'))
        self.class_ = ldamulticore.LdaMulticore
        self.model = self.class_(corpus, id2word=dictionary, num_topics=2, passes=100)

    # override LdaModel because multicore does not allow alpha=auto
    def test_alpha_auto(self):
        self.assertRaises(RuntimeError, self.class_, alpha='auto')


# endclass TestLdaMulticore


if __name__ == '__main__':
    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
    unittest.main()
