语言:Python
框架:Flask、Tensorflow
功能描述:存在两个接口,一个接口实现开始训练模型的功能,一个接口实现停止训练的功能。
实现:用一个全局变量存储在训练中的模型。
# 存储所有训练任务
training_tasks = {}
# 训练模型的接口
@train_model.route("/train", methods=["POST"])
def train():
try:
data = request.get_data()
data = json.loads(data)
print(data)
modelId = data["modelId"]
if modelId in training_tasks:
return {"success": False, "message": f"{modelId} 已经在训练中"}
stop_event = threading.Event()
# 在新线程中启动训练
train_thread = threading.Thread(
target=start_train,
args=(data, stop_event)
)
training_tasks[modelId] = {
'thread': train_thread,
'stop_event': stop_event
}
train_thread.start()
return {"success": "success", "message": "开始训练"}
except Exception as e:
return {"success": False, "message": str(e)}
def start_train(data, stop_event):
try:
# 获取任务参数
modelId = data["modelId"]
except Exception as e:
response_data = {"success": False, "message": str(e)}
return response_data
class StopTrainingCallback(keras.callbacks.Callback):
def __init__(self, model, modelId, stop_event):
super().__init__()
self.model = model
self.modelId = modelId
self.stop_event = stop_event
def on_train_begin(self, logs=None):
if self.stop_event.is_set():
self.model.stop_training = True # 设置此标志会使model.fit提前终止
print(f"训练在开始前被停止")
def on_batch_begin(self, batch, logs=None):
if self.stop_event.is_set():
self.model.stop_training = True # 设置此标志会使model.fit提前终止
print(f"训练在批次被停止")
# 强制抛出一个异常以确保立即停止
raise KeyboardInterrupt("训练被用户停止")
# 模型真正训练的函数
def start_train(data, stop_event):
# 定义模型及训练数据
model = "xxx"
modelId = "xxx"
train_dataset = "xxx"
test_dataset = "xxx"
train_steps = len(list(train_dataset))
test_steps = len(list(test_dataset))
epochs = "xxx"
stoptrainingcallback = StopTrainingCallback(model, modelId, stop_event)
try:
# 在开始训练前立即检查停止事件
if stop_event.is_set():
log.info(f"训练 {modelId} 在开始前被停止")
callback_log.info("模型训练在开始前被停止")
raise KeyboardInterrupt("Training stopped before start")
model.fit(
train_dataset,
steps_per_epoch=train_steps,
epochs=epochs,
verbose=2,
shuffle=True,
validation_data=test_dataset,
validation_steps=test_steps,
callbacks=[stoptrainingcallback]
)
response_data = {"success": True, "message": "Success"}
except KeyboardInterrupt:
response_data = {"success": False, "message": "模型训练被用户停止."
}
except tf.errors.ResourceExhaustedError as e:
# 显存不足错误
response_data = {"success": False, "message": "GPU内存不足,请调整训练参数."}
except Exception as e:
print("模型训练失败")
response_data = {"success": False, "message": str(e)}
finally:
if data["modelId"] in training_tasks:
del training_tasks[data["modelId"]]
return response_data
# 停止训练的接口
@stop_train.route('/stop', methods=['POST'])
def stop():
data = request.get_data()
try:
data = json.loads(data)
modelId = data.get("modelId",'') # 每个模型有一个唯一的UUID
if modelId == '':
return jsonify({"success": False, "message": "modelId为空,无法停止训练.", "data": ''})
except Exception as e:
print("停止模型训练接口请求数据出错:", str(e))
return jsonify({"success": False, "message": "参数错误.", "data": ''})
# 调用服务层停止训练
result = stop_train_service(modelId)
print(result["message"])
# 返回响应
return jsonify(result)
# 调用服务层停止训练
def stop_train_service(modelId):
# 检查模型是否存在
if modelId not in training_tasks:
return {"success": "error", "message": f"没有找到模型 {modelId} 的训练任务"}
# 获取停止事件并设置
stop_event = training_tasks[modelId].get('stop_event')
if stop_event:
stop_event.set()
# 清理任务记录
del training_tasks[modelId]
return {"success": "success", "message": f"停止 {modelId} 模型训练的请求已发送"}
else:
return {"success": "error", "message": f"模型 {modelId} 的停止训练事件不存在"}