from .field import Field from . import AbstractProtocolComponent from nameHandling.base import ComponentName, NameStyleBproto from errors import BprotoDuplicateNameError, BprotoMessageIDAlreadyUsed from copy import deepcopy from collections import OrderedDict class Message(AbstractProtocolComponent): """Representation of a bproto message. Should not be instantiated directly, but rather through the FactoryMessage class. Contains a dict of fields, where the key is the field name and the value is the Field object. Inherit from AbstractProtocolComponent """ def __init__(self, name: ComponentName, index_number: int): """Should not be used directly. Use the FactoryMessage class. Args: name (str): Name of the message, should be unique index_number (int): Index number of message, used for ordering, should be unique """ # field: {name: (type, array_size)} self.name: ComponentName = name self.message_index_number = index_number self.fields: dict[ComponentName, Field] = {} def __deepcopy__(self, memo={}): if id(self) in memo: return memo[id(self)] new_message = Message( deepcopy(self.name, memo), deepcopy(self.message_index_number, memo) ) new_message.fields = { deepcopy(k, memo): deepcopy(v, memo) for k, v in self.fields.items() } memo[id(self)] = new_message return new_message # Inherited from AbstractProtocolComponent def get_identifier(self): return self.message_index_number def get_name(self): return self.name def get_type_name(self): return "message" def get_size_bytes(self) -> int: return sum([field.get_size_bytes() for field in self.fields.values()]) def get_size_bits(self) -> int: return sum([field.get_size_bits() for field in self.fields.values()]) # Methodes specific to this class def apply_naming(self, naming: str): """Refactor this, should not be used. Args: naming (str): _description_ """ self.name = f"{naming}_{self.name}" class FactoryMessage(): """Factory class for build Message objects representing a bproto message. This is ment to be used during the frontend compilation stage, during the traversal of the abstract syntax tree. Fields are added to the message using the add_field method. The message is finalized using the assemble method. After that the Factory should not be used anymore. Its a one-time use class, so create a new instance for each Message object. This class should be used to create Message objects. """ def __init__(self): self.fields: list[Field] = [] def add_field(self, field: Field): """Adds a finished field to the message. This does not check for duplications, this is done in the assemble method. Args: field (BprotoField): The field to add to the message. """ self.fields.append(field) @staticmethod def sort_fields_dict(fields: dict[ComponentName, Field]) -> dict[ComponentName, Field]: """Static methode for sorting a dictionary of bproto fields by their position. Args: fields (dict[ComponentName, BprotoField]): The fields to sort. Returns: dict[ComponentName, BprotoField]: The sorted fields. """ return OrderedDict(sorted(fields.items(), key=lambda x: x[1].pos)) def assemble(self, name: ComponentName, index_number: int) -> Message: """Finalize the message and create the Message object, returning it. After this method is called, the Factory should not be used anymore. Args: name (ComponentName): Name of the message, should be unique. (No uniqueness checks are done here) index_number (int): Index number of message, used for ordering, should be unique. (No uniqueness checks are done here) Raises: BprotoDuplicateNameError: Raise if a field name is used more than once. Returns: BprotoMessage: The finished message object. """ resulting_message = Message(NameStyleBproto.fromStr(name), index_number) field_names: set[ComponentName] = set() field_positions: set[int] = set() fields_dict: dict[ComponentName, Field] = {} for i in self.fields: if i.name in field_names: raise BprotoDuplicateNameError(i.name, resulting_message) if i.pos in field_positions: raise BprotoMessageIDAlreadyUsed(i.pos, resulting_message) field_positions.add(i.pos) field_names.add(i.name) fields_dict[i.name] = i fields_dict = FactoryMessage.sort_fields_dict(fields_dict) resulting_message.fields = fields_dict return resulting_message