数据集 ======================= 数据文件存储在 ``spikezoo/data/`` 路径下,组织形式如下: .. code-block:: console spikezoo ├── data | ├── base | ├── reds_base | └── u_caltech ├── archs ... └── utils .. _dataset_prepare: 数据来源 ---------------- 1. ``BASE`` 数据集 ^^^^^^^^^^^^^ - **介绍:** 从 ``REDS_BASE`` 数据集中选取部分数据构建的小数据集。 - **地址:** 已内置于 Spike-Zoo 仓库当中。 .. code-block:: console base ├── test │ ├── gt │ └── spike └── train ├── gt └── spike 调用方式: .. code-block:: python from spikezoo.datasets.base_dataset import BaseDataset,BaseDatasetConfig 2. ``REDS_BASE`` 数据集 ^^^^^^^^^^^^^ - **介绍:** 由 `Spk2ImgNet `_ 基于REDS数据集仿真生成的脉冲-清晰图成对数据集。 - **地址:** 训练集: `下载地址 `_ | 测试集: `下载地址 `_ .. code-block:: console reds_base ├── test │ ├── gt │ └── spike └── train ├── gt └── spike 调用方式: .. code-block:: python from spikezoo.datasets.reds_base_dataset import REDS_BASE, REDS_BASEConfig 3. ``RealData`` 数据集 ^^^^^^^^^^^^^ - **介绍:** 真实拍摄脉数据集接口,可以包含 ``recVidarReal2019``, ``momVidarReal2021`` 以及自己拍摄的无清晰图对的真实数据集。 .. code-block:: console realdata ├── xxx.dat └── sss.dat 调用方式: .. code-block:: python from spikezoo.datasets.realdata_dataset import RealData, RealDataConfig 4. ``UHSR`` 数据集 ^^^^^^^^^^^^^ - **介绍:** 真实拍摄脉数据集 `UHSR `_ ,包括 ``U-CALTECH`` 和 ``U-CIFAR`` 两部分。(使用该数据集旨在探索不同模型在暗光场景下的泛化性) - **地址:** `百度网盘地址 `_ .. code-block:: console u_caltech ├── test └── train 调用方式: .. code-block:: python from spikezoo.datasets.uhsr_dataset import UHSR, UHSRConfig 5. ``SZData`` 数据集 ^^^^^^^^^^^^^ - **介绍:** 基于Spike-Zoo仿真管线构建的数据集 - **地址:** `参考地址 `_ .. code-block:: console szdata ├── test │ ├── sharp_data │ └── spike_data └── train ├── sharp_data └── spike_data 调用方式: .. code-block:: python from spikezoo.datasets.szdata_dataset import SZData, SZDataConfig .. _datasetparam: 参数配置 ---------------- 以 ``BASE`` 数据集为例,配置类代码定义如下: .. code-block:: python @dataclass class BaseDatasetConfig: # ------------- Not Recommended to Change ------------- "Dataset name." dataset_name: str = "base" "Directory specifying location of data." root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/base") "Image width." width: int = 400 "Image height." height: int = 250 "Spike paried with the image or not." with_img: bool = True "Dataset spike length for the train data." spike_length_train: int = -1 "Dataset spike length for the test data." spike_length_test: int = -1 "Dir name for the spike." spike_dir_name: str = "spike" "Dir name for the image." img_dir_name: str = "gt" "Rate. (-1 denotes variant)" rate: float = 0.6 # ------------- Config ------------- "Use the data augumentation technique or not." use_aug: bool = False "Use cache mechanism." use_cache: bool = False "Crop size." crop_size: tuple = (-1, -1) "Load the dataset from local or spikezoo lib." dataset_cls_local: Optional[Dataset] = None "Spike load version. [python,cpp]" spike_load_version: Literal["python", "cpp"] = "python" 参数解释如下: - ``dataset_name`` : 数据集的名称,如 ``"base"``, ``"reds_base"`` 和 ``"uhsr"``。 - ``root_dir`` : 数据集的根路径。 - ``width`` : 输入脉冲的宽度。 - ``height`` : 输入脉冲的高度。 - ``with_img`` : 输入数据是否包含 GT 清晰图,真实数据集一般设置为 False。 - ``spike_length_train`` : 训练集中输入脉冲的长度,在 BASE 数据集中为 41。(如果设置为 -1,则表示对输出的脉冲不做任何裁剪,可能会导致显存占用较高。) - ``spike_length_test`` : 测试集中输入脉冲的长度,在 BASE 数据集中为 301。 - ``spike_dir_name`` : 用于存储脉冲数据文件夹的名字,在 BASE 数据集中为 ``spike``。 - ``img_dir_name`` : 用于存储清晰图数据文件夹的名字,在 BASE 数据集中为 ``gt``。 - ``rate`` : 表示脉冲转化系数,在 REDS_BASE 数据集中默认设置为 0.6。 - ``use_aug`` : 表示是否使用数据增强技术。 - ``use_cache`` : 表示是否使用数据缓存技术。在数据 I/O 较大且 GPU 利用率较低时开启可以加速训练,但可能会增加 RAM 占用。 - ``crop_size`` : 训练时如果使用数据增强技术,裁剪的尺寸大小,默认值为 (-1, -1) 表示不裁剪。 - ``dataset_cls_local`` : 调用本地设计的数据集类,默认值为 ``None`` (即导入 ``spikezoo`` 仓库的数据类)。 - ``spike_load_version`` : 脉冲加载时使用 ``python`` 接口还是 ``cpp`` 接口,默认 ``python`` 接口。 数据加载类 ---------------- 以 ``BASE`` 数据集为例,数据类代码定义如下: .. code-block:: python class BaseDataset(Dataset): # 初始化数据集实例 def __init__(self, cfg: BaseDatasetConfig): # 获取数据集样本总数 def __len__(self): # 获取指定索引的样本(统一接口返回字典) def __getitem__(self, idx: int): # 链接数据集的源数据 def build_source(self, split: Literal["train", "test"] = "test"): # 数据路径预处理 def prepare_data(self): # 脉冲文件检索方法 def get_spike_files(self, path: Path): # 脉冲加载逻辑(支持.dat/.npz格式) def load_spike(self, idx): # 脉冲获取统一接口 def get_spike(self, idx): # 图像文件检索方法 def get_image_files(self, path: Path): # 图像读取接口 def get_img(self, idx): # 数据缓存机制实现 def cache_data(self): .. _param_rate: 实例化 ---------------- .. code-block:: python from spikezoo.datasets.base_dataset import BaseDataset,BaseDatasetConfig cfg = BaseDatasetConfig() dataset = BaseDataset(cfg) dataset.build_source(split = "test") batch = dataset[0] for key,val in batch.items(): print(key,val) 输出样本为字典格式,包含以下键值: - ``spike`` : 脉冲张量(形状 [T,H,W]) - ``gt_img`` : 清晰图像张量(形状 [3,H,W]) - ``rate`` : 脉冲转化系数标量 .. note:: ``rate`` 参数的作用是对重构图像进行亮度矫正 ``img = img / rate``,以消除脉冲重构图和真实清晰图在幅值上的差异(仿真过程中存在光电转化系数,导致脉冲重构图和真实清晰图的像素亮度呈比例关系)。 .. code-block:: python # data spike = batch["spike"] img = batch["gt_img"] rate = batch["rate"] # process tfp = spike.mean(dim = 0,keepdim = False) print(f"重构图像的均值为{tfp.mean()}") tfp_correct = tfp / rate print(f"重构图像矫正后的均值为{tfp_correct.mean()}") print(f"清晰图像的均值为{img.mean()}") # 重构图像的均值为0.28766903281211853 # 重构图像矫正后的均值为0.4794484078884125 # 清晰图像的均值为0.48153188824653625 在利用 ``rate`` 参数矫正后,重构图像的均值和给定清晰图像的均值近似相等。 自定义数据集开发 ---------------- 除了上述提供的数据集形式,Spike-Zoo还支持使用自定义数据集,这里以标准仿真管线生成的数据集为例,说明如何扩展基础数据集类的使用: **目录结构:** .. code-block:: console your_data_path ├── test │ ├── sharp_data │ └── spike_data └── train ├── sharp_data └── spike_data **实现步骤:** *方式一、将数据集加入到Spike-Zoo仓库中,采样标准调用方式* 1. 创建数据集文件 ``spikezoo/datasets/yourdataset_dataset.py`` 并将数据按上述结构存储在 ``spikezoo/data/your_data_path`` 路径下 2. 继承基类并分别实现 ``YourDatasetConfig`` 和 ``YourDataset``: .. code-block:: python from torch.utils.data import Dataset from pathlib import Path from dataclasses import dataclass from typing import Literal, Union from typing import Optional from spikezoo.datasets.base_dataset import BaseDatasetConfig,BaseDataset @dataclass class YourDatasetConfig(BaseDatasetConfig): dataset_name: str = "yourdataset" root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/your_data_path") width: int = 400 height: int = 250 with_img: bool = True spike_length_train: int = -1 spike_length_test: int = -1 spike_dir_name: str = "spike_data" img_dir_name: str = "sharp_data" rate: float = 1 class YourDataset(BaseDataset): def __init__(self, cfg: BaseDatasetConfig): super(YourDataset, self).__init__(cfg) 3. 本地文件 ``test.py`` 调用自定义数据集 .. code-block:: python from spikezoo.datasets.yourdataset_dataset import YourDataset,YourDatasetConfig cfg = YourDatasetConfig() dataset = YourDataset(cfg) dataset.build_source(split = "test") batch = dataset[0] for key,val in batch.items(): print(key,val) *方式二、本地直接继承数据集基类* 1. 创建本地运行文件 ``test.py``,数据存储在本地路径 ``your_data_path/`` 下 2. 同方式一步骤2实现 ``YourDatasetConfig`` 和 ``YourDataset``, 其中 ``root_dir`` 替换为 ``Path("your_data_path")`` 3. 直接调用自定义数据集: .. code-block:: python cfg = YourDatasetConfig(dataset_cls_local=YourDataset) dataset = YourDataset(cfg) dataset.build_source(split = "test") batch = dataset[0] for key,val in batch.items(): print(key,val) 具体例子见:https://github.com/chenkang455/Spike-Zoo/blob/main/examples/test/test_yourdataset.py