diff --git a/tools/data_analyze.py b/tools/data_analyze.py index ba04d88..84aada4 100644 --- a/tools/data_analyze.py +++ b/tools/data_analyze.py @@ -2,6 +2,7 @@ from pathlib import Path from sys import exit import numpy as np +import open3d as o3d from configargparse import ( ArgParser, ArgumentDefaultsRawHelpFormatter, @@ -10,9 +11,32 @@ from configargparse import ( from numpy.lib import recfunctions as rfn from rich.progress import Progress from rosbags.highlevel import AnyReader +from rosbags.rosbag1 import Writer +from rosbags.typesys import Stores, get_types_from_msg, get_typestore +from rosbags.typesys.stores.ros1_noetic import ( + sensor_msgs__msg__PointCloud2 as PointCloud2, +) +from rosbags.typesys.stores.ros1_noetic import ( + sensor_msgs__msg__PointField as PointField, +) from util import existing_path + +def get_o3d_pointcloud(points: np.ndarray) -> o3d.geometry.PointCloud: + xyz_array = np.stack( + [ + points["x"].astype(np.float64), + points["y"].astype(np.float64), + points["z"].astype(np.float64), + ], + axis=-1, + ) + filtered_xyz = xyz_array[~np.all(xyz_array == 0, axis=1)] + o3d_vector = o3d.utility.Vector3dVector(filtered_xyz) + return o3d.geometry.PointCloud(o3d_vector) + + # Mapping of PointField datatypes to NumPy dtypes POINTFIELD_DATATYPES = { 1: np.int8, # INT8 @@ -25,6 +49,18 @@ POINTFIELD_DATATYPES = { 8: np.float64, # FLOAT64 } +# Reverse map from numpy dtype to PointField datatype code +REVERSE_POINTFIELD_DATATYPES = { + np.dtype(np.int8).type: 1, + np.dtype(np.uint8).type: 2, + np.dtype(np.int16).type: 3, + np.dtype(np.uint16).type: 4, + np.dtype(np.int32).type: 5, + np.dtype(np.uint32).type: 6, + np.dtype(np.float32).type: 7, + np.dtype(np.float64).type: 8, +} + def read_pointcloud(msg): # Build the dtype dynamically from the fields @@ -42,10 +78,7 @@ def read_pointcloud(msg): if field.offset > current_offset: gap_size = field.offset - current_offset gap_field_name = f"unused_{current_offset}" - dtype_fields[gap_field_name] = ( - f"V{gap_size}", - current_offset, - ) # Raw bytes as filler + dtype_fields[gap_field_name] = (f"V{gap_size}", current_offset) current_offset += gap_size dtype_fields[field.name] = (np_dtype, field.offset) @@ -61,7 +94,7 @@ def read_pointcloud(msg): return np.frombuffer(msg.data, dtype=dtype) -def clean_pointcloud(points): +def clean_pointcloud(points) -> np.ndarray: valid_fields = [ name for name in points.dtype.names if not name.startswith("unused_") ] @@ -69,6 +102,38 @@ def clean_pointcloud(points): return cleaned_points +def create_pointcloud2_msg(original_msg, cleaned_points): + new_fields = [] + offset = 0 + for name in cleaned_points.dtype.names: + np_dtype = cleaned_points.dtype[name].type + if np_dtype not in REVERSE_POINTFIELD_DATATYPES: + raise ValueError(f"No PointField datatype code for dtype {np_dtype}") + new_fields.append( + PointField( + name=name, + offset=offset, + datatype=REVERSE_POINTFIELD_DATATYPES[np_dtype], + count=1, + ) + ) + offset += cleaned_points.dtype[name].itemsize + + return PointCloud2( + header=original_msg.header, + height=original_msg.height, + width=original_msg.width, + fields=new_fields, + is_bigendian=original_msg.is_bigendian, + point_step=cleaned_points.itemsize, + row_step=int( + (cleaned_points.itemsize * cleaned_points.size) / original_msg.height + ), + is_dense=original_msg.is_dense, + data=cleaned_points.view("uint8"), + ) + + def main() -> int: parser = ArgParser( config_file_parser_class=YAMLConfigFileParser, @@ -107,24 +172,63 @@ def main() -> int: output_file_paths=[(output_path / "config.yaml").as_posix()], ) + cleaned_bag_path = output_path / "cleaned_bag.bag" + cleaned_bag_path.unlink(missing_ok=True) + with AnyReader([args.input_experiment_path]) as reader: connections = reader.connections topics = dict(sorted({conn.topic: conn for conn in connections}.items())) assert ( args.pointcloud_topic in topics ), f"Topic {args.pointcloud_topic} not found" - topic = topics[args.pointcloud_topic] - with Progress() as progress: - task = progress.add_task("Analyzing data", total=topic.msgcount) - for connection, timestamp, rawdata in reader.messages(connections=[topic]): - pointcloud_msg = reader.deserialize(rawdata, connection.msgtype) - original_pointcloud = read_pointcloud(pointcloud_msg) - cleaned_pointcloud = clean_pointcloud(original_pointcloud) + original_types = {} + typestore = get_typestore(Stores.ROS1_NOETIC) + for connection in topics.values(): + original_types.update( + get_types_from_msg(connection.msgdef, connection.msgtype) + ) + typestore.register(original_types) - # todo add analysis here + with Writer(cleaned_bag_path) as writer: + # Add all connections to the new bag + new_connections = {} + for topic_name, old_conn in topics.items(): + new_conn = writer.add_connection( + topic=topic_name, + msgtype=old_conn.msgtype, + msgdef=old_conn.msgdef, + typestore=typestore, + ) + new_connections[old_conn.id] = new_conn - progress.advance(task) + with Progress() as progress: + task = progress.add_task( + "Processing all messages", total=reader.message_count + ) + + for connection, timestamp, rawdata in reader.messages(): + if connection.topic == args.pointcloud_topic: + # For the pointcloud topic, we need to modify the data + msg = reader.deserialize(rawdata, connection.msgtype) + original_pointcloud = read_pointcloud(msg) + cleaned_pointcloud = clean_pointcloud(original_pointcloud) + + o3d_pcd = get_o3d_pointcloud(cleaned_pointcloud) + + msg = create_pointcloud2_msg(msg, cleaned_pointcloud) + writer.write( + connection=new_connections[connection.id], + timestamp=timestamp, + data=typestore.serialize_ros1( + message=msg, typename=msg.__msgtype__ + ), + ) + else: + # For all other topics, we can write rawdata directly, no need to deserialize + writer.write(new_connections[connection.id], timestamp, rawdata) + + progress.advance(task) return 0