快速开始 ======================= 建议用户在运行测试案例前,先熟悉Spike-Zoo的核心架构组成: - ``Dataset:`` 统一数据接口规范,提供脉冲数据与清晰图像的标准化访问 - ``Model:`` 封装脉冲重建网络,集成输入处理、算法核心与输出后处理流程 - ``Pipeline:`` 整合数据与模型,实现指标计算、图像存储及训练管理等全流程功能 本框架设计参考 `NeRFStudio `_ 架构, 通过 ``Pipeline`` 统一调度 ``Model`` 和 ``Dataset`` 实现端到端功能。 .. raw:: html
.. image:: imgs/pipeline.png :width: 450px .. raw:: html
代码组织结构如下: .. code-block:: bash spikezoo ├── archs # 网络架构实现 ├── models # 模型封装(输入输出处理) ├── data # 原始数据存储 ├── datasets # 数据集接口封装 ├── pipeline # 流程管理系统 ├── metrics # 评估指标计算 └── utils # 工具函数集合 框架采用 **配置驱动** 的设计模式,通过配置类 ``MyClassConfig`` 集中管理参数, 结合 ``@dataclass`` 自动生成构造函数,将配置注入目标类 ``MyClass`` 完成实例化。 .. code-block:: python from dataclasses import dataclass # 配置定义 @dataclass class MyClassConfig: name: str = "myclass" # 类实现 class MyClass: def __init__(self,cfg:MyClassConfig): self.cfg = cfg # 实例化过程 config = MyClassConfig() cls = MyClass(config) 推理流程 ---------------- 管线构建 ^^^^^^^^^^^ 支持 **单模型推理** 与 **多模型对比** 两种模式,后者可执行多个模型的推理并进行结果对比。 **单模型管线构建:** .. code-block:: python from spikezoo.pipeline.base_pipeline import Pipeline, PipelineConfig import spikezoo as sz pipeline = Pipeline( cfg=PipelineConfig(save_folder="results",version="v023"), model_cfg=sz.METHOD.BASE, dataset_cfg=sz.DATASET.BASE ) 构建参数解析: - ``cfg`` : 管线参数配置,包括存储路径和模型版本参数。 - ``model_cfg`` : 模型参数配置 - ``dataset_cfg`` : 数据集参数配置 **多模型管线构建:** .. code-block:: python import spikezoo as sz from spikezoo.pipeline.ensemble_pipeline import EnsemblePipeline, EnsemblePipelineConfig pipeline = EnsemblePipeline( cfg=EnsemblePipelineConfig(save_folder="results",version="v023"), model_cfg_list=[ sz.METHOD.BASE,sz.METHOD.TFP,sz.METHOD.TFI,sz.METHOD.SPK2IMGNET,sz.METHOD.WGSE, sz.METHOD.SSML,sz.METHOD.BSF,sz.METHOD.STIR,sz.METHOD.SPIKECLIP,sz.METHOD.SSIR], dataset_cfg=sz.DATASET.BASE, ) - ``model_cfg_list`` : 多模型参数配置列表 管线参数说明参考 :ref:`eval_config`,针对模型和数据集的加载,提供直接命名和参数实例化两种方式,具体使用参考 :ref:`eval_initial`。 功能接口 ^^^^^^^^^^^^ 单/多模型管线提供统一功能接口: - **I-单段脉冲重建:** 支持三种输入方式生成重建图像并计算指标 .. code-block:: python # 方式1: 从数据集加载测试样本(默认测试集),结果存储于infer_from_dataset pipeline.infer_from_dataset(idx=0) # 方式2: 从.dat文件加载脉冲,结果存储于infer_from_file pipeline.infer_from_file(file_path='data/data.dat', width=400, height=250,rate = 0.6) # 方式3: 直接传入脉冲张量,结果存储于infer_from_spk spike = sz.load_vidar_dat("data/data.dat", width=400, height=250) pipeline.infer_from_spk(spike,rate = 0.6) .. note:: 函数接口参数解释见 :ref:`api_pipeline`,其中 ``rate`` 参数对重构图像进行了亮度矫正,具体作用参见 :ref:`param_rate` 。 - **II-数据集可视化:** 批量保存数据集所有样本的重建结果 .. code-block:: python # 结果存储于infer_from_dataset pipeline.save_imgs_from_dataset() - **III-量化指标计算:** 执行数据集级别的性能评估 .. code-block:: python # 指标结果写入result.log pipeline.cal_metrics() - **IV-模型参数分析:** 计算模型参数量与计算复杂度 .. code-block:: python # 分析结果写入result.log pipeline.cal_params() .. note:: 参数计算主要包含参数量(Params)、计算量(FLOPs)和延迟(Latency),计算代码如下所示: .. code-block:: python def _cal_prams_model(self, model): """Calculate the parameters for the given model.""" network = model.net model_name = model.cfg.model_name.upper() # params params = sum(p.numel() for p in network.parameters()) # latency spike = torch.zeros((1, 200, 250, 400)).cuda() start_time = time.time() for _ in range(100): model.spk2img(spike) latency = (time.time() - start_time) / 100 # flop # todo thop bug for BSF flops, _ = profile((model), inputs=(spike,)) re_msg = ( "Total params: %.4fM" % (params / 1e6), "FLOPs:" + str(flops / 1e9) + "{}".format("G"), "Latency: {:.6f} seconds".format(latency), ) 关于不同模型的指标和参数计算结果,参见 :ref:`version`。 训练流程 ---------------- 1. 下载 ``REDS_BASE`` 数据集并放置在 ``spikezoo/data/reds_base`` 路径下(或者其他路径,在 ``root_dir`` 参数中设置即可),参考 :ref:`dataset_prepare` 。 2. 构建训练管线代码,基于 ``BASE`` 模型开始训练: .. code-block:: python from spikezoo.pipeline.train_pipeline import TrainPipelineConfig, TrainPipeline from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig from spikezoo.models.base_model import BaseModelConfig pipeline = TrainPipeline( cfg=TrainPipelineConfig(save_folder="results"), dataset_cfg=REDS_BASEConfig(root_dir = "spikezoo/data/reds_base"), model_cfg=BaseModelConfig(), ) pipeline.train() .. note:: 单卡NVIDIA RTX 4090实测:训练耗时约2分钟,PSNR 32.8dB / SSIM 0.92。单卡GTX 1050 Ti: 训练耗时约12分钟,PSNR 30.59dB / SSIM 0.86。完整训练配置参考 :ref:`train_pipe`。模型具体参数配置参考 :ref:`modelparam`, 数据集具体参数配置参考 :ref:`datasetparam`。 模型直接调用 ---------------- 除通过管线调用外,也支持模型独立使用,根据给定输入脉冲输出重构图像: .. code-block:: python import spikezoo as sz from spikezoo.models.base_model import BaseModel, BaseModelConfig # 输入数据加载 spike = sz.load_vidar_dat("data/data.dat", width=400, height=250, out_format="tensor") spike = spike[None].cuda() print(f"输入脉冲尺寸: {spike.shape}") # 网络初始化 net = BaseModel(BaseModelConfig(model_params={"inDim": 41})) net.build_network(mode = "debug") # 推理执行 recon_img = net(spike) print(recon_img.shape, recon_img.max(), recon_img.min()) 更多高级用法详见 :ref:`model_use` 。 函数调用 ---------------- Spike-Zoo 集成了各种针对脉冲相机设计的函数库,可以参考 :ref:`api_utils` 查看相关使用方式。