快速开始
建议用户在运行测试案例前,先熟悉Spike-Zoo的核心架构组成:
Dataset:统一数据接口规范,提供脉冲数据与清晰图像的标准化访问Model:封装脉冲重建网络,集成输入处理、算法核心与输出后处理流程Pipeline:整合数据与模型,实现指标计算、图像存储及训练管理等全流程功能
本框架设计参考 NeRFStudio 架构,
通过 Pipeline 统一调度 Model 和 Dataset 实现端到端功能。
代码组织结构如下:
spikezoo
├── archs # 网络架构实现
├── models # 模型封装(输入输出处理)
├── data # 原始数据存储
├── datasets # 数据集接口封装
├── pipeline # 流程管理系统
├── metrics # 评估指标计算
└── utils # 工具函数集合
框架采用 配置驱动 的设计模式,通过配置类 MyClassConfig 集中管理参数,
结合 @dataclass 自动生成构造函数,将配置注入目标类 MyClass 完成实例化。
from dataclasses import dataclass
# 配置定义
@dataclass
class MyClassConfig:
name: str = "myclass"
# 类实现
class MyClass:
def __init__(self,cfg:MyClassConfig):
self.cfg = cfg
# 实例化过程
config = MyClassConfig()
cls = MyClass(config)
推理流程
管线构建
支持 单模型推理 与 多模型对比 两种模式,后者可执行多个模型的推理并进行结果对比。
单模型管线构建:
from spikezoo.pipeline.base_pipeline import Pipeline, PipelineConfig
import spikezoo as sz
pipeline = Pipeline(
cfg=PipelineConfig(save_folder="results",version="v023"),
model_cfg=sz.METHOD.BASE,
dataset_cfg=sz.DATASET.BASE
)
构建参数解析:
cfg: 管线参数配置,包括存储路径和模型版本参数。model_cfg: 模型参数配置dataset_cfg: 数据集参数配置
多模型管线构建:
import spikezoo as sz
from spikezoo.pipeline.ensemble_pipeline import EnsemblePipeline, EnsemblePipelineConfig
pipeline = EnsemblePipeline(
cfg=EnsemblePipelineConfig(save_folder="results",version="v023"),
model_cfg_list=[
sz.METHOD.BASE,sz.METHOD.TFP,sz.METHOD.TFI,sz.METHOD.SPK2IMGNET,sz.METHOD.WGSE,
sz.METHOD.SSML,sz.METHOD.BSF,sz.METHOD.STIR,sz.METHOD.SPIKECLIP,sz.METHOD.SSIR],
dataset_cfg=sz.DATASET.BASE,
)
model_cfg_list: 多模型参数配置列表
功能接口
单/多模型管线提供统一功能接口:
I-单段脉冲重建: 支持三种输入方式生成重建图像并计算指标
# 方式1: 从数据集加载测试样本(默认测试集),结果存储于infer_from_dataset
pipeline.infer_from_dataset(idx=0)
# 方式2: 从.dat文件加载脉冲,结果存储于infer_from_file
pipeline.infer_from_file(file_path='data/data.dat', width=400, height=250,rate = 0.6)
# 方式3: 直接传入脉冲张量,结果存储于infer_from_spk
spike = sz.load_vidar_dat("data/data.dat", width=400, height=250)
pipeline.infer_from_spk(spike,rate = 0.6)
备注
函数接口参数解释见 spikezoo.pipeline,其中 rate 参数对重构图像进行了亮度矫正,具体作用参见 实例化 。
II-数据集可视化: 批量保存数据集所有样本的重建结果
# 结果存储于infer_from_dataset
pipeline.save_imgs_from_dataset()
III-量化指标计算: 执行数据集级别的性能评估
# 指标结果写入result.log
pipeline.cal_metrics()
IV-模型参数分析: 计算模型参数量与计算复杂度
# 分析结果写入result.log
pipeline.cal_params()
备注
参数计算主要包含参数量(Params)、计算量(FLOPs)和延迟(Latency),计算代码如下所示:
def _cal_prams_model(self, model):
"""Calculate the parameters for the given model."""
network = model.net
model_name = model.cfg.model_name.upper()
# params
params = sum(p.numel() for p in network.parameters())
# latency
spike = torch.zeros((1, 200, 250, 400)).cuda()
start_time = time.time()
for _ in range(100):
model.spk2img(spike)
latency = (time.time() - start_time) / 100
# flop # todo thop bug for BSF
flops, _ = profile((model), inputs=(spike,))
re_msg = (
"Total params: %.4fM" % (params / 1e6),
"FLOPs:" + str(flops / 1e9) + "{}".format("G"),
"Latency: {:.6f} seconds".format(latency),
)
关于不同模型的指标和参数计算结果,参见 发行版本介绍。
训练流程
下载
REDS_BASE数据集并放置在spikezoo/data/reds_base路径下(或者其他路径,在root_dir参数中设置即可),参考 数据来源 。构建训练管线代码,基于
BASE模型开始训练:
from spikezoo.pipeline.train_pipeline import TrainPipelineConfig, TrainPipeline
from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig
from spikezoo.models.base_model import BaseModelConfig
pipeline = TrainPipeline(
cfg=TrainPipelineConfig(save_folder="results"),
dataset_cfg=REDS_BASEConfig(root_dir = "spikezoo/data/reds_base"),
model_cfg=BaseModelConfig(),
)
pipeline.train()
模型直接调用
除通过管线调用外,也支持模型独立使用,根据给定输入脉冲输出重构图像:
import spikezoo as sz
from spikezoo.models.base_model import BaseModel, BaseModelConfig
# 输入数据加载
spike = sz.load_vidar_dat("data/data.dat", width=400, height=250, out_format="tensor")
spike = spike[None].cuda()
print(f"输入脉冲尺寸: {spike.shape}")
# 网络初始化
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}))
net.build_network(mode = "debug")
# 推理执行
recon_img = net(spike)
print(recon_img.shape, recon_img.max(), recon_img.min())
更多高级用法详见 实例化 。
函数调用
Spike-Zoo 集成了各种针对脉冲相机设计的函数库,可以参考 spikezoo.utils 查看相关使用方式。