correctly write rosbags

This commit is contained in:
Jan Kowalczyk
2024-12-16 17:30:49 +01:00
parent d118d40795
commit fe6f6f449f

View File

@@ -2,6 +2,7 @@ from pathlib import Path
from sys import exit from sys import exit
import numpy as np import numpy as np
import open3d as o3d
from configargparse import ( from configargparse import (
ArgParser, ArgParser,
ArgumentDefaultsRawHelpFormatter, ArgumentDefaultsRawHelpFormatter,
@@ -10,9 +11,32 @@ from configargparse import (
from numpy.lib import recfunctions as rfn from numpy.lib import recfunctions as rfn
from rich.progress import Progress from rich.progress import Progress
from rosbags.highlevel import AnyReader from rosbags.highlevel import AnyReader
from rosbags.rosbag1 import Writer
from rosbags.typesys import Stores, get_types_from_msg, get_typestore
from rosbags.typesys.stores.ros1_noetic import (
sensor_msgs__msg__PointCloud2 as PointCloud2,
)
from rosbags.typesys.stores.ros1_noetic import (
sensor_msgs__msg__PointField as PointField,
)
from util import existing_path from util import existing_path
def get_o3d_pointcloud(points: np.ndarray) -> o3d.geometry.PointCloud:
xyz_array = np.stack(
[
points["x"].astype(np.float64),
points["y"].astype(np.float64),
points["z"].astype(np.float64),
],
axis=-1,
)
filtered_xyz = xyz_array[~np.all(xyz_array == 0, axis=1)]
o3d_vector = o3d.utility.Vector3dVector(filtered_xyz)
return o3d.geometry.PointCloud(o3d_vector)
# Mapping of PointField datatypes to NumPy dtypes # Mapping of PointField datatypes to NumPy dtypes
POINTFIELD_DATATYPES = { POINTFIELD_DATATYPES = {
1: np.int8, # INT8 1: np.int8, # INT8
@@ -25,6 +49,18 @@ POINTFIELD_DATATYPES = {
8: np.float64, # FLOAT64 8: np.float64, # FLOAT64
} }
# Reverse map from numpy dtype to PointField datatype code
REVERSE_POINTFIELD_DATATYPES = {
np.dtype(np.int8).type: 1,
np.dtype(np.uint8).type: 2,
np.dtype(np.int16).type: 3,
np.dtype(np.uint16).type: 4,
np.dtype(np.int32).type: 5,
np.dtype(np.uint32).type: 6,
np.dtype(np.float32).type: 7,
np.dtype(np.float64).type: 8,
}
def read_pointcloud(msg): def read_pointcloud(msg):
# Build the dtype dynamically from the fields # Build the dtype dynamically from the fields
@@ -42,10 +78,7 @@ def read_pointcloud(msg):
if field.offset > current_offset: if field.offset > current_offset:
gap_size = field.offset - current_offset gap_size = field.offset - current_offset
gap_field_name = f"unused_{current_offset}" gap_field_name = f"unused_{current_offset}"
dtype_fields[gap_field_name] = ( dtype_fields[gap_field_name] = (f"V{gap_size}", current_offset)
f"V{gap_size}",
current_offset,
) # Raw bytes as filler
current_offset += gap_size current_offset += gap_size
dtype_fields[field.name] = (np_dtype, field.offset) dtype_fields[field.name] = (np_dtype, field.offset)
@@ -61,7 +94,7 @@ def read_pointcloud(msg):
return np.frombuffer(msg.data, dtype=dtype) return np.frombuffer(msg.data, dtype=dtype)
def clean_pointcloud(points): def clean_pointcloud(points) -> np.ndarray:
valid_fields = [ valid_fields = [
name for name in points.dtype.names if not name.startswith("unused_") name for name in points.dtype.names if not name.startswith("unused_")
] ]
@@ -69,6 +102,38 @@ def clean_pointcloud(points):
return cleaned_points return cleaned_points
def create_pointcloud2_msg(original_msg, cleaned_points):
new_fields = []
offset = 0
for name in cleaned_points.dtype.names:
np_dtype = cleaned_points.dtype[name].type
if np_dtype not in REVERSE_POINTFIELD_DATATYPES:
raise ValueError(f"No PointField datatype code for dtype {np_dtype}")
new_fields.append(
PointField(
name=name,
offset=offset,
datatype=REVERSE_POINTFIELD_DATATYPES[np_dtype],
count=1,
)
)
offset += cleaned_points.dtype[name].itemsize
return PointCloud2(
header=original_msg.header,
height=original_msg.height,
width=original_msg.width,
fields=new_fields,
is_bigendian=original_msg.is_bigendian,
point_step=cleaned_points.itemsize,
row_step=int(
(cleaned_points.itemsize * cleaned_points.size) / original_msg.height
),
is_dense=original_msg.is_dense,
data=cleaned_points.view("uint8"),
)
def main() -> int: def main() -> int:
parser = ArgParser( parser = ArgParser(
config_file_parser_class=YAMLConfigFileParser, config_file_parser_class=YAMLConfigFileParser,
@@ -107,24 +172,63 @@ def main() -> int:
output_file_paths=[(output_path / "config.yaml").as_posix()], output_file_paths=[(output_path / "config.yaml").as_posix()],
) )
cleaned_bag_path = output_path / "cleaned_bag.bag"
cleaned_bag_path.unlink(missing_ok=True)
with AnyReader([args.input_experiment_path]) as reader: with AnyReader([args.input_experiment_path]) as reader:
connections = reader.connections connections = reader.connections
topics = dict(sorted({conn.topic: conn for conn in connections}.items())) topics = dict(sorted({conn.topic: conn for conn in connections}.items()))
assert ( assert (
args.pointcloud_topic in topics args.pointcloud_topic in topics
), f"Topic {args.pointcloud_topic} not found" ), f"Topic {args.pointcloud_topic} not found"
topic = topics[args.pointcloud_topic]
with Progress() as progress: original_types = {}
task = progress.add_task("Analyzing data", total=topic.msgcount) typestore = get_typestore(Stores.ROS1_NOETIC)
for connection, timestamp, rawdata in reader.messages(connections=[topic]): for connection in topics.values():
pointcloud_msg = reader.deserialize(rawdata, connection.msgtype) original_types.update(
original_pointcloud = read_pointcloud(pointcloud_msg) get_types_from_msg(connection.msgdef, connection.msgtype)
cleaned_pointcloud = clean_pointcloud(original_pointcloud) )
typestore.register(original_types)
# todo add analysis here with Writer(cleaned_bag_path) as writer:
# Add all connections to the new bag
new_connections = {}
for topic_name, old_conn in topics.items():
new_conn = writer.add_connection(
topic=topic_name,
msgtype=old_conn.msgtype,
msgdef=old_conn.msgdef,
typestore=typestore,
)
new_connections[old_conn.id] = new_conn
progress.advance(task) with Progress() as progress:
task = progress.add_task(
"Processing all messages", total=reader.message_count
)
for connection, timestamp, rawdata in reader.messages():
if connection.topic == args.pointcloud_topic:
# For the pointcloud topic, we need to modify the data
msg = reader.deserialize(rawdata, connection.msgtype)
original_pointcloud = read_pointcloud(msg)
cleaned_pointcloud = clean_pointcloud(original_pointcloud)
o3d_pcd = get_o3d_pointcloud(cleaned_pointcloud)
msg = create_pointcloud2_msg(msg, cleaned_pointcloud)
writer.write(
connection=new_connections[connection.id],
timestamp=timestamp,
data=typestore.serialize_ros1(
message=msg, typename=msg.__msgtype__
),
)
else:
# For all other topics, we can write rawdata directly, no need to deserialize
writer.write(new_connections[connection.id], timestamp, rawdata)
progress.advance(task)
return 0 return 0