package com.trie.base.bean; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; public class TrieNode implements Comparable<TrieNode>{ /** * 节点的key * */ protected int key; /** *节点对应的字符串 */ protected String word; /** * 节点的前缀节点 */ protected TrieNode fatherNode; /** * 以该节点作为前缀的子节点 */ protected List<TrieNode> childNodes; /** * 在list的index和char的int值之间建立映射 */ protected Map<Integer,Integer> indexMap; /** * 搜索次数 */ protected int click=0; protected Object obj; public List<TrieNode> getChildNodes() { if(null==childNodes){ childNodes=new ArrayList<TrieNode>(); indexMap =new HashMap<Integer, Integer>(); } return childNodes; } public void setChildNodes(List<TrieNode> childNodes) { this.childNodes = childNodes; } public TrieNode getFatherNode() { return fatherNode; } public void setFatherNode(TrieNode fatherNode) { this.fatherNode = fatherNode; } public String getWord() { return word; } public void setWord(String word) { this.word = word; } /** * @param key */ public TrieNode(int key,String word) { super(); this.key = key; this.word=word; } /** * @param key */ public TrieNode() { super(); } public boolean add(TrieNode childNode){ childNode.setFatherNode(this); getChildNodes().add(childNode); getIndexMap().put(childNode.getKey(),getChildNodes().size()); return true; } public TrieNode get(char c){ return get((int)c); } public TrieNode get(int c){ Integer o=getIndexMap().get(c); if(null==o){ return null; } return getChildNodes().get(o-1); } public TrieNode addClick(){ this.click++; return this; } public Map<Integer, Integer> getIndexMap() { if(null==indexMap){ indexMap=new HashMap<Integer, Integer>(); } return indexMap; } public void setIndexMap(Map<Integer, Integer> indexMap) { this.indexMap = indexMap; } public int getClick() { return click; } public void setClick(int click) { this.click = click; } public int getKey() { return key; } public void setKey(int key) { this.key = key; } public List<TrieNode> getAllChild(){ List<TrieNode> list = new ArrayList<TrieNode>(); for(Iterator<TrieNode> it=getChildNodes().iterator();it.hasNext();){ TrieNode node=it.next(); list.add(node); if(node.getChildNodes().size()>0){ list.addAll(node.getAllChild()); } } return list; } public int compareTo(TrieNode o) { if(null==o||o.getClick()<this.click){ return -1; }else if(o.getClick()==this.click){ return 0; }else{ return 1; } } public Object getObj() { return obj; } public void setObj(Object obj) { if(obj==null){ return; } this.obj = obj; } }
package com.trie.base; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import com.trie.base.bean.TrieNode; public class Trie { //该字符树所有节点的数量 protected int entityTrieNodeSize=0; //该字符树所有节点的数量 protected int trieNodeSize=0; //第一层Node protected TrieNode firstLvNode; //禁用词 protected Set<String> filterWords; //搜索链表 protected List<TrieNode> entityNodes; /** * 插入节点 */ public TrieNode insert(String word,Object obj){ if(filter(word)){ return null; }; char[] words=word.toCharArray(); TrieNode temp=null; TrieNode tempFather=null; for(int i=0;i<words.length;i++){ int c=(int)words[i]; //首字符,从第一层Node中搜索 if(0==i){ temp=getFirstLvNode().get(c); //第一层node内不存在此字符,增加 if(null==temp){ temp=addNewTrieNode(getFirstLvNode(),words, i); } tempFather=temp; //其他字符 }else{ temp=tempFather.get(c); //node内不存在此字符,增加 if(null==temp){ temp=addNewTrieNode(tempFather,words, i); } tempFather=temp; } } tempFather.setObj(obj); getEntityNodes().add(tempFather); entityTrieNodeSize++; return tempFather; } public TrieNode insert(String word){ return insert(word,null); } public TrieNode addNewTrieNode(TrieNode tempFather,char[] words,int i){ TrieNode trieNode = new TrieNode((int)words[i],new String(words,0,i+1)); trieNodeSize++; tempFather.add(trieNode); return trieNode; } /** * 搜索输入的词 返回节点 并在搜索次数中+1 */ public TrieNode search(String word){ TrieNode node= get(word); if(null==node){ return null; } return node.addClick(); }; public TrieNode get(String word){ char[] words=word.toCharArray(); TrieNode temp=null; for(int i=0;i<words.length;i++){ int c=(int)words[i]; //首字符 if(0==i){ temp=getFirstLvNode().get(c); if(temp==null){ return null; } //其他层的字符 }else{ if(temp==null){ return null; } temp=temp.get(c); } } return temp; }; /** * 根据设定的禁用词来过滤节点 */ public boolean filter(String word){ if(null==filterWords){ return false; } boolean result = filterWords.add(word); if(result){ filterWords.remove(word); } return !result; } /** * 搜索输入的词与该词所有前缀 */ public List<TrieNode> searchAndPrefix(String word){ List<TrieNode> nodes=new ArrayList<TrieNode>(); char[] words=word.toCharArray(); TrieNode temp=null; for(int i=0;i<words.length;i++){ int c=(int)words[i]; //首字符 if(0==i){ temp=getFirstLvNode().get(c); if(temp==null){ return null; } //其他层的字符 }else{ if(temp==null){ return null; } temp=temp.get(c); } nodes.add(temp); } temp.addClick(); return nodes; }; public List<TrieNode> searchAndPrefix(String word,int lv){ List<TrieNode> nodes=new ArrayList<TrieNode>(); char[] words=word.toCharArray(); TrieNode temp=null; for(int i=0;i<words.length;i++){ int c=(int)words[i]; if(0==i){ temp=getFirstLvNode().get(c); if(temp==null){ return null; } }else{ if(temp==null){ return null; } temp=temp.get(c); } if(i>lv-1){ nodes.add(temp); } } temp.addClick(); return nodes; }; /** * 搜索前缀找出相关的词,按顺序取词 */ public List<TrieNode> searchByPrefix(String word){ return searchByPrefix(word,true,0); }; /** * 搜索前缀找出相关的词,并根据排序取词 */ public List<TrieNode> searchByPrefix(String word,boolean desc,int size){ List<TrieNode> list=new ArrayList<TrieNode>(); TrieNode node=get(word); if(null==node){ return list; } list.add(node); list.addAll(node.getAllChild()); return sort(list,desc,size); }; public List<TrieNode> sort(boolean desc,int size){ return sort(getEntityNodes(),desc,size); } /** * 对所有实体根据click进行排序,取出规定数量的节点 * desc true为默认逆序,false默认顺序 */ public List<TrieNode> sort(List<TrieNode> elist,boolean desc,int size){ java.util.Collections.sort(elist); if(size<=0||size>elist.size()){ size=elist.size(); } if(desc){ return elist.subList(0,size); }else{ List<TrieNode> list =elist.subList(elist.size()-size,elist.size()); java.util.Collections.reverse(list); return list; } } public TrieNode getFirstLvNode() { if(null==firstLvNode){ setFirstLvNode(new TrieNode()); } return firstLvNode; } public void setFirstLvNode(TrieNode firstLvNode) { this.firstLvNode = firstLvNode; } public Set<String> getFilterWords() { if(null==filterWords){ filterWords=new HashSet<String>(); } return filterWords; } public void setFilterWords(Set<String> filterWords) { this.filterWords = filterWords; } public int getTrieNodeSize() { return trieNodeSize; } public void setTrieNodeSize(int trieNodeSize) { this.trieNodeSize = trieNodeSize; } public void addFilterWord(String word){ getFilterWords().add(word); } public int getEntityTrieNodeSize() { return entityTrieNodeSize; } public void setEntityTrieNodeSize(int entityTrieNodeSize) { this.entityTrieNodeSize = entityTrieNodeSize; } public List<TrieNode> getEntityNodes() { if(null==entityNodes){ entityNodes=java.util.Collections.synchronizedList(new ArrayList()); } return entityNodes; } public void setEntityNodes(List entityNodes) { this.entityNodes = entityNodes; } }