处理管线
处理管线封装了模型与数据集,提供单模型推理/训练和多模型联合推理的统一接口。
推理管线
参数配置
单模型管线
参数配置代码:
@dataclass
class PipelineConfig:
"Loading weights from local or version on the url."
version: Literal["local", "v010", "v023"] = "local"
"Save folder for the code running result."
save_folder: str = ""
"Saved experiment name."
exp_name: str = ""
"Evaluate metrics or not."
save_metric: bool = True
"Metric names for evaluation."
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","niqe","brisque"])
"Save recoverd images or not."
save_img: bool = True
"Normalizing recoverd images and gt or not."
img_norm: bool = False
"Batch size for the test dataloader."
bs_test: int = 1
"Num_workers for the test dataloader."
nw_test: int = 0
"Pin_memory true or false for the dataloader."
pin_memory: bool = False
"Different modes for the pipeline."
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"
参数说明:
version: 权重加载来源(本地路径或发行版本)save_folder: 运行结果存储路径(日志/图像/指标)exp_name: 实验命名标识(默认使用时间戳)save_metric: 是否输出量化指标metric_names: 指定输出的评估指标save_img: 是否保存重建图像img_norm: 保存图像以及测试指标前是否进行归一化处理bs_test: 测试批大小nw_test: 测试数据加载线程数pin_memory: 启用pin_memory模式_mode: 管线运行模式标识(内部参数)
多模型管线:
@dataclass
class EnsemblePipelineConfig(PipelineConfig):
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "multi_mode"
参数说明:
_mode: 强制设置为多模型模式(内部标识参数)
参数设置补充说明
实例化
单模型管线
管线初始化代码接口如下:
class Pipeline:
def __init__(
self,
cfg: PipelineConfig,
model_cfg: Union[sz.METHOD, BaseModelConfig],
dataset_cfg: Union[sz.DATASET, BaseDatasetConfig],
):
self.cfg = cfg
self._setup_model_data(model_cfg, dataset_cfg)
self._setup_pipeline()
管线实例化支持以下两种方式:
预设参数初始化
from spikezoo.pipeline import Pipeline, PipelineConfig
import spikezoo as sz
pipeline = Pipeline(
cfg=PipelineConfig(save_folder="results",version="v023"),
model_cfg=sz.METHOD.BASE,
dataset_cfg=sz.DATASET.BASE
)
此方式利用名称来直接指定方法和数据集的默认配置参数,针对数据集需要按照 数据来源 将数据下载至对应位置。
自定义参数初始化(推荐方式)
from spikezoo.pipeline import Pipeline, PipelineConfig
from spikezoo.models.base_model import BaseModelConfig
from spikezoo.datasets.base_dataset import BaseDatasetConfig
import spikezoo as sz
# 方式一:加载发行版v023预训练权重
pipeline = Pipeline(
cfg=PipelineConfig(save_folder="results",version="v023"),
model_cfg=BaseModelConfig(),
dataset_cfg=BaseDatasetConfig()
)
# 方式二:加载本地预训练权重
pipeline = Pipeline(
cfg=PipelineConfig(save_folder="results",version="local"),
model_cfg=BaseModelConfig(ckpt_path="spikezoo/models/weights/v023/base.pth"),
dataset_cfg=BaseDatasetConfig()
)
多模型管线
管线初始化代码接口如下:
class EnsemblePipeline(Pipeline):
def __init__(
self,
cfg: PipelineConfig,
model_cfg_list: Union[List[sz.METHOD], List[BaseModelConfig]],
dataset_cfg: Union[sz.DATASET, BaseDatasetConfig],
):
self.cfg = cfg
self._setup_model_data(model_cfg_list, dataset_cfg)
self._setup_pipeline()
支持两种配置方式:
预设参数初始化
import spikezoo as sz
from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig
pipeline = EnsemblePipeline(
cfg=EnsemblePipelineConfig(save_folder="results",version="v023"),
model_cfg_list=[
sz.METHOD.BASE,sz.METHOD.TFP,sz.METHOD.TFI,sz.METHOD.SPK2IMGNET,sz.METHOD.WGSE,
sz.METHOD.SSML,sz.METHOD.BSF,sz.METHOD.STIR,sz.METHOD.SPIKECLIP,sz.METHOD.SSIR],
dataset_cfg=sz.DATASET.BASE,
)
自定义参数初始化(推荐方式)
import spikezoo as sz
from spikezoo.datasets.base_dataset import BaseDatasetConfig
from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig
from spikezoo.models.base_model import BaseModel,BaseModelConfig
from spikezoo.models.tfp_model import TFPModel,TFPConfig
from spikezoo.models.tfi_model import TFIModel,TFIConfig
from spikezoo.models.spk2imgnet_model import Spk2ImgNet,Spk2ImgNetConfig
from spikezoo.models.wgse_model import WGSE,WGSEConfig
from spikezoo.models.ssml_model import SSML,SSMLConfig
from spikezoo.models.bsf_model import BSF,BSFConfig
from spikezoo.models.stir_model import STIR,STIRConfig
from spikezoo.models.ssir_model import SSIR,SSIRConfig
from spikezoo.models.spikeclip_model import SpikeCLIP,SpikeCLIPConfig
pipeline = EnsemblePipeline(
cfg=EnsemblePipelineConfig(save_folder="results",version="v023"),
model_cfg_list=[
BaseModelConfig(),TFPConfig(),TFIConfig(),Spk2ImgNetConfig(),WGSEConfig(),
SSMLConfig(),BSFConfig(),STIRConfig(),SpikeCLIPConfig(),SSIRConfig()],
dataset_cfg=BaseDatasetConfig(),
)
训练管线
参数配置
在推理管线的基础上,训练管线配置代码增加了额外的训练控制参数:
@dataclass
class TrainPipelineConfig(PipelineConfig):
# parameters setting
"Training epochs."
epochs: int = 10
"Steps per to save images."
steps_per_save_imgs: int = 10
"Steps per to save model weights."
steps_per_save_ckpt: int = 10
"Steps per to calculate the metrics."
steps_per_cal_metrics: int = 10
"Step for gradient accumulation. (for snn methods)"
steps_grad_accumulation: int = 4
"Pipeline mode."
_mode: Literal["single_mode", "multi_mode", "train_mode"] = "train_mode"
"Use tensorboard or not"
use_tensorboard: bool = True
"Random seed."
seed: int = 521
# dataloader setting
"Batch size for the train dataloader."
bs_train: int = 8
"Num_workers for the train dataloader."
nw_train: int = 4
# train setting - optimizer & scheduler & loss_dict
"Optimizer config."
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-3)
"Scheduler config."
scheduler_cfg: Optional[SchedulerConfig] = None
"Loss dict {loss_name,weight}."
loss_weight_dict: Dict[Literal["l1", "l2"], float] = field(default_factory=lambda: {"l1": 1})
参数详解:
epochs: 总训练轮次steps_per_save_imgs: 重建图像保存间隔(单位:epoch)steps_per_save_ckpt: 模型权重保存间隔(单位:epoch)steps_per_cal_metrics: 指标计算间隔(单位:epoch)steps_grad_accumulation: 梯度累积步数(适用于SNN方法)_mode: 强制设置为训练模式use_tensorboard: 启用TensorBoard可视化seed: 随机数种子bs_train: 训练批大小nw_train: 训练数据加载线程数optimizer_cfg: 优化器配置(默认Adam)scheduler_cfg: 学习率调度策略loss_weight_dict: 损失函数权重配置
实例化
基础训练示例(快速验证):
from spikezoo.pipeline import TrainPipelineConfig, TrainPipeline
from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig
from spikezoo.models.base_model import BaseModelConfig
pipeline = TrainPipeline(
cfg=TrainPipelineConfig(save_folder="results", epochs = 10),
dataset_cfg=REDS_BASEConfig(root_dir = "spikezoo/data/reds_base"),
model_cfg=BaseModelConfig(),
)
pipeline.train()
备注
单卡4090 GPU实测:训练耗时约2分钟,PSNR 32.8dB / SSIM 0.92
高级配置示例(完整训练):
from spikezoo.utils.optimizer_utils import OptimizerConfig, AdamOptimizerConfig
from spikezoo.utils.scheduler_utils import SchedulerConfig, MultiStepSchedulerConfig
from dataclasses import dataclass, field
from spikezoo.pipeline.train_pipeline import TrainPipelineConfig
from typing import Optional, Dict, List
from spikezoo.pipeline import TrainPipeline
from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig
from spikezoo.models import BaseModelConfig
@dataclass
class REDS_BASE_TrainConfig(TrainPipelineConfig):
"""REDS-BASE数据集专用训练配置"""
# 参数设置
epochs: int = 600
steps_per_save_imgs: int = 200
steps_per_save_ckpt: int = 500
steps_per_cal_metrics: int = 100
metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","lpips","niqe","brisque","piqe"])
# 数据加载设置
bs_train: int = 8
nw_train: int = 4
pin_memory: bool = False
# 训练策略
optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400], gamma=0.2) # WGSE论文配置
loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})
pipeline = TrainPipeline(
cfg=REDS_BASE_TrainConfig(save_folder="results", exp_name="base"),
dataset_cfg=REDS_BASEConfig(root_dir="spikezoo/data/reds_base", use_aug=True, crop_size=(128, 128)),
model_cfg=BaseModelConfig(),
)
pipeline.train()
备注
完整训练结果:PSNR 36.5dB / SSIM 0.965
更多模型在REDS_BASE数据集上的训练配置示例可参考: https://github.com/chenkang455/Spike-Zoo/tree/main/examples/train_reds_base
自定义训练
Spike-Zoo 提供通过继承基类的方式来分别实现 model, dataset 和 pipeline,以尽量少的代码修改完成自定义功能设置。
具体例子见:https://github.com/chenkang455/Spike-Zoo/tree/main/examples/train_reds_base