模型 ======================= Spike-Zoo的模型组件 ``model`` 对网络架构 ``arch`` 进行封装,主要负责模型输入脉冲的预处理、输出图像的后处理以及训练相关的损失函数计算等功能。 模型介绍 ---------------- 有待完善........ .. _modelparam: 参数配置 ---------------- 以 BASE 模型为例,配置类定义如下: .. code-block:: python @dataclass class BaseModelConfig: # ------------- Not Recommended to Change ------------- "Registerd model name." model_name: str = "base" "File name of the specified model." model_file_name: str = "nets" "Class name of the specified model in spikezoo/archs/base/{model_file_name}.py." model_cls_name: str = "BaseNet" "Spike input length. (local mode)" model_length: int = 41 "Spike input length for different versions." model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41}) "Model require model parameters or not." require_params: bool = True "Model parameters. (local mode)" model_params: dict = field(default_factory=lambda: {}) "Model parameters for different versions." model_params_dict: dict = field(default_factory=lambda: {"v010": {}, "v023": {}}) # ------------- Config ------------- "Load ckpt path. Used on the local mode." ckpt_path: str = "" "Load pretrained weights or not. (default false, set to true during the evaluation mode.)" load_state: bool = False "Multi-GPU setting." multi_gpu: bool = False "Base url." base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download" "Load the model from local class or spikezoo lib. (None)" model_cls_local: Optional[nn.Module] = None "Load the arch from local class or spikezoo lib. (None)" arch_cls_local: Optional[nn.Module] = None 参数说明如下: - ``model_name`` : 注册的模型名称,如 ``"base"``, ``"spk2imgnet"`` 和 ``"spikeclip"`` - ``model_file_name`` : 模型架构定义文件名称,如 :file:`spikezoo/archs/base/nets.py` - ``model_cls_name`` : 模型类名,对应 ``spikezoo/archs/base/nets.py`` 文件中的 ``BaseNet`` - ``model_length`` : 输入脉冲的标准长度,用于将输入裁剪为指定尺寸(本地模式) - ``model_length_dict`` : 不同发行版本对应的输入脉冲长度 - ``require_params`` : 是否需要进行模型参数学习 - ``model_params`` : 模型初始化参数(本地模式),用于实例化 ``archs`` 中的网络 - ``model_params_dict`` : 不同发行版本对应的模型参数配置 - ``ckpt_path`` : 预训练权重加载路径(仅限本地模式使用) - ``load_state`` : 是否加载预训练权重(默认关闭,评估模式需设为开启) - ``multi_gpu`` : 是否启用多GPU训练模式 - ``base_url`` : 各版本预训练权重的云端存储地址 - ``model_cls_local`` : 调用本地设计的模型类,默认值为 ``None``(即导入 ``spikezoo`` 仓库的模型类) - ``arch_cls_local`` : 调用本地设计的网络架构,默认值为 ``None``(即导入 ``spikezoo`` 仓库的网络架构) 模型构建类 ---------------- 以 BASE 模型为例,模型构建类定义如下: .. code-block:: python class BaseModel(nn.Module): # 初始化模型实例 def __init__(self, cfg: BaseModelConfig): # 前向推理接口:输入脉冲,输出重建图像 def forward(self, spike): # 核心转换方法:将单个脉冲转换为图像(被训练和推理接口调用) def spk2img(self, spike): # 网络构建方法:加载模型架构并选择是否加载权重 def build_network( self, mode: Literal["debug", "train", "eval"] = "debug", version: Literal["local", "v010", "v023"] = "local", ): # 网络权重保存 def save_network(self, save_path): # 输入脉冲长度裁剪 def crop_spike_length(self, spike): # 输入脉冲预处理(尺寸调整、脉冲表征转换等) def preprocess_spike(self, spike): # 输出图像后处理(尺寸还原、亮度校正等) def postprocess_img(self, image): # 获取训练输出字典(训练时可能包含多组输出) def get_outputs_dict(self, batch): # 获取需要保存的可视化图像字典 def get_visual_dict(self, batch, outputs): # 根据输出结果和输入数据计算损失值 def get_loss_dict(self, outputs, batch, loss_weight_dict): # 损失函数定义方法 def get_loss_func(self, name: Literal["l1", "l2"]): # 获取用于计算图像指标的真值-重建图像对 def get_paired_imgs(self, batch, outputs): # 将输入数据载入计算设备 def feed_to_device(self, batch): **注意事项:** 模型的核心作用是实现单段输入脉冲到重构图像的映射关系,但脉冲重构在训练和推理阶段接口不同: - 训练接口 ``get_outputs_dict``: 部分训练方法如 ``ssml`` 和 ``stir`` 存在多个输出构成损失函数,故该接口会输出一个字典并在 ``get_loss_dict`` 接口中计算损失函数 - 推理接口 ``forward``: 直接调用 ``spk2img`` 函数实现脉冲输入到重构图像的映射 - 各函数里使用 ``batch`` 和 ``outputs`` 时注意和数据集给定的字典 ``key`` 对齐 .. _model_use: 实例化 ---------------- 模型除在 ``pipeline`` 中与 ``dataset`` 结合使用外,也提供单独调用方式: .. code-block:: python import spikezoo as sz from spikezoo.models.base_model import BaseModel, BaseModelConfig # 输入数据加载 spike = sz.load_vidar_dat("data/data.dat", width=400, height=250, out_format="tensor") spike = spike[None].cuda() print(f"Input spike shape: {spike.shape}") # 网络初始化 net = BaseModel(BaseModelConfig(model_params={"inDim": 41})) net.build_network(mode = "debug") # 推理过程 recon_img = net(spike) print(recon_img.shape,recon_img.max(),recon_img.min()) ``build_network`` 的典型用法: .. code-block:: python # 1. 调试模式构建网络,是否加载权重由配置决定 net = BaseModel(BaseModelConfig(model_params={"inDim": 41}, load_state=False)) net.build_network(mode="debug") # 2. 训练模式构建网络,是否加载权重由配置决定 net = BaseModel(BaseModelConfig(model_params={"inDim": 41}, load_state=False)) net.build_network(mode="train") # 3. 评估模式构建网络,自动加载本地配置指定的预训练权重 net = BaseModel(BaseModelConfig(model_params={"inDim": 41}, load_state=True,ckpt_path="spikezoo/models/weights/v023/base.pth")) net.build_network(mode="eval", version="local") # 4. 评估模式构建网络,自动加载发行版预训练权重 net = BaseModel(BaseModelConfig()) net.build_network(mode="eval", version="v023") 模式说明: * ``debug`` : 调试模式,验证脉冲到图像的转换流程 * ``eval`` : 评估模式,支持从本地路径 ``ckpt_path`` 或发行版(如 ``v023``)加载权重 * ``train`` : 训练模式,默认不加载权重,可通过 ``load_state`` 参数控制权重加载 其他模型的使用方式: .. code-block:: python from spikezoo.models.tfp_model import TFPModel,TFPConfig from spikezoo.models.tfi_model import TFIModel,TFIConfig from spikezoo.models.spk2imgnet_model import Spk2ImgNet,Spk2ImgNetConfig from spikezoo.models.wgse_model import WGSE,WGSEConfig from spikezoo.models.ssml_model import SSML,SSMLConfig from spikezoo.models.bsf_model import BSF,BSFConfig from spikezoo.models.stir_model import STIR,STIRConfig from spikezoo.models.ssir_model import SSIR,SSIRConfig from spikezoo.models.spikeclip_model import SpikeCLIP,SpikeCLIPConfig 自定义模型开发 ---------------- 除了给定模型以外,Spike-Zoo还支持使用自定义设计模型: **实现步骤:** *方式一、将模型加入到Spike-Zoo仓库中,采样标准调用方式* 1. 创建模型文件 ``spikezoo/models/yourmodel_model.py`` 2. 继承基类并分别实现 ``YourModelConfig`` 和 ``YourModel``: .. code-block:: python from torch.utils.data import Dataset from pathlib import Path from dataclasses import dataclass from typing import Literal, Union from typing import Optional from spikezoo.models.base_model import BaseModel, BaseModelConfig from dataclasses import field import torch.nn as nn @dataclass class YourModelConfig(BaseModelConfig): model_name: str = "yourmodel" # 需与文件名保持一致 model_file_name: str = "arch.net" # archs路径下的模块路径 model_cls_name: str = "YourNet" # 模型类名 model_length: int = 41 require_params: bool = True model_params: dict = field(default_factory=lambda: {"inDim": 41}) class YourModel(BaseModel): def __init__(self, cfg: BaseModelConfig): super(YourModel, self).__init__(cfg) 3. 创建架构文件 ``spikezoo/archs/yourmodel/arch/net.py``, 其中 ``yourmodel`` 对应 ``model_name``,``arch/net.py`` 对应 ``model_file_name``,``YourNet`` 是架构的类名称 4. 架构文件代码如下: .. code-block:: python import torch.nn as nn def conv_layer(inDim, outDim, ks, s, p, norm_layer="none"): ## convolutional layer conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p) relu = nn.ReLU(True) assert norm_layer in ("batch", "instance", "none") if norm_layer == "none": seq = nn.Sequential(*[conv, relu]) else: if norm_layer == "instance": norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm else: momentum = 0.1 norm = nn.BatchNorm2d(outDim, momentum=momentum, affine=True, track_running_stats=True) seq = nn.Sequential(*[conv, norm, relu]) return seq class YourNet(nn.Module): """Borrow the structure from the SpikeCLIP. (https://arxiv.org/abs/2501.04477)""" def __init__(self, inDim=41): super(YourNet, self).__init__() norm = "none" outDim = 1 convBlock1 = conv_layer(inDim, 64, 3, 1, 1) convBlock2 = conv_layer(64, 128, 3, 1, 1, norm) convBlock3 = conv_layer(128, 64, 3, 1, 1, norm) convBlock4 = conv_layer(64, 16, 3, 1, 1, norm) conv = nn.Conv2d(16, outDim, 3, 1, 1) self.seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv]) def forward(self, x): return self.seq(x) 5. 本地文件 ``test.py`` 调用自定义模型 .. code-block:: python from spikezoo.models.yourmodel_model import YourModel, YourModelConfig net = YourModel(YourModelConfig()) net.build_network(mode="debug") *方式二、本地直接继承模型基类* 1. 创建本地运行文件 ``test.py`` 2. 同方式一步骤2和4实现 ``YourModelConfig``, ``YourModel`` 和 ``YourNet``, 其中 ``model_file_name``, ``model_cls_name`` 参数可以忽略设置 3. 直接调用自定义模型: .. code-block:: python net = YourModel( YourModelConfig( model_cls_local=YourModel, arch_cls_local=YourNet, load_state=True, ckpt_path="spikezoo/models/weights/v023/base.pth", ) ) net.build_network(mode="eval") 具体例子见:https://github.com/chenkang455/Spike-Zoo/blob/main/examples/test/test_yourmodel.py *示例、封装已有模型* 1. 将STIR官方代码克隆至 ``spikezoo/archs/stir`` 目录 2. 定位模型定义文件 ``spikezoo/archs/stir/models/networks_STIR.py``,其中类名为 ``STIR`` 3. 在 ``spikezoo/models`` 目录下创建 ``stir_model.py``,配置参数如下: .. code-block:: python @dataclass class STIRConfig(BaseModelConfig): model_name: str = "stir" # 需与文件名保持一致 model_file_name: str = "models.networks_STIR" # archs路径下的模块路径 model_cls_name: str = "STIR" # 模型类名 model_length: int = 61 # 标准输入长度 require_params: bool = True # 需要参数初始化 model_params: dict = field(default_factory=lambda: {}) # 使用默认参数 4. 继承基类实现STIR模型。由于涉及多次输入脉冲下采样处理,需重写脉冲预处理和后处理方法: .. code-block:: python class STIR(BaseModel): def __init__(self, cfg: BaseModelConfig): super(STIR, self).__init__(cfg) def preprocess_spike(self, spike): spike = self.crop_spike_length(spike) if self.spike_size == (250, 400): spike = torch.cat([spike, spike[:, :, -6:]], dim=2) elif self.spike_size == (480, 854): spike = torch.cat([spike, spike[:, :, :, -10:]], dim=3) return spike def postprocess_img(self, image): if self.spike_size == (250, 400): image = image[:, :, :250, :] elif self.spike_size == (480, 854): image = image[:, :, :, :854] return image def get_outputs_dict(self, batch): spike = batch["spike"] rate = batch["rate"].view(-1, 1, 1, 1).float() outputs = {} spike = self.preprocess_spike(spike) img_pred_0, Fs_lv_0, Fs_lv_1, Fs_lv_2, Fs_lv_3, Fs_lv_4, Est = self.net(spike) img_pred_0 = self.postprocess_img(img_pred_0) outputs["recon_img"] = img_pred_0 / rate return outputs 5. STIR的多尺度金字塔损失函数可通过重写 ``get_loss_dict`` 实现,但因性能提升有限暂未实现