快速开始

建议用户在运行测试案例前,先熟悉Spike-Zoo的核心架构组成:

  • Dataset: 统一数据接口规范,提供脉冲数据与清晰图像的标准化访问

  • Model: 封装脉冲重建网络,集成输入处理、算法核心与输出后处理流程

  • Pipeline: 整合数据与模型,实现指标计算、图像存储及训练管理等全流程功能

本框架设计参考 NeRFStudio 架构, 通过 Pipeline 统一调度 ModelDataset 实现端到端功能。

代码组织结构如下:

spikezoo
├── archs     # 网络架构实现
├── models    # 模型封装(输入输出处理)
├── data      # 原始数据存储
├── datasets  # 数据集接口封装
├── pipeline  # 流程管理系统
├── metrics   # 评估指标计算
└── utils     # 工具函数集合

框架采用 配置驱动 的设计模式,通过配置类 MyClassConfig 集中管理参数, 结合 @dataclass 自动生成构造函数,将配置注入目标类 MyClass 完成实例化。

from dataclasses import dataclass
# 配置定义
@dataclass
class MyClassConfig:
    name: str = "myclass"
# 类实现
class MyClass:
    def __init__(self,cfg:MyClassConfig):
        self.cfg = cfg
# 实例化过程
config = MyClassConfig()
cls = MyClass(config)

推理流程

管线构建

支持 单模型推理多模型对比 两种模式,后者可执行多个模型的推理并进行结果对比。

单模型管线构建:

from spikezoo.pipeline.base_pipeline import Pipeline, PipelineConfig
import spikezoo as sz
pipeline = Pipeline(
    cfg=PipelineConfig(save_folder="results",version="v023"),
    model_cfg=sz.METHOD.BASE,
    dataset_cfg=sz.DATASET.BASE
)

构建参数解析:

  • cfg : 管线参数配置,包括存储路径和模型版本参数。

  • model_cfg : 模型参数配置

  • dataset_cfg : 数据集参数配置

多模型管线构建:

import spikezoo as sz
from spikezoo.pipeline.ensemble_pipeline import EnsemblePipeline, EnsemblePipelineConfig
pipeline = EnsemblePipeline(
    cfg=EnsemblePipelineConfig(save_folder="results",version="v023"),
    model_cfg_list=[
        sz.METHOD.BASE,sz.METHOD.TFP,sz.METHOD.TFI,sz.METHOD.SPK2IMGNET,sz.METHOD.WGSE,
        sz.METHOD.SSML,sz.METHOD.BSF,sz.METHOD.STIR,sz.METHOD.SPIKECLIP,sz.METHOD.SSIR],
    dataset_cfg=sz.DATASET.BASE,
)
  • model_cfg_list : 多模型参数配置列表

管线参数说明参考 参数配置,针对模型和数据集的加载,提供直接命名和参数实例化两种方式,具体使用参考 实例化

功能接口

单/多模型管线提供统一功能接口:

  • I-单段脉冲重建: 支持三种输入方式生成重建图像并计算指标

# 方式1: 从数据集加载测试样本(默认测试集),结果存储于infer_from_dataset
pipeline.infer_from_dataset(idx=0)
# 方式2: 从.dat文件加载脉冲,结果存储于infer_from_file
pipeline.infer_from_file(file_path='data/data.dat', width=400, height=250,rate = 0.6)
# 方式3: 直接传入脉冲张量,结果存储于infer_from_spk
spike = sz.load_vidar_dat("data/data.dat", width=400, height=250)
pipeline.infer_from_spk(spike,rate = 0.6)

备注

函数接口参数解释见 spikezoo.pipeline,其中 rate 参数对重构图像进行了亮度矫正,具体作用参见 实例化

  • II-数据集可视化: 批量保存数据集所有样本的重建结果

# 结果存储于infer_from_dataset
pipeline.save_imgs_from_dataset()
  • III-量化指标计算: 执行数据集级别的性能评估

# 指标结果写入result.log
pipeline.cal_metrics()
  • IV-模型参数分析: 计算模型参数量与计算复杂度

# 分析结果写入result.log
pipeline.cal_params()

备注

参数计算主要包含参数量(Params)、计算量(FLOPs)和延迟(Latency),计算代码如下所示:

def _cal_prams_model(self, model):
    """Calculate the parameters for the given model."""
    network = model.net
    model_name = model.cfg.model_name.upper()
    # params
    params = sum(p.numel() for p in network.parameters())
    # latency
    spike = torch.zeros((1, 200, 250, 400)).cuda()
    start_time = time.time()
    for _ in range(100):
        model.spk2img(spike)
    latency = (time.time() - start_time) / 100
    # flop # todo thop bug for BSF
    flops, _ = profile((model), inputs=(spike,))
    re_msg = (
        "Total params: %.4fM" % (params / 1e6),
        "FLOPs:" + str(flops / 1e9) + "{}".format("G"),
        "Latency: {:.6f} seconds".format(latency),
    )

关于不同模型的指标和参数计算结果,参见 发行版本介绍

训练流程

  1. 下载 REDS_BASE 数据集并放置在 spikezoo/data/reds_base 路径下(或者其他路径,在 root_dir 参数中设置即可),参考 数据来源

  2. 构建训练管线代码,基于 BASE 模型开始训练:

from spikezoo.pipeline.train_pipeline import TrainPipelineConfig, TrainPipeline
from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig
from spikezoo.models.base_model import BaseModelConfig
pipeline = TrainPipeline(
    cfg=TrainPipelineConfig(save_folder="results"),
    dataset_cfg=REDS_BASEConfig(root_dir = "spikezoo/data/reds_base"),
    model_cfg=BaseModelConfig(),
)
pipeline.train()

备注

单卡NVIDIA RTX 4090实测:训练耗时约2分钟,PSNR 32.8dB / SSIM 0.92。单卡GTX 1050 Ti: 训练耗时约12分钟,PSNR 30.59dB / SSIM 0.86。完整训练配置参考 训练管线。模型具体参数配置参考 参数配置, 数据集具体参数配置参考 参数配置

模型直接调用

除通过管线调用外,也支持模型独立使用,根据给定输入脉冲输出重构图像:

import spikezoo as sz
from spikezoo.models.base_model import BaseModel, BaseModelConfig
# 输入数据加载
spike = sz.load_vidar_dat("data/data.dat", width=400, height=250, out_format="tensor")
spike = spike[None].cuda()
print(f"输入脉冲尺寸: {spike.shape}")
# 网络初始化
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}))
net.build_network(mode = "debug")
# 推理执行
recon_img = net(spike)
print(recon_img.shape, recon_img.max(), recon_img.min())

更多高级用法详见 实例化

函数调用

Spike-Zoo 集成了各种针对脉冲相机设计的函数库,可以参考 spikezoo.utils 查看相关使用方式。