import re
from io import StringIO, TextIOWrapper
from pathlib import Path
from typing import Any, Callable, List, Optional, Text, Union

import ruamel.yaml as raml
from epicstuff import open
from ruamel.yaml.comments import CommentedMap


class TAML(raml.YAML):
	def __init__(self: Any, *, typ: Optional[Union[List[Text], Text]] = None, pure: Any = False, output: Any = None, plug_ins: Any = None) -> None:
		'''
		typ: 'rt'/None -> RoundTripLoader/RoundTripDumper,  (default)
			'safe'    -> SafeLoader/SafeDumper,
			'unsafe'  -> normal/unsafe Loader/Dumper (pending deprecation)
			'full'    -> full Dumper only, including python built-ins that are potentially unsafe to load
			'base'    -> baseloader
		pure: if True only use Python modules
		input/output: needed to work as context manager
		plug_ins: a list of plug-in files
		'''
		super().__init__(typ=typ, pure=pure, output=output, plug_ins=plug_ins)

		self.indent(mapping=2, sequence=2, offset=2)
		# self.default_flow_style = None
		# self.width = 4096
	def load(self, stream: Union[str, Path, TextIOWrapper]) -> CommentedMap:
		'''
		at this point you either have the non-pure Parser (which has its own reader and
		scanner) or you have the pure Parser.
		If the pure Parser is set, then set the Reader and Scanner, if not already set.
		If either the Scanner or Reader are set, you cannot use the non-pure Parser,
			so reset it to the pure parser and set the Reader resp. Scanner if necessary

		this description was copied from ruamel.yaml
		'''
		if isinstance(stream, (str, Path)):
			with open(stream, 'r') as f:
				file = f.read()
		elif isinstance(stream, TextIOWrapper):
			file = stream.read()
		else:
			raise TypeError

		# replace tabs after newline with 2 spaces for each tab, generated by copilot
		file = re.sub(r'(?<=\n)(\t+)', lambda match: '  ' * len(match.group(1)), file)

		return super().load(StringIO(''.join(file)))

	def dump(self, data, stream: Any = Union[str, Path, TextIOWrapper], *, transform: Callable = None) -> None:
		# create temporary "text file"
		tmp = StringIO()

		# dump to tmp
		if self.map_indent != 2 or self.sequence_indent != 2 or self.sequence_dash_offset != 2:
			print('warning: changing indentation may cause this to not work')
		super().dump(data, tmp, transform=transform)
		tmp.seek(0)

		# convert indentation to tabs
		tmp = tmp.read()
		tmp = re.sub(r'(?<=\n)( {2})+', lambda match: '\t' * (len(match.group(0)) // 2), tmp)

		# dump to stream
		if isinstance(stream, (str, Path)):
			with open(stream, 'w') as f:
				f.write(tmp)
		elif isinstance(stream, TextIOWrapper):
			stream.write(tmp)
		else:
			raise TypeError


taml = TAML()
