python实现cnn特征提取_GitHub - lxztju/pytorch_classification: 利用pytorch实现图像分类的一个完整的代码,训练,预测,TTA,模型融合,模型部署,...

pytorch_classification

利用pytorch实现图像分类,其中包含的densenet,resnext,mobilenet,efficientnet, resnet等图像分类网络,可以根据需要再行利用torchvision扩展其他的分类算法

实现功能

基础功能利用pytorch实现图像分类

包含带有warmup的cosine学习率调整

warmup的step学习率优调整

多模型融合预测,加权与投票融合

利用flask实现模型云端api部署

使用tta测试时增强进行预测

添加label smooth的pytorch实现(标签平滑)

添加使用cnn提取特征,并使用SVM,RF,MLP,KNN等分类器进行分类。

更新添加了模型蒸馏的的训练方法

添加中间层可视化

更新模型部署(采用flask+Redis的方法)

c++ libtorch进行模型部署的简单demo

运行环境

python3.7

pytorch 1.1

torchvision 0.3.0

代码仓库的使用

数据集形式

原始数据集存储形式为,同个类别的图像存储在同一个文件夹下,所有类别的图像存储在一个主文件夹data下。

|-- data

|-- train

|--label1

|--*.jpg

|--label2

|--*.jpg

|--label

|--*.jpg

...

|-- val<

你可能感兴趣的:(python实现cnn特征提取)