huggingface/pytorch-image-models
单卡:
python train.py --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 128 --validation-batch-size 128 --color-jitter-prob 0.2 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
多卡,下面参数的4表示4块卡一起训练:
sh distributed_train.sh 4 --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 64 --validation-batch-size 64 --color-jitter-prob 0.5 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
多卡的另一种形式,更改监听的端口号:
python -m torch.distributed.launch --nproc_per_node=3 --master_port=29501 train_v2.py --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 64 --validation-batch-size 64 --color-jitter-prob 0.5 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
python onnx_export.py huggingface\pytorch-image-models\output\train\20240529-132242-vit_base_patch16_224-224\model_best.onnx --mean 0 0 0 --std 1 1 1 --img-size 224 --checkpoint huggingface\pytorch-image-models\output\train\20240529-132242-vit_base_patch16_224-224\model_best.pth.tar
训练集组织形式如yolov8_cls:
│imagenet/
├──train/
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ......
│ ├── ......
├──val/
│ ├── n01440764
│ │ ├── ILSVRC2012_val_00000293.JPEG
│ │ ├── ILSVRC2012_val_00002138.JPEG
│ │ ├── ......
│ ├── ......
sh distributed_train_v2.sh 4 --pretrained --input-size 3 224 224 --mean 0 0 0 --std 1 1 1 --batch-size 64 --validation-batch-size 64 --color-jitter-prob 0.5 --grayscale-prob 0.2 --gaussian-blur-prob 0.2 --save-images
#!/usr/bin/env python3
""" ImageNet Training Script
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
training results with some of the latest networks and training techniques. It favours canonical PyTorch
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
This script was started from an early version of the PyTorch ImageNet example
(https://github.com/pytorch/examples/tree/master/imagenet)
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import importlib
import json
import logging
import os
import time
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from functools import partial
import torch
import torch.nn as nn
import torchvision.utils
import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm import utils
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScaler
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
has_compile = hasattr(torch, 'compile')
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset parameters
group = parser.add_argument_group('Dataset parameters')
# Keep this argument outside the dataset group because it is positional.
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
help='path to dataset (positional is *deprecated*, use --data-dir)')
parser.add_argument('--data-dir', metavar='DIR', default=r'/media/lg/C2032F933B04C4E6/00.Data/009.Uniform/81.version-2024.05.25/00.train_224_224',
help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("/") (default: ImageFolder or ImageTar if empty)' )
group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--train-num-samples', default=None, type=int,
metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
parser.add_argument('--val-num-samples', default=None, type=int,
metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
group.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
group.add_argument('--input-img-mode', default=None, type=str,
help='Dataset image conversion mode for input images.')
group.add_argument('--input-key', default=None, type=str,
help='Dataset key for input images.')
group.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.')
# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
group.add_argument('--pretrained-path', default='/home/test/pytorch-image-models/output/train/20240528-142446-vit_base_patch16_224-224/last.pth.tar', type=str,
help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Load this checkpoint into model after initialization (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=3000, metavar='N',
help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N',
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=1.0, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=128, metavar='N',
help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',
help='The number of steps to accumulate gradients (default: 1)')
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
group.add_argument('--head-init-scale', default=None, type=float,
help='Head initialization scale')
group.add_argument('--head-init-bias', default=None, type=float,
help='Head initialization bias value')
# scripting / codegen
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
help="Enable compilation w/ specified backend (default: inductor).")
# Device & distributed
group = parser.add_argument_group('Device parameters')
group.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
help='torch.cuda.synchronize() end of each step')
group.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
help="Python imports for device backend modules.")
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5,
help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
help='LR scheduler (default: "step"')
group.add_argument('--sched-on-updates', action='store_true', default=False,
help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
group.add_argument('--warmup-prefix', action='store_true', default=False,
help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
group.add_argument('--train-crop-mode', type=str, default=None,
help='Crop-mode in train'),
group.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.5,
help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',
help='Probability of applying any color jitter.')
group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',
help='Probability of applying random grayscale conversion.')
group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',
help='Probability of applying gaussian blur.')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-sum', action='store_true', default=False,
help='Sum over classes when using BCE loss.')
group.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled).')
group.add_argument('--bce-pos-weight', type=float, default=None,
help='Positive weighting for BCE loss.')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights.')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='Decay factor for model weights moving average (default: 0.9998)')
group.add_argument('--model-ema-warmup', action='store_true',
help='Enable warmup for model EMA decay.')
# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
utils.setup_default_logging()
args, args_text = _parse_args()
if args.device_modules:
for module in args.device_modules:
importlib.import_module(module)
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
args.prefetcher = not args.no_prefetcher
args.grad_accum_steps = max(1, args.grad_accum_steps)
device = utils.init_distributed_device(args)
if args.distributed:
_logger.info(
'Training in distributed mode with multiple processes, 1 device per process.'
f'Process {args.rank}, total {args.world_size}, device {args.device}.')
else:
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
if args.amp:
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
use_amp = 'apex'
assert args.amp_dtype == 'float16'
else:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
use_amp = 'native'
assert args.amp_dtype in ('float16', 'bfloat16')
if args.amp_dtype == 'bfloat16':
amp_dtype = torch.bfloat16
utils.random_seed(args.seed, args.rank)
if args.fuser:
utils.set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
factory_kwargs = {}
if args.pretrained_path:
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
factory_kwargs['pretrained_cfg_overlay'] = dict(
file=args.pretrained_path,
num_classes=-1, # force head adaptation
)
model = create_model(
args.model,
pretrained=args.pretrained,
in_chans=in_chans,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
**factory_kwargs,
**args.model_kwargs,
)
if args.head_init_scale is not None:
with torch.no_grad():
model.get_classifier().weight.mul_(args.head_init_scale)
model.get_classifier().bias.mul_(args.head_init_scale)
if args.head_init_bias is not None:
nn.init.constant_(model.get_classifier().bias, args.head_init_bias)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.grad_checkpointing:
model.set_grad_checkpointing(enable=True)
if utils.is_primary(args):
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
model.to(device=device)
if args.channels_last:
model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
args.dist_bn = '' # disable dist_bn when sync BN active
assert not args.split_bn
if has_apex and use_amp == 'apex':
# Apex SyncBN used with Apex AMP
# WARNING this won't currently work with models using BatchNormAct2d
model = convert_syncbn_model(model)
else:
model = convert_sync_batchnorm(model)
if utils.is_primary(args):
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not args.torchcompile
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
if not args.lr:
global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
batch_ratio = global_batch_size / args.lr_base_size
if not args.lr_base_scale:
on = args.opt.lower()
args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
if args.lr_base_scale == 'sqrt':
batch_ratio = batch_ratio ** 0.5
args.lr = args.lr_base * batch_ratio
if utils.is_primary(args):
_logger.info(
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
**args.opt_kwargs,
)
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
assert device.type == 'cuda'
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
try:
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
except (AttributeError, TypeError):
# fallback to CUDA only AMP for PyTorch < 1.10
assert device.type == 'cuda'
amp_autocast = torch.cuda.amp.autocast
if device.type == 'cuda' and amp_dtype == torch.float16:
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
loss_scaler = NativeScaler()
if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if utils.is_primary(args):
_logger.info('AMP not enabled. Training in float32.')
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
resume_epoch = resume_checkpoint(
model,
args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=utils.is_primary(args),
)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
model_ema = utils.ModelEmaV3(
model,
decay=args.model_ema_decay,
use_warmup=args.model_ema_warmup,
device='cpu' if args.model_ema_force_cpu else None,
)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
if args.torchcompile:
model_ema = torch.compile(model_ema, backend=args.torchcompile)
# setup distributed training
if args.distributed:
if has_apex and use_amp == 'apex':
# Apex DDP preferred unless native amp is activated
if utils.is_primary(args):
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if utils.is_primary(args):
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
if args.torchcompile:
# torch compile should be done after DDP
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
model = torch.compile(model, backend=args.torchcompile)
# create the train and eval datasets
if args.data and not args.data_dir:
args.data_dir = args.data
if args.input_img_mode is None:
input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
else:
input_img_mode = args.input_img_mode
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
split=args.train_split,
is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
seed=args.seed,
repeats=args.epoch_repeats,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.train_num_samples,
)
if args.val_split:
dataset_eval = create_dataset(
args.dataset,
root=args.data_dir,
split=args.val_split,
is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
input_img_mode=input_img_mode,
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.val_num_samples,
)
# setup mixup / cutmix
collate_fn = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup,
cutmix_alpha=args.cutmix,
cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob,
switch_prob=args.mixup_switch_prob,
mode=args.mixup_mode,
label_smoothing=args.smoothing,
num_classes=args.num_classes
)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
# create data loaders w/ augmentation pipeline
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train,
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
train_crop_mode=args.train_crop_mode,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
color_jitter_prob=args.color_jitter_prob,
grayscale_prob=args.grayscale_prob,
gaussian_blur_prob=args.gaussian_blur_prob,
auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
device=device,
use_prefetcher=args.prefetcher,
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding,
)
loader_eval = None
if args.val_split:
eval_workers = args.workers
if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
eval_workers = min(2, args.workers)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size or args.batch_size,
is_training=False,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=eval_workers,
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
device=device,
use_prefetcher=args.prefetcher,
)
# setup loss function
if args.jsd_loss:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform which outputs sparse, soft targets
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(
target_threshold=args.bce_target_thresh,
sum_classes=args.bce_sum,
pos_weight=args.bce_pos_weight,
)
else:
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
if args.bce_loss:
train_loss_fn = BinaryCrossEntropy(
smoothing=args.smoothing,
target_threshold=args.bce_target_thresh,
sum_classes=args.bce_sum,
pos_weight=args.bce_pos_weight,
)
else:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.to(device=device)
validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric if loader_eval is not None else 'loss'
decreasing_metric = eval_metric == 'loss'
best_metric = None
best_epoch = None
saver = None
output_dir = None
if utils.is_primary(args):
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
saver = utils.CheckpointSaver(
model=model,
optimizer=optimizer,
args=args,
model_ema=model_ema,
amp_scaler=loss_scaler,
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing_metric,
max_history=args.checkpoint_hist
)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# setup learning rate schedule and starting epoch
updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
lr_scheduler, num_epochs = create_scheduler_v2(
optimizer,
**scheduler_kwargs(args, decreasing_metric=decreasing_metric),
updates_per_epoch=updates_per_epoch,
)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
if args.sched_on_updates:
lr_scheduler.step_update(start_epoch * updates_per_epoch)
else:
lr_scheduler.step(start_epoch)
if utils.is_primary(args):
_logger.info(
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
results = []
try:
for epoch in range(start_epoch, num_epochs):
if hasattr(dataset_train, 'set_epoch'):
dataset_train.set_epoch(epoch)
elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
train_metrics = train_one_epoch(
epoch,
model,
loader_train,
optimizer,
train_loss_fn,
args,
lr_scheduler=lr_scheduler,
saver=saver,
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if utils.is_primary(args):
_logger.info("Distributing BatchNorm running means and vars")
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
if loader_eval is not None:
eval_metrics = validate(
model,
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema,
loader_eval,
validate_loss_fn,
args,
device=device,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
)
eval_metrics = ema_eval_metrics
else:
eval_metrics = None
if output_dir is not None:
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
utils.update_summary(
epoch,
train_metrics,
eval_metrics,
filename=os.path.join(output_dir, 'summary.csv'),
lr=sum(lrs) / len(lrs),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb,
)
if eval_metrics is not None:
latest_metric = eval_metrics[eval_metric]
else:
latest_metric = train_metrics[eval_metric]
if saver is not None:
# save proper checkpoint with eval metric
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, latest_metric)
results.append({
'epoch': epoch,
'train': train_metrics,
'validation': eval_metrics,
})
except KeyboardInterrupt:
pass
results = {'all': results}
if best_metric is not None:
results['best'] = results['all'][best_epoch - start_epoch]
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
print(f'--result\n{json.dumps(results, indent=4)}')
def train_one_epoch(
epoch,
model,
loader,
optimizer,
loss_fn,
args,
device=torch.device('cuda'),
lr_scheduler=None,
saver=None,
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None,
num_updates_total=None,
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
has_no_sync = hasattr(model, "no_sync")
update_time_m = utils.AverageMeter()
data_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
model.train()
accum_steps = args.grad_accum_steps
last_accum_steps = len(loader) % accum_steps
updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps
num_updates = epoch * updates_per_epoch
last_batch_idx = len(loader) - 1
last_batch_idx_to_accum = len(loader) - last_accum_steps
data_start_time = update_start_time = time.time()
optimizer.zero_grad()
update_sample_count = 0
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_batch_idx
need_update = last_batch or (batch_idx + 1) % accum_steps == 0
update_idx = batch_idx // accum_steps
if batch_idx >= last_batch_idx_to_accum:
accum_steps = last_accum_steps
if not args.prefetcher:
input, target = input.to(device), target.to(device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# multiply by accum steps to get equivalent for full update
data_time_m.update(accum_steps * (time.time() - data_start_time))
def _forward():
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
if accum_steps > 1:
loss /= accum_steps
return loss
def _backward(_loss):
if loss_scaler is not None:
loss_scaler(
_loss,
optimizer,
clip_grad=args.clip_grad,
clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order,
need_update=need_update,
)
else:
_loss.backward(create_graph=second_order)
if need_update:
if args.clip_grad is not None:
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad,
mode=args.clip_mode,
)
optimizer.step()
if has_no_sync and not need_update:
with model.no_sync():
loss = _forward()
_backward(loss)
else:
loss = _forward()
_backward(loss)
if not args.distributed:
losses_m.update(loss.item() * accum_steps, input.size(0))
update_sample_count += input.size(0)
if not need_update:
data_start_time = time.time()
continue
num_updates += 1
optimizer.zero_grad()
if model_ema is not None:
model_ema.update(model, step=num_updates)
if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()
time_now = time.time()
update_time_m.update(time.time() - update_start_time)
update_start_time = time_now
if update_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
update_sample_count *= args.world_size
if utils.is_primary(args):
_logger.info(
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] '
f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
f'LR: {lr:.3e} '
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
)
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True
)
if saver is not None and args.recovery_interval and (
(update_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=update_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
update_sample_count = 0
data_start_time = time.time()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)])
def validate(
model,
loader,
loss_fn,
args,
device=torch.device('cuda'),
amp_autocast=suppress,
log_suffix=''
):
batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
top5_m = utils.AverageMeter()
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.to(device)
target = target.to(device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
acc1 = utils.reduce_tensor(acc1, args.world_size)
acc5 = utils.reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
if device.type == 'cuda':
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) '
f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) '
f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) '
f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
)
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics
if __name__ == '__main__':
main()
可选的在train.py
里修改使用的网络:
group.add_argument('--model', default='vit_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50")')
支持的网络(根据名字选择)如下:
# 本示例说明获取timm支持的所有模型以及所有有预训练参数模型的获取方法
import timm
# 1. timm支持的所有模型
# timm supports a wide variety of pretrained and non-pretrained models for number of Image based tasks.
# The list_models function returns a list of models ordered alphabetically that are supported by timm.
supported_models = timm.list_models()
print(supported_models[:5])
print(len(supported_models))
# 2. timm支持的所有有预训练参数的模型
# To list all the models that have pretrained weights, timm provides a convenience parameter pretrained that could be passed in list_models function as below.
supported_pretrained_models = timm.list_models(pretrained=True)
print(supported_pretrained_models[:5])
print(len(supported_pretrained_models))
# 3. 使用通配符查看特定模型,如timm.list_models('*resne*')
resnet_models = timm.list_models('*resnet*')
print(f"ResNet系列模型:{resnet_models}")
vit_models = timm.list_models('*vit*')
print(f"vit系列模型:{vit_models}")
pytorch-image-models/timm/models
模型结构文件是指提供网络结构的文件,即pytorch-image-models/timm/models
下面的.py
文件中包含@register_model
(即用来提供网络模型的.py
)的文件。
核心功能:定义所有支持的模型架构(如 ResNet、ViT、EfficientNet 等),并提供统一的模型创建接口。
每个文件对应一种模型家族(如 resnet.py、vision_transformer.py、efficientnet.py),定义网络结构和预训练权重配置。
timm
的 _registry.py
文件中 register_model
函数实现了模型的自动注册机制,使得import timm
之后在 timm.create_model
之前,所有模型就都被存储到了 _model_entrypoints
中。
工作机制解释:
register_model
本质上是一个装饰器。装饰器的作用是修改函数的功能,在这里,它为模型构建函数添加了注册功能。@register_model
装饰一个模型构建函数 (例如 resnet50
) 时,装饰器会将该模型的名称 (resnet50
) 和对应的构建函数 (resnet50
) 存储到 _model_entrypoints
字典中。timm
库中的模型文件 (例如 resnet.py
) 在导入 timm
库时会被自动执行,而这些模型文件中都使用了 @register_model
装饰器,因此模型在导入 timm
库时就被自动注册了。create_model
调用: 当你调用 timm.create_model('resnet50')
时,create_model
函数会从 _model_entrypoints
字典中查找名为 resnet50
的模型构建函数,并调用该函数来创建模型实例。register_model
函数通过装饰器机制,实现了模型的自动注册。在导入 timm
库时,所有模型都会被自动注册到 _model_entrypoints
字典中,使得 timm.create_model
能够方便地创建模型实例。优点:
@register_model
装饰即可,无需修改其他代码。timm.create_model
即可创建模型实例,无需关心模型的注册细节。非模型结构文件是指不提供网络结构的文件,即pytorch-image-models/timm/models
下面的.py
文件中不包含@register_model
(即不是用来提供网络模型的.py
)的文件。这些文件的名称是什么,作用是什么?
After inspecting the .py
files under timm/models
, I found the following files that do not contain @register_model
:
__init__.py
: This file imports all the other model definition files and any helper modules/classes used in the model definitions. This makes it easier to import models and related functions from a single location (timm.models
).
_builder.py
: This file contains helper functions for building models with configurations. It includes functions for loading pretrained weights, adapting models for feature extraction, and handling various model configurations.
_efficientnet_blocks.py
: This file defines the building blocks for EfficientNet-like models, such as Squeeze-and-Excitation blocks, different convolutional blocks, and attention modules.
_efficientnet_builder.py
: This file provides the logic for constructing EfficientNet models based on architecture definitions. It handles scaling of model depths and widths, block configurations, and weight initialization.
_factory.py
: This file contains functions for creating models based on their names. It includes logic for parsing model names, handling pretrained weight loading, and setting model attributes.
_features.py
: This file defines helper classes and functions for extracting features from models. It includes tools for specifying feature extraction points, handling feature metadata, and creating feature extraction wrappers.
_features_fx.py
: This file contains PyTorch FX based feature extraction helpers. It provides tools for tracing model graphs, identifying feature extraction nodes, and creating feature extraction wrappers using FX.
_helpers.py
: This file contains helper functions for model creation, weight loading, and state dictionary manipulation. It includes functions for cleaning state dictionaries, loading checkpoints, and remapping state dictionaries.
_hub.py
: This file contains helper functions for interacting with the Hugging Face Hub. It includes functions for downloading and saving models and configurations, as well as pushing models to the Hub.
_manipulate.py
: This file contains helper functions for manipulating model parameters and modules. It includes functions for applying functions to named modules, grouping modules and parameters, and flattening nested modules.
_pretrained.py
: This file defines data classes and functions for handling pretrained model configurations. It includes classes for storing pretrained weight URLs, input configurations, and other metadata.
_prune.py
: This file contains helper functions for model pruning. It includes functions for extracting and modifying specific layers in a model, as well as adapting models based on pruning configurations.
These files play crucial roles in defining, building, and manipulating models within the timm
library, but they don’t directly provide network architectures like the files that use @register_model
.
.py
文件的作用1. _efficientnet_blocks.py
vs. _efficientnet_builder.py
_efficientnet_blocks.py
: This module defines the individual building blocks that make up EfficientNet and related models. These blocks are often variations of inverted residual blocks, depthwise separable convolutions, or attention mechanisms. It focuses on the micro-architecture of the models.
_efficientnet_builder.py
: This module handles the construction of complete EfficientNet models by assembling the blocks defined in _efficientnet_blocks.py
. It takes care of:
In essence, _efficientnet_blocks.py
provides the ingredients, while _efficientnet_builder.py
provides the recipe and cooking instructions for creating EfficientNet models.
2. _features.py
vs. _features_fx.py
_features.py
: This module provides tools for extracting intermediate features from models using traditional PyTorch methods like:
_features_fx.py
: This module leverages the PyTorch FX framework for feature extraction. FX allows for symbolic tracing of model execution, making it easier to:
FX generally provides a more flexible and efficient way to extract features, especially for complex models.
3. _manipulate.py
This module offers a collection of helper functions for manipulating model parameters and modules. Common use cases include:
nn.Sequential
inside another nn.Sequential
) to simplify model structure.Usage Examples:
# Group parameters based on module names
grouped_params = group_parameters(model, {'stem': '^conv1', 'blocks': '^layer'})
# Apply weight initialization to specific modules
named_apply(init_weights, model, depth_first=False, include_root=True)
# Flatten nested sequential modules
flattened_modules = flatten_modules(model.named_modules())
4. _prune.py
This module provides tools for model pruning, which involves removing less important connections or neurons to reduce model size and complexity. Key functions include:
Usage Examples:
# Prune a specific convolutional layer
pruned_conv = prune_conv_layer(original_conv, sparsity=0.5)
set_layer(model, 'blocks.2.1.conv', pruned_conv)
# Adapt a model from a pruning configuration file
adapted_model = adapt_model_from_file(model, 'resnet50_pruned')
These helper modules streamline common model manipulation tasks, making it easier to customize and optimize models within the timm
library.
pytorch-image-models/timm/layers
timm
自定义layers
The timm/layers
directory contains a variety of modules and functions that serve as building blocks for the models in the timm
library. Here’s a breakdown of the key roles of each file:
Core Layers & Functions:
activations.py
: Provides a collection of activation functions (ReLU, Swish, Mish, etc.) with a consistent interface for easy swapping and potential JIT scripting/export.activations_me.py
: Offers memory-efficient versions of some activations using custom autograd, but these are not compatible with JIT or ONNX export.adaptive_avgmax_pool.py
: Implements adaptive average and max pooling layers, including combinations and concatenation.attention2d.py
: Defines 2D attention mechanisms, including multi-query attention and spatial attention with downsampling.attention_pool.py
: Implements attention pooling with a latent query, useful for global feature aggregation.attention_pool2d.py
: Provides 2D attention pooling mechanisms, including those with learned and rotary position embeddings.blur_pool.py
: Implements BlurPool, an anti-aliasing technique that combines blurring and downsampling.bottleneck_attn.py
: Defines the Bottleneck Attention module, a type of self-attention used in Bottleneck Transformers.cbam.py
: Implements the Convolutional Block Attention Module (CBAM), a combination of channel and spatial attention.classifier.py
: Provides classifier heads with pooling, dropout, and fully connected layers.cond_conv2d.py
: Implements Conditionally Parameterized Convolutions (CondConv), which dynamically adjust filters based on input.config.py
: Manages global configuration flags for layers, such as JIT scripting, ONNX export, and fused attention settings.conv2d_same.py
: Offers “SAME” padding convolution layers, similar to TensorFlow’s padding behavior.conv_bn_act.py
: Combines convolution, batch normalization, and activation into a single module.create_act.py
: Factory function for creating activation functions and layers based on names.create_attn.py
: Factory function for creating attention modules based on names.create_conv2d.py
: Factory function for creating different types of 2D convolutions (standard, mixed, conditional).create_norm.py
: Factory function for creating normalization layers (BatchNorm, GroupNorm, LayerNorm, etc.)create_norm_act.py
: Factory function for creating combined normalization and activation layers.drop.py
: Implements DropBlock and DropPath (Stochastic Depth) regularization techniques.eca.py
: Defines the Efficient Channel Attention (ECA) module.evo_norm.py
: Implements EvoNorm layers, a type of normalization.fast_norm.py
: Provides optimized implementations of GroupNorm and LayerNorm for mixed precision training.filter_response_norm.py
: Implements Filter Response Normalization (FRN) layers.format.py
: Utilities for handling different tensor formats (NCHW, NHWC, etc.).gather_excite.py
: Defines the Gather-Excite attention module.global_context.py
: Implements the Global Context (GC) attention block.grid.py
: Functions for generating N-dimensional grids.grn.py
: Implements Global Response Normalization (GRN) layer.halo_attn.py
: Defines the Halo Attention module.helpers.py
: Various helper functions for layers (e.g., make_divisible
).hybrid_embed.py
: Provides layers for embedding CNN feature maps into a transformer-compatible format.inplace_abn.py
: Implements Inplace Activated Batch Normalization (InplaceABN).interpolate.py
: Interpolation utilities for layers.lambda_layer.py
: Defines the Lambda Layer, an attention-like mechanism.layer_scale.py
: Implements LayerScale, a scaling factor applied to layer outputs.linear.py
: A modified linear layer with support for mixed precision training.median_pool.py
: Implements a median pooling layer.mixed_conv2d.py
: Defines MixedConv2d, which uses multiple kernel sizes in a single convolution.mlp.py
: Implements Multi-Layer Perceptrons (MLPs) with various configurations.ml_decoder.py
: Provides an ML decoder head.non_local_attn.py
: Implements Non-Local Attention blocks.norm.py
: Normalization layers with fast norm options.norm_act.py
: Combines normalization and activation layers into a single module.padding.py
: Helper functions for padding operations, including “SAME” padding.patch_dropout.py
: Implements PatchDropout, a type of dropout.patch_embed.py
: Provides layers for converting images to patches for transformer input.pool2d_same.py
: Pooling layers with “SAME” padding.pos_embed.py
: Utilities for absolute position embeddings.pos_embed_rel.py
: Modules and functions for relative position embeddings.pos_embed_sincos.py
: Implements sin-cos, fourier, and rotary position embeddings.selective_kernel.py
: Defines Selective Kernel Convolution and Attention modules.separable_conv.py
: Implements depthwise separable convolutions.space_to_depth.py
: Provides space-to-depth and depth-to-space operations.split_attn.py
: Implements Split Attention (Splat) used in ResNeSt models.split_batchnorm.py
: Defines SplitBatchNorm, which splits the input batch into multiple parts for separate normalization.squeeze_excite.py
: Implements Squeeze-and-Excitation (SE) attention blocks.std_conv.py
: Provides convolutions with weight standardization.test_time_pool.py
: Implements test-time pooling for improved inference accuracy.trace_utils.py
: Helper functions for symbolic tracing.typing.py
: Type hints for layers.weight_init.py
: Weight initialization functions.timm
为什么要自定义layers
?pytorch
中已经有了构建网络结构的层定义了,比如relu
,conv
、pooling
等,为什么还需要中pytorch-image-models/timm/layers
的这些层的定义呢?
While PyTorch offers a solid foundation for building neural networks, timm
provides its own set of layer definitions for several reasons:
1. Enhanced Functionality & Flexibility:
timm
includes layers not readily available in core PyTorch, such as:
timm
layers often offer more configuration options than their PyTorch counterparts. For example, ConvNormAct
combines convolution, normalization, and activation in a single module with flexible choices for each component.timm
layers are designed to be modular and easily combined, promoting code reusability and experimentation with different architectures.2. Optimization & Efficiency:
timm
provides memory-efficient versions of some activations (activations_me.py
) using custom autograd, which can be beneficial for large models or limited memory.fast_norm.py
) for improved performance in mixed precision training.3. Consistency & Compatibility:
timm
layers often provide a more consistent interface, such as ensuring the channel dimension is always the first argument in normalization layers.timm
layers are designed with TorchScript and ONNX export in mind, making it easier to deploy models.4. Research & Experimentation:
timm
incorporates layers from recent research papers, allowing for quick experimentation with new architectures and ideas.timm
makes it easy to add or modify layers to explore new research directions.In summary, while PyTorch provides the essentials, timm
builds upon them by offering a richer set of layers, optimized implementations, and greater flexibility for building and experimenting with state-of-the-art models.
pytorch-image-models/timm/loss
The timm/loss
directory contains various loss functions that can be used for training deep learning models, particularly those focused on image-related tasks. Here’s a breakdown of the files and their purposes:
asymmetric_loss.py
: Implements the Asymmetric Loss function, designed to address class imbalance in multi-label classification problems. It assigns different weights to positive and negative samples to improve learning in the presence of under-represented classes.
binary_cross_entropy.py
: Provides a Binary Cross Entropy (BCE) loss with additional features like label smoothing, target thresholding, and optional one-hot conversion for dense targets. Useful for binary or multi-label classification tasks.
cross_entropy.py
: Implements two variations of cross-entropy loss:
LabelSmoothingCrossEntropy
: Applies label smoothing to the standard cross-entropy loss, preventing overconfidence and improving generalization.SoftTargetCrossEntropy
: Calculates cross-entropy with soft targets (probability distributions) instead of hard labels. Useful for distillation or other knowledge transfer tasks.jsd.py
: Implements the JSD (Jensen-Shannon Divergence) Cross Entropy loss. This loss combines cross-entropy with a Jensen-Shannon Divergence term, which encourages the model to produce consistent predictions across augmented versions of the same input. Useful for improving robustness and uncertainty estimation.
In summary, the timm/loss
module provides a collection of loss functions that go beyond the standard PyTorch offerings. These losses address issues like class imbalance, overconfidence, and robustness, making them valuable tools for training image models.
pytorch-image-models/timm/optim
optim_factory.py
optim_factory.py
: Contains functions for creating and registering optimizers, handling parameter groups, and applying weight decay and layer decay.
除了optim_factory.py
之外都是timm
实现的优化器。
The timm/optim
directory houses a collection of optimizers that extend beyond the standard optimizers available in PyTorch’s torch.optim
module. These optimizers incorporate various enhancements and modifications to improve training performance, convergence speed, and memory efficiency.
pytorch-image-models/timm/scheduler
scheduler.py
定义了一个名为Scheduler
的基类,目的是用于实现优化器参数调度器(例如学习率调度器)。它与 PyTorch 的内建调度器不同,强调在每个训练周期结束时(即每个 epoch 或者每次优化器更新之后)动态调整优化器的参数。它支持噪声的引入,目的是在训练过程中增加一些随机性,以避免模型过拟合。调度器的核心思想是通过调用 step 或 step_update 来在每个 epoch 或每次优化器更新时调整学习率等参数。
除了scheduler.py
和 scheduler_factory.py
之外的.py
文件都是学习率调度方法,继承自scheduler.py
中的Scheduler
的基类。scheduler.py
和 scheduler_factory.py
是通过具体的学习率调度算法联系起来的,比如CosineLRScheduler
。
scheduler_factory.py
定义了一个学习率调度器的工厂函数 (create_scheduler
) 以及一些辅助函数和参数处理逻辑,旨在根据配置文件或命令行参数来创建适当的学习率调度器。代码支持多种类型的调度器(如cosine
, step,
multistep,
plateau,
poly` 等),并允许使用不同的调度策略、噪声调整、循环学习率等,它还支持学习率预热、噪声扰动、循环学习率等高级功能,能够帮助提升训练过程的灵活性和效果。通过调整学习率调度器的行为,可以对模型训练进行精细控制,帮助优化训练过程。
pytorch-image-models/timm/data
The timm/data
directory contains modules and functions that handle various aspects of data loading, preprocessing, augmentation, and overall data pipeline management for training and evaluating image models. Here’s a categorized summary:
Data Augmentation:
auto_augment.py
: Implements various automatic augmentation strategies, including:
mixup.py
: Implements Mixup and Cutmix augmentation techniques, where images and labels are mixed to improve model robustness and generalization.random_erasing.py
: Implements Random Erasing, an augmentation that randomly erases rectangular regions of an image.Dataset and Sampler:
dataset.py
: Defines ImageDataset
(for standard image folders and tar files) and IterableImageDataset
(for iterable datasets like TFDS and WDS). Also includes AugMixDataset
for applying AugMix.
distributed_sampler.py
: Provides samplers for distributed training:
OrderedDistributedSampler
: Ensures each process gets a distinct subset of data.RepeatAugSampler
: Allows different augmentations of the same sample to be seen by different processes.dataset_factory.py
: This file acts as a factory for creating various types of datasets. It provides the create_dataset
function, which can instantiate datasets from different sources, including:
timm
’s own ImageDataset
.torchvision
.HFDS
) and iterable (HFIDS
) datasets.IterableImageDataset
to handle TFDS datasets.IterableImageDataset
for WDS datasets.dataset_info.py
: This file defines the DatasetInfo
abstract base class, which provides a common interface for accessing information about datasets, such as:
CustomDatasetInfo
class for easily creating dataset information for custom datasets.Transforms:
transforms.py
: Defines a set of image transformations, including:
RandomResizedCropAndInterpolation
: Randomly crops and resizes with various interpolation modes.CenterCropOrPad
: Center crops or pads an image to a target size.RandomCropOrPad
: Randomly crops or pads.ResizeKeepRatio
: Resizes while maintaining aspect ratio.TrimBorder
: Trims a border from an image.transforms_factory.py
: Provides factory functions (transforms_noaug_train
, transforms_imagenet_train
, transforms_imagenet_eval
) to create sets of transforms optimized for different training and evaluation scenarios.Data Configuration and Utilities:
config.py
: Provides functions (resolve_data_config
, resolve_model_data_config
) to determine image size, mean, standard deviation, and other data processing parameters based on model and dataset configurations.constants.py
: Defines constants like default crop percentage, ImageNet mean and standard deviation, etc.imagenet_info.py
: Provides a class (ImageNetInfo
) to access information about ImageNet dataset subsets, such as class labels, descriptions, and mappings.real_labels.py
: Implements an evaluator (RealLabelsImagenet
) to assess model performance using ImageNet’s “real” labels.Data Loading:
loader.py
: Contains functions for creating data loaders, including create_loader
which handles various data loading aspects like:
fast_collate
for optimized collation).PrefetchLoader
.This comprehensive set of modules and functions within timm/data
streamlines data loading, preprocessing, and augmentation, making it easier to train and evaluate state-of-the-art image models.
TensorFlow related:
tf_preprocessing.py
: This file enables the use of TensorFlow’s image preprocessing pipeline within PyTorch transforms. It defines the TfPreprocessTransform
class, which leverages TensorFlow’s preprocessing functions (like preprocess_for_train
and preprocess_for_eval
) to perform operations such as:
pytorch-image-models/timm/data/readers
Dataset Readers:
reader_image_folder.py
: This file defines the ReaderImageFolder
class, which enables reading images from folders. It scans folders recursively, infers labels from the folder structure, and provides a mapping between class names and indices.
reader_image_in_tar.py
: This file defines the ReaderImageInTar
class, designed for reading images from tar files. It handles single tar files, folders of tar files, and even nested tar files. It also manages class mappings and caching of tar file information.
reader_image_tar.py
: This file defines ReaderImageTar
, which reads images from a single tar file. It’s similar to ReaderImageInTar
but with more limited functionality. It’s likely to be deprecated in the future.
reader_tfds.py
: This file provides the ReaderTfds
class, which wraps TensorFlow Datasets (TFDS) for use in PyTorch. It handles dataset loading, shuffling, batching, and decoding of image samples from TFDS datasets.
reader_hfds.py
: This file defines the ReaderHfds
class, which wraps Hugging Face Datasets (HFDS) for use in PyTorch. It handles loading and decoding image samples from HFDS datasets.
reader_hfids.py
: This file defines the ReaderHfids
class, which wraps Hugging Face Iterable Datasets (HFIDS) for use in PyTorch. It handles streaming and decoding image samples from HFIDS datasets.
reader_wds.py
: This file provides the ReaderWds
class, which wraps WebDataset (WDS) for use in PyTorch. It handles loading and decoding samples from WDS, including support for sharding and distributed training.
Utilities and Helpers:
class_map.py
: This file contains the load_class_map
function, which loads a class map from a text file or pickle file. This map is used to associate class names with indices.
img_extensions.py
: This file manages the supported image file extensions. It provides functions like get_img_extensions
, is_img_extension
, and set_img_extensions
to control which file types are recognized as images.
shared_count.py
: This file defines the SharedCount
class, which provides a way to share a counter across multiple processes. This is useful for things like epoch tracking in distributed training.
reader.py
: This file defines the base Reader
class, which provides a common interface for all dataset readers.
reader_factory.py
: This file contains the create_reader
factory function, which instantiates the appropriate reader class based on the provided configuration parameters.
pytorch-image-models/timm/utils
I’ve diligently reviewed the files in the pytorch-image-models/timm/utils
directory and I’m prepared to give you an accurate and organized summary.
The timm/utils
directory contains a collection of utility modules and functions that support various aspects of model training, evaluation, and manipulation. Here’s a breakdown:
Model and Optimization:
agc.py
: Implements Adaptive Gradient Clipping (AGC), a technique to clip gradients based on the unit-wise norm of parameters, preventing excessive weight updates.clip_grad.py
: Provides a function (dispatch_clip_grad
) to apply different gradient clipping methods, including norm-based clipping, value-based clipping, and AGC.model_ema.py
: Implements Exponential Moving Average (EMA) of model weights, a technique to maintain a smoothed version of the model’s parameters for better generalization. Includes multiple versions of EMA with varying performance and compatibility.CUDA and AMP:
cuda.py
: Provides utilities for mixed precision training with Automatic Mixed Precision (AMP). Includes both the older ApexScaler and the newer NativeScaler for handling gradient scaling and unscaling.Training and Checkpointing:
checkpoint_saver.py
: Implements a CheckpointSaver
class to manage saving and loading model checkpoints, including tracking the best performing checkpoints and handling recovery checkpoints.decay_batch.py
: Provides functions for decaying the batch size during training, which can be useful for improving stability and generalization.Distributed Training:
distributed.py
: Contains utilities for distributed training, including functions to:
Logging and Metrics:
log.py
: Provides functions for setting up logging, including a custom formatter (FormatterNoInfo
) that omits ‘INFO’ level logging to the console.metrics.py
: Defines an AverageMeter
class for tracking average values and an accuracy
function for calculating top-k accuracy.summary.py
: Implements the update_summary
function to log training and evaluation metrics to a CSV file and optionally to Weights & Biases (wandb).Miscellaneous:
attention_extract.py
: Implements the AttentionExtract
class, which allows extracting attention maps from models using either PyTorch FX or hooks.jit.py
: Provides functions to configure the JIT (Just-In-Time) compiler for scripting and tracing models.misc.py
: Contains miscellaneous helper functions, including natural_key
for natural sorting of strings and functions for adding boolean arguments to an argument parser.model.py
: Provides utilities for working with models, including:
unwrap_model
: Gets the underlying model from wrappers like DataParallel
or ModelEma
.get_state_dict
: Extracts the state dictionary from a model.freeze
and unfreeze
: Freeze or unfreeze model parameters.reparameterize_model
: Converts a model to a deployable form by fusing layers and/or reparameterizing modules.onnx.py
: Provides functions for exporting models to ONNX format and verifying the exported models.random.py
: Provides a random_seed
function to set seeds for torch
, numpy
, and random
modules.pytorch-image-models/tests
Model Tests:
test_models.py
: This file contains comprehensive tests for various aspects of the models in the timm
library. It covers a wide range of model architectures and includes tests for:
Layer Tests
test_layers.py
: This file focuses on testing individual layers and modules used within the timm
models. It includes tests for:
Attention2d
and MultiQueryAttentionV2
.Optimizer Tests
test_optim.py
: This file contains tests for the optimizers implemented in the timm/optim
directory. It covers a variety of optimizers and includes tests for:
Utility Tests
test_utils.py
: This file tests various utility functions and classes provided in the timm/utils
directory. It includes tests for:
freeze
and unfreeze
functions that control parameter freezing.ActivationStatsHook
and related functions for extracting activation statistics.reparameterize_model
function that converts models to a deployable form.get_state_dict
function with custom unwrap functions.pytorch-image-models/*.py
Training and Evaluation:
train.py
: This is the primary script for training models on ImageNet or similar datasets. It provides a comprehensive training loop with various options for:
validate.py
: This script is used to evaluate trained models or pretrained models on ImageNet or similar datasets. It provides options for:
Benchmarking and Profiling:
benchmark.py
: This script benchmarks the inference and training performance of models. It measures metrics like:
ONNX Export and Validation:
onnx_export.py
: This script exports PyTorch models to ONNX format, allowing them to be used in other frameworks and environments. It supports various options for controlling the export process, such as opset version and dynamic size.
onnx_validate.py
: This script validates the accuracy and performance of exported ONNX models using the ONNX runtime. It compares the ONNX model’s outputs to the original PyTorch model’s outputs to ensure correctness.
Other Utilities:
avg_checkpoints.py
: This script averages the weights of multiple model checkpoints. This can be useful for improving model performance and stability.
bulk_runner.py
: This script runs the validate.py
or benchmark.py
script in separate processes for each model in a specified list. This allows for bulk validation or benchmarking of multiple models.
clean_checkpoint.py
: This script cleans a model checkpoint by removing unnecessary data like optimizer state and GPU tensors, making it suitable for sharing and distribution.
hubconf.py
: This file defines the entry points for the timm
models in the Hugging Face Hub.
inference.py
: This script performs inference on a dataset using a specified model and outputs the results in various formats (CSV, JSON, etc.).
These scripts and utilities provide a comprehensive toolkit for training, evaluating, benchmarking, exporting, and managing image models within the timm
library.