在关联规则挖掘领域最经典的算法法是Apriori,其致命的缺点是需要多次扫描事务数据库。于是人们提出了各种裁剪(prune)数据集的方法以减少I/O开支,韩嘉炜老师的FP-Tree算法就是其中非常高效的一种。
严格地说Apriori和FP-Tree都是寻找频繁项集的算法,频繁项集就是所谓的“支持度”比较高的项集,下面解释一下支持度和置信度的概念。
设事务数据库为:
A E F G A F G A B E F G E F G
则{A,F,G}的支持度数为3,支持度为3/4。
{F,G}的支持度数为4,支持度为4/4。
{A}的支持度数为3,支持度为3/4。
{F,G}=>{A}的置信度为:{A,F,G}的支持度数 除以 {F,G}的支持度数,即3/4
{A}=>{F,G}的置信度为:{A,F,G}的支持度数 除以 {A}的支持度数,即3/3
强关联规则挖掘是在满足一定支持度的情况下寻找置信度达到阈值的所有模式。
我们举个例子来详细讲解FP-Tree算法的完整实现。
事务数据库如下,一行表示一条购物记录:
牛奶,鸡蛋,面包,薯片 鸡蛋,爆米花,薯片,啤酒 鸡蛋,面包,薯片 牛奶,鸡蛋,面包,爆米花,薯片,啤酒 牛奶,面包,啤酒 鸡蛋,面包,啤酒 牛奶,面包,薯片 牛奶,鸡蛋,面包,黄油,薯片 牛奶,鸡蛋,黄油,薯片
我们的目的是要找出哪些商品总是相伴出现的,比如人们买薯片的时候通常也会买鸡蛋,则[薯片,鸡蛋]就是一条频繁模式(frequent pattern)。
FP-Tree算法第一步:扫描事务数据库,每项商品按频数递减排序,并删除频数小于最小支持度MinSup的商品。(第一次扫描数据库)
薯片:7鸡蛋:7面包:7牛奶:6啤酒:4 (这里我们令MinSup=3)
以上结果就是频繁1项集,记为F1。
第二步:对于每一条购买记录,按照F1中的顺序重新排序。(第二次也是最后一次扫描数据库)
薯片,鸡蛋,面包,牛奶 薯片,鸡蛋,啤酒 薯片,鸡蛋,面包 薯片,鸡蛋,面包,牛奶,啤酒 面包,牛奶,啤酒 鸡蛋,面包,啤酒 薯片,面包,牛奶 薯片,鸡蛋,面包,牛奶 薯片,鸡蛋,牛奶
第三步:把第二步得到的各条记录插入到FP-Tree中。刚开始时后缀模式为空。
插入每一条(薯片,鸡蛋,面包,牛奶)之后
插入第二条记录(薯片,鸡蛋,啤酒)
插入第三条记录(面包,牛奶,啤酒)
估计你也知道怎么插了,最终生成的FP-Tree是:
上图中左边的那一叫做表头项,树中相同名称的节点要链接起来,链表的第一个元素就是表头项里的元素。
如果FP-Tree为空(只含一个虚的root节点),则FP-Growth函数返回。
此时输出表头项的每一项+postModel,支持度为表头项中对应项的计数。
第四步:从FP-Tree中找出频繁项。
遍历表头项中的每一项(我们拿“牛奶:6”为例),对于各项都执行以下(1)到(5)的操作:
(1)从FP-Tree中找到所有的“牛奶”节点,向上遍历它的祖先节点,得到4条路径:
薯片:7,鸡蛋:6,牛奶:1 薯片:7,鸡蛋:6,面包:4,牛奶:3 薯片:7,面包:1,牛奶:1 面包:1,牛奶:1
对于每一条路径上的节点,其count都设置为牛奶的count
薯片:1,鸡蛋:1,牛奶:1 薯片:3,鸡蛋:3,面包:3,牛奶:3 薯片:1,面包:1,牛奶:1 面包:1,牛奶:1
因为每一项末尾都是牛奶,可以把牛奶去掉,得到条件模式基(Conditional Pattern Base,CPB),此时的后缀模式是:(牛奶)。
薯片:1,鸡蛋:1 薯片:3,鸡蛋:3,面包:3 薯片:1,面包:1 面包:1
(2)我们把上面的结果当作原始的事务数据库,返回到第3步,递归迭代运行。
没讲清楚,你可以参考这篇博客,直接看核心代码吧:
public void FPGrowth(List<List<String>> transRecords, List<String> postPattern,Context context) throws IOException, InterruptedException { // 构建项头表,同时也是频繁1项集 ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords); // 构建FP-Tree TreeNode treeRoot = buildFPTree(transRecords, HeaderTable); // 如果FP-Tree为空则返回 if (treeRoot.getChildren()==null || treeRoot.getChildren().size() == 0) return; //输出项头表的每一项+postPattern if(postPattern!=null){ for (TreeNode header : HeaderTable) { String outStr=header.getName(); int count=header.getCount(); for (String ele : postPattern) outStr+="\t" + ele; context.write(new IntWritable(count), new Text(outStr)); } } // 找到项头表的每一项的条件模式基,进入递归迭代 for (TreeNode header : HeaderTable) { // 后缀模式增加一项 List<String> newPostPattern = new LinkedList<String>(); newPostPattern.add(header.getName()); if (postPattern != null) newPostPattern.addAll(postPattern); // 寻找header的条件模式基CPB,放入newTransRecords中 List<List<String>> newTransRecords = new LinkedList<List<String>>(); TreeNode backnode = header.getNextHomonym(); while (backnode != null) { int counter = backnode.getCount(); List<String> prenodes = new ArrayList<String>(); TreeNode parent = backnode; // 遍历backnode的祖先节点,放到prenodes中 while ((parent = parent.getParent()).getName() != null) { prenodes.add(parent.getName()); } while (counter-- > 0) { newTransRecords.add(prenodes); } backnode = backnode.getNextHomonym(); } // 递归迭代 FPGrowth(newTransRecords, newPostPattern,context); } }
对于FP-Tree已经是单枝的情况,就没有必要再递归调用FPGrowth了,直接输出整条路径上所有节点的各种组合+postModel就可了。例如当FP-Tree为:
我们直接输出:
3 A+postModel
3 B+postModel
3 A+B+postModel
就可以了。
如何按照上面代码里的做法,是先输出:
3 A+postModel
3 B+postModel
然后把B插入到postModel的头部,重新建立一个FP-Tree,这时Tree中只含A,于是输出
3 A+(B+postModel)
两种方法结果是一样的,但毕竟重新建立FP-Tree计算量大些。
FP树节点定义
package
fptree;
import
java.util.ArrayList;
import
java.util.List;
public
class
TreeNode
implements
Comparable<TreeNode> {
private
String name;
// 节点名称
private
int
count;
// 计数
private
TreeNode parent;
// 父节点
private
List<TreeNode> children;
// 子节点
private
TreeNode nextHomonym;
// 下一个同名节点
public
TreeNode() {
}
public
TreeNode(String name) {
this
.name = name;
}
public
String getName() {
return
name;
}
public
void
setName(String name) {
this
.name = name;
}
public
int
getCount() {
return
count;
}
public
void
setCount(
int
count) {
this
.count = count;
}
public
TreeNode getParent() {
return
parent;
}
public
void
setParent(TreeNode parent) {
this
.parent = parent;
}
public
List<TreeNode> getChildren() {
return
children;
}
public
void
addChild(TreeNode child) {
if
(
this
.getChildren() ==
null
) {
List<TreeNode> list =
new
ArrayList<TreeNode>();
list.add(child);
this
.setChildren(list);
}
else
{
this
.getChildren().add(child);
}
}
public
TreeNode findChild(String name) {
List<TreeNode> children =
this
.getChildren();
if
(children !=
null
) {
for
(TreeNode child : children) {
if
(child.getName().equals(name)) {
return
child;
}
}
}
return
null
;
}
public
void
setChildren(List<TreeNode> children) {
this
.children = children;
}
public
void
printChildrenName() {
List<TreeNode> children =
this
.getChildren();
if
(children !=
null
) {
for
(TreeNode child : children) {
System.out.print(child.getName() +
" "
);
}
}
else
{
System.out.print(
"null"
);
}
}
public
TreeNode getNextHomonym() {
return
nextHomonym;
}
public
void
setNextHomonym(TreeNode nextHomonym) {
this
.nextHomonym = nextHomonym;
}
public
void
countIncrement(
int
n) {
this
.count += n;
}
@Override
public
int
compareTo(TreeNode arg0) {
// TODO Auto-generated method stub
int
count0 = arg0.getCount();
// 跟默认的比较大小相反,导致调用Arrays.sort()时是按降序排列
return
count0 -
this
.count;
}
}
|
挖掘频繁模式
package
fptree;
import
java.io.BufferedReader;
import
java.io.FileReader;
import
java.io.IOException;
import
java.util.ArrayList;
import
java.util.Collections;
import
java.util.Comparator;
import
java.util.HashMap;
import
java.util.LinkedList;
import
java.util.List;
import
java.util.Map;
import
java.util.Map.Entry;
import
java.util.Set;
public
class
FPTree {
private
int
minSuport;
public
int
getMinSuport() {
return
minSuport;
}
public
void
setMinSuport(
int
minSuport) {
this
.minSuport = minSuport;
}
// 从若干个文件中读入Transaction Record
public
List<List<String>> readTransRocords(String... filenames) {
List<List<String>> transaction =
null
;
if
(filenames.length >
0
) {
transaction =
new
LinkedList<List<String>>();
for
(String filename : filenames) {
try
{
FileReader fr =
new
FileReader(filename);
BufferedReader br =
new
BufferedReader(fr);
try
{
String line;
List<String> record;
while
((line = br.readLine()) !=
null
) {
if
(line.trim().length()>
0
){
String str[] = line.split(
","
);
record =
new
LinkedList<String>();
for
(String w : str)
record.add(w);
transaction.add(record);
}
}
}
finally
{
br.close();
}
}
catch
(IOException ex) {
System.out.println(
"Read transaction records failed."
+ ex.getMessage());
System.exit(
1
);
}
}
}
return
transaction;
}
// FP-Growth算法
public
void
FPGrowth(List<List<String>> transRecords,
List<String> postPattern) {
// 构建项头表,同时也是频繁1项集
ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);
// 构建FP-Tree
TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);
// 如果FP-Tree为空则返回
if
(treeRoot.getChildren()==
null
|| treeRoot.getChildren().size() ==
0
)
return
;
//输出项头表的每一项+postPattern
if
(postPattern!=
null
){
for
(TreeNode header : HeaderTable) {
System.out.print(header.getCount() +
"\t"
+ header.getName());
for
(String ele : postPattern)
System.out.print(
"\t"
+ ele);
System.out.println();
}
}
// 找到项头表的每一项的条件模式基,进入递归迭代
for
(TreeNode header : HeaderTable) {
// 后缀模式增加一项
List<String> newPostPattern =
new
LinkedList<String>();
newPostPattern.add(header.getName());
if
(postPattern !=
null
)
newPostPattern.addAll(postPattern);
// 寻找header的条件模式基CPB,放入newTransRecords中
List<List<String>> newTransRecords =
new
LinkedList<List<String>>();
TreeNode backnode = header.getNextHomonym();
while
(backnode !=
null
) {
int
counter = backnode.getCount();
List<String> prenodes =
new
ArrayList<String>();
TreeNode parent = backnode;
// 遍历backnode的祖先节点,放到prenodes中
while
((parent = parent.getParent()).getName() !=
null
) {
prenodes.add(parent.getName());
}
while
(counter-- >
0
) {
newTransRecords.add(prenodes);
}
backnode = backnode.getNextHomonym();
}
// 递归迭代
FPGrowth(newTransRecords, newPostPattern);
}
}
// 构建项头表,同时也是频繁1项集
public
ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {
ArrayList<TreeNode> F1 =
null
;
if
(transRecords.size() >
0
) {
F1 =
new
ArrayList<TreeNode>();
Map<String, TreeNode> map =
new
HashMap<String, TreeNode>();
// 计算事务数据库中各项的支持度
for
(List<String> record : transRecords) {
for
(String item : record) {
if
(!map.keySet().contains(item)) {
TreeNode node =
new
TreeNode(item);
node.setCount(
1
);
map.put(item, node);
}
else
{
map.get(item).countIncrement(
1
);
}
}
}
// 把支持度大于(或等于)minSup的项加入到F1中
Set<String> names = map.keySet();
for
(String name : names) {
TreeNode tnode = map.get(name);
if
(tnode.getCount() >= minSuport) {
F1.add(tnode);
}
}
Collections.sort(F1);
return
F1;
}
else
{
return
null
;
}
}
// 构建FP-Tree
public
TreeNode buildFPTree(List<List<String>> transRecords,
ArrayList<TreeNode> F1) {
TreeNode root =
new
TreeNode();
// 创建树的根节点
for
(List<String> transRecord : transRecords) {
LinkedList<String> record = sortByF1(transRecord, F1);
TreeNode subTreeRoot = root;
TreeNode tmpRoot =
null
;
if
(root.getChildren() !=
null
) {
while
(!record.isEmpty()
&& (tmpRoot = subTreeRoot.findChild(record.peek())) !=
null
) {
tmpRoot.countIncrement(
1
);
subTreeRoot = tmpRoot;
record.poll();
}
}
addNodes(subTreeRoot, record, F1);
}
return
root;
}
// 把交易记录按项的频繁程序降序排列
public
LinkedList<String> sortByF1(List<String> transRecord,
ArrayList<TreeNode> F1) {
Map<String, Integer> map =
new
HashMap<String, Integer>();
for
(String item : transRecord) {
// 由于F1已经是按降序排列的,
for
(
int
i =
0
; i < F1.size(); i++) {
TreeNode tnode = F1.get(i);
if
(tnode.getName().equals(item)) {
map.put(item, i);
}
}
}
ArrayList<Entry<String, Integer>> al =
new
ArrayList<Entry<String, Integer>>(
map.entrySet());
Collections.sort(al,
new
Comparator<Map.Entry<String, Integer>>() {
@Override
public
int
compare(Entry<String, Integer> arg0,
Entry<String, Integer> arg1) {
// 降序排列
return
arg0.getValue() - arg1.getValue();
}
});
LinkedList<String> rest =
new
LinkedList<String>();
for
(Entry<String, Integer> entry : al) {
rest.add(entry.getKey());
}
return
rest;
}
// 把record作为ancestor的后代插入树中
public
void
addNodes(TreeNode ancestor, LinkedList<String> record,
ArrayList<TreeNode> F1) {
if
(record.size() >
0
) {
while
(record.size() >
0
) {
String item = record.poll();
TreeNode leafnode =
new
TreeNode(item);
leafnode.setCount(
1
);
leafnode.setParent(ancestor);
ancestor.addChild(leafnode);
for
(TreeNode f1 : F1) {
if
(f1.getName().equals(item)) {
while
(f1.getNextHomonym() !=
null
) {
f1 = f1.getNextHomonym();
}
f1.setNextHomonym(leafnode);
break
;
}
}
addNodes(leafnode, record, F1);
}
}
}
public
static
void
main(String[] args) {
FPTree fptree =
new
FPTree();
fptree.setMinSuport(
3
);
List<List<String>> transRecords = fptree
.readTransRocords(
"/home/orisun/test/market"
);
fptree.FPGrowth(transRecords,
null
);
}
}
|
输入文件
牛奶,鸡蛋,面包,薯片 鸡蛋,爆米花,薯片,啤酒 鸡蛋,面包,薯片 牛奶,鸡蛋,面包,爆米花,薯片,啤酒 牛奶,面包,啤酒 鸡蛋,面包,啤酒 牛奶,面包,薯片 牛奶,鸡蛋,面包,黄油,薯片 牛奶,鸡蛋,黄油,薯片
输出
6 薯片 鸡蛋 5 薯片 面包 5 鸡蛋 面包 4 薯片 鸡蛋 面包 5 薯片 牛奶 5 面包 牛奶 4 鸡蛋 牛奶 4 薯片 面包 牛奶 4 薯片 鸡蛋 牛奶 3 面包 鸡蛋 牛奶 3 薯片 面包 鸡蛋 牛奶 3 鸡蛋 啤酒 3 面包 啤酒
在上面的代码我们把整个事务数据库放在一个List<List<String>>里面传给FPGrowth,在实际中这是不可取的,因为内存不可能容下整个事务数据库,我们可能需要从关系关系数据库中一条一条地读入来建立FP-Tree。但无论如何 FP-Tree是肯定需要放在内存中的,但内存如果容不下怎么办?另外FPGrowth仍然是非常耗时的,你想提高速度怎么办?解决办法:分而治之,并行计算。
我们把原始事务数据库分成N部分,在N个节点上并行地进行FPGrowth挖掘,最后把关联规则汇总到一起就可以了。关键问题是怎么“划分”才会不遗露任何一条关联规则呢?参见这篇博客。这里为了达到并行计算的目的,采用了一种“冗余”的划分方法,即各部分的并集大于原来的集合。这种方法最终求出来的关联规则也是有冗余的,比如在节点1上得到一条规则(6:啤酒,尿布),在节点2上得到一条规则(3:尿布,啤酒),显然节点2上的这条规则是冗余的,需要采用后续步骤把冗余的规则去掉。
代码:
Record.java
package
fptree;
import
java.io.DataInput;
import
java.io.DataOutput;
import
java.io.IOException;
import
java.util.Collections;
import
java.util.LinkedList;
import
org.apache.hadoop.io.WritableComparable;
public
class
Record
implements
WritableComparable<Record>{
LinkedList<String> list;
public
Record(){
list=
new
LinkedList<String>();
}
public
Record(String[] arr){
list=
new
LinkedList<String>();
for
(
int
i=
0
;i<arr.length;i++)
list.add(arr[i]);
}
@Override
public
String toString(){
String str=list.get(
0
);
for
(
int
i=
1
;i<list.size();i++)
str+=
"\t"
+list.get(i);
return
str;
}
@Override
public
void
readFields(DataInput in)
throws
IOException {
list.clear();
String line=in.readUTF();
String []arr=line.split(
"\\s+"
);
for
(
int
i=
0
;i<arr.length;i++)
list.add(arr[i]);
}
@Override
public
void
write(DataOutput out)
throws
IOException {
out.writeUTF(
this
.toString());
}
@Override
public
int
compareTo(Record obj) {
Collections.sort(list);
Collections.sort(obj.list);
return
this
.toString().compareTo(obj.toString());
}
}
|
DC_FPTree.java
package
fptree;
import
java.io.BufferedReader;
import
java.io.IOException;
import
java.io.InputStreamReader;
import
java.util.ArrayList;
import
java.util.BitSet;
import
java.util.Collections;
import
java.util.Comparator;
import
java.util.HashMap;
import
java.util.LinkedList;
import
java.util.List;
import
java.util.Map;
import
java.util.Map.Entry;
import
java.util.Set;
import
org.apache.hadoop.conf.Configuration;
import
org.apache.hadoop.conf.Configured;
import
org.apache.hadoop.fs.FSDataInputStream;
import
org.apache.hadoop.fs.FileSystem;
import
org.apache.hadoop.fs.Path;
import
org.apache.hadoop.io.IntWritable;
import
org.apache.hadoop.io.LongWritable;
import
org.apache.hadoop.io.Text;
import
org.apache.hadoop.mapreduce.Job;
import
org.apache.hadoop.mapreduce.Mapper;
import
org.apache.hadoop.mapreduce.Reducer;
import
org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import
org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import
org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import
org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
import
org.apache.hadoop.util.Tool;
import
org.apache.hadoop.util.ToolRunner;
public
class
DC_FPTree
extends
Configured
implements
Tool {
private
static
final
int
GroupNum =
10
;
private
static
final
int
minSuport=
6
;
public
static
class
GroupMapper
extends
Mapper<LongWritable, Text, IntWritable, Record> {
List<String> freq =
new
LinkedList<String>();
// 频繁1项集
List<List<String>> freq_group =
new
LinkedList<List<String>>();
// 分组后的频繁1项集
@Override
public
void
setup(Context context)
throws
IOException {
// 从文件读入频繁1项集
FileSystem fs = FileSystem.get(context.getConfiguration());
Path freqFile =
new
Path(
"/user/orisun/input/F1"
);
FSDataInputStream in = fs.open(freqFile);
InputStreamReader isr =
new
InputStreamReader(in);
BufferedReader br =
new
BufferedReader(isr);
try
{
String line;
while
((line = br.readLine()) !=
null
) {
String[] str = line.split(
"\\s+"
);
String word = str[
0
];
freq.add(word);
}
}
finally
{
br.close();
}
// 对频繁1项集进行分组
Collections.shuffle(freq);
// 打乱顺序
int
cap = freq.size() / GroupNum;
// 每段分为一组
for
(
int
i =
0
; i < GroupNum; i++) {
List<String> list =
new
LinkedList<String>();
for
(
int
j =
0
; j < cap; j++) {
list.add(freq.get(i * cap + j));
}
freq_group.add(list);
}
int
remainder = freq.size() % GroupNum;
int
base = GroupNum * cap;
for
(
int
i =
0
; i < remainder; i++) {
freq_group.get(i).add(freq.get(base + i));
}
}
@Override
public
void
map(LongWritable key, Text value, Context context)
throws
IOException, InterruptedException {
String[] arr = value.toString().split(
"\\s+"
);
Record record =
new
Record(arr);
LinkedList<String> list = record.list;
BitSet bs=
new
BitSet(freq_group.size());
bs.clear();
while
(record.list.size() >
0
) {
String item = list.peekLast();
// 取出record的最后一项
int
i=
0
;
for
(; i < freq_group.size(); i++) {
if
(bs.get(i))
continue
;
if
(freq_group.get(i).contains(item)) {
bs.set(i);
break
;
}
}
if
(i<freq_group.size()){
//找到了
context.write(
new
IntWritable(i), record);
}
record.list.pollLast();
}
}
}
public
static
class
FPReducer
extends
Reducer<IntWritable,Record,IntWritable,Text>{
public
void
reduce(IntWritable key,Iterable<Record> values,Context context)
throws
IOException,InterruptedException{
List<List<String>> trans=
new
LinkedList<List<String>>();
while
(values.iterator().hasNext()){
Record record=values.iterator().next();
LinkedList<String> list=
new
LinkedList<String>();
for
(String ele:record.list)
list.add(ele);
trans.add(list);
}
FPGrowth(trans,
null
,context);
}
// FP-Growth算法
public
void
FPGrowth(List<List<String>> transRecords,
List<String> postPattern,Context context)
throws
IOException, InterruptedException {
// 构建项头表,同时也是频繁1项集
ArrayList<TreeNode> HeaderTable = buildHeaderTable(transRecords);
// 构建FP-Tree
TreeNode treeRoot = buildFPTree(transRecords, HeaderTable);
// 如果FP-Tree为空则返回
if
(treeRoot.getChildren()==
null
|| treeRoot.getChildren().size() ==
0
)
return
;
//输出项头表的每一项+postPattern
if
(postPattern!=
null
){
for
(TreeNode header : HeaderTable) {
String outStr=header.getName();
int
count=header.getCount();
for
(String ele : postPattern)
outStr+=
"\t"
+ ele;
context.write(
new
IntWritable(count),
new
Text(outStr));
}
}
// 找到项头表的每一项的条件模式基,进入递归迭代
for
(TreeNode header : HeaderTable) {
// 后缀模式增加一项
List<String> newPostPattern =
new
LinkedList<String>();
newPostPattern.add(header.getName());
if
(postPattern !=
null
)
newPostPattern.addAll(postPattern);
// 寻找header的条件模式基CPB,放入newTransRecords中
List<List<String>> newTransRecords =
new
LinkedList<List<String>>();
TreeNode backnode = header.getNextHomonym();
while
(backnode !=
null
) {
int
counter = backnode.getCount();
List<String> prenodes =
new
ArrayList<String>();
TreeNode parent = backnode;
// 遍历backnode的祖先节点,放到prenodes中
while
((parent = parent.getParent()).getName() !=
null
) {
prenodes.add(parent.getName());
}
while
(counter-- >
0
) {
newTransRecords.add(prenodes);
}
backnode = backnode.getNextHomonym();
}
// 递归迭代
FPGrowth(newTransRecords, newPostPattern,context);
}
}
// 构建项头表,同时也是频繁1项集
public
ArrayList<TreeNode> buildHeaderTable(List<List<String>> transRecords) {
ArrayList<TreeNode> F1 =
null
;
if
(transRecords.size() >
0
) {
F1 =
new
ArrayList<TreeNode>();
Map<String, TreeNode> map =
new
HashMap<String, TreeNode>();
// 计算事务数据库中各项的支持度
for
(List<String> record : transRecords) {
for
(String item : record) {
if
(!map.keySet().contains(item)) {
TreeNode node =
new
TreeNode(item);
node.setCount(
1
);
map.put(item, node);
}
else
{
map.get(item).countIncrement(
1
);
}
}
}
// 把支持度大于(或等于)minSup的项加入到F1中
Set<String> names = map.keySet();
for
(String name : names) {
TreeNode tnode = map.get(name);
if
(tnode.getCount() >= minSuport) {
F1.add(tnode);
}
}
Collections.sort(F1);
return
F1;
}
else
{
return
null
;
}
}
// 构建FP-Tree
public
TreeNode buildFPTree(List<List<String>> transRecords,
ArrayList<TreeNode> F1) {
TreeNode root =
new
TreeNode();
// 创建树的根节点
for
(List<String> transRecord : transRecords) {
LinkedList<String> record = sortByF1(transRecord, F1);
TreeNode subTreeRoot = root;
TreeNode tmpRoot =
null
;
if
(root.getChildren() !=
null
) {
while
(!record.isEmpty()
&& (tmpRoot = subTreeRoot.findChild(record.peek())) !=
null
) {
tmpRoot.countIncrement(
1
);
subTreeRoot = tmpRoot;
record.poll();
}
}
addNodes(subTreeRoot, record, F1);
}
return
root;
}
// 把交易记录按项的频繁程序降序排列
public
LinkedList<String> sortByF1(List<String> transRecord,
ArrayList<TreeNode> F1) {
Map<String, Integer> map =
new
HashMap<String, Integer>();
for
(String item : transRecord) {
// 由于F1已经是按降序排列的,
for
(
int
i =
0
; i < F1.size(); i++) {
TreeNode tnode = F1.get(i);
if
(tnode.getName().equals(item)) {
map.put(item, i);
}
}
}
ArrayList<Entry<String, Integer>> al =
new
ArrayList<Entry<String, Integer>>(
map.entrySet());
Collections.sort(al,
new
Comparator<Map.Entry<String, Integer>>() {
@Override
public
int
compare(Entry<String, Integer> arg0,
Entry<String, Integer> arg1) {
// 降序排列
return
arg0.getValue() - arg1.getValue();
}
});
LinkedList<String> rest =
new
LinkedList<String>();
for
(Entry<String, Integer> entry : al) {
rest.add(entry.getKey());
}
return
rest;
}
// 把record作为ancestor的后代插入树中
public
void
addNodes(TreeNode ancestor, LinkedList<String> record,
ArrayList<TreeNode> F1) {
if
(record.size() >
0
) {
while
(record.size() >
0
) {
String item = record.poll();
TreeNode leafnode =
new
TreeNode(item);
leafnode.setCount(
1
);
leafnode.setParent(ancestor);
ancestor.addChild(leafnode);
for
(TreeNode f1 : F1) {
if
(f1.getName().equals(item)) {
while
(f1.getNextHomonym() !=
null
) {
f1 = f1.getNextHomonym();
}
f1.setNextHomonym(leafnode);
break
;
}
}
addNodes(leafnode, record, F1);
}
}
}
}
public
static
class
InverseMapper
extends
Mapper<LongWritable, Text, Record, IntWritable> {
@Override
public
void
map(LongWritable key, Text value, Context context)
throws
IOException, InterruptedException {
String []arr=value.toString().split(
"\\s+"
);
int
count=Integer.parseInt(arr[
0
]);
Record record=
new
Record();
for
(
int
i=
1
;i<arr.length;i++){
record.list.add(arr[i]);
}
context.write(record,
new
IntWritable(count));
}
}
public
static
class
MaxReducer
extends
Reducer<Record,IntWritable,IntWritable,Record>{
public
void
reduce(Record key,Iterable<IntWritable> values,Context context)
throws
IOException,InterruptedException{
int
max=-
1
;
for
(IntWritable value:values){
int
i=value.get();
if
(i>max)
max=i;
}
context.write(
new
IntWritable(max), key);
}
}
@Override
public
int
run(String[] arg0)
throws
Exception {
Configuration conf=getConf();
conf.set(
"mapred.task.timeout"
,
"6000000"
);
Job job=
new
Job(conf);
job.setJarByClass(DC_FPTree.
class
);
FileSystem fs=FileSystem.get(getConf());
FileInputFormat.setInputPaths(job,
"/user/orisun/input/data"
);
Path outDir=
new
Path(
"/user/orisun/output"
);
fs.delete(outDir,
true
);
FileOutputFormat.setOutputPath(job, outDir);
job.setMapperClass(GroupMapper.
class
);
job.setReducerClass(FPReducer.
class
);
job.setInputFormatClass(TextInputFormat.
class
);
job.setOutputFormatClass(TextOutputFormat.
class
);
job.setMapOutputKeyClass(IntWritable.
class
);
job.setMapOutputValueClass(Record.
class
);
job.setOutputKeyClass(IntWritable.
class
);
job.setOutputKeyClass(Text.
class
);
boolean
success=job.waitForCompletion(
true
);
job=
new
Job(conf);
job.setJarByClass(DC_FPTree.
class
);
FileInputFormat.setInputPaths(job,
"/user/orisun/output/part-r-*"
);
Path outDir2=
new
Path(
"/user/orisun/output2"
);
fs.delete(outDir2,
true
);
FileOutputFormat.setOutputPath(job, outDir2);
job.setMapperClass(InverseMapper.
class
);
job.setReducerClass(MaxReducer.
class
);
//job.setNumReduceTasks(0);
job.setInputFormatClass(TextInputFormat.
class
);
job.setOutputFormatClass(TextOutputFormat.
class
);
job.setMapOutputKeyClass(Record.
class
);
job.setMapOutputValueClass(IntWritable.
class
);
job.setOutputKeyClass(IntWritable.
class
);
job.setOutputKeyClass(Record.
class
);
success |= job.waitForCompletion(
true
);
return
success?
0
:
1
;
}
public
static
void
main(String[] args)
throws
Exception{
int
res=ToolRunner.run(
new
Configuration(),
new
DC_FPTree(), args);
System.exit(res);
}
}
|