import org.apache.commons.lang3.StringUtils;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
public class MatchTensor {
public static String filePath="D:\\ideaprojects\\match1\\src\\main\\java\\com\\data\\vocab.txt";
public static String labelPath="D:\\ideaprojects\\match1\\src\\main\\java\\com\\data\\label.txt";
public static HashMap hashMap =new HashMap(){
{
put("0",1693);
}
};
public static Boolean is_control(char char_){
if (("\t".equals(char_)) |("\n".equals(char_))|("\r".equals(char_))){
return false;
}
Character c=new Character(char_);
int type=c.getType(char_);
if (type==15 | type==16 |type==18|type==19|type==7){
return true;
}
return false;
}
public static Boolean is_whitespace(char char_){
Character newchar=new Character(char_);
if(" ".equals(newchar)|"\t".equals(newchar)|"\n".equals(newchar)|"\r".equals(newchar)){
return true;
}
if (newchar.getType(newchar)==12){
return true;
}
return false;
}
public static String cleanText(String str){
String result="";
for (int i = 0; i < str.length(); i++){
int chr1 = (char) str.charAt(i);
if((chr1==0)|(chr1==0xfffd)|is_control(str.charAt(i))){
continue;
}
if(is_whitespace(str.charAt(i))){
result+=" ";
}else{
result+=str.charAt(i);
}
}
return result;
}
private static String run_strip_accents(String text){
return "";
}
private static Boolean _is_punctuation(char char_){
int num= (int)char_ ;
if ((num >= 33 && num <= 47) | (num >= 58 && num <= 64) |
(num >= 91 && num <= 96) | (num >= 123 && num <= 126)){
return true;
}
Character c=new Character(char_);
int t=c.getType(char_);
if((t>=20 && t<=24)|(t>=29 && t<=30)){
return true ;
}
return false;
}
private static ArrayList _run_split_on_punc(String text){
char[] chars=text.toCharArray();
ArrayList output=new ArrayList();
//String[] output=new String[500];
int i=0;
Boolean start_new_word =true;
while(i result=new ArrayList<>();
for(ArrayList arrayList:output){
String ss="";
for(int m = 0;m < arrayList.size(); m++){
ss+=arrayList.get(m);
}
result.add(ss);
}
return result;
}
public static ArrayList whitespace_tokenize(String text){
ArrayList outtexts=new ArrayList();
if(StringUtils.isNotEmpty(text)){
String[] texts=text.split("\\s+");// Java 以空格分割字符串
for(String i:texts){
outtexts.add(i);
}
}
return outtexts;
}
public static ArrayList basicTokenize(String text){
String newText=cleanText(text);
ArrayList outtexts= whitespace_tokenize(newText);
ArrayList split_tokens=new ArrayList<>();
for(String s:outtexts){
String news=s.toLowerCase();
//String newtoken=run_strip_accents(news);
ArrayList newss=_run_split_on_punc(news);
for(String ss:newss){
split_tokens.add(ss);
}
}
String all="";
for(String s:split_tokens){
all+=s+" ";
}
return whitespace_tokenize(all.substring(0,all.length()-1));
}
public static HashMap getVocab(String filePath) throws IOException {
HashMap hashMap=new HashMap<>();
FileInputStream fileInputStream=new FileInputStream(filePath);
InputStreamReader inputStreamReader=new InputStreamReader(fileInputStream,"UTF-8");
BufferedReader br=new BufferedReader(inputStreamReader);
String line="";
int i=0;
while((line=br.readLine())!=null){ // line 是否有换行符
hashMap.put(line,i);
i+=1;
}
return hashMap;
}
public static String index2label(int index) throws IOException {
HashMap hashMap=new HashMap<>();
FileInputStream fileInputStream=new FileInputStream(labelPath);
InputStreamReader inputStreamReader=new InputStreamReader(fileInputStream,"UTF-8");
BufferedReader br=new BufferedReader(inputStreamReader);
String line="";
int i=0;
while((line=br.readLine())!=null){ // line 是否有换行符
hashMap.put(i,line);
i+=1;
}
return hashMap.get(index);
}
public static ArrayList wordpiece_tokenizer(String text) throws IOException{
HashMap hashMap=getVocab(filePath);
ArrayList output_tokens=new ArrayList<>();
for(String s:whitespace_tokenize(text)){
char[] chars = s.toCharArray();
if(chars.length>100){
output_tokens.add("[UNK]");
continue;
}
boolean is_bad=false;
int start=0;
ArrayList sub_tokens=new ArrayList<>();
while(start0){
substr="##"+substr;
}
if( hashMap.containsKey(substr)){
cur_subsrt=substr;
break;
}
end-=1;
}
if(cur_subsrt.isEmpty()){
is_bad=true;
break;
}
sub_tokens.add(cur_subsrt);
start=end;
}
if(is_bad){
output_tokens.add("[UNK]");
}else{
for(String tokens:sub_tokens){
output_tokens.add(tokens);
}
}
}
return output_tokens;
}
public static ArrayList tokenize(String text) throws IOException {
ArrayList split_tokens=new ArrayList<>();
ArrayList list=basicTokenize(text);
for(String s:list){
for(String tokens:wordpiece_tokenizer(s)){
split_tokens.add(tokens);
}
}
return split_tokens;
}
public static InputFeature constructTensor(String data) throws IOException {
HashMap hashMap= getVocab(filePath);
InputExample example=new InputExample();
example.setGuid("1");
example.setText_a("........");
example.setText_b("");
example.setLabel("0");
ArrayList tokens_a = tokenize(example.getText_a());
ArrayList tokens_anew=new ArrayList<>();
ArrayList tokens=new ArrayList<>();
ArrayList segment_ids=new ArrayList<>();
tokens.add("[CLS]");
segment_ids.add(0);
if(tokens_a.size()>23){
for(int i=0;i input_ids=new ArrayList<>();
ArrayList input_mask=new ArrayList<>();
for(String s:tokens){
input_ids.add(hashMap.get(s));
input_mask.add(1);
}
while(input_ids.size()<25){
input_ids.add(0);
input_mask.add(0);
segment_ids.add(0);
}
InputFeature feature=new InputFeature();
feature.setInput_ids(input_ids);
feature.setInput_mask(input_mask);
feature.setSegments_ids(segment_ids);
feature.setLabel_id(1693);
return feature;
}
public static Session readGraph() throws IOException {
String modelDir = ".";
byte[] graphDef = Files.readAllBytes(Paths.get("D:\\ideaprojects\\match1\\src\\main\\java\\com\\data\\graph.db"));
Graph g = new Graph();
g.importGraphDef(graphDef);
Session session= new Session(g);
return session;
}
public static void main(String[] args) throws Exception{
Date start = new Date();
Date sess=new Date();
Session session = readGraph();
Date t1=new Date();
System.out.println(t1.getTime()-start.getTime());
System.out.println("..........t1");
InputFeature input = constructTensor("........");
Date t2=new Date();
System.out.println(t2.getTime()-t1.getTime());
System.out.println("..........t2");
ArrayList input_ids1 = input.getInput_ids();
int[] inputs_ids=new int[25];
for(int i=0;i input_mask1 = input.getInput_mask();
int[] input_mask=new int[25];
for(int i=0;i segments_ids1 = input.getSegments_ids();
int[] segments_ids=new int[25];
for(int i=0;i map=new HashMap();
float[] t=prop[0];
for(int i=0;i result1=new HashMap<>();
List list=new ArrayList();
for(int i=t.length-1;i>=t.length-3;i--){
HashMap hashMap=new HashMap();
int label=map.get(t[i]);
System.out.println(".................................");
hashMap.put("weight",t[i]);
System.out.println(t[i]);
hashMap.put("text",index2label(label));
System.out.println(label);
System.out.println(index2label(label));
list.add(hashMap);
}
Date end=new Date();
long diff=end.getTime()-start.getTime();
System.out.println(start.getTime());
System.out.println(diff);
result1.put("result",list);
Date t5=new Date();
System.out.println(t5.getTime()-t4.getTime());
System.out.println("..........t5");
}
}
如果仅仅看 java调用python 模型 推荐这个 https://blog.csdn.net/rabbit_judy/article/details/80054085#commentBox
这个主要是记录一下 在java中 对数据采用 bert的数据处理方式 然后 调用python训练的bert模型 ,效果很好,但是速度慢些
python 中的 ord(‘中’) 函数等价于 java 中的 ( int)('中')
python中的unicodedata.category('中') 等价于 Character c=new Character('中') c.getTypeI('中')
一开始运行 java 调用python 模型 很慢 后来知道是什么原因了 应该是 运行到 session.runner()很慢 只要 把 session放在静态变量里面 第一次很慢 然后就很快了 ,然后Tensor.create(Integer[]) 一定不要放 integer[] 数组对象 会有问题 int[]数组就没有问题