libtorch学习第八

今天学习dataloader

基于c++17

#include
#include
#include
#include
#include
#include

using std::cout; using std::endl;

namespace fs = std::filesystem;

void loadDataFromFolder(
	const std::string& path,
	const std::string& type,
	std::vector<std::string>& list_images,
	std::vector<int>& list_labels,
	int& label
);

class datasetClc : public torch::data::Dataset<datasetClc>
{
public:
	datasetClc(const std::string& image_dir, const std::string& type)
	{
		loadDataFromFolder(image_dir, type, _imagePaths, _labels, _nLabel);
	}

	// override get()
	torch::data::Example<> get(size_t index) override;

	// override size()
	torch::optional<size_t> size() const override
	{
		return _imagePaths.size();
	}

private:
	std::vector<std::string> _imagePaths;
	std::vector<int> _labels;
	int _nLabel;
};


int main()
{
	int batchSize = 2;
	std::string imageDir = "F:\\docker_python\\hymenoptera_data\\train";
	
	std::vector<std::string> imagePathList;
	std::vector<int> labelsList;
	
	/*int nLabel = 0;
	loadDataFromFolder(imageDir, "jpg", imagePathList, labelsList, nLabel);

	for (int i = 0; i < imagePathList.size(); ++i)
	{
		cout << imagePathList[i] << ": " << labelsList[i] << endl;
	}*/

	auto mdataset = datasetClc(imageDir, "jpg").map(torch::data::transforms::Stack<>());

	auto mdataloader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(std::move(mdataset), batchSize);

	for (auto& batch : *mdataloader)
	{
		auto data = batch.data;
		auto target = batch.target;
		cout << data.sizes() <<", "<<target.sizes() << endl;
	}

	return 0;
}

void loadDataFromFolder(
	const std::string& path, const std::string& type, std::vector<std::string>& list_images, std::vector<int>& list_labels, int& label)
{
	/*
	定义: 
	'ants': 0,
	'bees': 1
	*/
	label = 0;

	fs::path currentPath(path);
	for (const auto& entry : fs::directory_iterator(currentPath))
	{
		if (fs::is_directory(entry) == false) continue;

		std::string dirname = entry.path().filename().string();
		
		//
		for (const auto& entry2 : fs::directory_iterator(entry))
		{			
			
			size_t dotIndex = entry2.path().string().find_last_of(".");
			if (dotIndex != std::string::npos)
			{
			
				std::string fileex = entry2.path().string().substr(dotIndex + 1);

				if (fileex != type) continue;

				//
				list_images.push_back(entry2.path().string()); // 图像路径
				if (dirname == "ants")
				{
					list_labels.push_back(0);
				}
				else list_labels.push_back(1);
			}
		}

		label += 1;
	}
}

torch::data::Example<> datasetClc::get(size_t index)
{
	std::string imagePath = _imagePaths.at(index);
	cv::Mat img = cv::imread(imagePath);
	cv::resize(img, img, cv::Size(224, 224));
	int label = _labels.at(index);
	
	torch::Tensor imgTensor = torch::from_blob(img.data, { img.rows, img.cols,3 }, 
		torch::kByte).permute({ 2,0,1 });
	torch::Tensor labelTensor = torch::full({ 1 }, label);

	return { imgTensor.clone(), labelTensor.clone() };
}

测试数据集:点击下载

结果

libtorch学习第八_第1张图片

你可能感兴趣的:(pytorch,学习)