1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
| import numpy as np def createTree(dataset,depth): """ dataset为数据集,depth为深度 """ treeNode={} m=np.shape(dataset)[0] if m==0: return None else: m,n=np.shape(dataset) split=(depth%n) dataset=sorted(dataset,key=lambda x:x[split]) median=m//2 treeNode['split']=split treeNode['median']=dataset[median] depth+=1 treeNode['left']=createTree(dataset[0:median],depth) treeNode['right']=createTree(dataset[median+1:],depth) return treeNode
def searchTree(treeNode,data): """ data为单独点测试集,为了搜索到最近邻 """ if treeNode is None: return [0]*len(data),float('inf') split=treeNode['split'] median_point=treeNode['median'] if data[split]<=median_point[split]: nearestPoint,nearestDistance=searchTree(treeNode['left'],data) else: nearestPoint,nearestDistance=searchTree(treeNode['right'],data) nowDistance=np.linalg.norm(data-median_point) if nowDistance<nearestDistance: nearestDistance=nowDistance nearestPoint=median_point.copy() splitDistance=abs(data[split]-median_point[split]) if splitDistance>nearestDistance: return nearestPoint,nearestDistance else: if data[split]<=median_point[split]: nextTree=treeNode['right'] else: nextTree=treeNode['left'] nearPoint,nearDistance=searchTree(nextTree,data) if nearDistance<nearestDistance: nearestDistance=nearDistance nearestPoint=nearPoint.copy() return nearestPoint,nearestDistance
dataset=np.array([[1,2],[3,4],[5,4],[7,2],[6,3],[8,7]]) tree=createTree(dataset,0) tree
m,n=searchTree(tree,[1,5]) m,n
Out[85]: (array([3, 4]), 2.23606797749979)
|