"""Test module for compatibility with plain HDF files."""

import shutil
import tempfile
from pathlib import Path

import numpy as np

import tables as tb
from tables.tests import common


class PaddedArrayTestCase(common.TestFileMixin, common.PyTablesTestCase):
    """Test for H5T_COMPOUND (Table) datatype with padding.

    Regression test for issue gh-734

    itemsize.h5 was created with h5py with the array `expectedData` (see below)
    in the table `/Test`:
    'A' and 'B' are 4 + 4 bytes, with 8 bytes padding.

    $ h5ls -v itemsize.h5
    Test                     Dataset {3/3}
    Location:  1:800
    Links:     1
    Storage:   48 logical bytes, 48 allocated bytes, 100.00% utilization
    Type:      struct {
                   "A"                +0    native unsigned int
                   "B"                +4    native unsigned int
               } 16 bytes

    """
    h5fname = common.test_filename('itemsize.h5')

    def test(self):
        arr = self.h5file.get_node('/Test')
        data = arr.read()
        expectedData = np.array(
            [(1, 11), (2, 12), (3, 13)],
            dtype={'names': ['A', 'B'], 'formats': ['<u4', '<u4'],
                   'offsets': [0, 4], 'itemsize': 16})
        self.assertTrue(common.areArraysEqual(data, expectedData))


class EnumTestCase(common.TestFileMixin, common.PyTablesTestCase):
    """Test for enumerated datatype.

    See ftp://ftp.hdfgroup.org/HDF5/current/src/unpacked/test/enum.c.

    """

    h5fname = common.test_filename('smpl_enum.h5')

    def test(self):
        self.assertIn('/EnumTest', self.h5file)

        arr = self.h5file.get_node('/EnumTest')
        self.assertIsInstance(arr, tb.Array)

        enum = arr.get_enum()
        expectedEnum = tb.Enum(['RED', 'GREEN', 'BLUE', 'WHITE', 'BLACK'])
        self.assertEqual(enum, expectedEnum)

        data = list(arr.read())
        expectedData = [
            enum[name] for name in
            ['RED', 'GREEN', 'BLUE', 'WHITE', 'BLACK',
             'RED', 'GREEN', 'BLUE', 'WHITE', 'BLACK']]
        self.assertEqual(data, expectedData)


class NumericTestCase(common.TestFileMixin, common.PyTablesTestCase):
    """Test for several numeric datatypes.

    See
    ftp://ftp.ncsa.uiuc.edu/HDF/files/hdf5/samples/[fiu]l?{8,16,32,64}{be,le}.c
    (they seem to be no longer available).

    """

    def test(self):
        self.assertIn('/TestArray', self.h5file)

        arr = self.h5file.get_node('/TestArray')
        self.assertIsInstance(arr, tb.Array)

        self.assertEqual(arr.atom.type, self.type)
        self.assertEqual(arr.byteorder, self.byteorder)
        self.assertEqual(arr.shape, (6, 5))

        data = arr.read()
        expectedData = np.array([
            [0, 1, 2, 3, 4],
            [1, 2, 3, 4, 5],
            [2, 3, 4, 5, 6],
            [3, 4, 5, 6, 7],
            [4, 5, 6, 7, 8],
            [5, 6, 7, 8, 9]], dtype=self.type)
        self.assertTrue(common.areArraysEqual(data, expectedData))


class F64BETestCase(NumericTestCase):
    h5fname = common.test_filename('smpl_f64be.h5')
    type = 'float64'
    byteorder = 'big'


class F64LETestCase(NumericTestCase):
    h5fname = common.test_filename('smpl_f64le.h5')
    type = 'float64'
    byteorder = 'little'


class I64BETestCase(NumericTestCase):
    h5fname = common.test_filename('smpl_i64be.h5')
    type = 'int64'
    byteorder = 'big'


class I64LETestCase(NumericTestCase):
    h5fname = common.test_filename('smpl_i64le.h5')
    type = 'int64'
    byteorder = 'little'


class I32BETestCase(NumericTestCase):
    h5fname = common.test_filename('smpl_i32be.h5')
    type = 'int32'
    byteorder = 'big'


class I32LETestCase(NumericTestCase):
    h5fname = common.test_filename('smpl_i32le.h5')
    type = 'int32'
    byteorder = 'little'


class ChunkedCompoundTestCase(common.TestFileMixin, common.PyTablesTestCase):
    """Test for a more complex and chunked compound structure.

    This is generated by a chunked version of the example in
    ftp://ftp.ncsa.uiuc.edu/HDF/files/hdf5/samples/compound2.c.

    """

    h5fname = common.test_filename('smpl_compound_chunked.h5')

    def test(self):
        self.assertIn('/CompoundChunked', self.h5file)

        tbl = self.h5file.get_node('/CompoundChunked')
        self.assertIsInstance(tbl, tb.Table)

        self.assertEqual(
            tbl.colnames,
            ['a_name', 'c_name', 'd_name', 'e_name', 'f_name', 'g_name'])

        self.assertEqual(tbl.coltypes['a_name'], 'int32')
        self.assertEqual(tbl.coldtypes['a_name'].shape, ())

        self.assertEqual(tbl.coltypes['c_name'], 'string')
        self.assertEqual(tbl.coldtypes['c_name'].shape, ())

        self.assertEqual(tbl.coltypes['d_name'], 'int16')
        self.assertEqual(tbl.coldtypes['d_name'].shape, (5, 10))

        self.assertEqual(tbl.coltypes['e_name'], 'float32')
        self.assertEqual(tbl.coldtypes['e_name'].shape, ())

        self.assertEqual(tbl.coltypes['f_name'], 'float64')
        self.assertEqual(tbl.coldtypes['f_name'].shape, (10,))

        self.assertEqual(tbl.coltypes['g_name'], 'uint8')
        self.assertEqual(tbl.coldtypes['g_name'].shape, ())

        for m in range(len(tbl)):
            row = tbl[m]
        # This version of the loop seems to fail because of ``iterrows()``.
        # for (m, row) in enumerate(tbl):
            self.assertEqual(row['a_name'], m)
            self.assertEqual(row['c_name'], b"Hello!")
            dRow = row['d_name']
            for n in range(5):
                for o in range(10):
                    self.assertEqual(dRow[n][o], m + n + o)
            self.assertAlmostEqual(row['e_name'], m * 0.96, places=6)
            fRow = row['f_name']
            for n in range(10):
                self.assertAlmostEqual(fRow[n], m * 1024.9637)
            self.assertEqual(row['g_name'], ord('m'))


class ContiguousCompoundTestCase(common.TestFileMixin,
                                 common.PyTablesTestCase):
    """Test for support of native contiguous compound datasets.

    This example has been provided by Dav Clark.

    """

    h5fname = common.test_filename('non-chunked-table.h5')

    def test(self):
        self.assertIn('/test_var/structure variable', self.h5file)

        tbl = self.h5file.get_node('/test_var/structure variable')
        self.assertIsInstance(tbl, tb.Table)

        self.assertEqual(
            tbl.colnames,
            ['a', 'b', 'c', 'd'])

        self.assertEqual(tbl.coltypes['a'], 'float64')
        self.assertEqual(tbl.coldtypes['a'].shape, ())

        self.assertEqual(tbl.coltypes['b'], 'float64')
        self.assertEqual(tbl.coldtypes['b'].shape, ())

        self.assertEqual(tbl.coltypes['c'], 'float64')
        self.assertEqual(tbl.coldtypes['c'].shape, (2,))

        self.assertEqual(tbl.coltypes['d'], 'string')
        self.assertEqual(tbl.coldtypes['d'].shape, ())

        for row in tbl.iterrows():
            self.assertEqual(row['a'], 3.0)
            self.assertEqual(row['b'], 4.0)
            self.assertTrue(common.allequal(
                row['c'], np.array([2.0, 3.0], dtype="float64")))
            self.assertEqual(row['d'], b"d")

        self.h5file.close()


class ContiguousCompoundAppendTestCase(common.TestFileMixin,
                                       common.PyTablesTestCase):
    """Test for appending data to native contiguous compound datasets."""

    h5fname = common.test_filename('non-chunked-table.h5')

    def test(self):
        self.assertIn('/test_var/structure variable', self.h5file)
        self.h5file.close()
        # Do a copy to a temporary to avoid modifying the original file
        h5fname_copy = tempfile.mktemp(".h5")
        shutil.copy(self.h5fname, h5fname_copy)
        # Reopen in 'a'ppend mode
        try:
            self.h5file = tb.open_file(h5fname_copy, 'a')
        except OSError:
            # Problems for opening (probably not permisions to write the file)
            return
        tbl = self.h5file.get_node('/test_var/structure variable')
        # Try to add rows to a non-chunked table (this should raise an error)
        self.assertRaises(tb.HDF5ExtError, tbl.append,
                          [(4.0, 5.0, [2.0, 3.0], 'd')])
        # Appending using the Row interface
        self.assertRaises(tb.HDF5ExtError, tbl.row.append)
        # Remove the file copy
        self.h5file.close()  # Close the handler first
        Path(h5fname_copy).unlink()


class ExtendibleTestCase(common.TestFileMixin, common.PyTablesTestCase):
    """Test for extendible datasets.

    See the example programs in the Introduction to HDF5.

    """

    h5fname = common.test_filename('smpl_SDSextendible.h5')

    def test(self):
        self.assertIn('/ExtendibleArray', self.h5file)

        arr = self.h5file.get_node('/ExtendibleArray')
        self.assertIsInstance(arr, tb.EArray)

        self.assertEqual(arr.byteorder, 'big')
        self.assertEqual(arr.atom.type, 'int32')
        self.assertEqual(arr.shape, (10, 5))
        self.assertEqual(arr.extdim, 0)
        self.assertEqual(len(arr), 10)

        data = arr.read()
        expectedData = np.array([
            [1, 1, 1, 3, 3],
            [1, 1, 1, 3, 3],
            [1, 1, 1, 0, 0],
            [2, 0, 0, 0, 0],
            [2, 0, 0, 0, 0],
            [2, 0, 0, 0, 0],
            [2, 0, 0, 0, 0],
            [2, 0, 0, 0, 0],
            [2, 0, 0, 0, 0],
            [2, 0, 0, 0, 0]], dtype=arr.atom.type)

        self.assertTrue(common.areArraysEqual(data, expectedData))


class SzipTestCase(common.TestFileMixin, common.PyTablesTestCase):
    """Test for native HDF5 files with datasets compressed with szip."""

    h5fname = common.test_filename('test_szip.h5')

    def test(self):
        self.assertIn('/dset_szip', self.h5file)

        arr = self.h5file.get_node('/dset_szip')
        filters = ("Filters(complib='szip', shuffle=False, bitshuffle=False, "
                   "fletcher32=False, least_significant_digit=None)")
        self.assertEqual(repr(arr.filters), filters)


# this demonstrates github #203
class MatlabFileTestCase(common.TestFileMixin, common.PyTablesTestCase):
    h5fname = common.test_filename('matlab_file.mat')

    def test_unicode(self):
        array = self.h5file.get_node('/', 'a')
        self.assertEqual(array.shape, (3, 1))

    # in Python 3 this will be the same as the test above
    def test_string(self):
        array = self.h5file.get_node('/', 'a')
        self.assertEqual(array.shape, (3, 1))

    def test_numpy_str(self):
        array = self.h5file.get_node(np.str_('/'), np.str_('a'))
        self.assertEqual(array.shape, (3, 1))


class ObjectReferenceTestCase(common.TestFileMixin, common.PyTablesTestCase):
    h5fname = common.test_filename('test_ref_array1.mat')

    def test_node_var(self):
        array = self.h5file.get_node('/ANN/my_arr')
        self.assertEqual(array.shape, (1, 3))

    def test_ref_utf_str(self):
        array = self.h5file.get_node('/ANN/my_arr')

        self.assertTrue(common.areArraysEqual(
            array[0][0][0], np.array([0, 0], dtype=np.uint64)))


class ObjectReferenceRecursiveTestCase(common.TestFileMixin,
                                       common.PyTablesTestCase):
    h5fname = common.test_filename('test_ref_array2.mat')

    def test_var(self):
        array = self.h5file.get_node('/var')
        self.assertEqual(array.shape, (3, 1))

    def test_ref_str(self):
        array = self.h5file.get_node('/var')

        self.assertTrue(common.areArraysEqual(
            array[1][0][0],
            np.array([[116], [101], [115], [116]], dtype=np.uint16)))

    def test_double_ref(self):
        array = self.h5file.get_node('/var')
        self.assertTrue(common.areArraysEqual(
            array[2][0][0][1][0],
            np.array([[105], [110], [115], [105], [100], [101]],
                     dtype=np.uint16)))


def suite():
    """Return a test suite consisting of all the test cases in the module."""

    theSuite = common.unittest.TestSuite()
    niter = 1

    for i in range(niter):
        theSuite.addTest(common.unittest.makeSuite(PaddedArrayTestCase))
        theSuite.addTest(common.unittest.makeSuite(EnumTestCase))
        theSuite.addTest(common.unittest.makeSuite(F64BETestCase))
        theSuite.addTest(common.unittest.makeSuite(F64LETestCase))
        theSuite.addTest(common.unittest.makeSuite(I64BETestCase))
        theSuite.addTest(common.unittest.makeSuite(I64LETestCase))
        theSuite.addTest(common.unittest.makeSuite(I32BETestCase))
        theSuite.addTest(common.unittest.makeSuite(I32LETestCase))
        theSuite.addTest(common.unittest.makeSuite(ChunkedCompoundTestCase))
        theSuite.addTest(common.unittest.makeSuite(ContiguousCompoundTestCase))
        theSuite.addTest(
            common.unittest.makeSuite(ContiguousCompoundAppendTestCase))
        theSuite.addTest(common.unittest.makeSuite(ExtendibleTestCase))
        theSuite.addTest(common.unittest.makeSuite(SzipTestCase))
        theSuite.addTest(common.unittest.makeSuite(MatlabFileTestCase))
        theSuite.addTest(common.unittest.makeSuite(ObjectReferenceTestCase))
        theSuite.addTest(
            common.unittest.makeSuite(ObjectReferenceRecursiveTestCase))

    return theSuite


if __name__ == '__main__':
    import sys
    common.parse_argv(sys.argv)
    common.print_versions()
    common.unittest.main(defaultTest='suite')
