java 调用python 模型



import org.apache.commons.lang3.StringUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;

public class MatchTensor {
    public static String filePath="D:\\ideaprojects\\match1\\src\\main\\java\\com\\data\\vocab.txt";
    public static String labelPath="D:\\ideaprojects\\match1\\src\\main\\java\\com\\data\\label.txt";
    public static HashMap hashMap =new HashMap(){
        {
            put("0",1693);
        }
    };
    public static Boolean is_control(char char_){
      if (("\t".equals(char_)) |("\n".equals(char_))|("\r".equals(char_))){
          return false;
        }
        Character c=new Character(char_);
        int type=c.getType(char_);
        if (type==15 | type==16 |type==18|type==19|type==7){
            return true;
        }
        return false;

    }
    public static Boolean is_whitespace(char char_){
        Character newchar=new Character(char_);
      if(" ".equals(newchar)|"\t".equals(newchar)|"\n".equals(newchar)|"\r".equals(newchar)){
          return true;
      }
      if (newchar.getType(newchar)==12){
          return true;
      }
      return false;
    }
    public static String cleanText(String str){
        String result="";
        for (int i = 0; i < str.length(); i++){
            int chr1 = (char) str.charAt(i);
            if((chr1==0)|(chr1==0xfffd)|is_control(str.charAt(i))){
                  continue;
            }
            if(is_whitespace(str.charAt(i))){
                 result+=" ";
            }else{
                result+=str.charAt(i);
            }

        }
        return result;
    }
    private static String run_strip_accents(String text){
         return "";
    }
    private static Boolean _is_punctuation(char char_){
         int num= (int)char_ ;
        if ((num >= 33 && num <= 47) | (num >= 58 && num <= 64) |
                (num >= 91 && num <= 96) | (num >= 123 && num <= 126)){
            return true;
        }
        Character c=new Character(char_);
        int t=c.getType(char_);
        if((t>=20 && t<=24)|(t>=29 && t<=30)){
            return true ;
        }
         return  false;
    }
    private static  ArrayList _run_split_on_punc(String text){
        char[] chars=text.toCharArray();
        ArrayList output=new ArrayList();
        //String[] output=new String[500];
        int i=0;
        Boolean start_new_word =true;
        while(i result=new ArrayList<>();
        for(ArrayList arrayList:output){
            String ss="";
            for(int m = 0;m < arrayList.size(); m++){
                ss+=arrayList.get(m);
            }
            result.add(ss);
        }
        return result;
    }
    public static ArrayList whitespace_tokenize(String text){
        ArrayList outtexts=new ArrayList();
        if(StringUtils.isNotEmpty(text)){
            String[] texts=text.split("\\s+");// Java 以空格分割字符串
            for(String i:texts){
                outtexts.add(i);
            }
        }
        return outtexts;
    }
    public static ArrayList basicTokenize(String text){
       String newText=cleanText(text);
       ArrayList outtexts= whitespace_tokenize(newText);
       ArrayList split_tokens=new ArrayList<>();
       for(String s:outtexts){
           String news=s.toLowerCase();
           //String newtoken=run_strip_accents(news);
           ArrayList newss=_run_split_on_punc(news);
           for(String ss:newss){
               split_tokens.add(ss);
           }
       }
       String all="";
       for(String s:split_tokens){
            all+=s+" ";
       }
       return whitespace_tokenize(all.substring(0,all.length()-1));
    }
    public static HashMap getVocab(String filePath) throws IOException {
        HashMap hashMap=new HashMap<>();
        FileInputStream fileInputStream=new FileInputStream(filePath);
        InputStreamReader inputStreamReader=new InputStreamReader(fileInputStream,"UTF-8");
        BufferedReader br=new BufferedReader(inputStreamReader);
        String line="";
        int i=0;
        while((line=br.readLine())!=null){ // line 是否有换行符
            hashMap.put(line,i);
            i+=1;
        }

        return hashMap;

    }
    public static String index2label(int index) throws IOException {
        HashMap hashMap=new HashMap<>();
        FileInputStream fileInputStream=new FileInputStream(labelPath);
        InputStreamReader inputStreamReader=new InputStreamReader(fileInputStream,"UTF-8");
        BufferedReader br=new BufferedReader(inputStreamReader);
        String line="";
        int i=0;
        while((line=br.readLine())!=null){ // line 是否有换行符
            hashMap.put(i,line);
            i+=1;
        }
        return hashMap.get(index);

    }
    public static ArrayList wordpiece_tokenizer(String text) throws IOException{
        HashMap hashMap=getVocab(filePath);
          ArrayList output_tokens=new ArrayList<>();
         for(String s:whitespace_tokenize(text)){
             char[] chars = s.toCharArray();
             if(chars.length>100){
                 output_tokens.add("[UNK]");
                 continue;
             }
             boolean is_bad=false;
             int start=0;
             ArrayList sub_tokens=new ArrayList<>();
             while(start0){
                         substr="##"+substr;
                     }
                     if( hashMap.containsKey(substr)){
                         cur_subsrt=substr;
                         break;
                     }
                     end-=1;
                 }
                 if(cur_subsrt.isEmpty()){
                     is_bad=true;
                      break;
                 }
                 sub_tokens.add(cur_subsrt);
                 start=end;
             }
             if(is_bad){
                 output_tokens.add("[UNK]");
             }else{
                 for(String tokens:sub_tokens){
                     output_tokens.add(tokens);
                 }

             }
         }
         return output_tokens;

    }
    public static ArrayList tokenize(String text) throws IOException {
        ArrayList split_tokens=new ArrayList<>();
        ArrayList list=basicTokenize(text);
     for(String s:list){
         for(String tokens:wordpiece_tokenizer(s)){
             split_tokens.add(tokens);
         }
     }
     return split_tokens;
    }
    public static InputFeature constructTensor(String data) throws IOException {

        HashMap hashMap= getVocab(filePath);
        InputExample example=new InputExample();
        example.setGuid("1");
        example.setText_a("........");
        example.setText_b("");
        example.setLabel("0");
        ArrayList tokens_a = tokenize(example.getText_a());
        ArrayList tokens_anew=new ArrayList<>();
        ArrayList tokens=new ArrayList<>();
        ArrayList segment_ids=new ArrayList<>();
        tokens.add("[CLS]");
        segment_ids.add(0);
        if(tokens_a.size()>23){
            for(int i=0;i input_ids=new ArrayList<>();
        ArrayList input_mask=new ArrayList<>();
        for(String s:tokens){
            input_ids.add(hashMap.get(s));
            input_mask.add(1);
        }
        while(input_ids.size()<25){
            input_ids.add(0);
            input_mask.add(0);
            segment_ids.add(0);
        }
        InputFeature feature=new InputFeature();
        feature.setInput_ids(input_ids);
        feature.setInput_mask(input_mask);
        feature.setSegments_ids(segment_ids);
        feature.setLabel_id(1693);
        return feature;
    }
    public static Session readGraph() throws IOException {
            String modelDir = ".";
        byte[] graphDef = Files.readAllBytes(Paths.get("D:\\ideaprojects\\match1\\src\\main\\java\\com\\data\\graph.db"));
        Graph g = new Graph();
        g.importGraphDef(graphDef);
        Session session= new Session(g);
        return session;
    }
    public static void main(String[] args) throws  Exception{

        Date start = new Date();
        Date sess=new Date();
        Session session = readGraph();
        Date t1=new Date();
        System.out.println(t1.getTime()-start.getTime());
        System.out.println("..........t1");
        InputFeature input = constructTensor("........");
        Date t2=new Date();
        System.out.println(t2.getTime()-t1.getTime());
        System.out.println("..........t2");
            ArrayList input_ids1 = input.getInput_ids();
            int[] inputs_ids=new int[25];
            for(int i=0;i input_mask1 = input.getInput_mask();
            int[] input_mask=new int[25];
            for(int i=0;i segments_ids1 = input.getSegments_ids();
            int[] segments_ids=new int[25];
            for(int i=0;i map=new HashMap();
            float[] t=prop[0];
            for(int i=0;i result1=new HashMap<>();
            List list=new ArrayList();
            for(int i=t.length-1;i>=t.length-3;i--){
                HashMap hashMap=new HashMap();
                int label=map.get(t[i]);
                System.out.println(".................................");
                hashMap.put("weight",t[i]);
                System.out.println(t[i]);
                hashMap.put("text",index2label(label));
                System.out.println(label);
                System.out.println(index2label(label));
                list.add(hashMap);
            }
            Date end=new Date();
            long diff=end.getTime()-start.getTime();
            System.out.println(start.getTime());
            System.out.println(diff);
            result1.put("result",list);
        Date t5=new Date();
        System.out.println(t5.getTime()-t4.getTime());
        System.out.println("..........t5");
        }



}







如果仅仅看 java调用python 模型 推荐这个 https://blog.csdn.net/rabbit_judy/article/details/80054085#commentBox

这个主要是记录一下 在java中 对数据采用 bert的数据处理方式 然后 调用python训练的bert模型 ,效果很好,但是速度慢些  

python 中的 ord(‘中’) 函数等价于 java 中的 ( int)('中')

python中的unicodedata.category('中') 等价于 Character c=new Character('中')  c.getTypeI('中')

一开始运行 java 调用python 模型 很慢 后来知道是什么原因了  应该是  运行到 session.runner()很慢 只要 把 session放在静态变量里面  第一次很慢 然后就很快了 ,然后Tensor.create(Integer[]) 一定不要放 integer[] 数组对象   会有问题  int[]数组就没有问题  

你可能感兴趣的:(java 调用python 模型)