模型

Spike-Zoo的模型组件 model 对网络架构 arch 进行封装,主要负责模型输入脉冲的预处理、输出图像的后处理以及训练相关的损失函数计算等功能。

模型介绍

有待完善........

参数配置

以 BASE 模型为例,配置类定义如下:

@dataclass
class BaseModelConfig:
    # ------------- Not Recommended to Change -------------
    "Registerd model name."
    model_name: str = "base"
    "File name of the specified model."
    model_file_name: str = "nets"
    "Class name of the specified model in spikezoo/archs/base/{model_file_name}.py."
    model_cls_name: str = "BaseNet"
    "Spike input length. (local mode)"
    model_length: int = 41
    "Spike input length for different versions."
    model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
    "Model require model parameters or not."
    require_params: bool = True
    "Model parameters. (local mode)"
    model_params: dict = field(default_factory=lambda: {})
    "Model parameters for different versions."
    model_params_dict: dict = field(default_factory=lambda: {"v010": {}, "v023": {}})
    # ------------- Config -------------
    "Load ckpt path. Used on the local mode."
    ckpt_path: str = ""
    "Load pretrained weights or not. (default false, set to true during the evaluation mode.)"
    load_state: bool = False
    "Multi-GPU setting."
    multi_gpu: bool = False
    "Base url."
    base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download"
    "Load the model from local class or spikezoo lib. (None)"
    model_cls_local: Optional[nn.Module] = None
    "Load the arch from local class or spikezoo lib. (None)"
    arch_cls_local: Optional[nn.Module] = None

参数说明如下:

  • model_name : 注册的模型名称,如 "base", "spk2imgnet""spikeclip"

  • model_file_name : 模型架构定义文件名称,如 spikezoo/archs/base/nets.py

  • model_cls_name : 模型类名,对应 spikezoo/archs/base/nets.py 文件中的 BaseNet

  • model_length : 输入脉冲的标准长度,用于将输入裁剪为指定尺寸(本地模式)

  • model_length_dict : 不同发行版本对应的输入脉冲长度

  • require_params : 是否需要进行模型参数学习

  • model_params : 模型初始化参数(本地模式),用于实例化 archs 中的网络

  • model_params_dict : 不同发行版本对应的模型参数配置

  • ckpt_path : 预训练权重加载路径(仅限本地模式使用)

  • load_state : 是否加载预训练权重(默认关闭,评估模式需设为开启)

  • multi_gpu : 是否启用多GPU训练模式

  • base_url : 各版本预训练权重的云端存储地址

  • model_cls_local : 调用本地设计的模型类,默认值为 None``(即导入 ``spikezoo 仓库的模型类)

  • arch_cls_local : 调用本地设计的网络架构,默认值为 None``(即导入 ``spikezoo 仓库的网络架构)

模型构建类

以 BASE 模型为例,模型构建类定义如下:

class BaseModel(nn.Module):
    # 初始化模型实例
    def __init__(self, cfg: BaseModelConfig):
    # 前向推理接口:输入脉冲,输出重建图像
    def forward(self, spike):
    # 核心转换方法:将单个脉冲转换为图像(被训练和推理接口调用)
    def spk2img(self, spike):
    # 网络构建方法:加载模型架构并选择是否加载权重
    def build_network(
        self,
        mode: Literal["debug", "train", "eval"] = "debug",
        version: Literal["local", "v010", "v023"] = "local",
    ):
    # 网络权重保存
    def save_network(self, save_path):
    # 输入脉冲长度裁剪
    def crop_spike_length(self, spike):
    # 输入脉冲预处理(尺寸调整、脉冲表征转换等)
    def preprocess_spike(self, spike):
    # 输出图像后处理(尺寸还原、亮度校正等)
    def postprocess_img(self, image):
    # 获取训练输出字典(训练时可能包含多组输出)
    def get_outputs_dict(self, batch):
    # 获取需要保存的可视化图像字典
    def get_visual_dict(self, batch, outputs):
    # 根据输出结果和输入数据计算损失值
    def get_loss_dict(self, outputs, batch, loss_weight_dict):
    # 损失函数定义方法
    def get_loss_func(self, name: Literal["l1", "l2"]):
    # 获取用于计算图像指标的真值-重建图像对
    def get_paired_imgs(self, batch, outputs):
    # 将输入数据载入计算设备
    def feed_to_device(self, batch):

注意事项:

模型的核心作用是实现单段输入脉冲到重构图像的映射关系,但脉冲重构在训练和推理阶段接口不同:

  • 训练接口 get_outputs_dict: 部分训练方法如 ssmlstir 存在多个输出构成损失函数,故该接口会输出一个字典并在 get_loss_dict 接口中计算损失函数

  • 推理接口 forward: 直接调用 spk2img 函数实现脉冲输入到重构图像的映射

  • 各函数里使用 batchoutputs 时注意和数据集给定的字典 key 对齐

实例化

模型除在 pipeline 中与 dataset 结合使用外,也提供单独调用方式:

import spikezoo as sz
from spikezoo.models.base_model import BaseModel, BaseModelConfig
# 输入数据加载
spike = sz.load_vidar_dat("data/data.dat", width=400, height=250, out_format="tensor")
spike = spike[None].cuda()
print(f"Input spike shape: {spike.shape}")
# 网络初始化
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}))
net.build_network(mode = "debug")
# 推理过程
recon_img = net(spike)
print(recon_img.shape,recon_img.max(),recon_img.min())

build_network 的典型用法:

# 1. 调试模式构建网络,是否加载权重由配置决定
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}, load_state=False))
net.build_network(mode="debug")
# 2. 训练模式构建网络,是否加载权重由配置决定
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}, load_state=False))
net.build_network(mode="train")
# 3. 评估模式构建网络,自动加载本地配置指定的预训练权重
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}, load_state=True,ckpt_path="spikezoo/models/weights/v023/base.pth"))
net.build_network(mode="eval", version="local")
# 4. 评估模式构建网络,自动加载发行版预训练权重
net = BaseModel(BaseModelConfig())
net.build_network(mode="eval", version="v023")

模式说明:

  • debug : 调试模式,验证脉冲到图像的转换流程

  • eval : 评估模式,支持从本地路径 ckpt_path 或发行版(如 v023)加载权重

  • train : 训练模式,默认不加载权重,可通过 load_state 参数控制权重加载

其他模型的使用方式:

from spikezoo.models.tfp_model import TFPModel,TFPConfig
from spikezoo.models.tfi_model import TFIModel,TFIConfig
from spikezoo.models.spk2imgnet_model import Spk2ImgNet,Spk2ImgNetConfig
from spikezoo.models.wgse_model import WGSE,WGSEConfig
from spikezoo.models.ssml_model import SSML,SSMLConfig
from spikezoo.models.bsf_model import BSF,BSFConfig
from spikezoo.models.stir_model import STIR,STIRConfig
from spikezoo.models.ssir_model import SSIR,SSIRConfig
from spikezoo.models.spikeclip_model import SpikeCLIP,SpikeCLIPConfig

自定义模型开发

除了给定模型以外,Spike-Zoo还支持使用自定义设计模型:

实现步骤:

方式一、将模型加入到Spike-Zoo仓库中,采样标准调用方式

  1. 创建模型文件 spikezoo/models/yourmodel_model.py

  2. 继承基类并分别实现 YourModelConfigYourModel:

from torch.utils.data import Dataset
from pathlib import Path
from dataclasses import dataclass
from typing import Literal, Union
from typing import Optional
from spikezoo.models.base_model import BaseModel, BaseModelConfig
from dataclasses import field
import torch.nn as nn

@dataclass
class YourModelConfig(BaseModelConfig):
    model_name: str = "yourmodel"  # 需与文件名保持一致
    model_file_name: str = "arch.net"  # archs路径下的模块路径
    model_cls_name: str = "YourNet"  # 模型类名
    model_length: int = 41
    require_params: bool = True
    model_params: dict = field(default_factory=lambda: {"inDim": 41})

class YourModel(BaseModel):
    def __init__(self, cfg: BaseModelConfig):
        super(YourModel, self).__init__(cfg)
  1. 创建架构文件 spikezoo/archs/yourmodel/arch/net.py, 其中 yourmodel 对应 model_namearch/net.py 对应 model_file_nameYourNet 是架构的类名称

  2. 架构文件代码如下:

import torch.nn as nn

def conv_layer(inDim, outDim, ks, s, p, norm_layer="none"):
    ## convolutional layer
    conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
    relu = nn.ReLU(True)
    assert norm_layer in ("batch", "instance", "none")
    if norm_layer == "none":
        seq = nn.Sequential(*[conv, relu])
    else:
        if norm_layer == "instance":
            norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False)  # instance norm
        else:
            momentum = 0.1
            norm = nn.BatchNorm2d(outDim, momentum=momentum, affine=True, track_running_stats=True)
        seq = nn.Sequential(*[conv, norm, relu])
    return seq


class YourNet(nn.Module):
    """Borrow the structure from the SpikeCLIP. (https://arxiv.org/abs/2501.04477)"""

    def __init__(self, inDim=41):
        super(YourNet, self).__init__()
        norm = "none"
        outDim = 1
        convBlock1 = conv_layer(inDim, 64, 3, 1, 1)
        convBlock2 = conv_layer(64, 128, 3, 1, 1, norm)
        convBlock3 = conv_layer(128, 64, 3, 1, 1, norm)
        convBlock4 = conv_layer(64, 16, 3, 1, 1, norm)
        conv = nn.Conv2d(16, outDim, 3, 1, 1)
        self.seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])

    def forward(self, x):
        return self.seq(x)
  1. 本地文件 test.py 调用自定义模型

from spikezoo.models.yourmodel_model import YourModel, YourModelConfig
net = YourModel(YourModelConfig())
net.build_network(mode="debug")

方式二、本地直接继承模型基类

  1. 创建本地运行文件 test.py

  2. 同方式一步骤2和4实现 YourModelConfig, YourModelYourNet, 其中 model_file_name, model_cls_name 参数可以忽略设置

  3. 直接调用自定义模型:

net = YourModel(
    YourModelConfig(
        model_cls_local=YourModel,
        arch_cls_local=YourNet,
        load_state=True,
        ckpt_path="spikezoo/models/weights/v023/base.pth",
    )
)
net.build_network(mode="eval")

具体例子见:https://github.com/chenkang455/Spike-Zoo/blob/main/examples/test/test_yourmodel.py

示例、封装已有模型

  1. 将STIR官方代码克隆至 spikezoo/archs/stir 目录

  2. 定位模型定义文件 spikezoo/archs/stir/models/networks_STIR.py,其中类名为 STIR

  3. spikezoo/models 目录下创建 stir_model.py,配置参数如下:

@dataclass
class STIRConfig(BaseModelConfig):
    model_name: str = "stir"  # 需与文件名保持一致
    model_file_name: str = "models.networks_STIR"  # archs路径下的模块路径
    model_cls_name: str = "STIR"  # 模型类名
    model_length: int = 61  # 标准输入长度
    require_params: bool = True  # 需要参数初始化
    model_params: dict = field(default_factory=lambda: {})  # 使用默认参数
  1. 继承基类实现STIR模型。由于涉及多次输入脉冲下采样处理,需重写脉冲预处理和后处理方法:

class STIR(BaseModel):
    def __init__(self, cfg: BaseModelConfig):
        super(STIR, self).__init__(cfg)

    def preprocess_spike(self, spike):
        spike = self.crop_spike_length(spike)
        if self.spike_size == (250, 400):
            spike = torch.cat([spike, spike[:, :, -6:]], dim=2)
        elif self.spike_size == (480, 854):
            spike = torch.cat([spike, spike[:, :, :, -10:]], dim=3)
        return spike

    def postprocess_img(self, image):
        if self.spike_size == (250, 400):
            image = image[:, :, :250, :]
        elif self.spike_size == (480, 854):
            image = image[:, :, :, :854]
        return image

    def get_outputs_dict(self, batch):
        spike = batch["spike"]
        rate = batch["rate"].view(-1, 1, 1, 1).float()
        outputs = {}
        spike = self.preprocess_spike(spike)
        img_pred_0, Fs_lv_0, Fs_lv_1, Fs_lv_2, Fs_lv_3, Fs_lv_4, Est = self.net(spike)
        img_pred_0 = self.postprocess_img(img_pred_0)
        outputs["recon_img"] = img_pred_0 / rate
        return outputs
  1. STIR的多尺度金字塔损失函数可通过重写 get_loss_dict 实现,但因性能提升有限暂未实现