# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.


import asyncio
import logging

import pytest

import zmq
import zmq.asyncio

try:
    import tornado

    from zmq.eventloop import zmqstream
except ImportError:
    tornado = None  # type: ignore


pytestmark = pytest.mark.usefixtures("io_loop")


@pytest.fixture
async def push_pull(socket):
    push = zmqstream.ZMQStream(socket(zmq.PUSH))
    pull = zmqstream.ZMQStream(socket(zmq.PULL))
    port = push.bind_to_random_port('tcp://127.0.0.1')
    pull.connect('tcp://127.0.0.1:%i' % port)
    return (push, pull)


@pytest.fixture
def push(push_pull):
    push, pull = push_pull
    return push


@pytest.fixture
def pull(push_pull):
    push, pull = push_pull
    return pull


async def test_callable_check(pull):
    """Ensure callable check works."""

    pull.on_send(lambda *args: None)
    pull.on_recv(lambda *args: None)
    with pytest.raises(AssertionError):
        pull.on_recv(1)
    with pytest.raises(AssertionError):
        pull.on_send(1)
    with pytest.raises(AssertionError):
        pull.on_recv(zmq)


async def test_on_recv_basic(push, pull):
    sent = [b'basic']
    push.send_multipart(sent)
    f = asyncio.Future()

    def callback(msg):
        f.set_result(msg)

    pull.on_recv(callback)
    recvd = await asyncio.wait_for(f, timeout=5)
    assert recvd == sent


async def test_on_recv_wake(push, pull):
    sent = [b'wake']

    f = asyncio.Future()
    pull.on_recv(f.set_result)
    await asyncio.sleep(0.5)
    push.send_multipart(sent)
    recvd = await asyncio.wait_for(f, timeout=5)
    assert recvd == sent


async def test_on_recv_async(push, pull):
    if tornado.version_info < (5,):
        pytest.skip()
    sent = [b'wake']

    f = asyncio.Future()

    async def callback(msg):
        await asyncio.sleep(0.1)
        f.set_result(msg)

    pull.on_recv(callback)
    await asyncio.sleep(0.5)
    push.send_multipart(sent)
    recvd = await asyncio.wait_for(f, timeout=5)
    assert recvd == sent


async def test_on_recv_async_error(push, pull, caplog):
    sent = [b'wake']

    f = asyncio.Future()

    async def callback(msg):
        f.set_result(msg)
        1 / 0

    pull.on_recv(callback)
    await asyncio.sleep(0.1)
    with caplog.at_level(logging.ERROR, logger=zmqstream.gen_log.name):
        push.send_multipart(sent)
        recvd = await asyncio.wait_for(f, timeout=5)
        assert recvd == sent
        # logging error takes a tick later
        await asyncio.sleep(0.5)

    messages = [
        x.message
        for x in caplog.get_records("call")
        if x.name == zmqstream.gen_log.name
    ]
    assert "Uncaught exception in ZMQStream callback" in "\n".join(messages)


async def test_shadow_socket(context):
    with context.socket(zmq.PUSH, socket_class=zmq.asyncio.Socket) as socket:
        with pytest.warns(RuntimeWarning):
            stream = zmqstream.ZMQStream(socket)
        assert type(stream.socket) is zmq.Socket
        assert stream.socket.underlying == socket.underlying
        stream.close()
