Source code for event_aug.spike_injection

from typing import Tuple, Union

import cv2
import h5py
import numpy as np

from event_aug.utils import video_to_array


[docs]def inject_event_spikes( event_data_path: str, save_path: str, spikes_video_path: str = None, spikes_arr: Union[np.ndarray, str] = None, spikes_arr_from_file: bool = False, memory_map: bool = True, resize_size: Tuple[int, int] = None, crop_size: Tuple[int, int] = None, fps: int = None, label: int = None, polarity: int = None, timestamp_keys: Tuple[str] = None, xy_keys: Tuple[str] = None, label_keys: Tuple[str] = None, polarity_keys: Tuple[str] = None, iterations: int = 1, verbose=False, ): """ Injects specified event spikes into existing event recordings data to serve as augmentation. Parameters ---------- event_data_path : str Path to the .h5 file containing the event recordings data. save_path : str Path to save the augmented event data with spikes injected to. spikes_video_path : str Path to the video containing the event spikes data as frames. spikes_arr : np.ndarray NumPy Array containing the event spikes data as frames. spikes_arr_from_file : bool Whether to read the spikes video array from a .npy file. memory_map : bool Whether to use memory mapping to read the spikes video array. resize_size : Tuple[int, int] If specified, the size to reshape the event spikes frames to. crop_size : Tuple[int, int] If specified, the size to crop the event spikes frames to. fps : int Frame rate to use for the event spikes to be injected. label : int Label to assign for the event spikes to be injected. polarity : int Polarity to assign for the event spikes to be injected. timestamp_keys : Tuple[str] Keys to use to index the event timestamp data. xy_keys : Tuple[str] Keys to use to index the event xy coordinates data. label_keys : Tuple[str] Keys to use to index the event label data. polarity_keys : Tuple[str] Keys to use to index the event polarity data. iterations : int Number of times to inject event spikes. verbose : bool Whether to print progress messages. """ assert event_data_path.endswith(".h5"), "Event data path must poin to a .h5 file" with h5py.File(event_data_path, "r") as f: keys = list(f.keys()) if timestamp_keys is None: timestamp_keys = [key for key in keys if key.startswith("timestamp")] assert len(timestamp_keys) != 0, "No timestamp data found in event data" if xy_keys is None: xy_keys = [key for key in keys if key.startswith("xy")] assert len(xy_keys) != 0, "No xy coordinate data found in event data" if label_keys is None: label_keys = [key for key in keys if key.startswith("label")] assert len(label_keys) != 0, "No label data found in event data" if polarity_keys is None: polarity_keys = [key for key in keys if key.startswith("polarity")] assert len(polarity_keys) != 0, "No polarity data found in event data" timestamps = { timestamp_key: f[timestamp_key][:].copy() for timestamp_key in timestamp_keys } xy_coords = {xy_key: f[xy_key][:].copy() for xy_key in xy_keys} labels = {label_key: f[label_key][:].copy() for label_key in label_keys} polarities = { polarity_key: f[polarity_key][:].copy() for polarity_key in polarity_keys } if label is None: label = 0 print("\nNo label provided. Using label '0' for all event spikes to be injected") if polarity is None: polarity = 0 print( "\nNo polarity provided. Using polarity '0' for all event spikes to be" " injected\n" ) assert (spikes_video_path is not None) or ( spikes_arr is not None ), "Either spikes_video_path or spikes_arr must be specified" if spikes_video_path is not None: assert spikes_video_path.endswith( ".mp4" ), "Spikes video path must point to a .mp4 file" spikes_arr, vid_fps = video_to_array( spikes_video_path, grayscale=True, return_fps=True ) if fps is None: fps = vid_fps print( "No frame rate provided for injection. Using the video's frame rate:" f" {fps}" ) if spikes_arr_from_file is True: assert type(spikes_arr) == str, "Spikes array must be a path to a .npy file" assert spikes_arr.endswith(".npy"), "Spikes array must be a .npy file" if memory_map is True: spikes_arr = np.load(spikes_arr, mmap_mode="r") else: spikes_arr = np.load(spikes_arr) spikes_arr = spikes_arr.astype("bool") assert spikes_arr.ndim == 3, ( "Spikes array must of shape (n_frames, height, width), where each 2D frame" " contains the event spikes data as 0s or 1s" ) assert fps is not None, "fps must be specified if a spikes array is given as input" total_events_injected = 0 for iteration in range(iterations): print(f"\nIteration {iteration + 1} of {iterations}") completed_frames = iteration * spikes_arr.shape[0] for n_frame in range(spikes_arr.shape[0]): if verbose is True: print(f"\nProcessing frame {n_frame} of the event spikes video/array") timestep = (completed_frames + n_frame) / fps timestep = round(timestep * 1e6) spikes_frame = spikes_arr[n_frame].astype(np.uint8) if (np.unique(spikes_frame) > 1).any(): spikes_frame = np.round(spikes_frame / 255).astype(np.uint8) if resize_size is not None: spikes_frame = cv2.resize( spikes_frame, resize_size, interpolation=cv2.INTER_NEAREST ).astype(np.uint8) if crop_size is not None: frame_size = spikes_frame.shape assert ( frame_size[0] > crop_size[0] and frame_size[1] > crop_size[1] ), "Crop size must be smaller than the frame size" start_x = frame_size[1] // 2 - crop_size[1] // 2 start_y = frame_size[0] // 2 - crop_size[0] // 2 spikes_frame = spikes_frame[ start_y : start_y + crop_size[0], start_x : start_x + crop_size[1] ] spike_coords = np.argwhere(spikes_frame == 1)[:, ::-1] total_events_injected += len(spike_coords) if verbose is True: print( f"Injecting event spikes found at {len(spike_coords)} locations in" " the frame" ) insert_ids = {key: None for key in timestamps.keys()} for timestamp_key in timestamps.keys(): insert_ids[timestamp_key] = np.searchsorted( timestamps[timestamp_key], timestep ) for timestamp_key, xy_key, label_key, polarity_key, in zip( timestamps.keys(), xy_coords.keys(), labels.keys(), polarities.keys() ): insertion_timesteps = [timestep for _ in range(len(spike_coords))] timestamps[timestamp_key] = np.insert( timestamps[timestamp_key], insert_ids[timestamp_key], insertion_timesteps, ) xy_coords[xy_key] = np.insert( xy_coords[xy_key], insert_ids[timestamp_key], spike_coords, axis=0 ) insertion_labels = [label for _ in range(len(spike_coords))] labels[label_key] = np.insert( labels[label_key], insert_ids[timestamp_key], insertion_labels ) insertion_polarities = [polarity for _ in range(len(spike_coords))] polarities[polarity_key] = np.insert( polarities[polarity_key], insert_ids[timestamp_key], insertion_polarities, ) print(f"\nInjected {total_events_injected} events into the event data\n") if verbose is True: print(f"Saving event data with specified event spikes injected to {save_path}\n") with h5py.File(save_path, "w") as f: for timestamp_key, xy_key, label_key, polarity_key in zip( timestamps.keys(), xy_coords.keys(), labels.keys(), polarities.keys() ): f.create_dataset(timestamp_key, data=timestamps[timestamp_key]) f.create_dataset(xy_key, data=xy_coords[xy_key]) f.create_dataset(label_key, data=labels[label_key]) f.create_dataset(polarity_key, data=polarities[polarity_key])