处理管线 ======================= 处理管线封装了模型与数据集,提供单模型推理/训练和多模型联合推理的统一接口。 推理管线 ---------------- .. _eval_config: 参数配置 ^^^^^^^^^^^ 1. 单模型管线 参数配置代码: .. code-block:: python @dataclass class PipelineConfig: "Loading weights from local or version on the url." version: Literal["local", "v010", "v023"] = "local" "Save folder for the code running result." save_folder: str = "" "Saved experiment name." exp_name: str = "" "Evaluate metrics or not." save_metric: bool = True "Metric names for evaluation." metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","niqe","brisque"]) "Save recoverd images or not." save_img: bool = True "Normalizing recoverd images and gt or not." img_norm: bool = False "Batch size for the test dataloader." bs_test: int = 1 "Num_workers for the test dataloader." nw_test: int = 0 "Pin_memory true or false for the dataloader." pin_memory: bool = False "Different modes for the pipeline." _mode: Literal["single_mode", "multi_mode", "train_mode"] = "single_mode" 参数说明: - ``version`` : 权重加载来源(本地路径或发行版本) - ``save_folder`` : 运行结果存储路径(日志/图像/指标) - ``exp_name`` : 实验命名标识(默认使用时间戳) - ``save_metric`` : 是否输出量化指标 - ``metric_names`` : 指定输出的评估指标 - ``save_img`` : 是否保存重建图像 - ``img_norm`` : 保存图像以及测试指标前是否进行归一化处理 - ``bs_test`` : 测试批大小 - ``nw_test`` : 测试数据加载线程数 - ``pin_memory`` : 启用 ``pin_memory`` 模式 - ``_mode`` : 管线运行模式标识(内部参数) 2. 多模型管线: .. code-block:: python @dataclass class EnsemblePipelineConfig(PipelineConfig): _mode: Literal["single_mode", "multi_mode", "train_mode"] = "multi_mode" 参数说明: - ``_mode`` : 强制设置为多模型模式(内部标识参数) 3. 参数设置补充说明 - ``version``: 支持 ``"local"``, ``"v010"``, ``"v023"`` 三个参数设置,其中 ``"local"``表示从本地路径加载权重,使用方式见 :ref:`eval_initial`,不同发行版本的介绍见 :ref:`version`。 - ``metric_names``: 指定评测指标,例如 ``["psnr", "ssim","niqe","brisque"]``, 通过 :ref:`support` 查看Spike-Zoo支持的指标。 - ``img_norm``: 会同时将重构图像和清晰图像归一化,影响最终保存图像的可视化以及指标计算。(``SpikeCLIP`` 由于是基于文本训练的模型,输出不在`[0,1]`范围内,会自动归一化输出图像) .. _eval_initial: 实例化 ^^^^^^^^^^^ 1. 单模型管线 管线初始化代码接口如下: .. code-block:: python class Pipeline: def __init__( self, cfg: PipelineConfig, model_cfg: Union[sz.METHOD, BaseModelConfig], dataset_cfg: Union[sz.DATASET, BaseDatasetConfig], ): self.cfg = cfg self._setup_model_data(model_cfg, dataset_cfg) self._setup_pipeline() 管线实例化支持以下两种方式: - 预设参数初始化 .. code-block:: python from spikezoo.pipeline import Pipeline, PipelineConfig import spikezoo as sz pipeline = Pipeline( cfg=PipelineConfig(save_folder="results",version="v023"), model_cfg=sz.METHOD.BASE, dataset_cfg=sz.DATASET.BASE ) 此方式利用名称来直接指定方法和数据集的默认配置参数,针对数据集需要按照 :ref:`dataset_prepare` 将数据下载至对应位置。 - 自定义参数初始化(推荐方式) .. code-block:: python from spikezoo.pipeline import Pipeline, PipelineConfig from spikezoo.models.base_model import BaseModelConfig from spikezoo.datasets.base_dataset import BaseDatasetConfig import spikezoo as sz # 方式一:加载发行版v023预训练权重 pipeline = Pipeline( cfg=PipelineConfig(save_folder="results",version="v023"), model_cfg=BaseModelConfig(), dataset_cfg=BaseDatasetConfig() ) # 方式二:加载本地预训练权重 pipeline = Pipeline( cfg=PipelineConfig(save_folder="results",version="local"), model_cfg=BaseModelConfig(ckpt_path="spikezoo/models/weights/v023/base.pth"), dataset_cfg=BaseDatasetConfig() ) 2. 多模型管线 管线初始化代码接口如下: .. code-block:: python class EnsemblePipeline(Pipeline): def __init__( self, cfg: PipelineConfig, model_cfg_list: Union[List[sz.METHOD], List[BaseModelConfig]], dataset_cfg: Union[sz.DATASET, BaseDatasetConfig], ): self.cfg = cfg self._setup_model_data(model_cfg_list, dataset_cfg) self._setup_pipeline() 支持两种配置方式: - 预设参数初始化 .. code-block:: python import spikezoo as sz from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig pipeline = EnsemblePipeline( cfg=EnsemblePipelineConfig(save_folder="results",version="v023"), model_cfg_list=[ sz.METHOD.BASE,sz.METHOD.TFP,sz.METHOD.TFI,sz.METHOD.SPK2IMGNET,sz.METHOD.WGSE, sz.METHOD.SSML,sz.METHOD.BSF,sz.METHOD.STIR,sz.METHOD.SPIKECLIP,sz.METHOD.SSIR], dataset_cfg=sz.DATASET.BASE, ) - 自定义参数初始化(推荐方式) .. code-block:: python import spikezoo as sz from spikezoo.datasets.base_dataset import BaseDatasetConfig from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig from spikezoo.models.base_model import BaseModel,BaseModelConfig from spikezoo.models.tfp_model import TFPModel,TFPConfig from spikezoo.models.tfi_model import TFIModel,TFIConfig from spikezoo.models.spk2imgnet_model import Spk2ImgNet,Spk2ImgNetConfig from spikezoo.models.wgse_model import WGSE,WGSEConfig from spikezoo.models.ssml_model import SSML,SSMLConfig from spikezoo.models.bsf_model import BSF,BSFConfig from spikezoo.models.stir_model import STIR,STIRConfig from spikezoo.models.ssir_model import SSIR,SSIRConfig from spikezoo.models.spikeclip_model import SpikeCLIP,SpikeCLIPConfig pipeline = EnsemblePipeline( cfg=EnsemblePipelineConfig(save_folder="results",version="v023"), model_cfg_list=[ BaseModelConfig(),TFPConfig(),TFIConfig(),Spk2ImgNetConfig(),WGSEConfig(), SSMLConfig(),BSFConfig(),STIRConfig(),SpikeCLIPConfig(),SSIRConfig()], dataset_cfg=BaseDatasetConfig(), ) .. _train_pipe: 训练管线 ---------------- 参数配置 ^^^^^^^^^^^ 在推理管线的基础上,训练管线配置代码增加了额外的训练控制参数: .. code-block:: python @dataclass class TrainPipelineConfig(PipelineConfig): # parameters setting "Training epochs." epochs: int = 10 "Steps per to save images." steps_per_save_imgs: int = 10 "Steps per to save model weights." steps_per_save_ckpt: int = 10 "Steps per to calculate the metrics." steps_per_cal_metrics: int = 10 "Step for gradient accumulation. (for snn methods)" steps_grad_accumulation: int = 4 "Pipeline mode." _mode: Literal["single_mode", "multi_mode", "train_mode"] = "train_mode" "Use tensorboard or not" use_tensorboard: bool = True "Random seed." seed: int = 521 # dataloader setting "Batch size for the train dataloader." bs_train: int = 8 "Num_workers for the train dataloader." nw_train: int = 4 # train setting - optimizer & scheduler & loss_dict "Optimizer config." optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-3) "Scheduler config." scheduler_cfg: Optional[SchedulerConfig] = None "Loss dict {loss_name,weight}." loss_weight_dict: Dict[Literal["l1", "l2"], float] = field(default_factory=lambda: {"l1": 1}) 参数详解: - ``epochs`` : 总训练轮次 - ``steps_per_save_imgs`` : 重建图像保存间隔(单位:epoch) - ``steps_per_save_ckpt`` : 模型权重保存间隔(单位:epoch) - ``steps_per_cal_metrics`` : 指标计算间隔(单位:epoch) - ``steps_grad_accumulation`` : 梯度累积步数(适用于SNN方法) - ``_mode`` : 强制设置为训练模式 - ``use_tensorboard`` : 启用TensorBoard可视化 - ``seed`` : 随机数种子 - ``bs_train`` : 训练批大小 - ``nw_train`` : 训练数据加载线程数 - ``optimizer_cfg`` : 优化器配置(默认Adam) - ``scheduler_cfg`` : 学习率调度策略 - ``loss_weight_dict`` : 损失函数权重配置 实例化 ^^^^^^^^^^^ 基础训练示例(快速验证): .. code-block:: python from spikezoo.pipeline import TrainPipelineConfig, TrainPipeline from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig from spikezoo.models.base_model import BaseModelConfig pipeline = TrainPipeline( cfg=TrainPipelineConfig(save_folder="results", epochs = 10), dataset_cfg=REDS_BASEConfig(root_dir = "spikezoo/data/reds_base"), model_cfg=BaseModelConfig(), ) pipeline.train() .. note:: 单卡4090 GPU实测:训练耗时约2分钟,PSNR 32.8dB / SSIM 0.92 高级配置示例(完整训练): .. code-block:: python from spikezoo.utils.optimizer_utils import OptimizerConfig, AdamOptimizerConfig from spikezoo.utils.scheduler_utils import SchedulerConfig, MultiStepSchedulerConfig from dataclasses import dataclass, field from spikezoo.pipeline.train_pipeline import TrainPipelineConfig from typing import Optional, Dict, List from spikezoo.pipeline import TrainPipeline from spikezoo.datasets.reds_base_dataset import REDS_BASEConfig from spikezoo.models import BaseModelConfig @dataclass class REDS_BASE_TrainConfig(TrainPipelineConfig): """REDS-BASE数据集专用训练配置""" # 参数设置 epochs: int = 600 steps_per_save_imgs: int = 200 steps_per_save_ckpt: int = 500 steps_per_cal_metrics: int = 100 metric_names: List[str] = field(default_factory=lambda: ["psnr", "ssim","lpips","niqe","brisque","piqe"]) # 数据加载设置 bs_train: int = 8 nw_train: int = 4 pin_memory: bool = False # 训练策略 optimizer_cfg: OptimizerConfig = AdamOptimizerConfig(lr=1e-4) scheduler_cfg: Optional[SchedulerConfig] = MultiStepSchedulerConfig(milestones=[400], gamma=0.2) # WGSE论文配置 loss_weight_dict: Dict = field(default_factory=lambda: {"l1": 1}) pipeline = TrainPipeline( cfg=REDS_BASE_TrainConfig(save_folder="results", exp_name="base"), dataset_cfg=REDS_BASEConfig(root_dir="spikezoo/data/reds_base", use_aug=True, crop_size=(128, 128)), model_cfg=BaseModelConfig(), ) pipeline.train() .. note:: 完整训练结果:PSNR 36.5dB / SSIM 0.965 更多模型在REDS_BASE数据集上的训练配置示例可参考: https://github.com/chenkang455/Spike-Zoo/tree/main/examples/train_reds_base 自定义训练 ^^^^^^^^^^^ Spike-Zoo 提供通过继承基类的方式来分别实现 ``model``, ``dataset`` 和 ``pipeline``,以尽量少的代码修改完成自定义功能设置。 具体例子见:https://github.com/chenkang455/Spike-Zoo/tree/main/examples/train_reds_base