java调用ONNX模型

一、导出一个onnx模型

这里训练了一个简单的线性回归模型
通过SerializeToString完成导出。

from sklearn.linear_model import LinearRegression
import numpy as np
import onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

# 训练一个简单的线性回归模型
X = np.array([[1], [2], [3], [4]])
y = np.array([2, 4, 6, 8])
model = LinearRegression()
model.fit(X, y)

initial_type = [('float_input', FloatTensorType([None, 1]))]
onnx_model = convert_sklearn(model, initial_types = initial_type)
with open('linear_regression.onnx', 'wb') as f:
    f.write(onnx_model.SerializeToString())

二、java项目中maven添加依赖

<dependency>
    <groupId>ai.onnxruntimegroupId>
    <artifactId>onnxruntimeartifactId>
    <version>1.12.0version>
dependency>

三、代码调用

import ai.onnxruntime.*;
import java.nio.FloatBuffer;
import java.util.Map;

public class ONNXModelCaller {
    public static void main(String[] args) throws OrtException {
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession session = env.createSession("linear_regression.onnx", new OrtSession.SessionOptions());

        float[] inputData = {5.0f};
        Tensor inputTensor = Tensor.createFromByteBuffer(FloatBuffer.wrap(inputData), new long[]{1, 1});
        OrtSession.Result results = session.run(Map.of("float_input", inputTensor));
        Tensor outputTensor = results.get(0);
        float[] outputData = (float[]) outputTensor.getData();
        System.out.println("Prediction: " + outputData[0]);

        results.close();
        session.close();
        env.close();
    }
}

基于深度学习框架(如 PyTorch 或 TensorFlow)构建,且可以转换为 ONNX 格式,这种方法可以提供高效的跨语言部署

你可能感兴趣的:(java,开发语言)