1
0
Fork 0
mirror of https://git.rwth-aachen.de/acs/public/villas/node/ synced 2025-03-09 00:00:00 +01:00
VILLASnode/python/villas/node/formats.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

331 lines
9.1 KiB
Python
Raw Permalink Normal View History

"""
Author: Philipp Jungkamp <Philipp.Jungkamp@opal-rt.com>
SPDX-FileCopyrightText: 2023 OPAL-RT Germany GmbH
SPDX-License-Identifier: Apache-2.0
""" # noqa: E501
import re
from dataclasses import dataclass, field
from itertools import groupby
from typing import Iterable
from villas.node.sample import Sample, Signal, Timestamp
import villas.node.villas_pb2 as pb
class SignalList(list[type]):
types = {
"b": bool,
"i": int,
"f": float,
"c": complex,
}
_type_to_char = {t: c for c, t in types.items()}
def __init__(self, fmt: str | Sample | Iterable[type] = "64f"):
if isinstance(fmt, Sample):
super().__init__(map(type, fmt.data) if fmt.data else [])
return
elif not isinstance(fmt, str):
super().__init__(fmt)
return
super().__init__()
regex = f"([{''.join(self.__class__.types.keys())}])"
fields = iter(re.split(regex, fmt))
while (count_str := next(fields, None)) is not None:
if count_str:
count = int(count_str)
else:
count = 1
try:
ty_str = next(fields)
except StopIteration:
if count_str:
raise ValueError("Expected type specifier.")
else:
break
try:
ty = self.__class__.types[ty_str]
except KeyError:
raise ValueError(f"Unknown type {ty_str}")
self.extend([ty] * count)
def __str__(self):
fmt = ""
for ty, run in groupby(self):
run_length = sum(1 for _ in run)
c = self.__class__._type_to_char[ty]
if run_length > 1:
fmt += f"{run_length}"
fmt += f"{c}"
return fmt
def __repr__(self):
return f"{self.__class__.__name__}('{self.__str__()}')"
@dataclass(kw_only=True)
class Format:
"""
The base for VILLASnode formats in Python.
"""
ts_origin: bool = True
ts_received: bool = True
sequence: bool = True
data: bool = True
def _strip_sample(self, sample: Sample) -> Sample:
if not self.ts_origin:
sample.ts_origin = None
if not self.ts_received:
sample.ts_received = None
if not self.sequence:
sample.sequence = None
if not self.data:
sample.data = []
return sample
@dataclass
class VillasHuman(Format):
"""
The villas.human format in Python.
"""
signal_list: SignalList = field(default_factory=SignalList)
separator: str = "\t"
delimiter: str = "\n"
def load(self, file) -> list[Sample]:
"""
Load samples from a text mode file object.
"""
return self.loads(file.read())
def loads(self, s: str) -> list[Sample]:
"""
Load samples from a string.
"""
s.strip(self.separator + self.delimiter)
sample_strs = s.split(sep=self.delimiter)
samples = (self.load_sample(sample) for sample in sample_strs)
return [s for s in samples if s is not None]
def dump(self, samples: Iterable[Sample], file):
"""
Dump samples to a text mode file object.
"""
return file.write(self.dumps(samples))
def dumps(self, samples: Iterable[Sample]) -> str:
"""
Dump samples to a string.
"""
sample_strs = (self.dump_sample(sample) for sample in iter(samples))
return "".join(sample_strs)
def load_sample(self, sample: str) -> Sample | None:
"""
Load a single sample from a string.
"""
sample = sample.strip(self.delimiter)
if sample.startswith("#"):
return None
fields = sample.split(sep=self.separator)
if not fields[0]:
return None
m = re.match(
r"(\d+)(?:\.(\d+))?([-+]\d+(?:\.\d+)?"
r"(?:e[+-]?\d+)?)?(?:\((\d+)\))?(F)?",
fields[0],
)
if m is None:
raise ValueError(f"Invalid header: {fields[0]}")
ts_seconds = int(m.group(1))
ts_nanoseconds = int(m.group(2)) if m.group(2) else 0
ts_offset = float(m.group(3)) if m.group(3) else None
sequence = int(m.group(4)) if m.group(4) else None
new_frame = bool(m.group(5))
ts_origin = Timestamp(ts_seconds, ts_nanoseconds)
if ts_offset is not None:
ts_received_raw = ts_origin.timestamp() + ts_offset
ts_received = Timestamp.fromtimestamp(ts_received_raw)
else:
ts_received = None
data: list[Signal] = []
for ty, value in zip(self.signal_list, fields[1:]):
if ty is bool:
data.append(bool(int(value)))
elif ty is int:
data.append(int(value))
elif ty is float:
data.append(float(value))
elif ty is complex:
data.append(self._unpack_complex(value))
return self._strip_sample(
Sample(
ts_origin=ts_origin,
ts_received=ts_received,
sequence=sequence,
new_frame=new_frame,
data=data,
)
)
def dump_sample(self, smp: Sample) -> str:
"""
Dump a single sample to a string.
"""
smp = self._strip_sample(smp)
s = ""
if smp.ts_origin is not None:
s += f"{smp.ts_origin.seconds}"
if smp.ts_origin.nanoseconds != 0:
s += f".{smp.ts_origin.nanoseconds:09}"
if smp.ts_received is not None:
off = smp.ts_received.timestamp() - smp.ts_origin.timestamp()
s += f"+{off}"
if smp.sequence is not None:
s += f"({smp.sequence})"
if smp.new_frame:
s += "F"
for ty, value in zip(self.signal_list, smp.data):
s += self.separator
assert ty == type(value)
match value:
case bool():
s += str(int(value))
case int():
s += str(value)
case float():
s += str(value)
case complex():
s += self._pack_complex(value)
s += self.delimiter
return s
def _unpack_complex(self, s: str) -> complex:
return complex(s.lower().replace("i", "j"))
def _pack_complex(self, z: complex) -> str:
return f"{z.real}+{z.imag}i"
@dataclass
class Protobuf(Format):
"""
The protobuf format in Python.
"""
def loadb(self, b: bytes) -> list[Sample]:
msg = pb.Message()
msg.ParseFromString(b)
return [self.load_sample(sample) for sample in msg.samples]
def dumpb(self, samples: Iterable[Sample]) -> bytes:
msg = pb.Message()
for sample in samples:
msg.samples.append(self.dump_sample(sample))
return msg.SerializeToString()
def load_sample(self, pb_sample: pb.Sample) -> Sample:
sample = Sample()
if pb_sample.HasField("ts_origin"):
sample.ts_origin = Timestamp(
pb_sample.ts_origin.sec, pb_sample.ts_origin.nsec
)
if pb_sample.HasField("ts_received"):
sample.ts_received = Timestamp(
pb_sample.ts_received.sec, pb_sample.ts_received.nsec
)
if pb_sample.HasField("new_frame"):
sample.new_frame = pb_sample.new_frame
if pb_sample.HasField("sequence"):
sample.sequence = pb_sample.sequence
for value in pb_sample.values:
if value.HasField("i"):
sample.data.append(value.i)
elif value.HasField("f"):
sample.data.append(value.f)
elif value.HasField("b"):
sample.data.append(value.b)
elif value.HasField("z"):
sample.data.append(complex(value.z.real, value.z.imag))
else:
raise Exception("Missing value")
return self._strip_sample(sample)
def dump_sample(self, sample: Sample) -> pb.Sample:
pb_sample = pb.Sample()
pb_sample.type = pb.Sample.Type.DATA
pb_sample.new_frame = sample.new_frame
if sample.ts_origin:
pb_sample.ts_origin.sec = sample.ts_origin.seconds
pb_sample.ts_origin.nsec = sample.ts_origin.nanoseconds
if sample.ts_received:
pb_sample.ts_received.sec = sample.ts_received.seconds
pb_sample.ts_received.nsec = sample.ts_received.nanoseconds
if sample.sequence:
pb_sample.sequence = sample.sequence
for value in sample.data:
pb_value = pb.Value()
if isinstance(value, int):
pb_value.i = value
elif isinstance(value, float):
pb_value.f = value
elif isinstance(value, bool):
pb_value.b = value
elif isinstance(value, complex):
pb_value.z.real = value.real
pb_value.z.imag = value.imag
pb_sample.values.append(pb_value)
return pb_sample