模型
Spike-Zoo的模型组件 model 对网络架构 arch 进行封装,主要负责模型输入脉冲的预处理、输出图像的后处理以及训练相关的损失函数计算等功能。
模型介绍
有待完善........
参数配置
以 BASE 模型为例,配置类定义如下:
@dataclass
class BaseModelConfig:
# ------------- Not Recommended to Change -------------
"Registerd model name."
model_name: str = "base"
"File name of the specified model."
model_file_name: str = "nets"
"Class name of the specified model in spikezoo/archs/base/{model_file_name}.py."
model_cls_name: str = "BaseNet"
"Spike input length. (local mode)"
model_length: int = 41
"Spike input length for different versions."
model_length_dict: dict = field(default_factory=lambda: {"v010": 41, "v023": 41})
"Model require model parameters or not."
require_params: bool = True
"Model parameters. (local mode)"
model_params: dict = field(default_factory=lambda: {})
"Model parameters for different versions."
model_params_dict: dict = field(default_factory=lambda: {"v010": {}, "v023": {}})
# ------------- Config -------------
"Load ckpt path. Used on the local mode."
ckpt_path: str = ""
"Load pretrained weights or not. (default false, set to true during the evaluation mode.)"
load_state: bool = False
"Multi-GPU setting."
multi_gpu: bool = False
"Base url."
base_url: str = "https://github.com/chenkang455/Spike-Zoo/releases/download"
"Load the model from local class or spikezoo lib. (None)"
model_cls_local: Optional[nn.Module] = None
"Load the arch from local class or spikezoo lib. (None)"
arch_cls_local: Optional[nn.Module] = None
参数说明如下:
model_name: 注册的模型名称,如"base","spk2imgnet"和"spikeclip"model_file_name: 模型架构定义文件名称,如spikezoo/archs/base/nets.pymodel_cls_name: 模型类名,对应spikezoo/archs/base/nets.py文件中的BaseNetmodel_length: 输入脉冲的标准长度,用于将输入裁剪为指定尺寸(本地模式)model_length_dict: 不同发行版本对应的输入脉冲长度require_params: 是否需要进行模型参数学习model_params: 模型初始化参数(本地模式),用于实例化archs中的网络model_params_dict: 不同发行版本对应的模型参数配置ckpt_path: 预训练权重加载路径(仅限本地模式使用)load_state: 是否加载预训练权重(默认关闭,评估模式需设为开启)multi_gpu: 是否启用多GPU训练模式base_url: 各版本预训练权重的云端存储地址model_cls_local: 调用本地设计的模型类,默认值为None``(即导入 ``spikezoo仓库的模型类)arch_cls_local: 调用本地设计的网络架构,默认值为None``(即导入 ``spikezoo仓库的网络架构)
模型构建类
以 BASE 模型为例,模型构建类定义如下:
class BaseModel(nn.Module):
# 初始化模型实例
def __init__(self, cfg: BaseModelConfig):
# 前向推理接口:输入脉冲,输出重建图像
def forward(self, spike):
# 核心转换方法:将单个脉冲转换为图像(被训练和推理接口调用)
def spk2img(self, spike):
# 网络构建方法:加载模型架构并选择是否加载权重
def build_network(
self,
mode: Literal["debug", "train", "eval"] = "debug",
version: Literal["local", "v010", "v023"] = "local",
):
# 网络权重保存
def save_network(self, save_path):
# 输入脉冲长度裁剪
def crop_spike_length(self, spike):
# 输入脉冲预处理(尺寸调整、脉冲表征转换等)
def preprocess_spike(self, spike):
# 输出图像后处理(尺寸还原、亮度校正等)
def postprocess_img(self, image):
# 获取训练输出字典(训练时可能包含多组输出)
def get_outputs_dict(self, batch):
# 获取需要保存的可视化图像字典
def get_visual_dict(self, batch, outputs):
# 根据输出结果和输入数据计算损失值
def get_loss_dict(self, outputs, batch, loss_weight_dict):
# 损失函数定义方法
def get_loss_func(self, name: Literal["l1", "l2"]):
# 获取用于计算图像指标的真值-重建图像对
def get_paired_imgs(self, batch, outputs):
# 将输入数据载入计算设备
def feed_to_device(self, batch):
注意事项:
模型的核心作用是实现单段输入脉冲到重构图像的映射关系,但脉冲重构在训练和推理阶段接口不同:
训练接口
get_outputs_dict: 部分训练方法如ssml和stir存在多个输出构成损失函数,故该接口会输出一个字典并在get_loss_dict接口中计算损失函数推理接口
forward: 直接调用spk2img函数实现脉冲输入到重构图像的映射各函数里使用
batch和outputs时注意和数据集给定的字典key对齐
实例化
模型除在 pipeline 中与 dataset 结合使用外,也提供单独调用方式:
import spikezoo as sz
from spikezoo.models.base_model import BaseModel, BaseModelConfig
# 输入数据加载
spike = sz.load_vidar_dat("data/data.dat", width=400, height=250, out_format="tensor")
spike = spike[None].cuda()
print(f"Input spike shape: {spike.shape}")
# 网络初始化
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}))
net.build_network(mode = "debug")
# 推理过程
recon_img = net(spike)
print(recon_img.shape,recon_img.max(),recon_img.min())
build_network 的典型用法:
# 1. 调试模式构建网络,是否加载权重由配置决定
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}, load_state=False))
net.build_network(mode="debug")
# 2. 训练模式构建网络,是否加载权重由配置决定
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}, load_state=False))
net.build_network(mode="train")
# 3. 评估模式构建网络,自动加载本地配置指定的预训练权重
net = BaseModel(BaseModelConfig(model_params={"inDim": 41}, load_state=True,ckpt_path="spikezoo/models/weights/v023/base.pth"))
net.build_network(mode="eval", version="local")
# 4. 评估模式构建网络,自动加载发行版预训练权重
net = BaseModel(BaseModelConfig())
net.build_network(mode="eval", version="v023")
模式说明:
debug: 调试模式,验证脉冲到图像的转换流程eval: 评估模式,支持从本地路径ckpt_path或发行版(如v023)加载权重train: 训练模式,默认不加载权重,可通过load_state参数控制权重加载
其他模型的使用方式:
from spikezoo.models.tfp_model import TFPModel,TFPConfig
from spikezoo.models.tfi_model import TFIModel,TFIConfig
from spikezoo.models.spk2imgnet_model import Spk2ImgNet,Spk2ImgNetConfig
from spikezoo.models.wgse_model import WGSE,WGSEConfig
from spikezoo.models.ssml_model import SSML,SSMLConfig
from spikezoo.models.bsf_model import BSF,BSFConfig
from spikezoo.models.stir_model import STIR,STIRConfig
from spikezoo.models.ssir_model import SSIR,SSIRConfig
from spikezoo.models.spikeclip_model import SpikeCLIP,SpikeCLIPConfig
自定义模型开发
除了给定模型以外,Spike-Zoo还支持使用自定义设计模型:
实现步骤:
方式一、将模型加入到Spike-Zoo仓库中,采样标准调用方式
创建模型文件
spikezoo/models/yourmodel_model.py继承基类并分别实现
YourModelConfig和YourModel:
from torch.utils.data import Dataset
from pathlib import Path
from dataclasses import dataclass
from typing import Literal, Union
from typing import Optional
from spikezoo.models.base_model import BaseModel, BaseModelConfig
from dataclasses import field
import torch.nn as nn
@dataclass
class YourModelConfig(BaseModelConfig):
model_name: str = "yourmodel" # 需与文件名保持一致
model_file_name: str = "arch.net" # archs路径下的模块路径
model_cls_name: str = "YourNet" # 模型类名
model_length: int = 41
require_params: bool = True
model_params: dict = field(default_factory=lambda: {"inDim": 41})
class YourModel(BaseModel):
def __init__(self, cfg: BaseModelConfig):
super(YourModel, self).__init__(cfg)
创建架构文件
spikezoo/archs/yourmodel/arch/net.py, 其中yourmodel对应model_name,arch/net.py对应model_file_name,YourNet是架构的类名称架构文件代码如下:
import torch.nn as nn
def conv_layer(inDim, outDim, ks, s, p, norm_layer="none"):
## convolutional layer
conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
relu = nn.ReLU(True)
assert norm_layer in ("batch", "instance", "none")
if norm_layer == "none":
seq = nn.Sequential(*[conv, relu])
else:
if norm_layer == "instance":
norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm
else:
momentum = 0.1
norm = nn.BatchNorm2d(outDim, momentum=momentum, affine=True, track_running_stats=True)
seq = nn.Sequential(*[conv, norm, relu])
return seq
class YourNet(nn.Module):
"""Borrow the structure from the SpikeCLIP. (https://arxiv.org/abs/2501.04477)"""
def __init__(self, inDim=41):
super(YourNet, self).__init__()
norm = "none"
outDim = 1
convBlock1 = conv_layer(inDim, 64, 3, 1, 1)
convBlock2 = conv_layer(64, 128, 3, 1, 1, norm)
convBlock3 = conv_layer(128, 64, 3, 1, 1, norm)
convBlock4 = conv_layer(64, 16, 3, 1, 1, norm)
conv = nn.Conv2d(16, outDim, 3, 1, 1)
self.seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])
def forward(self, x):
return self.seq(x)
本地文件
test.py调用自定义模型
from spikezoo.models.yourmodel_model import YourModel, YourModelConfig
net = YourModel(YourModelConfig())
net.build_network(mode="debug")
方式二、本地直接继承模型基类
创建本地运行文件
test.py同方式一步骤2和4实现
YourModelConfig,YourModel和YourNet, 其中model_file_name,model_cls_name参数可以忽略设置直接调用自定义模型:
net = YourModel(
YourModelConfig(
model_cls_local=YourModel,
arch_cls_local=YourNet,
load_state=True,
ckpt_path="spikezoo/models/weights/v023/base.pth",
)
)
net.build_network(mode="eval")
具体例子见:https://github.com/chenkang455/Spike-Zoo/blob/main/examples/test/test_yourmodel.py
示例、封装已有模型
将STIR官方代码克隆至
spikezoo/archs/stir目录定位模型定义文件
spikezoo/archs/stir/models/networks_STIR.py,其中类名为STIR在
spikezoo/models目录下创建stir_model.py,配置参数如下:
@dataclass
class STIRConfig(BaseModelConfig):
model_name: str = "stir" # 需与文件名保持一致
model_file_name: str = "models.networks_STIR" # archs路径下的模块路径
model_cls_name: str = "STIR" # 模型类名
model_length: int = 61 # 标准输入长度
require_params: bool = True # 需要参数初始化
model_params: dict = field(default_factory=lambda: {}) # 使用默认参数
继承基类实现STIR模型。由于涉及多次输入脉冲下采样处理,需重写脉冲预处理和后处理方法:
class STIR(BaseModel):
def __init__(self, cfg: BaseModelConfig):
super(STIR, self).__init__(cfg)
def preprocess_spike(self, spike):
spike = self.crop_spike_length(spike)
if self.spike_size == (250, 400):
spike = torch.cat([spike, spike[:, :, -6:]], dim=2)
elif self.spike_size == (480, 854):
spike = torch.cat([spike, spike[:, :, :, -10:]], dim=3)
return spike
def postprocess_img(self, image):
if self.spike_size == (250, 400):
image = image[:, :, :250, :]
elif self.spike_size == (480, 854):
image = image[:, :, :, :854]
return image
def get_outputs_dict(self, batch):
spike = batch["spike"]
rate = batch["rate"].view(-1, 1, 1, 1).float()
outputs = {}
spike = self.preprocess_spike(spike)
img_pred_0, Fs_lv_0, Fs_lv_1, Fs_lv_2, Fs_lv_3, Fs_lv_4, Est = self.net(spike)
img_pred_0 = self.postprocess_img(img_pred_0)
outputs["recon_img"] = img_pred_0 / rate
return outputs
STIR的多尺度金字塔损失函数可通过重写
get_loss_dict实现,但因性能提升有限暂未实现