#   pycsla-binary: Python implementation of CSLA .NET binary serialisation
#   Copyright (C) 2025  Lee Yingtong Li (RunasSudo)
#
#   This program is free software: you can redistribute it and/or modify
#   it under the terms of the GNU Affero General Public License as published by
#   the Free Software Foundation, either version 3 of the License, or
#   (at your option) any later version.
#
#   This program is distributed in the hope that it will be useful,
#   but WITHOUT ANY WARRANTY; without even the implied warranty of
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#   GNU Affero General Public License for more details.
#
#   You should have received a copy of the GNU Affero General Public License
#   along with this program.  If not, see <https://www.gnu.org/licenses/>.

import io
import struct

from .known_types import CslaKnownTypes
from .serialization_info import ChildData, FieldData, SerializationInfo

class CslaBinaryReader:
	"""Reads binary-serialised CSLA data into SerializationInfo objects"""
	
	def __init__(self, stream: io.BufferedIOBase):
		self.stream = stream
		self.keywords_dictionary = {}
	
	def read(self):
		# CslaBinaryReader.Read
		total_count = self.read_int32()
		
		serialisation_infos = []
		
		for _ in range(total_count):
			info = SerializationInfo()
			
			# Read ReferenceID
			info.reference_id = self.read_int32()
			
			# Read TypeName
			info.type_name = self.read_object()  # string
			if not isinstance(info.type_name, str):
				raise ValueError('Expected string, got {}'.format(type(info.type_name).__name__))
			
			# Read children
			child_count = self.read_int32()
			for _ in range(child_count):
				system_name = self.read_object()  # string
				if not isinstance(system_name, str):
					raise ValueError('Expected string, got {}'.format(type(system_name).__name__))
				
				is_dirty = self.read_object()  # bool
				if not isinstance(is_dirty, bool):
					raise ValueError('Expected bool, got {}'.format(type(is_dirty).__name__))
				
				reference_id = self.read_object()  # int
				if not isinstance(reference_id, int):
					raise ValueError('Expected int, got {}'.format(type(reference_id).__name__))
				
				info.children.append(ChildData(system_name, is_dirty, reference_id))
			
			# Read field values
			value_count = self.read_int32()
			for _ in range(value_count):
				system_name = self.read_object()  # string
				if not isinstance(system_name, str):
					raise ValueError('Expected string, got {}'.format(type(system_name).__name__))
				
				enum_type_name = self.read_object()  # string
				if not isinstance(enum_type_name, str):
					raise ValueError('Expected string, got {}'.format(type(enum_type_name).__name__))
				
				is_dirty = self.read_object()  # bool
				if not isinstance(is_dirty, bool):
					raise ValueError('Expected bool, got {}'.format(type(is_dirty).__name__))
				
				value = self.read_object()
				
				info.values.append(FieldData(system_name, enum_type_name or None, is_dirty, value))
			
			serialisation_infos.append(info)
		
		return serialisation_infos
	
	def read_7bit_encoded_int(self):
		# BinaryReader.Read7BitEncodedInt
		# "The integer of the value parameter is written out seven bits at a time, starting with the seven least-significant bits. The high bit of a byte indicates whether there are more bytes to be written after this one."
		result = 0
		shift = 0
		while True:
			byte = self.stream.read(1)[0]
			byte_7lsb = byte & 0b01111111
			byte_1msb = byte >> 7
			
			result |= byte_7lsb << shift
			
			if not byte_1msb:
				# MSB is not set - return now
				return result
			
			shift += 7
	
	def read_int32(self):
		# BinaryReader.ReadInt32
		return struct.unpack('<i', self.stream.read(4))[0]
	
	def read_object(self):
		# CslaBinaryReader.ReadObject
		known_type = self.stream.read(1)[0]
		
		if known_type == CslaKnownTypes.Boolean.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Char.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.SByte.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Byte.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Int16.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.UInt16.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Int32.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.UInt32.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Int64.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.UInt16.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Single.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Double.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Decimal.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.DateTime.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.String.value:
			return self.read_string()
		
		if known_type == CslaKnownTypes.TimeSpan.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.DateTimeOffset.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Guid.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.ByteArray.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.CharArray.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.ListOfInt.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.Null.value:
			raise NotImplementedError()
		
		if known_type == CslaKnownTypes.StringWithDictionaryKey.value:
			system_string = self.read_string()
			dictionary_key = self.read_int32()
			self.keywords_dictionary[dictionary_key] = system_string
			return system_string
		
		if known_type == CslaKnownTypes.StringDictionaryKey.value:
			raise NotImplementedError()
		
		raise ValueError('Unexpected object tag {}'.format(known_type))
	
	def read_string(self):
		# BinaryReader.ReadString - "The string is prefixed with the length, encoded as an integer seven bits at a time."
		length = self.read_7bit_encoded_int()
		return self.stream.read(length).decode('utf-8')