From c7f71672e1c5541bc3b003d113855b33524222f3 Mon Sep 17 00:00:00 2001 From: Alina Lenk Date: Sat, 23 Jul 2022 17:30:38 +0200 Subject: [PATCH 7/9] generate_packets.py: Move Field.get_fill() into types See osdn#45213 Signed-off-by: Alina Lenk --- common/generate_packets.py | 188 +++++++++++++++++++++++++++++-------- 1 file changed, 149 insertions(+), 39 deletions(-) diff --git a/common/generate_packets.py b/common/generate_packets.py index fd45c192ef..1e1d03eb93 100755 --- a/common/generate_packets.py +++ b/common/generate_packets.py @@ -284,6 +284,8 @@ class Location: outside of recursive field types like arrays, this will usually just be a field of a packet, but it serves to concisely handle the recursion.""" + _INDICES = "ijk" + def __init__(self, name: str, location: "str | None" = None, depth: int = 0): self.name = name if location is None: @@ -297,6 +299,21 @@ class Location: self and incremented depth""" return type(self)(self.name, new_location, self.depth + 1) + @property + def index(self) -> str: + """The index name for the current depth""" + if self.depth > len(self._INDICES): + raise ValueError("nested too deeply") + return self._INDICES[self.depth] + + @property + def sub(self) -> "Location": + """A location one level deeper with the current index subscript + added to the end. + + `field` ~> `field[i]` ~> `field[i][j]` etc.""" + return self.deeper("{self}[{self.index}]".format(self = self)) + def __str__(self) -> str: return self.location @@ -420,9 +437,10 @@ class SizeInfo: return hash((__class__, self.declared, self._actual)) -class FieldType(ABC): - """Abstract base class (ABC) for classes representing type information - for fields of a packet""" +class RawFieldType(ABC): + """Abstract base class (ABC) for classes representing types defined in a + packets definition file. These types may require the addition of a size + in order to be usable; see the array() method and the FieldType class.""" TYPE_INFO_PATTERN = re.compile(r"^([^()]*)\(([^()]*)\)$") """Matches a field type. @@ -432,13 +450,19 @@ class FieldType(ABC): - public type (aka struct type)""" @staticmethod - def parse(type_text: str) -> "FieldType": + def parse(type_text: str) -> "RawFieldType": """Parse a single field type""" mo = __class__.TYPE_INFO_PATTERN.fullmatch(type_text) if mo is None: raise ValueError("malformed type or undefined alias: %r" % type_text) dataio_type, public_type = mo.groups("") + if dataio_type == "worklist": + return WorklistType(dataio_type, public_type) + + if dataio_type in ("string", "estring"): + return NeedSizeType(dataio_type, public_type, StringType) + mo = BoolType.TYPE_PATTERN.fullmatch(dataio_type) if mo is not None: return BoolType(mo, public_type) @@ -450,6 +474,37 @@ class FieldType(ABC): # default fallback case return BasicType(dataio_type, public_type) + @abstractmethod + def array(self, size: SizeInfo) -> "FieldType": + raise NotImplementedError + + @abstractmethod + def __str__(self) -> str: + return super().__str__() + + def __repr__(self) -> str: + return "<{self.__class__.__name__} {self}>".format(self = self) + + +class NeedSizeType(RawFieldType): + """Helper class for field types that require a size to be usable.""" + + def __init__(self, dataio_type: str, public_type: str, cls: typing.Callable[[str, str, SizeInfo], "FieldType"]): + self.dataio_type = dataio_type + self.public_type = public_type + self.cls = cls + + def array(self, size: SizeInfo) -> "FieldType": + return self.cls(self.dataio_type, self.public_type, size) + + def __str__(self) -> str: + return "{self.dataio_type}({self.public_type})".format(self = self) + + +class FieldType(RawFieldType): + """Abstract base class (ABC) for classes representing type information + usable for fields of a packet""" + dataio_type = "" public_type = "" @@ -468,16 +523,13 @@ class FieldType(ABC): raise NotImplementedError @abstractmethod - def sizes(self) -> typing.Iterable[SizeInfo]: - """Yield the sizes associated with this type, from outer to inner""" + def get_code_fill(self, location: Location) -> str: raise NotImplementedError @abstractmethod - def __str__(self) -> str: - return super().__str__() - - def __repr__(self) -> str: - return "<{self.__class__.__name__} {self}>".format(self = self) + def sizes(self) -> typing.Iterable[SizeInfo]: + """Yield the sizes associated with this type, from outer to inner""" + raise NotImplementedError class BasicType(FieldType): @@ -492,6 +544,11 @@ class BasicType(FieldType): {self.public_type} {location}; """.format(self = self, location = location) + def get_code_fill(self, location: Location) -> str: + return """\ +real_packet->{location} = {location}; +""".format(location = location) + def sizes(self) -> typing.Iterable[SizeInfo]: return () @@ -563,6 +620,66 @@ class FloatType(BasicType): return "{self.dataio_type}{self.float_factor:d}({self.public_type})".format(self = self) +class WorklistType(BasicType): + """Type information for a worklist field""" + + def __init__(self, dataio_type: str, public_type: str): + if dataio_type != "worklist": + raise ValueError("not a valid worklist type") + + if public_type != "struct worklist": + raise ValueError("worklist dataio type with non-worklist public type: %r" % public_type) + + super().__init__(dataio_type, public_type) + + def get_code_fill(self, location: Location) -> str: + return """\ +worklist_copy(&real_packet->{location}, {location}); +""".format(location = location) + + +class SizedType(BasicType): + """Abstract base class (ABC) for field types that include a size""" + + def __init__(self, dataio_type: str, public_type: str, size: SizeInfo): + super().__init__(dataio_type, public_type) + self.size = size + + def get_code_declaration(self, location: Location) -> str: + return super().get_code_declaration( + location.deeper("%s[%s]" % (location, self.size.declared)) + ) + + @abstractmethod + def get_code_fill(self, location: Location) -> str: + return super().get_code_fill(location) + + def sizes(self) -> typing.Iterable[SizeInfo]: + yield self.size + yield from super().sizes() + + def __str__(self) -> str: + return "%s[%s]" % (super().__str__(), self.size) + + +class StringType(SizedType): + """Type information for a string field""" + + def __init__(self, dataio_type: str, public_type: str, size: SizeInfo): + if dataio_type not in ("string", "estring"): + raise ValueError("not a valid string type") + + if public_type != "char": + raise ValueError("string dataio type with non-char public type: %r" % public_type) + + super().__init__(dataio_type, public_type, size) + + def get_code_fill(self, location: Location) -> str: + return """\ +sz_strlcpy(real_packet->{location}, {location}); +""".format(location = location) + + class ArrayType(FieldType): """Type information for an array field. Consists of size information and another FieldType for the array's elements, which may also be an @@ -585,6 +702,18 @@ class ArrayType(FieldType): location.deeper("%s[%s]" % (location, self.size.declared)) ) + def get_code_fill(self, location: Location) -> str: + inner_fill = prefix(" ", self.elem.get_code_fill(location.sub)) + return """\ +{{ + int {location.index}; + + for ({location.index} = 0; {location.index} < {self.size.real}; {location.index}++) {{ +{inner_fill}\ + }} +}} +""".format(self = self, location = location, inner_fill = inner_fill) + def sizes(self) -> typing.Iterable[SizeInfo]: yield self.size yield from self.elem.sizes() @@ -613,7 +742,7 @@ class Field: - the final array size""" @classmethod - def parse(cls, line: str, resolve_type: typing.Callable[[str], FieldType]) -> "typing.Iterable[Field]": + def parse(cls, line: str, resolve_type: typing.Callable[[str], RawFieldType]) -> "typing.Iterable[Field]": """Parse a single line defining one or more fields""" mo = cls.FIELDS_LINE_PATTERN.fullmatch(line) if mo is None: @@ -633,6 +762,10 @@ class Field: field_text = mo.group(1) field_type = field_type.array(SizeInfo.parse(mo.group(2))) mo = cls.FIELD_ARRAY_PATTERN.fullmatch(field_text) + + if not isinstance(field_type, FieldType): + raise ValueError("need an array size to use type %s" % field_type) + yield Field(field_text, field_type, flag_info) def __init__(self, name: str, type_info: FieldType, flags: FieldFlags): @@ -706,30 +839,7 @@ class Field: # Returns code which copies the arguments of the direct send # functions in the packet struct. def get_fill(self) -> str: - if self.dataio_type=="worklist": - return """\ -worklist_copy(&real_packet->{self.name}, {self.name}); -""".format(self = self) - if self.dimensions == 0: - return """\ -real_packet->{self.name} = {self.name}; -""".format(self = self) - if self.dataio_type=="string" or self.dataio_type=="estring": - return """\ -sz_strlcpy(real_packet->{self.name}, {self.name}); -""".format(self = self) - if self.dimensions == 1: - return """\ -{{ - int i; - - for (i = 0; i < {self.sizes[0].real}; i++) {{ - real_packet->{self.name}[i] = {self.name}[i]; - }} -}} -""".format(self = self) - - return repr(self.__dict__) + return self.type_info.get_code_fill(Location(self.name)) # Returns code which sets "differ" by comparing the field # instances of "old" and "readl_packet". @@ -2093,7 +2203,7 @@ class Packet: CANCEL_PATTERN = re.compile(r"^cancel\((.*)\)$") def __init__(self, packet_type: str, packet_number: int, flags_text: str, - lines: typing.Iterable[str], resolve_type: typing.Callable[[str], FieldType]): + lines: typing.Iterable[str], resolve_type: typing.Callable[[str], RawFieldType]): self.type = packet_type self.type_number = packet_number @@ -2569,10 +2679,10 @@ class PacketsDefinition(typing.Iterable[Packet]): raise ValueError("Unexpected line: " + line) - def resolve_type(self, type_text: str) -> FieldType: + def resolve_type(self, type_text: str) -> RawFieldType: """Resolve the given type""" if type_text not in self.types: - self.types[type_text] = FieldType.parse(type_text) + self.types[type_text] = RawFieldType.parse(type_text) return self.types[type_text] def define_type(self, alias: str, meaning: str): -- 2.34.1