机器学习笔记(Chapter 09 - 树回归)

第8章的线性回归创建的模型需要拟合所有的样本点(除了局部加权线性回归)。当数据拥有众多特征并且特征之间关系十分复杂时,构建全局模型的想法就比较困难,并且生活中很多问题是非线性的,无法用全局线性模型来拟合所有数据。一种方法是将数据集递归地切分成很多份易建模的数据,并对可以拟合的小数据集用线性回归建模。

复杂数据的局部性建模

  • 在Chapter03中介绍了贪心算法的决策树,构建算法是ID3,每次选取当前最佳特征来分割数据,并且按照这个特征的所有可能取值来划分,一旦切分完成,这个特征在之后的执行过程中不会再有任何用处。这种方法切分过于迅速,并且需要将连续型数据离散化后才能处理,这样就破坏了连续变量的内在性质。
  • 二元切分法是另一种树构建算法,每次将数据集切分成两半,如果数据的某个特征满足这个切分的条件,就将这些数据放入左子树,否则右子树。二元切分法也节省了树的构建时间,但树一般都是离线构建,因此意义不大。CART(Classification And Regression Trees,分类回归树)使用二元切分来处理连续型变量,并用R^2取代香农熵来分析模型的效果。

连续和离散型特征的树的构建

  • 使用字典存储树的数据结构,每个节点包含以下四个元素:待切分的特征、待切分的特征值、左子树、右子树。Chapter03中的每个节点可能有多个孩子,因此使用字典存储,而CART可以固定数据结构,因为每个非叶节点固定包含两个子树。下面创建回归树(叶节点包含单个值)和模型树(叶节点存储一个线性方程),创建树的代码可以重用,伪代码大致如下。

    • 找到最佳的待切分特征:
    •     如果该节点不能再分,将该节点存为叶节点
    •     执行二元切分
    •     在左右子树分别递归调用
  • CART算法实现 - regTrees.py。binSplitDataSet通过数组过滤切分数据集,createTree递归建立树,输入参数决定树的类型,leafType给出建立叶节点的函数,因此该参数也决定了要建立的是模型树还是回归树,errType代表误差计算函数,ops是一个包含树构建所需的其他参数的元组。代码中的chooseBestSplit函数选取最佳分类方式,尚未实现。github上的附书源码有错误,binSplitDataSet的两行最后没有[0]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from numpy import *

def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = map(float, curLine)
dataMat.append(fltLine)
return dataMat

def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0], :]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0], :]
return mat0, mat1

def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None: return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree

将CART算法用于回归

  • 如何实现数据切分要取决于叶节点的建模方式,回归树假设叶节点是常数值,可以通过计算数据的总方差代替香农熵判断数据的混乱度。
  • 函数chooseBestSplit的目标是找到数据切分的最佳位置,它遍历所有的特征及其可能的取值来找到使误差最小化的划分阈值。伪代码大致如下。

    • 对每个特征:对每个特征值:
    •     将数据集划分为两份
    •     计算切分的误差
    •     若当前误差小于最小误差,则更新
    • 返回最佳切分特征和阈值
  • 回归树切分函数 - regTrees.py,regLeaf负责生成叶节点,在回归树中,该模型是目标变量的均值。regErr是误差估计函数,计算目标变量总方差。chooseBestSplit的参数中为ops设定了tolS和tolN,tolS是容许的误差下降值,tolN是切分的最小样本数。在三种情况下chooseBestSplit会停止切分:误差下降不够大、切分子集数目小、剩余的特征值都相同。github的附书源码也有问题,chooseBestSplit函数中,for splitVal in set(dataSet[:,featIndex]):要增加.T.tolist()[0]否则会报无法hash的错误。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def regLeaf(dataSet):
return mean(dataSet[:,-1])

def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops = (1,4)):
tolS = ops[0]; tolN = ops[1]
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m, n = shape(dataSet)
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex].T.tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex, bestValue
  • 测试代码效果,数据来自ex00.txt和ex0.txt,用matplotlib绘制的图像如下。
1
2
3
4
5
6
7
8
>>> reload(regTrees)
>>> from numpy import *
>>> myDat = mat(regTrees.loadDataSet('ex00.txt'))
>>> regTrees.createTree(myMat)
{'spInd': 0, 'spVal': 0.48813, 'right': -0.044650285714285719, 'left': 1.0180967672413792}
>>> myDat1 = mat(regTrees.loadDataSet('ex0.txt'))
>>> regTrees.createTree(myDat1)
{'spInd': 1, 'spVal': 0.39435, 'right': {'spInd': 1, 'spVal': 0.197834, 'right': -0.023838155555555553, 'left': 1.0289583666666666}, 'left': {'spInd': 1, 'spVal': 0.582002, 'right': 1.980035071428571, 'left': {'spInd': 1, 'spVal': 0.797583, 'right': 2.9836209534883724, 'left': 3.9871631999999999}}}

树剪枝

如果树节点过多,则该模型可能对数据过拟合,通过降低决策树的复杂度来避免过拟合的过程称为剪枝。在函数chooseBestSplit中的三个提前终止条件是“预剪枝”操作,另一种形式的剪枝需要使用测试集和训练集,称作“后剪枝”。

预剪枝

  • 树构建算法对输入的tolS和tolN非常敏感,将ops换为(0,1)会发现生成的树非常臃肿,几乎为数据集中的每个样本都分配了一个叶节点。加载ex2.txt的数据,该数据集和前面ex00.txt的数据分布类似,但数量级是后者的100倍,在这种情况下,ex00构建出的树只有两个叶节点,而ex2构建出的树有非常多的叶节点。原因在于停止条件tolS对误差的数量级非常敏感。显然,通过不断修改停止条件并且比较哪个条件更好并不合理,多数情况下我们并不确定要寻找什么样的结果,计算机应该给出总体的概貌。

后剪枝

  • 使用后剪枝方法需要将数据集交叉验证,首先给定参数,使得构建出的树足够复杂,之后从上而下找到叶节点,判断合并两个叶节点是否能够取得更好的测试误差,如果是就合并。下面是回归树剪枝函数。函数isTree测试输入变量是否为一棵树,getMean对树进行塌陷处理,计算整棵树的平均值。prune函数对树剪枝,参数tree为待剪枝的树,testData是测试集。需要注意的是,测试集合训练集样本的取值范围不一定相同。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def isTree(obj):
return (type(obj).__name__ =='dict')

def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right'])/2.0

def prune(tree, testData):
if shape(testData)[0] == 0: return getMean(tree)
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errNoMerge = sum(power(lSet[:,-1] - tree['left'], 2)) +\
sum(power(rSet[:,-1] - tree['right'], 2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean, 2))
if errorMerge < errNoMerge:
print "merging"
return treeMean
else:
return tree
else:
return tree

模型树

  • 将叶节点设置为分段线性函数,分段线性指模型由多个线性片段组成。例如下图的数据,可以由0.0~0.3和0.3~1.0的两条直线组成。决策树相比其他机器学习算法易于理解,而模型树的可解释性是它优于回归树的特性之一。模型树同时具备更高的预测准确度。
  • 前面的代码已经给出了构建树的代码,只要修改参数errType和leafType。对于给定的数据集,先用现行的模型对它进行拟合,然后计算真实目标值和模型预测值之间的差距。最后求这些差值的平方和作为误差。modelLeaf函数生成叶节点,linearSolve返回回归系数,modelErr在数据集上调用linearSove,返回yHat和y之间的平方误差。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def linearSolve(dataSet):
m, n = shape(dataSet)
X = mat(ones((m,n))); Y = mat(ones((m,1)))
X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
xTx = X.T * X
if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverses,\n\
try increasing the second value of ops')

ws = xTx.I * (X.T * Y)
return ws, X, Y

def modelLeaf(dataSet):
ws, X, Y = linearSolve(dataSet)
return ws

def modelErr(dataSet):
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat, 2))

>>> myMat2 = mat(regTrees.loadDataSet('exp2.txt'))
>>> regTrees.createTree(myMat2, regTrees.modelLeaf, regTrees.modelErr, (1,10))
{'spInd': 0, 'spVal': 0.285477, 'right': matrix([[ 3.46877936],
[ 1.18521743]]), 'left': matrix([[ 1.69855694e-03],
[ 1.19647739e+01]])}

树回归和标准回归的比较

  • 对于输入的单个数据点,函数treeForeCast返回一个预测值。参数modelEval是对叶节点数据进行预测的函数的引用,函数treeForeCast自顶向下遍历整棵树,直到命中叶节点为止。一旦到达叶节点,它会在输入数据上调用modelEval,该参数默认值是regTreeEval。要对回归树叶节点预测,就调用regTreeEval,要对模型树节点预测,调用modelTreeEval。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def regTreeEval(model, inDat):
return float(model)

def modelTreeEval(model, inDat):
n = shape(inDat)[1]
X = mat(ones((1,n+1)))
X[:,1:n+1] = inDat
return float(X*model)

def treeForeCast(tree, inData, modelEval = regTreeEval):
if not isTree(tree): return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if isTree(tree['left']):
return treeForeCast(tree['left'], inData, modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForeCast(tree['right'], inData, modelEval)
else:
return modelEval(tree['right'], inData)

def createForeCast(tree, testData, modelEval = regTreeEval):
m = len(testData)
yHat = mat(zeros((m,1)))
for i in range(m):
yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
return yHat
  • 比较回归树、模型树和标准线性回归的R^2数值。可以看出,模型树的结果比回归树好,而树回归方法在预测复杂数据时会比简单的线性模型更有效。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
>>> trainMat = mat(regTrees.loadDataSet('bikeSpeedVsIq_train.txt'))
>>> testMat = mat(regTrees.loadDataSet('bikeSpeedVsIq_test.txt'))
>>> myTree = regTrees.createTree(trainMat, ops=(1,20))
>>> yHat = regTrees.createForeCast(myTree, testMat[:,0])
>>> corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]
0.96408523182221395
>>> myTree = regTrees.createTree(trainMat, regTrees.modelLeaf, regTrees.modelErr, ops=(1,20))
>>> yHat = regTrees.createForeCast(myTree, testMat[:,0], regTrees.modelTreeEval)
>>> corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]
0.97604121913806363
>>> ws, X, Y = regTrees.linearSolve(trainMat)
>>> ws
matrix([[ 37.58916794],
[ 6.18978355]])
>>> for i in range(shape(testMat)[0]):
... yHat[i] = testMat[i,0]*ws[1,0]+ws[0,0]
>>> corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]
0.94346842356747584

Tkinter库创建GUI

  • Tkinter模块的.grid()方法将widget安排在一个二维表格内,,默认widget会显示在0行0列,可以通过设定columnspan和rowspan来告诉布局管理器是否允许一个widget跨行或跨列。界面代码如下 - treeExplore.py。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from numpy import *
from Tkinter import *
import regTrees

def redraw(tolS, tolN):
pass
def drawNewTree():
pass

root = Tk()
Label(root, text="Plot Place Holder").grid(row = 0, columnspan = 3)
Label(root, text="tolN").grid(row = 1, column = 0)
tolNentry = Entry(root)
tolNentry.grid(row = 1, column = 1)
tolNentry.insert(0,'10')
Label(root, text="tolS").grid(row = 2, column =0 )
tolSentry = Entry(root)
tolSentry.grid(row = 2, column = 1)
tolSentry.insert(0,'1.0')
Label(root, text="path").grid(row = 3, column = 0)
datPentry = Entry(root)
datPentry.grid(row = 3, column = 1)
datPentry.insert(0,'sine.txt')

Button(root, text="ReDraw", command = drawNewTree).grid(row = 1, column =2, rowspan =3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable = chkBtnVar)
chkBtn.grid(row = 4, column = 0, columnspan=2)

reDraw.rawDat = mat(regTrees.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]), max(reDraw.rawDat[:,0]), 0.01)
reDraw(1.0, 10)
root.mainloop()
  • Matplotlib的构建程序包含一个前端面向用户,如plot和scatter方法等,同时创建一个后端,用于实现绘图和不同应用程序之间的接口。改变后端可以将图像绘制不同格式的文件上,将后端在设置为TkAgg,可以在所选GUI框架上调用Agg,呈现在画布上。下面的代码填补了上面的两个占位函数,另外将上面代码中加载文件的语句移入了按钮事件。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure

def reDraw(tolS, tolN):
reDraw.f.clf()
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get():
if tolN < 2: tolN = 2
myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, regTrees.modelErr, (tolS, tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat, regTrees.modelTreeEval)
else:
myTree = regTrees.createTree(reDraw.rawDat, ops=(tolS, tolN))
yHat = regTrees.createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:,0], reDraw.rawDat[:,1], s=5)
reDraw.a.plot(reDraw.testDat, yHat, linewidth = 2.0)
reDraw.canvas.show()

def getInputs():
try: tolN = int(tolNentry.get())
except:
tolN = 10
print "enter Integer for tolN"
tolNentry.delete(0, END)
tolNentry.insert(0, '10')
try: tolS = float(tolSentry.get())
except:
tolS = 1.0
print "enter Integer for tolS"
tolSentry.delete(0, END)
tolSentry.insert(0, '1.0')
try: datPath = str(datPentry.get())
except:
datPath = ''
print "enter path for test data"
tolSentry.delete(0, END)
tolSentry.insert(0, '')
return datPath, tolS, tolN

def drawNewTree():
datPath, tolS, tolN = getInputs()
try:
reDraw.rawDat = mat(regTrees.loadDataSet(datPath))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]), max(reDraw.rawDat[:,0]), 0.01)
reDraw(tolS, tolN)
except:
print "Cannot find file %s" % datPath
  • 绘制出的GUI界面如下。



参考文献: 《机器学习实战 - 美Peter Harrington》

原创作品,允许转载,转载时无需告知,但请务必以超链接形式标明文章原始出处(https://forec.github.io/2016/02/20/machinelearning9/) 、作者信息(Forec)和本声明。

分享到