【无标题】

这篇博客用的代码:

import json
import os
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from threading import Lock
from typing import Dict, Tuple, Union, Any

import torch

from builder import parser_server, parser_clients
from tools.logger import Logger
from tools.utils import clear_cache, same_seeds


class ExperimentLog(object):

    def __init__(self, save_path: str):
        self.records = {}
        self.save_path = save_path
        self.lock = Lock()

    def _update_iter(self, key, value):
        keys = key.split('.')
        current_record = self.records
        for idx, key in enumerate(keys):
            if idx != len(keys) - 1:
                if key not in current_record.keys():
                    current_record[key] = {}
                current_record = current_record[key]
            else:
                if key not in current_record.keys():
                    current_record[key] = value
                else:
                    if isinstance(current_record[key], list):
                        current_record[key].append(value)
                    elif isinstance(current_record[key], set):
                        current_record[key].add(value)
                    elif isinstance(current_record[key], dict):
                        current_record[key].update(value)
                    else:
                        current_record[key] = value

    def _save_logs(self):
        dirname = os.path.dirname(self.save_path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        with open(self.save_path, "w") as f:
            json.dump(self.records, f, indent=2)

    def record(self, key, value):
        self.lock.acquire()
        self._update_iter(key, value)
        self._save_logs()
        self.lock.release()


class VirtualContainer(object):

    def __init__(self, devices: list, parallel: int = 1) -> None:
        super().__init__()
        self.lock = Lock()
        self.devices = {device: parallel for device in devices}

    def max_worker(self):
        return sum(self.devices.values())

    def acquire_device(self, count=1):
        device = None
        self.lock.acquire()
        for dev, cnt in self.devices.items():
            if cnt > 0 and device is None:
                self.devices[dev] -= count
                device = dev
        self.lock.release()
        return device

    def release_device(self, device, count=1):
        self.lock.acquire()
        self.devices[device] += count
        self.lock.release()

    def possess_device(self, count=1):
        class VirtualProcess(object):

            def __init__(self, container) -> None:
                super().__init__()
                self.container = container
                self.device = None

            def __enter__(self):
                self.device = self.container.acquire_device(count)
                return self.device

            def __exit__(self, type, value, trace):
                self.container.release_device(self.device, count)
                return

        return VirtualProcess(self)


class ExperimentStage(object):

    def __init__(self, common_config: Dict, exp_configs: Union[Dict, Tuple[Dict]]):
        self.common_config = common_config
        self.exp_configs = [exp_configs] if isinstance(exp_configs, Dict) else exp_configs
        self.logger = Logger('stage')
        self.container = VirtualContainer(self.common_config['device'], self.common_config['parallel'])

    def __enter__(self):
        self.check_environment()
        return self

    def __exit__(self, type, value, trace):
        if type is not None and issubclass(type, Exception):
            self.logger.error(value)
            raise trace
        return self

    def check_environment(self):
        # check runtime device
        devices = self.common_config['device']
        for device in devices:
            try:
                torch.Tensor([0]).to(device)
            except Exception as ex:
                self.logger.error(f'Not available for given device {device}:{ex}')
                exit(1)

        # check dataset base path
        datasets_dir = self.common_config['datasets_dir']
        if not os.path.exists(datasets_dir):
            self.logger.error(f'Datasets base directory could not be found with {datasets_dir}.')
            exit(1)

        # check dataset base path
        checkpoints_dir = self.common_config['checkpoints_dir']
        if os.path.exists(checkpoints_dir):
            self.logger.warn(f'Checkpoint directory {checkpoints_dir} is not empty.')

        self.logger.info('Experiment stage build success.')

    def run(self):
        print(self.exp_configs)
        for exp_config in self.exp_configs:
            same_seeds(exp_config['random_seed'])

            # generate log with time-based savepath
            format_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
            # 参数为文件路径,保存到logs_dir: ./logs/    exp_name: fedstil
            log = ExperimentLog(os.path.join(
                self.common_config['logs_dir'],
                f"{exp_config['exp_name']}-{format_time}.json"
            ))
            log.record('config', exp_config)

            self.logger.info(f"Experiment loading succeed: {exp_config['exp_name']}")
            self.logger.info(f"For more details: {log.save_path}")

            # generate server and clients
            server = parser_server(exp_config, self.common_config)
            clients = parser_clients(exp_config, self.common_config)

            # initial validate for tasks
            with ThreadPoolExecutor(self.container.max_worker()) as pool:
                futures = []
                for client in clients:
                    print(client)
                    futures.append(pool.submit(
                        self._process_val,
                        *(client, log, 0, self.container)
                    ))
                for future in as_completed(futures):
                    future.result(timeout=1800)
                    if future.exception():
                        raise future.exception()

            # simulate communication process
            comm_rounds = int(exp_config['exp_opts']['comm_rounds'])
            for curr_round in range(1, comm_rounds + 1):
                self.logger.info(f'Start communication round: {curr_round:0>3d}/{comm_rounds:0>3d}')
                self._process_one_round(curr_round, server, clients, exp_config, log)

            del server, clients, log

    def _process_one_round(self, curr_round, server, clients, exp_config, log) -> Any:
        # sample online clients
        online_clients = random.sample(clients, exp_config['exp_opts']['online_clients'])
        val_intervals = exp_config['exp_opts']['val_interval']

        # update clients with server state
        for client in online_clients:
            if client.client_name not in server.clients.keys():
                server.register_client(client.client_name)
                dispatch_state = server.get_dispatch_integrated_state(client.client_name)
                if dispatch_state is not None:
                    client.update_by_integrated_state(dispatch_state)
            else:
                dispatch_state = server.get_dispatch_incremental_state(client.client_name)
                if dispatch_state is not None:
                    client.update_by_incremental_state(dispatch_state)
            server.save_state(
                f'{curr_round}-{server.server_name}-{client.client_name}',
                dispatch_state, True
            )
            del dispatch_state

        # simulate training for each online client
        with ThreadPoolExecutor(self.container.max_worker()) as pool:
            futures = []
            for client in online_clients:
                futures.append(pool.submit(
                    self._process_train,
                    *(client, log, curr_round, self.container)
                ))
            for future in as_completed(futures):
                future.result(timeout=1800)
                if future.exception():
                    raise future.exception()

        # simulate validation for each client
        if curr_round % val_intervals == 0:
            with ThreadPoolExecutor(self.container.max_worker()) as pool:
                futures = []
                for client in clients:
                    futures.append(pool.submit(
                        self._process_val,
                        *(client, log, curr_round, self.container)
                    ))
                for future in as_completed(futures):
                    future.result(timeout=1800)
                    if future.exception():
                        raise future.exception()

        # communication with server
        for client in online_clients:
            incremental_state = client.get_incremental_state()
            client.save_state(
                f'{curr_round}-{client.client_name}-{server.server_name}',
                incremental_state, True
            )
            if incremental_state is not None:
                server.set_client_incremental_state(client.client_name, incremental_state)
            del incremental_state

        server.calculate()

    @staticmethod
    @clear_cache
    def _process_train(client, log, curr_round, container):
        with container.possess_device() as device:
            try:
                task_pipeline = client.task_pipeline
                task = task_pipeline.next_task()
                if task['tr_epochs'] != 0:
                    tr_output = client.train(
                        epochs=task['tr_epochs'],
                        task_name=task['task_name'],
                        tr_loader=task['tr_loader'],
                        val_loader=task['query_loader'],
                        device=device
                    )
                    log.record(f"data.{client.client_name}.{curr_round}.{task['task_name']}", {
                        "tr_acc": tr_output['accuracy'],
                        "tr_loss": tr_output['loss']
                    })
            except Exception as ex:
                client.logger.error(ex)
                raise ex

    @staticmethod
    @clear_cache
    def _process_val(client, log, curr_round, container):
        with container.possess_device(container.max_worker()) as device:
            try:
                task_pipeline = client.task_pipeline
                for tid in range(len(task_pipeline.task_list)):
                    task = task_pipeline.get_task(tid)
                    cmc, mAP, avg_rep = client.validate(
                        task_name=task['task_name'],
                        query_loader=task['query_loader'],
                        gallery_loader=task['gallery_loaders'],
                        device=device
                    )
                    log.record(f"data.{client.client_name}.{curr_round}.{task['task_name']}", {
                        "val_rank_1": cmc[0],
                        "val_rank_3": cmc[2],
                        "val_rank_5": cmc[4],
                        "val_rank_10": cmc[9],
                        "val_map": mAP,
                    })
            except Exception as ex:
                client.logger.error(ex)
                raise ex

1.Python时间模块之datetime模块

from datetime import datetime
format_time = datetime.now().strftime('%Y-%m-%d-%H-%M')

Python时间模块之datetime模块

格式化时间,格式参照time模块中的strftime方法

from datetime import datetime
format_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
format_time2 = datetime.now()
print(format_time)
print(format_time2)

2.python中os库的使用

python中os库的使用

log = ExperimentLog(os.path.join(
                self.common_config['logs_dir'],
                f"{exp_config['exp_name']}-{format_time}.json"
            ))

上面这个代码中common_config是读取的comm_config.yaml文件,使用其中的log_dir的值:

 

python字符串前面加f是什么意思,如何表达式嵌入字符串中

然后与另外一个字符串进行拼接。

最后组成的是这样的文件名:

【无标题】_第1张图片

作为参数调用ExperimentLog方法,最终的效果是

将相关信息保存在同级目录的logs文件夹内:

【无标题】_第2张图片

  • os.path.join(path,*paths):组合path和paths,返回一个路径字符串

import os

print(os.path.join("D:","123"))

  • os.path.dirname(path):返回path中的目录名称

  • os.path.exists(path):判断path对应文件或目录是否存在,返回True或False

    def _save_logs(self):
        dirname = os.path.dirname(self.save_path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        with open(self.save_path, "w") as f:
            json.dump(self.records, f, indent=2)

3.python 线程-- 锁

# generate log with time-based savepath
format_time = datetime.now().strftime('%Y-%m-%d-%H-%M')
# 参数为文件路径,保存到logs_dir: ./logs/    exp_name: fedstil
log = ExperimentLog(os.path.join(
    self.common_config['logs_dir'],
    f"{exp_config['exp_name']}-{format_time}.json"
))
log.record('config', exp_config)
class ExperimentLog(object):

    def __init__(self, save_path: str):
        self.records = {}
        self.save_path = save_path
        self.lock = Lock()

    def _update_iter(self, key, value):
        keys = key.split('.')
        current_record = self.records
        for idx, key in enumerate(keys):
            if idx != len(keys) - 1:
                if key not in current_record.keys():
                    current_record[key] = {}
                current_record = current_record[key]
            else:
                if key not in current_record.keys():
                    current_record[key] = value
                else:
                    if isinstance(current_record[key], list):
                        current_record[key].append(value)
                    elif isinstance(current_record[key], set):
                        current_record[key].add(value)
                    elif isinstance(current_record[key], dict):
                        current_record[key].update(value)
                    else:
                        current_record[key] = value

    def _save_logs(self):
        dirname = os.path.dirname(self.save_path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        with open(self.save_path, "w") as f:
            json.dump(self.records, f, indent=2)

    def record(self, key, value):
        self.lock.acquire()
        self._update_iter(key, value)
        self._save_logs()
        self.lock.release()

这段代码就是2中实例化的ExperimentLog类,调用这个类的record()方法,其中用到了线程的锁。

python 线程(1)-- 常用方法与属性,锁,同步

互斥锁 Lock

线程同步能够保证多个线程安全访问竞争资源,最简单的同步机制是引入互斥锁。互斥锁为资源设置一个状态:锁定和非锁定。某个线程要更改共享数据时,先将其锁定,此时资源的状态为“锁定”,其他线程不能更改;直到该线程释放资源,将资源的状态变成“非锁定”,其他的线程才能再次锁定该资源。互斥锁保证了每次只有一个线程进行写入操作,从而保证了多线程情况下数据的正确性。

简单举例:

import time

import threading
from datetime import datetime

def fun(lock,cnt):
    re = lock.acquire(timeout=3)  # 默认阻塞线程,直到超时
    if not re:
        print('get lock failed')
        return

    global num

    temp = num

    time.sleep(0.2)

    temp -= 1

    num = temp
    print(f"{cnt}:{datetime.now()}")

    lock.release()


print('主线程开始运行……')

t_lst = []

num = 10  # 全局变量

lock = threading.Lock()

for i in range(10):
    t = threading.Thread(target=fun, args=(lock,i))

    t_lst.append(t)

    t.start()

[t.join() for t in t_lst]

print(f"main:{datetime.now()}")
print('num最后的值为:{}'.format(num))

print('主线程结束运行……')

这段举例代码中的重点是join()函数与互斥锁 Lock。

首先代码开始运行:

【无标题】_第3张图片

我们想要实现的效果为10个子线程依次修改num的值,每个子线程都num-=1,最终在主线程输出num,预计为0。

t_lst用于保存线程对象。

【无标题】_第4张图片

创建lock对象,然后进行10次循环,实现10个子线程,每个子线程都调用fun方法。

【无标题】_第5张图片

输出符合预期,先运行子线程,子线程全部结束后,输出主线程的时间:

【无标题】_第6张图片

但是如果删除join()后:

删除这行:

输出结果显示主线程先运行,输出num=10的时候子线程还没有开始。

【无标题】_第7张图片

join方法

join()方法的作用是在调用join()方法处,让所在线程(主线程)同步的等待被join的线程,等到join的线程结束后才执行当前所在线程。

join()方法需要使用在start()函数后。

4.Python中logging模块

Python中logging模块 以及重写

self.logger = Logger('stage')
import logging

# 设置输出格式
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(name)s]: %(levelname)s - %(message)s')


class Logger(object):

    def __init__(self, actuator: str = 'unknown'):
        # 实例化一个logging对象,name是记录日志的用例名
        self.logger = logging.getLogger(actuator)

    def debug(self, msg: str) -> None:
        self.logger.debug(msg)

    def info(self, msg: str) -> None:
        self.logger.info(msg)

    def warn(self, msg: str) -> None:
        self.logger.warning(msg)

    def error(self, msg: str) -> None:
        self.logger.error(msg)

    def info_train(self, task_name, device, train_cnt, accuracy, loss, current_epoch=0, total_epoch=0):
        self.logger.info(
            (f"[{current_epoch:0>3d}/{total_epoch:0>3d}] " if current_epoch and total_epoch else f"") +
            f"Train '{task_name}' on {device} with {train_cnt:,} images, " +
            f"accuracy: {accuracy:.2%}, loss: {loss:.4f}."
        )

    def info_validation(self, task_name, query_cnt, gallery_cnt, cmc, mAP) -> None:
        self.logger.info(
            """Validation '{}' with {:,} query images on {:,} gallery images:
            |- Rank-1 :  {:.2%}
            |- Rank-3 :  {:.2%}
            |- Rank-5 :  {:.2%}
            |- Rank-10 : {:.2%}
            |- mean AP : {:.2%}
            """.format(task_name, query_cnt, gallery_cnt, cmc[0], cmc[2], cmc[4], cmc[9], mAP)
        )

logging.getLogger(name=‘root’)

实例化一个logging对象,name是记录日志的用例名

Python的Logging模块

你可能感兴趣的:(做一个完整项目,学习,python)