From 0a505f8baa79bb1e909990cc398915b039a1c4ce Mon Sep 17 00:00:00 2001 From: Alina Lenk Date: Sun, 25 Sep 2022 23:52:47 +0200 Subject: [PATCH 3/4] generate_packets.py: Introduce type registry class for resolving field types See osdn#45717 Signed-off-by: Alina Lenk --- common/generate_packets.py | 138 +++++++++++++++++++++++++------------ 1 file changed, 93 insertions(+), 45 deletions(-) diff --git a/common/generate_packets.py b/common/generate_packets.py index 05c60b3de9..fb0e3ec8e1 100755 --- a/common/generate_packets.py +++ b/common/generate_packets.py @@ -459,6 +459,24 @@ class RawFieldType(ABC): packets definition file. These types may require the addition of a size in order to be usable; see the array() method and the FieldType class.""" + @abstractmethod + def array(self, size: SizeInfo) -> "FieldType": + raise NotImplementedError + + @abstractmethod + def __str__(self) -> str: + return super().__str__() + + def __repr__(self) -> str: + return "<{self.__class__.__name__} {self}>".format(self = self) + + +FieldTypeConstructor = typing.Callable[[str, str], RawFieldType] + +class TypeRegistry: + """Determines what Python class to use for field types based on their + dataio type and public type.""" + TYPE_INFO_PATTERN = re.compile(r"^([^()]*)\(([^()]*)\)$") """Matches a field type. @@ -466,57 +484,53 @@ class RawFieldType(ABC): - dataio type - public type (aka struct type)""" - @staticmethod - def parse(type_text: str) -> "RawFieldType": + def __init__(self, fallback: FieldTypeConstructor): + self.dataio_types = {} # type: dict[str, FieldTypeConstructor] + self.dataio_patterns = {} # type: dict[typing.Pattern[str], FieldTypeConstructor] + self.public_types = {} # type: dict[str, FieldTypeConstructor] + self.public_patterns = {} # type: dict[typing.Pattern[str], FieldTypeConstructor] + self.fallback = fallback + + def parse(self, type_text: str) -> RawFieldType: """Parse a single field type""" mo = __class__.TYPE_INFO_PATTERN.fullmatch(type_text) if mo is None: raise ValueError("malformed type or undefined alias: %r" % type_text) - dataio_type, public_type = mo.groups("") - - if dataio_type == "worklist": - return WorklistType(dataio_type, public_type) - - if dataio_type == "cm_parameter": - return CmParameterType(dataio_type, public_type) - - if dataio_type == "bitvector": - return BitvectorType(dataio_type, public_type) - - if dataio_type == "memory": - return NeedSizeType(dataio_type, public_type, MemoryType) + return self(*mo.groups()) - if dataio_type in ("string", "estring"): - return NeedSizeType(dataio_type, public_type, StringType) - - mo = BoolType.TYPE_PATTERN.fullmatch(dataio_type) - if mo is not None: - return BoolType(mo, public_type) - - mo = FloatType.TYPE_PATTERN.fullmatch(dataio_type) - if mo is not None: - return FloatType(mo, public_type) - - mo = IntType.TYPE_PATTERN.fullmatch(dataio_type) - if mo is not None: - return IntType(mo, public_type) + def __call__(self, dataio_type: str, public_type: str) -> RawFieldType: + try: + ctor = self.dataio_types[dataio_type] + except KeyError: + pass + else: + return ctor(dataio_type, public_type) - if public_type.startswith("struct "): - return StructType(dataio_type, public_type) + for pat, ctor in self.dataio_patterns.items(): + mo = pat.fullmatch(dataio_type) + if mo is not None: + self.dataio_types[dataio_type] = ctor + return ctor(dataio_type, public_type) - # default fallback case - return BasicType(dataio_type, public_type) + self.dataio_types[dataio_type] = self._by_public + return self._by_public(dataio_type, public_type) - @abstractmethod - def array(self, size: SizeInfo) -> "FieldType": - raise NotImplementedError + def _by_public(self, dataio_type: str, public_type: str) -> RawFieldType: + try: + ctor = self.public_types[public_type] + except KeyError: + pass + else: + return ctor(dataio_type, public_type) - @abstractmethod - def __str__(self) -> str: - return super().__str__() + for pat, ctor in self.public_patterns.items(): + mo = pat.fullmatch(public_type) + if mo is not None: + self.public_types[public_type] = ctor + return ctor(dataio_type, public_type) - def __repr__(self) -> str: - return "<{self.__class__.__name__} {self}>".format(self = self) + self.public_types[public_type] = self.fallback + return self.fallback(dataio_type, public_type) class NeedSizeType(RawFieldType): @@ -623,6 +637,8 @@ if (!DIO_GET({self.dataio_type}, &din, &field_addr, &real_packet->{location})) { def __str__(self) -> str: return "{self.dataio_type}({self.public_type})".format(self = self) +DEFAULT_REGISTRY = TypeRegistry(BasicType) + class IntType(BasicType): """Type information for an integer field""" @@ -666,6 +682,8 @@ result += key->%s; }} """.format(self = self, location = location) +DEFAULT_REGISTRY.dataio_patterns[IntType.TYPE_PATTERN] = IntType + class BoolType(IntType): """Type information for a boolean field""" @@ -691,6 +709,8 @@ class BoolType(IntType): super().__init__(dataio_info, public_type) +DEFAULT_REGISTRY.dataio_patterns[BoolType.TYPE_PATTERN] = BoolType + class FloatType(BasicType): """Type information for a float field""" @@ -741,6 +761,8 @@ if (!DIO_GET({self.dataio_type}, &din, &field_addr, &real_packet->{location}, {s def __str__(self) -> str: return "{self.dataio_type}{self.float_factor:d}({self.public_type})".format(self = self) +DEFAULT_REGISTRY.dataio_patterns[FloatType.TYPE_PATTERN] = FloatType + class BitvectorType(BasicType): """Type information for a bitvector field""" @@ -768,12 +790,27 @@ if (!DIO_BV_GET(&din, &field_addr, real_packet->{location})) {{ }} """.format(self = self, location = location) +DEFAULT_REGISTRY.dataio_types["bitvector"] = BitvectorType + class StructType(BasicType): """Type information for a field of some general struct type""" - def __init__(self, dataio_type: str, public_type: str): - assert public_type.startswith("struct ") + TYPE_PATTERN = re.compile(r"^struct \w+$") + """Matches a struct public type""" + + @typing.overload + def __init__(self, dataio_type: str, public_info: str): ... + @typing.overload + def __init__(self, dataio_type: str, public_info: "re.Match[str]"): ... + def __init__(self, dataio_type: str, public_info: "str | re.Match[str]"): + if isinstance(public_info, str): + mo = self.TYPE_PATTERN.fullmatch(public_info) + if mo is None: + raise ValueError("not a valid struct type") + public_info = mo + public_type = public_info.group(0) + super().__init__(dataio_type, public_type) def get_code_handle_param(self, location: Location) -> str: @@ -798,6 +835,8 @@ differ = !are_{self.dataio_type}s_equal(&old->{location}, &real_packet->{locatio e |= DIO_PUT({self.dataio_type}, &dout, &field_addr, &real_packet->{location}); """.format(self = self, location = location) +DEFAULT_REGISTRY.public_patterns[StructType.TYPE_PATTERN] = StructType + class CmParameterType(StructType): """Type information for a worklist field""" @@ -816,6 +855,8 @@ class CmParameterType(StructType): differ = !cm_are_parameter_equal(&old->{location}, &real_packet->{location}); """.format(self = self, location = location) +DEFAULT_REGISTRY.dataio_types["cm_parameter"] = CmParameterType + class WorklistType(StructType): """Type information for a worklist field""" @@ -834,6 +875,8 @@ class WorklistType(StructType): worklist_copy(&real_packet->{location}, {location}); """.format(location = location) +DEFAULT_REGISTRY.dataio_types["worklist"] = WorklistType + class SizedType(BasicType): """Abstract base class (ABC) for field types that include a size""" @@ -889,6 +932,8 @@ if (!DIO_GET({self.dataio_type}, &din, &field_addr, real_packet->{location}, siz }} """.format(self = self, location = location) +DEFAULT_REGISTRY.dataio_types["string"] = DEFAULT_REGISTRY.dataio_types["estring"] = partial(NeedSizeType, cls = StringType) + class MemoryType(SizedType): """Type information for a memory field""" @@ -927,6 +972,8 @@ if (!DIO_GET({self.dataio_type}, &din, &field_addr, real_packet->{location}, {se }} """.format(self = self, location = location, size_check = size_check) +DEFAULT_REGISTRY.dataio_types["memory"] = partial(NeedSizeType, cls = MemoryType) + class ArrayType(FieldType): """Type information for an array field. Consists of size information and @@ -2626,7 +2673,7 @@ class PacketsDefinition(typing.Iterable[Packet]): def resolve_type(self, type_text: str) -> RawFieldType: """Resolve the given type""" if type_text not in self.types: - self.types[type_text] = RawFieldType.parse(type_text) + self.types[type_text] = self.type_registry.parse(type_text) return self.types[type_text] def define_type(self, alias: str, meaning: str): @@ -2641,7 +2688,8 @@ class PacketsDefinition(typing.Iterable[Packet]): self.types[alias] = self.resolve_type(meaning) - def __init__(self): + def __init__(self, type_registry: "TypeRegistry | None" = None): + self.type_registry = type_registry or DEFAULT_REGISTRY self.types = {} self.packets = [] self.packets_by_type = {} -- 2.34.1