数据集

数据文件存储在 spikezoo/data/ 路径下,组织形式如下:

spikezoo
├── data
|   ├── base
|   ├── reds_base
|   └── u_caltech
├── archs
...
└── utils

数据来源

1. BASE 数据集

  • 介绍:REDS_BASE 数据集中选取部分数据构建的小数据集。

  • 地址: 已内置于 Spike-Zoo 仓库当中。

base
├── test
│   ├── gt
│   └── spike
└── train
    ├── gt
    └── spike

调用方式:

from spikezoo.datasets.base_dataset import BaseDataset,BaseDatasetConfig

2. REDS_BASE 数据集

reds_base
├── test
│   ├── gt
│   └── spike
└── train
    ├── gt
    └── spike

调用方式:

from spikezoo.datasets.reds_base_dataset import REDS_BASE, REDS_BASEConfig

3. RealData 数据集

  • 介绍: 真实拍摄脉数据集接口,可以包含 recVidarReal2019, momVidarReal2021 以及自己拍摄的无清晰图对的真实数据集。

realdata
├── xxx.dat
└── sss.dat

调用方式:

from spikezoo.datasets.realdata_dataset import RealData, RealDataConfig

4. UHSR 数据集

  • 介绍: 真实拍摄脉数据集 UHSR ,包括 U-CALTECHU-CIFAR 两部分。(使用该数据集旨在探索不同模型在暗光场景下的泛化性)

  • 地址: 百度网盘地址

u_caltech
├── test
└── train

调用方式:

from spikezoo.datasets.uhsr_dataset import UHSR, UHSRConfig

5. SZData 数据集

  • 介绍: 基于Spike-Zoo仿真管线构建的数据集

  • 地址: 参考地址

szdata
├── test
│   ├── sharp_data
│   └── spike_data
└── train
    ├── sharp_data
    └── spike_data

调用方式:

from spikezoo.datasets.szdata_dataset import SZData, SZDataConfig

参数配置

BASE 数据集为例,配置类代码定义如下:

@dataclass
class BaseDatasetConfig:
    # ------------- Not Recommended to Change -------------
    "Dataset name."
    dataset_name: str = "base"
    "Directory specifying location of data."
    root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/base")
    "Image width."
    width: int = 400
    "Image height."
    height: int = 250
    "Spike paried with the image or not."
    with_img: bool = True
    "Dataset spike length for the train data."
    spike_length_train: int = -1
    "Dataset spike length for the test data."
    spike_length_test: int = -1
    "Dir name for the spike."
    spike_dir_name: str = "spike"
    "Dir name for the image."
    img_dir_name: str = "gt"
    "Rate. (-1 denotes variant)"
    rate: float = 0.6

    # ------------- Config -------------
    "Use the data augumentation technique or not."
    use_aug: bool = False
    "Use cache mechanism."
    use_cache: bool = False
    "Crop size."
    crop_size: tuple = (-1, -1)
    "Load the dataset from local or spikezoo lib."
    dataset_cls_local: Optional[Dataset] = None
    "Spike load version. [python,cpp]"
    spike_load_version: Literal["python", "cpp"] = "python"

参数解释如下:

  • dataset_name : 数据集的名称,如 "base", "reds_base""uhsr"

  • root_dir : 数据集的根路径。

  • width : 输入脉冲的宽度。

  • height : 输入脉冲的高度。

  • with_img : 输入数据是否包含 GT 清晰图,真实数据集一般设置为 False。

  • spike_length_train : 训练集中输入脉冲的长度,在 BASE 数据集中为 41。(如果设置为 -1,则表示对输出的脉冲不做任何裁剪,可能会导致显存占用较高。)

  • spike_length_test : 测试集中输入脉冲的长度,在 BASE 数据集中为 301。

  • spike_dir_name : 用于存储脉冲数据文件夹的名字,在 BASE 数据集中为 spike

  • img_dir_name : 用于存储清晰图数据文件夹的名字,在 BASE 数据集中为 gt

  • rate : 表示脉冲转化系数,在 REDS_BASE 数据集中默认设置为 0.6。

  • use_aug : 表示是否使用数据增强技术。

  • use_cache : 表示是否使用数据缓存技术。在数据 I/O 较大且 GPU 利用率较低时开启可以加速训练,但可能会增加 RAM 占用。

  • crop_size : 训练时如果使用数据增强技术,裁剪的尺寸大小,默认值为 (-1, -1) 表示不裁剪。

  • dataset_cls_local : 调用本地设计的数据集类,默认值为 None (即导入 spikezoo 仓库的数据类)。

  • spike_load_version : 脉冲加载时使用 python 接口还是 cpp 接口,默认 python 接口。

数据加载类

BASE 数据集为例,数据类代码定义如下:

class BaseDataset(Dataset):
    # 初始化数据集实例
    def __init__(self, cfg: BaseDatasetConfig):
    # 获取数据集样本总数
    def __len__(self):
    # 获取指定索引的样本(统一接口返回字典)
    def __getitem__(self, idx: int):
    # 链接数据集的源数据
    def build_source(self, split: Literal["train", "test"] = "test"):
    # 数据路径预处理
    def prepare_data(self):
    # 脉冲文件检索方法
    def get_spike_files(self, path: Path):
    # 脉冲加载逻辑(支持.dat/.npz格式)
    def load_spike(self, idx):
    # 脉冲获取统一接口
    def get_spike(self, idx):
    # 图像文件检索方法
    def get_image_files(self, path: Path):
    # 图像读取接口
    def get_img(self, idx):
    # 数据缓存机制实现
    def cache_data(self):

实例化

from spikezoo.datasets.base_dataset import BaseDataset,BaseDatasetConfig
cfg = BaseDatasetConfig()
dataset = BaseDataset(cfg)
dataset.build_source(split = "test")
batch = dataset[0]
for key,val in batch.items():
    print(key,val)

输出样本为字典格式,包含以下键值:

  • spike : 脉冲张量(形状 [T,H,W])

  • gt_img : 清晰图像张量(形状 [3,H,W])

  • rate : 脉冲转化系数标量

备注

rate 参数的作用是对重构图像进行亮度矫正 img = img / rate,以消除脉冲重构图和真实清晰图在幅值上的差异(仿真过程中存在光电转化系数,导致脉冲重构图和真实清晰图的像素亮度呈比例关系)。

# data
spike = batch["spike"]
img = batch["gt_img"]
rate = batch["rate"]
# process
tfp = spike.mean(dim = 0,keepdim = False)
print(f"重构图像的均值为{tfp.mean()}")
tfp_correct = tfp / rate
print(f"重构图像矫正后的均值为{tfp_correct.mean()}")
print(f"清晰图像的均值为{img.mean()}")
# 重构图像的均值为0.28766903281211853
# 重构图像矫正后的均值为0.4794484078884125
# 清晰图像的均值为0.48153188824653625

在利用 rate 参数矫正后,重构图像的均值和给定清晰图像的均值近似相等。

自定义数据集开发

除了上述提供的数据集形式,Spike-Zoo还支持使用自定义数据集,这里以标准仿真管线生成的数据集为例,说明如何扩展基础数据集类的使用:

目录结构:

your_data_path
├── test
│   ├── sharp_data
│   └── spike_data
└── train
    ├── sharp_data
    └── spike_data

实现步骤:

方式一、将数据集加入到Spike-Zoo仓库中,采样标准调用方式

  1. 创建数据集文件 spikezoo/datasets/yourdataset_dataset.py 并将数据按上述结构存储在 spikezoo/data/your_data_path 路径下

  2. 继承基类并分别实现 YourDatasetConfigYourDataset:

from torch.utils.data import Dataset
from pathlib import Path
from dataclasses import dataclass
from typing import Literal, Union
from typing import Optional
from spikezoo.datasets.base_dataset import BaseDatasetConfig,BaseDataset

@dataclass
class YourDatasetConfig(BaseDatasetConfig):
    dataset_name: str = "yourdataset"
    root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/your_data_path")
    width: int = 400
    height: int = 250
    with_img: bool = True
    spike_length_train: int = -1
    spike_length_test: int = -1
    spike_dir_name: str = "spike_data"
    img_dir_name: str = "sharp_data"
    rate: float = 1

class YourDataset(BaseDataset):
    def __init__(self, cfg: BaseDatasetConfig):
        super(YourDataset, self).__init__(cfg)
  1. 本地文件 test.py 调用自定义数据集

from spikezoo.datasets.yourdataset_dataset import YourDataset,YourDatasetConfig
cfg = YourDatasetConfig()
dataset = YourDataset(cfg)
dataset.build_source(split = "test")
batch = dataset[0]
for key,val in batch.items():
    print(key,val)

方式二、本地直接继承数据集基类

  1. 创建本地运行文件 test.py,数据存储在本地路径 your_data_path/

  2. 同方式一步骤2实现 YourDatasetConfigYourDataset, 其中 root_dir 替换为 Path("your_data_path")

  3. 直接调用自定义数据集:

cfg = YourDatasetConfig(dataset_cls_local=YourDataset)
dataset = YourDataset(cfg)
dataset.build_source(split = "test")
batch = dataset[0]
for key,val in batch.items():
    print(key,val)

具体例子见:https://github.com/chenkang455/Spike-Zoo/blob/main/examples/test/test_yourdataset.py