import sys

import numpy as np

import unittest
from numba.core.compiler import compile_isolated, Flags
from numba import jit
from numba.core import types, errors, utils
from numba.tests.support import (captured_stdout, tag, TestCase,
                                 EnableNRTStatsMixin)


enable_pyobj_flags = Flags()
enable_pyobj_flags.enable_pyobject = True

force_pyobj_flags = Flags()
force_pyobj_flags.force_pyobject = True


def print_value(x):
    print(x)

def print_array_item(arr, i):
    print(arr[i].x)

def print_values(a, b, c):
    print(a, b, c)

def print_empty():
    print()

def print_string(x):
    print(x, "hop!", 3.5)

def print_vararg(a, b, c):
    print(a, b, *c)

def print_string_vararg(a, b, c):
    print(a, "hop!", b, *c)

def make_print_closure(x):
    def print_closure():
        return x
    return jit(nopython=True)(x)


class TestPrint(EnableNRTStatsMixin, TestCase):

    def test_print_values(self):
        """
        Test printing a single argument value.
        """
        pyfunc = print_value

        def check_values(typ, values):
            cr = compile_isolated(pyfunc, (typ,))
            cfunc = cr.entry_point
            for val in values:
                with captured_stdout():
                    cfunc(val)
                    self.assertEqual(sys.stdout.getvalue(), str(val) + '\n')

        # Various scalars
        check_values(types.int32, (1, -234))
        check_values(types.int64, (1, -234,
                                   123456789876543210, -123456789876543210))
        check_values(types.uint64, (1, 234,
                                   123456789876543210, 2**63 + 123))
        check_values(types.boolean, (True, False))
        check_values(types.float64, (1.5, 100.0**10.0, float('nan')))
        check_values(types.complex64, (1+1j,))
        check_values(types.NPTimedelta('ms'), (np.timedelta64(100, 'ms'),))

        cr = compile_isolated(pyfunc, (types.float32,))
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc(1.1)
            # Float32 will lose precision
            got = sys.stdout.getvalue()
            expect = '1.10000002384'
            self.assertTrue(got.startswith(expect))
            self.assertTrue(got.endswith('\n'))

        # NRT-enabled type
        with self.assertNoNRTLeak():
            x = [1, 3, 5, 7]
            with self.assertRefCount(x):
                check_values(types.List(types.int32), (x,))

        # Array will have to use object mode
        arraytype = types.Array(types.int32, 1, 'C')
        cr = compile_isolated(pyfunc, (arraytype,), flags=enable_pyobj_flags)
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc(np.arange(10, dtype=np.int32))
            self.assertEqual(sys.stdout.getvalue(),
                             '[0 1 2 3 4 5 6 7 8 9]\n')

    def test_print_array_item(self):
        """
        Test printing a Numpy character sequence
        """
        dtype = np.dtype([('x', 'S4')])
        arr = np.frombuffer(bytearray(range(1, 9)), dtype=dtype)

        pyfunc = print_array_item
        cfunc = jit(nopython=True)(pyfunc)
        for i in range(len(arr)):
            with captured_stdout():
                cfunc(arr, i)
                self.assertEqual(sys.stdout.getvalue(), str(arr[i]['x']) + '\n')

    def test_print_multiple_values(self):
        pyfunc = print_values
        cr = compile_isolated(pyfunc, (types.int32,) * 3)
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc(1, 2, 3)
            self.assertEqual(sys.stdout.getvalue(), '1 2 3\n')

    def test_print_nogil(self):
        pyfunc = print_values
        cfunc = jit(nopython=True, nogil=True)(pyfunc)
        with captured_stdout():
            cfunc(1, 2, 3)
            self.assertEqual(sys.stdout.getvalue(), '1 2 3\n')

    def test_print_empty(self):
        pyfunc = print_empty
        cr = compile_isolated(pyfunc, ())
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc()
            self.assertEqual(sys.stdout.getvalue(), '\n')

    def test_print_strings(self):
        pyfunc = print_string
        cr = compile_isolated(pyfunc, (types.int32,))
        cfunc = cr.entry_point
        with captured_stdout():
            cfunc(1)
            self.assertEqual(sys.stdout.getvalue(), '1 hop! 3.5\n')

    def test_print_vararg(self):
        # Test *args support for print().  This is desired since
        # print() can use a dedicated IR node.
        pyfunc = print_vararg
        cfunc = jit(nopython=True)(pyfunc)
        with captured_stdout():
            cfunc(1, (2, 3), (4, 5j))
            self.assertEqual(sys.stdout.getvalue(), '1 (2, 3) 4 5j\n')

        pyfunc = print_string_vararg
        cfunc = jit(nopython=True)(pyfunc)
        with captured_stdout():
            cfunc(1, (2, 3), (4, 5j))
            self.assertEqual(sys.stdout.getvalue(), '1 hop! (2, 3) 4 5j\n')

    def test_inner_fn_print(self):
        @jit(nopython=True)
        def foo(x):
            print(x)

        @jit(nopython=True)
        def bar(x):
            foo(x)
            foo('hello')

        # Printing an array requires the Env.
        # We need to make sure the inner function can obtain the Env.
        x = np.arange(5)
        with captured_stdout():
            bar(x)
            self.assertEqual(sys.stdout.getvalue(), '[0 1 2 3 4]\nhello\n')

    def test_print_w_kwarg_raises(self):
        @jit(nopython=True)
        def print_kwarg():
            print('x', flush=True)

        with self.assertRaises(errors.UnsupportedError) as raises:
            print_kwarg()
        expected = ("Numba's print() function implementation does not support "
                    "keyword arguments.")
        self.assertIn(raises.exception.msg, expected)

    def test_print_no_truncation(self):
        ''' See: https://github.com/numba/numba/issues/3811
        '''
        @jit(nopython=True)
        def foo():
            print(''.join(['a'] * 10000))
        with captured_stdout():
            foo()
            self.assertEqual(sys.stdout.getvalue(), ''.join(['a'] * 10000) + '\n')

if __name__ == '__main__':
    unittest.main()
