from pathlib import Path from sys import exit import numpy as np import open3d as o3d from configargparse import ( ArgParser, ArgumentDefaultsRawHelpFormatter, YAMLConfigFileParser, ) from numpy.lib import recfunctions as rfn from rich.progress import Progress from rosbags.highlevel import AnyReader from rosbags.rosbag1 import Writer from rosbags.typesys import Stores, get_types_from_msg, get_typestore from rosbags.typesys.stores.ros1_noetic import ( sensor_msgs__msg__PointCloud2 as PointCloud2, ) from rosbags.typesys.stores.ros1_noetic import ( sensor_msgs__msg__PointField as PointField, ) from util import existing_path def get_o3d_pointcloud(points: np.ndarray) -> o3d.geometry.PointCloud: xyz_array = np.stack( [ points["x"].astype(np.float64), points["y"].astype(np.float64), points["z"].astype(np.float64), ], axis=-1, ) filtered_xyz = xyz_array[~np.all(xyz_array == 0, axis=1)] o3d_vector = o3d.utility.Vector3dVector(filtered_xyz) return o3d.geometry.PointCloud(o3d_vector) # Mapping of PointField datatypes to NumPy dtypes POINTFIELD_DATATYPES = { 1: np.int8, # INT8 2: np.uint8, # UINT8 3: np.int16, # INT16 4: np.uint16, # UINT16 5: np.int32, # INT32 6: np.uint32, # UINT32 7: np.float32, # FLOAT32 8: np.float64, # FLOAT64 } # Reverse map from numpy dtype to PointField datatype code REVERSE_POINTFIELD_DATATYPES = { np.dtype(np.int8).type: 1, np.dtype(np.uint8).type: 2, np.dtype(np.int16).type: 3, np.dtype(np.uint16).type: 4, np.dtype(np.int32).type: 5, np.dtype(np.uint32).type: 6, np.dtype(np.float32).type: 7, np.dtype(np.float64).type: 8, } def read_pointcloud(msg): # Build the dtype dynamically from the fields dtype_fields = {} column_names = [] current_offset = 0 for field in msg.fields: np_dtype = POINTFIELD_DATATYPES.get(field.datatype) if np_dtype is None: raise ValueError( f"Unsupported datatype {field.datatype} for field {field.name}" ) if field.offset > current_offset: gap_size = field.offset - current_offset gap_field_name = f"unused_{current_offset}" dtype_fields[gap_field_name] = (f"V{gap_size}", current_offset) current_offset += gap_size dtype_fields[field.name] = (np_dtype, field.offset) column_names.append(field.name) current_offset = field.offset + np_dtype().itemsize if current_offset < msg.point_step: gap_size = msg.point_step - current_offset gap_field_name = f"unused_{current_offset}" dtype_fields[gap_field_name] = (f"V{gap_size}", current_offset) dtype = np.dtype(dtype_fields) return np.frombuffer(msg.data, dtype=dtype) def clean_pointcloud(points) -> np.ndarray: valid_fields = [ name for name in points.dtype.names if not name.startswith("unused_") ] cleaned_points = rfn.repack_fields(points[valid_fields]) return cleaned_points def create_pointcloud2_msg(original_msg, cleaned_points): new_fields = [] offset = 0 for name in cleaned_points.dtype.names: np_dtype = cleaned_points.dtype[name].type if np_dtype not in REVERSE_POINTFIELD_DATATYPES: raise ValueError(f"No PointField datatype code for dtype {np_dtype}") new_fields.append( PointField( name=name, offset=offset, datatype=REVERSE_POINTFIELD_DATATYPES[np_dtype], count=1, ) ) offset += cleaned_points.dtype[name].itemsize return PointCloud2( header=original_msg.header, height=original_msg.height, width=original_msg.width, fields=new_fields, is_bigendian=original_msg.is_bigendian, point_step=cleaned_points.itemsize, row_step=int( (cleaned_points.itemsize * cleaned_points.size) / original_msg.height ), is_dense=original_msg.is_dense, data=cleaned_points.view("uint8"), ) def main() -> int: parser = ArgParser( config_file_parser_class=YAMLConfigFileParser, default_config_files=["data_analyze_config.yaml"], formatter_class=ArgumentDefaultsRawHelpFormatter, description="Analyse data from a rosbag or mcap file and output additional data", ) parser.add_argument( "--config-file", is_config_file=True, help="yaml config file path" ) parser.add_argument( "--input-experiment-path", required=True, type=existing_path, help="path to experiment. (directly to bag file, to parent folder for mcap)", ) parser.add_argument( "--pointcloud-topic", default="/ouster/points", type=str, help="topic in the ros/mcap bag file containing the point cloud data", ) parser.add_argument( "--output-path", default=Path("./output"), type=Path, help="path augmented dataset should be written to", ) args = parser.parse_args() output_path = args.output_path / args.input_experiment_path.stem output_path.mkdir(parents=True, exist_ok=True) parser.write_config_file( parser.parse_known_args()[0], output_file_paths=[(output_path / "config.yaml").as_posix()], ) cleaned_bag_path = output_path / "cleaned_bag.bag" cleaned_bag_path.unlink(missing_ok=True) with AnyReader([args.input_experiment_path]) as reader: connections = reader.connections topics = dict(sorted({conn.topic: conn for conn in connections}.items())) assert ( args.pointcloud_topic in topics ), f"Topic {args.pointcloud_topic} not found" original_types = {} typestore = get_typestore(Stores.ROS1_NOETIC) for connection in topics.values(): original_types.update( get_types_from_msg(connection.msgdef, connection.msgtype) ) typestore.register(original_types) with Writer(cleaned_bag_path) as writer: # Add all connections to the new bag new_connections = {} for topic_name, old_conn in topics.items(): new_conn = writer.add_connection( topic=topic_name, msgtype=old_conn.msgtype, msgdef=old_conn.msgdef, typestore=typestore, ) new_connections[old_conn.id] = new_conn with Progress() as progress: task = progress.add_task( "Processing all messages", total=reader.message_count ) for connection, timestamp, rawdata in reader.messages(): if connection.topic == args.pointcloud_topic: # For the pointcloud topic, we need to modify the data msg = reader.deserialize(rawdata, connection.msgtype) original_pointcloud = read_pointcloud(msg) cleaned_pointcloud = clean_pointcloud(original_pointcloud) o3d_pcd = get_o3d_pointcloud(cleaned_pointcloud) msg = create_pointcloud2_msg(msg, cleaned_pointcloud) writer.write( connection=new_connections[connection.id], timestamp=timestamp, data=typestore.serialize_ros1( message=msg, typename=msg.__msgtype__ ), ) else: # For all other topics, we can write rawdata directly, no need to deserialize writer.write(new_connections[connection.id], timestamp, rawdata) progress.advance(task) return 0 if __name__ == "__main__": exit(main())