NNSearch.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import math
  2. def dist(x,y):
  3. return math.sqrt(sum(map(lambda z: (z[0] - z[1])**2, zip(x, y))))
  4. def maxby(data, fn, startValue=0.0):
  5. m = startValue
  6. for v in data:
  7. m = max(m, fn(v))
  8. return m
  9. class MaxHeap:
  10. def __init__(self, maxSize=None, isGreaterThan=None, smalestValue=0.0):
  11. self.heap = []
  12. self.size = 0
  13. self.maxSize = maxSize
  14. self.isGreaterThan = isGreaterThan if isGreaterThan is not None else (lambda a, b: a > b)
  15. self.smalestValue = smalestValue
  16. def insert(self, v):
  17. if self.maxSize is not None and self.size >= self.maxSize:
  18. self.replaceMax(v)
  19. return
  20. pos = self.size
  21. self.size += 1
  22. self.heap.append(v)
  23. while pos > 0:
  24. w = self.heap[pos // 2]
  25. if not self.isGreaterThan(v, w):
  26. break
  27. self.heap[pos] = w
  28. pos = pos // 2
  29. self.heap[pos] = v
  30. def childPos(self, pos):
  31. c = (pos + 1) * 2
  32. return (c - 1, c)
  33. def removeMax(self):
  34. if self.heap == []:
  35. self.size = 0
  36. return
  37. self.heap[0] = self.heap[-1]
  38. self.heap = self.heap[:-1]
  39. self.size -= 1
  40. x = self.heap[0]
  41. pos = 0
  42. size = self.size
  43. while pos < size:
  44. (left, right) = self.childPos(pos)
  45. if left >= size:
  46. break
  47. y = self.heap[left]
  48. if right >= size:
  49. if self.isGreaterThan(y, x):
  50. self.heap[pos] = y
  51. self.heap[left] = x
  52. break
  53. z = self.heap[right]
  54. (best, v) = (left, y) if self.isGreaterThan(y, z) else (right, z)
  55. if not self.isGreaterThan(v, x):
  56. break
  57. self.heap[pos] = v
  58. self.heap[best] = x
  59. pos = best
  60. def replaceMax(self, x):
  61. if self.heap == []:
  62. self.heap = [x]
  63. self.size = 1
  64. return
  65. if self.isGreaterThan(x, self.heap[0]):
  66. return
  67. self.heap[0] = x
  68. pos = 0
  69. size = len(self.heap)
  70. while pos < size:
  71. (left, right) = self.childPos(pos)
  72. if left >= size:
  73. break
  74. y = self.heap[left]
  75. if right >= size:
  76. if self.isGreaterThan(y, x):
  77. self.heap[pos] = y
  78. self.heap[left] = x
  79. break
  80. z = self.heap[right]
  81. (best, v) = (left, y) if self.isGreaterThan(y, z) else (right, z)
  82. if not self.isGreaterThan(v, x):
  83. break
  84. self.heap[pos] = v
  85. self.heap[best] = x
  86. pos = best
  87. def getMax(self):
  88. if self.heap == []:
  89. return self.smalestValue
  90. return self.heap[0]
  91. def toArray(self, mapFn=None):
  92. if mapFn is None:
  93. return self.heap.copy()
  94. else:
  95. return [mapFn(x) for x in self.heap]
  96. def length(self):
  97. return self.size
  98. class NNSearch:
  99. def __init__(self, nebSize=5):
  100. self.nebSize = nebSize
  101. self.neighbourhoods = []
  102. def fit(self, X, nebSize=None):
  103. if nebSize == None:
  104. nebSize = self.nebSize
  105. isGreaterThan = lambda x, y: x[1] > y[1]
  106. self.neighbourhoods = [MaxHeap(nebSize, isGreaterThan, (None, 0.0)) for _i in range(len(X))]
  107. for (i, x) in enumerate(X):
  108. nbh = self.neighbourhoods[i]
  109. nbh.insert((i, 0.0))
  110. for (j, y) in enumerate(X[i+1:]):
  111. j += i + 1
  112. d = dist(x,y)
  113. nbh.insert((j,d))
  114. self.neighbourhoods[j].insert((i,d))
  115. self.neighbourhoods[i] = nbh.toArray(lambda v: v[0])
  116. def neighbourhoodOfItem(self, i):
  117. return self.neighbourhoods[i]