Tensorflow实现用接口调用模型训练和停止训练功能

语言:Python
框架:Flask、Tensorflow
功能描述:存在两个接口,一个接口实现开始训练模型的功能,一个接口实现停止训练的功能。
实现:用一个全局变量存储在训练中的模型。

# 存储所有训练任务
training_tasks = {}
# 训练模型的接口
@train_model.route("/train", methods=["POST"])
def train():
    try:
        data = request.get_data()
        data = json.loads(data)
        print(data)
        modelId = data["modelId"]
        if modelId in training_tasks:
            return {"success": False, "message": f"{modelId} 已经在训练中"}
        stop_event = threading.Event()

        # 在新线程中启动训练
        train_thread = threading.Thread(
            target=start_train,
            args=(data, stop_event)
        )

        training_tasks[modelId] = {
            'thread': train_thread,
            'stop_event': stop_event
        }

        train_thread.start()
        return {"success": "success", "message": "开始训练"}
    except Exception as e:
        return  {"success": False, "message": str(e)}
def start_train(data, stop_event):
    try:
        # 获取任务参数
        modelId = data["modelId"]
    except Exception as e:
        response_data = {"success": False, "message": str(e)}
        return response_data
class StopTrainingCallback(keras.callbacks.Callback):
    def __init__(self, model, modelId, stop_event):
        super().__init__()
        self.model = model
        self.modelId = modelId
        self.stop_event = stop_event

    def on_train_begin(self, logs=None):
        if self.stop_event.is_set():
            self.model.stop_training = True # 设置此标志会使model.fit提前终止
            print(f"训练在开始前被停止")

    def on_batch_begin(self, batch, logs=None):
        if self.stop_event.is_set():
            self.model.stop_training = True # 设置此标志会使model.fit提前终止
            print(f"训练在批次被停止")
            # 强制抛出一个异常以确保立即停止
            raise KeyboardInterrupt("训练被用户停止")
# 模型真正训练的函数
def start_train(data, stop_event):
	# 定义模型及训练数据
	model = "xxx"
	modelId = "xxx"
	train_dataset = "xxx"
	test_dataset = "xxx"
	train_steps = len(list(train_dataset))
    test_steps = len(list(test_dataset))
    epochs = "xxx"
    stoptrainingcallback = StopTrainingCallback(model, modelId, stop_event)

    try:
        # 在开始训练前立即检查停止事件
        if stop_event.is_set():
            log.info(f"训练 {modelId} 在开始前被停止")
            callback_log.info("模型训练在开始前被停止")
            raise KeyboardInterrupt("Training stopped before start")

        model.fit(
            train_dataset,
            steps_per_epoch=train_steps,
            epochs=epochs,
            verbose=2,
            shuffle=True,
            validation_data=test_dataset,
            validation_steps=test_steps,
            callbacks=[stoptrainingcallback]
        )

        response_data = {"success": True, "message": "Success"}
    except KeyboardInterrupt:
        response_data = {"success": False, "message": "模型训练被用户停止."
        }
    except tf.errors.ResourceExhaustedError as e:
        # 显存不足错误
        response_data = {"success": False, "message": "GPU内存不足,请调整训练参数."}
    except Exception as e:
  		print("模型训练失败")
        response_data = {"success": False, "message": str(e)}
    finally:
        if data["modelId"] in training_tasks:
            del training_tasks[data["modelId"]]
    return response_data
# 停止训练的接口
@stop_train.route('/stop', methods=['POST'])
def stop():
    data = request.get_data()
    try:
        data = json.loads(data)
        modelId = data.get("modelId",'') # 每个模型有一个唯一的UUID
        if modelId == '':
            return jsonify({"success": False, "message": "modelId为空,无法停止训练.", "data": ''})
    except Exception as e:
        print("停止模型训练接口请求数据出错:", str(e))
        return jsonify({"success": False, "message": "参数错误.", "data": ''})
    # 调用服务层停止训练
    result = stop_train_service(modelId)
    print(result["message"])

    # 返回响应
    return jsonify(result)
# 调用服务层停止训练
def stop_train_service(modelId):
    # 检查模型是否存在
    if modelId not in training_tasks:
        return {"success": "error", "message": f"没有找到模型 {modelId} 的训练任务"}

    # 获取停止事件并设置
    stop_event = training_tasks[modelId].get('stop_event')
    if stop_event:
        stop_event.set()

        # 清理任务记录
        del training_tasks[modelId]

        return {"success": "success", "message": f"停止 {modelId} 模型训练的请求已发送"}
    else:
        return {"success": "error", "message": f"模型 {modelId} 的停止训练事件不存在"}

你可能感兴趣的:(python,深度学习,人工智能)