数据集
数据文件存储在 spikezoo/data/ 路径下,组织形式如下:
spikezoo
├── data
| ├── base
| ├── reds_base
| └── u_caltech
├── archs
...
└── utils
数据来源
1. BASE 数据集
介绍: 从
REDS_BASE数据集中选取部分数据构建的小数据集。地址: 已内置于 Spike-Zoo 仓库当中。
base
├── test
│ ├── gt
│ └── spike
└── train
├── gt
└── spike
调用方式:
from spikezoo.datasets.base_dataset import BaseDataset,BaseDatasetConfig
2. REDS_BASE 数据集
介绍: 由 Spk2ImgNet 基于REDS数据集仿真生成的脉冲-清晰图成对数据集。
reds_base
├── test
│ ├── gt
│ └── spike
└── train
├── gt
└── spike
调用方式:
from spikezoo.datasets.reds_base_dataset import REDS_BASE, REDS_BASEConfig
3. RealData 数据集
介绍: 真实拍摄脉数据集接口,可以包含
recVidarReal2019,momVidarReal2021以及自己拍摄的无清晰图对的真实数据集。
realdata
├── xxx.dat
└── sss.dat
调用方式:
from spikezoo.datasets.realdata_dataset import RealData, RealDataConfig
4. UHSR 数据集
u_caltech
├── test
└── train
调用方式:
from spikezoo.datasets.uhsr_dataset import UHSR, UHSRConfig
5. SZData 数据集
介绍: 基于Spike-Zoo仿真管线构建的数据集
地址: 参考地址
szdata
├── test
│ ├── sharp_data
│ └── spike_data
└── train
├── sharp_data
└── spike_data
调用方式:
from spikezoo.datasets.szdata_dataset import SZData, SZDataConfig
参数配置
以 BASE 数据集为例,配置类代码定义如下:
@dataclass
class BaseDatasetConfig:
# ------------- Not Recommended to Change -------------
"Dataset name."
dataset_name: str = "base"
"Directory specifying location of data."
root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/base")
"Image width."
width: int = 400
"Image height."
height: int = 250
"Spike paried with the image or not."
with_img: bool = True
"Dataset spike length for the train data."
spike_length_train: int = -1
"Dataset spike length for the test data."
spike_length_test: int = -1
"Dir name for the spike."
spike_dir_name: str = "spike"
"Dir name for the image."
img_dir_name: str = "gt"
"Rate. (-1 denotes variant)"
rate: float = 0.6
# ------------- Config -------------
"Use the data augumentation technique or not."
use_aug: bool = False
"Use cache mechanism."
use_cache: bool = False
"Crop size."
crop_size: tuple = (-1, -1)
"Load the dataset from local or spikezoo lib."
dataset_cls_local: Optional[Dataset] = None
"Spike load version. [python,cpp]"
spike_load_version: Literal["python", "cpp"] = "python"
参数解释如下:
dataset_name: 数据集的名称,如"base","reds_base"和"uhsr"。root_dir: 数据集的根路径。width: 输入脉冲的宽度。height: 输入脉冲的高度。with_img: 输入数据是否包含 GT 清晰图,真实数据集一般设置为 False。spike_length_train: 训练集中输入脉冲的长度,在 BASE 数据集中为 41。(如果设置为 -1,则表示对输出的脉冲不做任何裁剪,可能会导致显存占用较高。)spike_length_test: 测试集中输入脉冲的长度,在 BASE 数据集中为 301。spike_dir_name: 用于存储脉冲数据文件夹的名字,在 BASE 数据集中为spike。img_dir_name: 用于存储清晰图数据文件夹的名字,在 BASE 数据集中为gt。rate: 表示脉冲转化系数,在 REDS_BASE 数据集中默认设置为 0.6。use_aug: 表示是否使用数据增强技术。use_cache: 表示是否使用数据缓存技术。在数据 I/O 较大且 GPU 利用率较低时开启可以加速训练,但可能会增加 RAM 占用。crop_size: 训练时如果使用数据增强技术,裁剪的尺寸大小,默认值为 (-1, -1) 表示不裁剪。dataset_cls_local: 调用本地设计的数据集类,默认值为None(即导入spikezoo仓库的数据类)。spike_load_version: 脉冲加载时使用python接口还是cpp接口,默认python接口。
数据加载类
以 BASE 数据集为例,数据类代码定义如下:
class BaseDataset(Dataset):
# 初始化数据集实例
def __init__(self, cfg: BaseDatasetConfig):
# 获取数据集样本总数
def __len__(self):
# 获取指定索引的样本(统一接口返回字典)
def __getitem__(self, idx: int):
# 链接数据集的源数据
def build_source(self, split: Literal["train", "test"] = "test"):
# 数据路径预处理
def prepare_data(self):
# 脉冲文件检索方法
def get_spike_files(self, path: Path):
# 脉冲加载逻辑(支持.dat/.npz格式)
def load_spike(self, idx):
# 脉冲获取统一接口
def get_spike(self, idx):
# 图像文件检索方法
def get_image_files(self, path: Path):
# 图像读取接口
def get_img(self, idx):
# 数据缓存机制实现
def cache_data(self):
实例化
from spikezoo.datasets.base_dataset import BaseDataset,BaseDatasetConfig
cfg = BaseDatasetConfig()
dataset = BaseDataset(cfg)
dataset.build_source(split = "test")
batch = dataset[0]
for key,val in batch.items():
print(key,val)
输出样本为字典格式,包含以下键值:
spike: 脉冲张量(形状 [T,H,W])gt_img: 清晰图像张量(形状 [3,H,W])rate: 脉冲转化系数标量
备注
rate 参数的作用是对重构图像进行亮度矫正 img = img / rate,以消除脉冲重构图和真实清晰图在幅值上的差异(仿真过程中存在光电转化系数,导致脉冲重构图和真实清晰图的像素亮度呈比例关系)。
# data
spike = batch["spike"]
img = batch["gt_img"]
rate = batch["rate"]
# process
tfp = spike.mean(dim = 0,keepdim = False)
print(f"重构图像的均值为{tfp.mean()}")
tfp_correct = tfp / rate
print(f"重构图像矫正后的均值为{tfp_correct.mean()}")
print(f"清晰图像的均值为{img.mean()}")
# 重构图像的均值为0.28766903281211853
# 重构图像矫正后的均值为0.4794484078884125
# 清晰图像的均值为0.48153188824653625
在利用 rate 参数矫正后,重构图像的均值和给定清晰图像的均值近似相等。
自定义数据集开发
除了上述提供的数据集形式,Spike-Zoo还支持使用自定义数据集,这里以标准仿真管线生成的数据集为例,说明如何扩展基础数据集类的使用:
目录结构:
your_data_path
├── test
│ ├── sharp_data
│ └── spike_data
└── train
├── sharp_data
└── spike_data
实现步骤:
方式一、将数据集加入到Spike-Zoo仓库中,采样标准调用方式
创建数据集文件
spikezoo/datasets/yourdataset_dataset.py并将数据按上述结构存储在spikezoo/data/your_data_path路径下继承基类并分别实现
YourDatasetConfig和YourDataset:
from torch.utils.data import Dataset
from pathlib import Path
from dataclasses import dataclass
from typing import Literal, Union
from typing import Optional
from spikezoo.datasets.base_dataset import BaseDatasetConfig,BaseDataset
@dataclass
class YourDatasetConfig(BaseDatasetConfig):
dataset_name: str = "yourdataset"
root_dir: Union[str, Path] = Path(__file__).parent.parent / Path("data/your_data_path")
width: int = 400
height: int = 250
with_img: bool = True
spike_length_train: int = -1
spike_length_test: int = -1
spike_dir_name: str = "spike_data"
img_dir_name: str = "sharp_data"
rate: float = 1
class YourDataset(BaseDataset):
def __init__(self, cfg: BaseDatasetConfig):
super(YourDataset, self).__init__(cfg)
本地文件
test.py调用自定义数据集
from spikezoo.datasets.yourdataset_dataset import YourDataset,YourDatasetConfig
cfg = YourDatasetConfig()
dataset = YourDataset(cfg)
dataset.build_source(split = "test")
batch = dataset[0]
for key,val in batch.items():
print(key,val)
方式二、本地直接继承数据集基类
创建本地运行文件
test.py,数据存储在本地路径your_data_path/下同方式一步骤2实现
YourDatasetConfig和YourDataset, 其中root_dir替换为Path("your_data_path")直接调用自定义数据集:
cfg = YourDatasetConfig(dataset_cls_local=YourDataset)
dataset = YourDataset(cfg)
dataset.build_source(split = "test")
batch = dataset[0]
for key,val in batch.items():
print(key,val)
具体例子见:https://github.com/chenkang455/Spike-Zoo/blob/main/examples/test/test_yourdataset.py