处理管线

处理管线封装了模型与数据集,提供单模型推理/训练和多模型联合推理的统一接口。

推理管线

参数配置

  1. 单模型管线

参数配置代码:

@dataclass
class PipelineConfig:
    "Loading weights from local or version on the url."
    version: Literal["local", "v010", "v023"] = "local"
    "Save folder for the code running result."
    save_folder: str = ""
    "Saved experiment name."
    exp_name: str = ""
    "Evaluate metrics or not."
    save_metric: bool = True
    "Metric names for evaluation."
    metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","niqe","brisque"])
    "Save recoverd images or not."
    save_img: bool = True
    "Normalizing recoverd images and gt or not."
    img_norm: bool = False
    "Batch size for the test dataloader."
    bs_test: int = 1
    "Num_workers for the test dataloader."
    nw_test: int = 0
    "Pin_memory true or false for the dataloader."
    pin_memory: bool = False
    "Different modes for the pipeline."
    _mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode"

参数说明:

  • version : 权重加载来源(本地路径或发行版本)

  • save_folder : 运行结果存储路径(日志/图像/指标)

  • exp_name : 实验命名标识(默认使用时间戳)

  • save_metric : 是否输出量化指标

  • metric_names : 指定输出的评估指标

  • save_img : 是否保存重建图像

  • img_norm : 保存图像以及测试指标前是否进行归一化处理

  • bs_test : 测试批大小

  • nw_test : 测试数据加载线程数

  • pin_memory : 启用 pin_memory 模式

  • _mode : 管线运行模式标识(内部参数)

  1. 多模型管线:

@dataclass
class EnsemblePipelineConfig(PipelineConfig):
    _mode: Literal["single_mode", "multi_mode", "train_mode"] = "multi_mode"

参数说明:

  • _mode : 强制设置为多模型模式(内部标识参数)

  1. 参数设置补充说明

  • version: 支持 "local", "v010", "v023" 三个参数设置,其中 ``"local"``表示从本地路径加载权重,使用方式见 实例化,不同发行版本的介绍见 发行版本介绍

  • metric_names: 指定评测指标,例如 ["psnr", "ssim","niqe","brisque"], 通过 支持范围 查看Spike-Zoo支持的指标。

  • img_norm: 会同时将重构图像和清晰图像归一化,影响最终保存图像的可视化以及指标计算。(SpikeCLIP 由于是基于文本训练的模型,输出不在`[0,1]`范围内,会自动归一化输出图像)

实例化

  1. 单模型管线

管线初始化代码接口如下:

class Pipeline:
    def __init__(
        self,
        cfg: PipelineConfig,
        model_cfg: Union[sz.METHOD, BaseModelConfig],
        dataset_cfg: Union[sz.DATASET, BaseDatasetConfig],
    ):
        self.cfg = cfg
        self._setup_model_data(model_cfg, dataset_cfg)
        self._setup_pipeline()

管线实例化支持以下两种方式:

  • 预设参数初始化

from spikezoo.pipeline import Pipeline, PipelineConfig
import spikezoo as sz
pipeline = Pipeline(
    cfg=PipelineConfig(save_folder="results",version="v023"),
    model_cfg=sz.METHOD.BASE,
    dataset_cfg=sz.DATASET.BASE
)

此方式利用名称来直接指定方法和数据集的默认配置参数,针对数据集需要按照 数据来源 将数据下载至对应位置。

  • 自定义参数初始化(推荐方式)

from spikezoo.pipeline import Pipeline, PipelineConfig
from spikezoo.models.base_model import BaseModelConfig
from spikezoo.datasets.base_dataset import BaseDatasetConfig
import spikezoo as sz
# 方式一:加载发行版v023预训练权重
pipeline = Pipeline(
    cfg=PipelineConfig(save_folder="results",version="v023"),
    model_cfg=BaseModelConfig(),
    dataset_cfg=BaseDatasetConfig()
)
# 方式二:加载本地预训练权重
pipeline = Pipeline(
    cfg=PipelineConfig(save_folder="results",version="local"),
    model_cfg=BaseModelConfig(ckpt_path="spikezoo/models/weights/v023/base.pth"),
    dataset_cfg=BaseDatasetConfig()
)
  1. 多模型管线

管线初始化代码接口如下:

class EnsemblePipeline(Pipeline):
    def __init__(
        self,
        cfg: PipelineConfig,
        model_cfg_list: Union[List[sz.METHOD], List[BaseModelConfig]],
        dataset_cfg: Union[sz.DATASET, BaseDatasetConfig],
    ):
        self.cfg = cfg
        self._setup_model_data(model_cfg_list, dataset_cfg)
        self._setup_pipeline()

支持两种配置方式:

  • 预设参数初始化

import spikezoo as sz
from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig
pipeline = EnsemblePipeline(
    cfg=EnsemblePipelineConfig(save_folder="results",version="v023"),
    model_cfg_list=[
        sz.METHOD.BASE,sz.METHOD.TFP,sz.METHOD.TFI,sz.METHOD.SPK2IMGNET,sz.METHOD.WGSE,
        sz.METHOD.SSML,sz.METHOD.BSF,sz.METHOD.STIR,sz.METHOD.SPIKECLIP,sz.METHOD.SSIR],
    dataset_cfg=sz.DATASET.BASE,
)
  • 自定义参数初始化(推荐方式)

import spikezoo as sz
from spikezoo.datasets.base_dataset import BaseDatasetConfig
from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig
from spikezoo.models.base_model import BaseModel,BaseModelConfig
from spikezoo.models.tfp_model import TFPModel,TFPConfig
from spikezoo.models.tfi_model import TFIModel,TFIConfig
from spikezoo.models.spk2imgnet_model import Spk2ImgNet,Spk2ImgNetConfig
from spikezoo.models.wgse_model import WGSE,WGSEConfig
from spikezoo.models.ssml_model import SSML,SSMLConfig
from spikezoo.models.bsf_model import BSF,BSFConfig
from spikezoo.models.stir_model import STIR,STIRConfig
from spikezoo.models.ssir_model import SSIR,SSIRConfig
from spikezoo.models.spikeclip_model import SpikeCLIP,SpikeCLIPConfig

pipeline = EnsemblePipeline(
    cfg=EnsemblePipelineConfig(save_folder="results",version="v023"),
    model_cfg_list=[
        BaseModelConfig(),TFPConfig(),TFIConfig(),Spk2ImgNetConfig(),WGSEConfig(),
        SSMLConfig(),BSFConfig(),STIRConfig(),SpikeCLIPConfig(),SSIRConfig()],
    dataset_cfg=BaseDatasetConfig(),
)

训练管线

参数配置

在推理管线的基础上,训练管线配置代码增加了额外的训练控制参数:

@dataclass
class TrainPipelineConfig(PipelineConfig):
    # parameters setting
    "Training epochs."
    epochs: int = 10
    "Steps per to save images."
    steps_per_save_imgs: int = 10
    "Steps per to save model weights."
    steps_per_save_ckpt: int = 10
    "Steps per to calculate the metrics."
    steps_per_cal_metrics: int = 10
    "Step for gradient accumulation. (for snn methods)"
    steps_grad_accumulation: int = 4
    "Pipeline mode."
    _mode: Literal["single_mode", "multi_mode", "train_mode"] = "train_mode"
    "Use tensorboard or not"
    use_tensorboard: bool = True
    "Random seed."
    seed: int = 521
    # dataloader setting
    "Batch size for the train dataloader."
    bs_train: int = 8
    "Num_workers for the train dataloader."
    nw_train: int = 4

    # train setting - optimizer & scheduler & loss_dict
    "Optimizer config."
    optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-3)
    "Scheduler config."
    scheduler_cfg: Optional[SchedulerConfig] = None
    "Loss dict {loss_name,weight}."
    loss_weight_dict: Dict[Literal["l1", "l2"], float] = field(default_factory=lambda: {"l1": 1})

参数详解:

  • epochs : 总训练轮次

  • steps_per_save_imgs : 重建图像保存间隔(单位:epoch)

  • steps_per_save_ckpt : 模型权重保存间隔(单位:epoch)

  • steps_per_cal_metrics : 指标计算间隔(单位:epoch)

  • steps_grad_accumulation : 梯度累积步数(适用于SNN方法)

  • _mode : 强制设置为训练模式

  • use_tensorboard : 启用TensorBoard可视化

  • seed : 随机数种子

  • bs_train : 训练批大小

  • nw_train : 训练数据加载线程数

  • optimizer_cfg : 优化器配置(默认Adam)

  • scheduler_cfg : 学习率调度策略

  • loss_weight_dict : 损失函数权重配置

实例化

基础训练示例(快速验证):

from spikezoo.pipeline import TrainPipelineConfig, TrainPipeline
from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig
from spikezoo.models.base_model import BaseModelConfig
pipeline = TrainPipeline(
    cfg=TrainPipelineConfig(save_folder="results", epochs = 10),
    dataset_cfg=REDS_BASEConfig(root_dir = "spikezoo/data/reds_base"),
    model_cfg=BaseModelConfig(),
)
pipeline.train()

备注

单卡4090 GPU实测:训练耗时约2分钟,PSNR 32.8dB / SSIM 0.92

高级配置示例(完整训练):

from spikezoo.utils.optimizer_utils import OptimizerConfig, AdamOptimizerConfig
from spikezoo.utils.scheduler_utils import SchedulerConfig, MultiStepSchedulerConfig
from dataclasses import dataclass, field
from spikezoo.pipeline.train_pipeline import TrainPipelineConfig
from typing import Optional, Dict, List
from spikezoo.pipeline import TrainPipeline
from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig
from spikezoo.models import BaseModelConfig

@dataclass
class REDS_BASE_TrainConfig(TrainPipelineConfig):
    """REDS-BASE数据集专用训练配置"""

    # 参数设置
    epochs: int = 600
    steps_per_save_imgs: int = 200
    steps_per_save_ckpt: int = 500
    steps_per_cal_metrics: int = 100
    metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","lpips","niqe","brisque","piqe"])

    # 数据加载设置
    bs_train: int = 8
    nw_train: int = 4
    pin_memory: bool = False

    # 训练策略
    optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4)
    scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400], gamma=0.2) # WGSE论文配置
    loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1})

pipeline = TrainPipeline(
    cfg=REDS_BASE_TrainConfig(save_folder="results", exp_name="base"),
    dataset_cfg=REDS_BASEConfig(root_dir="spikezoo/data/reds_base", use_aug=True, crop_size=(128, 128)),
    model_cfg=BaseModelConfig(),
)
pipeline.train()

备注

完整训练结果:PSNR 36.5dB / SSIM 0.965

更多模型在REDS_BASE数据集上的训练配置示例可参考: https://github.com/chenkang455/Spike-Zoo/tree/main/examples/train_reds_base

自定义训练

Spike-Zoo 提供通过继承基类的方式来分别实现 model, datasetpipeline,以尽量少的代码修改完成自定义功能设置。

具体例子见:https://github.com/chenkang455/Spike-Zoo/tree/main/examples/train_reds_base