correctly write rosbags

This commit is contained in:
Jan Kowalczyk
2024-12-16 17:30:49 +01:00
parent d118d40795
commit fe6f6f449f

View File

@@ -2,6 +2,7 @@ from pathlib import Path
from sys import exit
import numpy as np
import open3d as o3d
from configargparse import (
ArgParser,
ArgumentDefaultsRawHelpFormatter,
@@ -10,9 +11,32 @@ from configargparse import (
from numpy.lib import recfunctions as rfn
from rich.progress import Progress
from rosbags.highlevel import AnyReader
from rosbags.rosbag1 import Writer
from rosbags.typesys import Stores, get_types_from_msg, get_typestore
from rosbags.typesys.stores.ros1_noetic import (
sensor_msgs__msg__PointCloud2 as PointCloud2,
)
from rosbags.typesys.stores.ros1_noetic import (
sensor_msgs__msg__PointField as PointField,
)
from util import existing_path
def get_o3d_pointcloud(points: np.ndarray) -> o3d.geometry.PointCloud:
xyz_array = np.stack(
[
points["x"].astype(np.float64),
points["y"].astype(np.float64),
points["z"].astype(np.float64),
],
axis=-1,
)
filtered_xyz = xyz_array[~np.all(xyz_array == 0, axis=1)]
o3d_vector = o3d.utility.Vector3dVector(filtered_xyz)
return o3d.geometry.PointCloud(o3d_vector)
# Mapping of PointField datatypes to NumPy dtypes
POINTFIELD_DATATYPES = {
1: np.int8, # INT8
@@ -25,6 +49,18 @@ POINTFIELD_DATATYPES = {
8: np.float64, # FLOAT64
}
# Reverse map from numpy dtype to PointField datatype code
REVERSE_POINTFIELD_DATATYPES = {
np.dtype(np.int8).type: 1,
np.dtype(np.uint8).type: 2,
np.dtype(np.int16).type: 3,
np.dtype(np.uint16).type: 4,
np.dtype(np.int32).type: 5,
np.dtype(np.uint32).type: 6,
np.dtype(np.float32).type: 7,
np.dtype(np.float64).type: 8,
}
def read_pointcloud(msg):
# Build the dtype dynamically from the fields
@@ -42,10 +78,7 @@ def read_pointcloud(msg):
if field.offset > current_offset:
gap_size = field.offset - current_offset
gap_field_name = f"unused_{current_offset}"
dtype_fields[gap_field_name] = (
f"V{gap_size}",
current_offset,
) # Raw bytes as filler
dtype_fields[gap_field_name] = (f"V{gap_size}", current_offset)
current_offset += gap_size
dtype_fields[field.name] = (np_dtype, field.offset)
@@ -61,7 +94,7 @@ def read_pointcloud(msg):
return np.frombuffer(msg.data, dtype=dtype)
def clean_pointcloud(points):
def clean_pointcloud(points) -> np.ndarray:
valid_fields = [
name for name in points.dtype.names if not name.startswith("unused_")
]
@@ -69,6 +102,38 @@ def clean_pointcloud(points):
return cleaned_points
def create_pointcloud2_msg(original_msg, cleaned_points):
new_fields = []
offset = 0
for name in cleaned_points.dtype.names:
np_dtype = cleaned_points.dtype[name].type
if np_dtype not in REVERSE_POINTFIELD_DATATYPES:
raise ValueError(f"No PointField datatype code for dtype {np_dtype}")
new_fields.append(
PointField(
name=name,
offset=offset,
datatype=REVERSE_POINTFIELD_DATATYPES[np_dtype],
count=1,
)
)
offset += cleaned_points.dtype[name].itemsize
return PointCloud2(
header=original_msg.header,
height=original_msg.height,
width=original_msg.width,
fields=new_fields,
is_bigendian=original_msg.is_bigendian,
point_step=cleaned_points.itemsize,
row_step=int(
(cleaned_points.itemsize * cleaned_points.size) / original_msg.height
),
is_dense=original_msg.is_dense,
data=cleaned_points.view("uint8"),
)
def main() -> int:
parser = ArgParser(
config_file_parser_class=YAMLConfigFileParser,
@@ -107,24 +172,63 @@ def main() -> int:
output_file_paths=[(output_path / "config.yaml").as_posix()],
)
cleaned_bag_path = output_path / "cleaned_bag.bag"
cleaned_bag_path.unlink(missing_ok=True)
with AnyReader([args.input_experiment_path]) as reader:
connections = reader.connections
topics = dict(sorted({conn.topic: conn for conn in connections}.items()))
assert (
args.pointcloud_topic in topics
), f"Topic {args.pointcloud_topic} not found"
topic = topics[args.pointcloud_topic]
with Progress() as progress:
task = progress.add_task("Analyzing data", total=topic.msgcount)
for connection, timestamp, rawdata in reader.messages(connections=[topic]):
pointcloud_msg = reader.deserialize(rawdata, connection.msgtype)
original_pointcloud = read_pointcloud(pointcloud_msg)
cleaned_pointcloud = clean_pointcloud(original_pointcloud)
original_types = {}
typestore = get_typestore(Stores.ROS1_NOETIC)
for connection in topics.values():
original_types.update(
get_types_from_msg(connection.msgdef, connection.msgtype)
)
typestore.register(original_types)
# todo add analysis here
with Writer(cleaned_bag_path) as writer:
# Add all connections to the new bag
new_connections = {}
for topic_name, old_conn in topics.items():
new_conn = writer.add_connection(
topic=topic_name,
msgtype=old_conn.msgtype,
msgdef=old_conn.msgdef,
typestore=typestore,
)
new_connections[old_conn.id] = new_conn
progress.advance(task)
with Progress() as progress:
task = progress.add_task(
"Processing all messages", total=reader.message_count
)
for connection, timestamp, rawdata in reader.messages():
if connection.topic == args.pointcloud_topic:
# For the pointcloud topic, we need to modify the data
msg = reader.deserialize(rawdata, connection.msgtype)
original_pointcloud = read_pointcloud(msg)
cleaned_pointcloud = clean_pointcloud(original_pointcloud)
o3d_pcd = get_o3d_pointcloud(cleaned_pointcloud)
msg = create_pointcloud2_msg(msg, cleaned_pointcloud)
writer.write(
connection=new_connections[connection.id],
timestamp=timestamp,
data=typestore.serialize_ros1(
message=msg, typename=msg.__msgtype__
),
)
else:
# For all other topics, we can write rawdata directly, no need to deserialize
writer.write(new_connections[connection.id], timestamp, rawdata)
progress.advance(task)
return 0