#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""Worker ("slave") process used in computing distributed Latent Dirichlet Allocation
(LDA, :class:`~gensim.models.ldamodel.LdaModel`).

Run this script on every node in your cluster. If you wish, you may even run it multiple times on a single machine,
to make better use of multiple cores (just beware that memory footprint increases accordingly).


How to use distributed :class:`~gensim.models.ldamodel.LdaModel`
----------------------------------------------------------------

#. Install needed dependencies (Pyro4) ::

    pip install gensim[distributed]

#. Setup serialization (on each machine) ::

    export PYRO_SERIALIZERS_ACCEPTED=pickle
    export PYRO_SERIALIZER=pickle

#. Run nameserver ::

    python -m Pyro4.naming -n 0.0.0.0 &

#. Run workers (on each machine) ::

    python -m gensim.models.lda_worker &

#. Run dispatcher ::

    python -m gensim.models.lda_dispatcher &

#. Run :class:`~gensim.models.ldamodel.LdaModel` in distributed mode :

.. sourcecode:: pycon

    >>> from gensim.test.utils import common_corpus, common_dictionary
    >>> from gensim.models import LdaModel
    >>>
    >>> model = LdaModel(common_corpus, id2word=common_dictionary, distributed=True)


Command line arguments
----------------------

.. program-output:: python -m gensim.models.lda_worker --help
   :ellipsis: 0, -7

"""

from __future__ import with_statement
import os
import sys
import logging
import threading
import tempfile
import argparse

try:
    import Queue
except ImportError:
    import queue as Queue
import Pyro4
from gensim.models import ldamodel
from gensim import utils

logger = logging.getLogger('gensim.models.lda_worker')


# periodically save intermediate models after every SAVE_DEBUG updates (0 for never)
SAVE_DEBUG = 0

LDA_WORKER_PREFIX = 'gensim.lda_worker'


class Worker:
    """Used as a Pyro4 class with exposed methods.

    Exposes every non-private method and property of the class automatically to be available for remote access.

    """

    def __init__(self):
        """Partly initialize the model."""
        self.model = None

    @Pyro4.expose
    def initialize(self, myid, dispatcher, **model_params):
        """Fully initialize the worker.

        Parameters
        ----------
        myid : int
            An ID number used to identify this worker in the dispatcher object.
        dispatcher : :class:`~gensim.models.lda_dispatcher.Dispatcher`
            The dispatcher responsible for scheduling this worker.
        **model_params
            Keyword parameters to initialize the inner LDA model,see :class:`~gensim.models.ldamodel.LdaModel`.

        """
        self.lock_update = threading.Lock()
        self.jobsdone = 0  # how many jobs has this worker completed?
        # id of this worker in the dispatcher; just a convenience var for easy access/logging TODO remove?
        self.myid = myid
        self.dispatcher = dispatcher
        self.finished = False
        logger.info("initializing worker #%s", myid)
        self.model = ldamodel.LdaModel(**model_params)

    @Pyro4.expose
    @Pyro4.oneway
    def requestjob(self):
        """Request jobs from the dispatcher, in a perpetual loop until :meth:`gensim.models.lda_worker.Worker.getstate`
        is called.

        Raises
        ------
        RuntimeError
            If `self.model` is None (i.e. worker non initialized).

        """
        if self.model is None:
            raise RuntimeError("worker must be initialized before receiving jobs")

        job = None
        while job is None and not self.finished:
            try:
                job = self.dispatcher.getjob(self.myid)
            except Queue.Empty:
                # no new job: try again, unless we're finished with all work
                continue
        if job is not None:
            logger.info("worker #%s received job #%i", self.myid, self.jobsdone)
            self.processjob(job)
            self.dispatcher.jobdone(self.myid)
        else:
            logger.info("worker #%i stopping asking for jobs", self.myid)

    @utils.synchronous('lock_update')
    def processjob(self, job):
        """Incrementally process the job and potentially logs progress.

        Parameters
        ----------
        job : iterable of list of (int, float)
            Corpus in BoW format.

        """
        logger.debug("starting to process job #%i", self.jobsdone)
        self.model.do_estep(job)
        self.jobsdone += 1
        if SAVE_DEBUG and self.jobsdone % SAVE_DEBUG == 0:
            fname = os.path.join(tempfile.gettempdir(), 'lda_worker.pkl')
            self.model.save(fname)
        logger.info("finished processing job #%i", self.jobsdone - 1)

    @Pyro4.expose
    def ping(self):
        """Test the connectivity with Worker."""
        return True

    @Pyro4.expose
    @utils.synchronous('lock_update')
    def getstate(self):
        """Log and get the LDA model's current state.

        Returns
        -------
        result : :class:`~gensim.models.ldamodel.LdaState`
            The current state.

        """
        logger.info("worker #%i returning its state after %s jobs", self.myid, self.jobsdone)
        result = self.model.state
        assert isinstance(result, ldamodel.LdaState)
        self.model.clear()  # free up mem in-between two EM cycles
        self.finished = True
        return result

    @Pyro4.expose
    @utils.synchronous('lock_update')
    def reset(self, state):
        """Reset the worker by setting sufficient stats to 0.

        Parameters
        ----------
        state : :class:`~gensim.models.ldamodel.LdaState`
            Encapsulates information for distributed computation of LdaModel objects.

        """
        assert state is not None
        logger.info("resetting worker #%i", self.myid)
        self.model.state = state
        self.model.sync_state()
        self.model.state.reset()
        self.finished = False

    @Pyro4.oneway
    def exit(self):
        """Terminate the worker."""
        logger.info("terminating worker #%i", self.myid)
        os._exit(0)


def main():
    parser = argparse.ArgumentParser(description=__doc__[:-130], formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("--host", help="Nameserver hostname (default: %(default)s)", default=None)
    parser.add_argument("--port", help="Nameserver port (default: %(default)s)", default=None, type=int)
    parser.add_argument(
        "--no-broadcast", help="Disable broadcast (default: %(default)s)", action='store_const',
        default=True, const=False
    )
    parser.add_argument("--hmac", help="Nameserver hmac key (default: %(default)s)", default=None)
    parser.add_argument(
        '-v', '--verbose', help='Verbose flag', action='store_const', dest="loglevel",
        const=logging.INFO, default=logging.WARNING
    )
    args = parser.parse_args()

    logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=args.loglevel)
    logger.info("running %s", " ".join(sys.argv))

    ns_conf = {
        "broadcast": args.no_broadcast,
        "host": args.host,
        "port": args.port,
        "hmac_key": args.hmac
    }
    utils.pyro_daemon(LDA_WORKER_PREFIX, Worker(), random_suffix=True, ns_conf=ns_conf)
    logger.info("finished running %s", " ".join(sys.argv))


if __name__ == '__main__':
    main()
