NNSearch.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. from sklearn.neighbors import NearestNeighbors
  5. def dist(x,y):
  6. return sum(map(lambda z: (z[0] - z[1])*(z[0] - z[1]), zip(x, y)))
  7. def maxby(data, fn, startValue=0.0):
  8. m = startValue
  9. for v in data:
  10. m = max(m, fn(v))
  11. return m
  12. def distancesToPoint(p, points):
  13. w = np.array(np.repeat([p], len(points), axis=0))
  14. d = tf.keras.layers.Subtract()([w, np.array(points)])
  15. t = tf.keras.layers.Dot(axes=(1,1))([d,d])
  16. # As the concrete distance is not needed and sqrt(x) is strict monotone
  17. # we avoid here unneccessary calculating of expensive roots.
  18. return t.numpy()
  19. def calculateCenter(points):
  20. if points.shape[0] == 1:
  21. return points[0]
  22. return tf.keras.layers.Average()(list(points)).numpy()
  23. class MaxHeap:
  24. def __init__(self, maxSize=None, isGreaterThan=None, smalestValue=(-1,0.0)):
  25. self.heap = []
  26. self.size = 0
  27. self.maxSize = maxSize
  28. self.isGreaterThan = isGreaterThan if isGreaterThan is not None else (lambda a, b: a > b)
  29. self.smalestValue = smalestValue
  30. self.indices = set()
  31. self.wasChanged = False
  32. self.insert(smalestValue)
  33. def insert(self, v):
  34. if self.maxSize is not None and self.size >= self.maxSize:
  35. return self.replaceMax(v)
  36. if v[0] in self.indices:
  37. return False
  38. self.indices.add(v[0])
  39. pos = self.size
  40. self.size += 1
  41. self.heap.append(v)
  42. while pos > 0:
  43. w = self.heap[pos // 2]
  44. if not self.isGreaterThan(v, w):
  45. break
  46. self.heap[pos] = w
  47. pos = pos // 2
  48. self.heap[pos] = v
  49. self.wasChanged = True
  50. return True
  51. def childPos(self, pos):
  52. c = (pos + 1) * 2
  53. return (c - 1, c)
  54. def removeMax(self):
  55. if self.heap == []:
  56. self.size = 0
  57. return False
  58. x = self.heap[0]
  59. self.indices.remove(x[0])
  60. self.heap[0] = self.heap[-1]
  61. self.heap = self.heap[:-1]
  62. self.size -= 1
  63. x = self.heap[0]
  64. pos = 0
  65. size = self.size
  66. while pos < size:
  67. (left, right) = self.childPos(pos)
  68. if left >= size:
  69. break
  70. y = self.heap[left]
  71. if right >= size:
  72. if self.isGreaterThan(y, x):
  73. self.heap[pos] = y
  74. self.heap[left] = x
  75. break
  76. z = self.heap[right]
  77. (best, v) = (left, y) if self.isGreaterThan(y, z) else (right, z)
  78. if not self.isGreaterThan(v, x):
  79. break
  80. self.heap[pos] = v
  81. self.heap[best] = x
  82. pos = best
  83. self.wasChanged = True
  84. return True
  85. def replaceMax(self, x):
  86. if self.heap == []:
  87. self.heap = [x]
  88. self.size = 1
  89. self.indices.add(x[0])
  90. self.wasChanged = True
  91. return True
  92. if x[0] in self.indices:
  93. return False
  94. if self.isGreaterThan(x, self.heap[0]):
  95. return False
  96. self.indices.remove((self.heap[0])[0])
  97. self.indices.add(x[0])
  98. self.heap[0] = x
  99. pos = 0
  100. size = len(self.heap)
  101. while pos < size:
  102. (left, right) = self.childPos(pos)
  103. if left >= size:
  104. break
  105. y = self.heap[left]
  106. if right >= size:
  107. if self.isGreaterThan(y, x):
  108. self.heap[pos] = y
  109. self.heap[left] = x
  110. break
  111. z = self.heap[right]
  112. (best, v) = (left, y) if self.isGreaterThan(y, z) else (right, z)
  113. if not self.isGreaterThan(v, x):
  114. break
  115. self.heap[pos] = v
  116. self.heap[best] = x
  117. pos = best
  118. self.wasChanged = True
  119. return True
  120. def getMax(self):
  121. if self.heap == []:
  122. return self.smalestValue
  123. return self.heap[0]
  124. def setMaxSize(self, maxSize):
  125. self.maxSize = maxSize
  126. while self.size > maxSize:
  127. self.removeMax()
  128. def toArray(self, mapFn=None):
  129. return list(self.indices)
  130. def length(self):
  131. return self.size
  132. class Ball:
  133. def __init__(self, points=None, indices=None, parent=None, center=None):
  134. if center is not None:
  135. self.center = center
  136. elif points is not None:
  137. self.center = calculateCenter(np.array(points))
  138. else:
  139. raise ParameterError("Missing points or center")
  140. self.radius = 0
  141. self.points = []
  142. self.indices = set()
  143. self.childs = []
  144. self.parent = parent
  145. if points is not None and indices is not None:
  146. for (i, r) in zip(indices, distancesToPoint(self.center, points)):
  147. self.add(i, r)
  148. elif points is not None:
  149. raise ParameterError("Missing indices")
  150. def findIt(self, r):
  151. if r == self.points[0][1]:
  152. return 0
  153. upper = len(self.points) - 1
  154. lower = 0
  155. while upper > lower + 1:
  156. h = (upper + lower) // 2
  157. if self.points[h][1] >= r:
  158. upper = h
  159. else:
  160. lower = h
  161. return upper
  162. def add(self, xi, r):
  163. if xi in self.indices:
  164. return False
  165. # Here we know, that x is not in points.
  166. newEntry = (xi, r)
  167. self.indices.add(xi)
  168. # Special case: empty list or new element will extend the radius:
  169. # Here we can avoid the search
  170. if self.points == [] or r >= self.radius:
  171. self.points.append(newEntry)
  172. self.radius = r
  173. return True
  174. # Special case: r <= min radius
  175. # Here we can avoid the search
  176. if self.points[0][1] >= r:
  177. self.points = [newEntry] + self.points
  178. return True
  179. # Here we know that min radius < r < max radius.
  180. # So len(points) >= 2.
  181. pos = self.findIt(r)
  182. # here shoul be: r(pos) >= r > r(pos - 1)
  183. self.points = self.points[:pos] + [newEntry] + self.points[pos: ]
  184. return True
  185. def remove(self, xi, r):
  186. if xi not in self.indices:
  187. return False
  188. # special case: remove the element with the heighest radius:
  189. if self.points[-1][0] == xi:
  190. self.indices.remove(xi)
  191. self.points = self.points[: -1]
  192. if self.points == []:
  193. self.radius = 0.0
  194. else:
  195. self.radius = self.points[-1][1]
  196. return True
  197. pos = self.findIt(r)
  198. nPoints = len(self.points)
  199. # pos is the smalest position with:
  200. # r(pos - 1) < r <= r(pos)
  201. # Just in case: check that we remove the correct index.
  202. while pos < nPoints and r == self.points[pos]:
  203. if self.points[pos][0] == xi:
  204. self.indices.remove(xi)
  205. self.points = self.points[:pos] + self.points[(pos + 1): ]
  206. return True
  207. pos += 1
  208. return False
  209. def divideBall(self, X):
  210. indices = [i for (i, _r) in self.points]
  211. points = np.array([X[i] for i in indices])
  212. distances = tf.keras.layers.Dot(axes=(1,1))([points, points]).numpy()
  213. ball = Ball(points, indices, center=np.zeros(len(points[0])))
  214. h = len(ball) // 2
  215. indicesA = [i for (i, _) in ball.points[:h]]
  216. indicesB = [i for (i, _) in ball.points[h:]]
  217. pointsA = np.array([X[i] for i in indicesA])
  218. pointsB = np.array([X[i] for i in indicesB])
  219. self.childs = []
  220. print(f"{len(points)} -> <{len(pointsA)}|{len(pointsB)}>")
  221. self.childs.append(Ball(pointsA, indicesA, self))
  222. self.childs.append(Ball(pointsB, indicesB, self))
  223. return self.childs
  224. def __len__(self):
  225. return len(self.points)
  226. def smalestBallFor(self, i):
  227. if i not in self.indices:
  228. return None
  229. for c in self.childs:
  230. b = c.smalestBallFor(i)
  231. if b is not None:
  232. return b
  233. return self
  234. class NNSearch:
  235. def __init__(self, nebSize=5):
  236. self.nebSize = nebSize
  237. self.neighbourhoods = []
  238. def neighbourhoodOfItem(self, i):
  239. return self.neighbourhoods[i]
  240. def fit(self, X, nebSize=None):
  241. self.fit_bruteForce_np(X, nebSize)
  242. def fit_optimized(self, X, nebSize=None):
  243. if nebSize == None:
  244. nebSize = self.nebSize
  245. nPoints = len(X)
  246. nFeatures = len(X[0])
  247. if nFeatures > 15 or nebSize >= (nPoints // 2):
  248. print("Using brute force")
  249. self.fit_bruteForce_np(X, nebSize)
  250. else:
  251. print("Using chained")
  252. self.fit_chained(X, nebSize)
  253. def fit_bruteForce(self, X, nebSize=None):
  254. if nebSize == None:
  255. nebSize = self.nebSize
  256. isGreaterThan = lambda x, y: x[1] > y[1]
  257. self.neighbourhoods = [MaxHeap(nebSize, isGreaterThan, (i, 0.0)) for i in range(len(X))]
  258. for (i, x) in enumerate(X):
  259. nbh = self.neighbourhoods[i]
  260. for (j, y) in enumerate(X[i+1:]):
  261. j += i + 1
  262. d = dist(x,y)
  263. nbh.insert((j,d))
  264. self.neighbourhoods[j].insert((i,d))
  265. self.neighbourhoods[i] = nbh.toArray(lambda v: v[0])
  266. def fit_bruteForce_np(self, X, nebSize=None):
  267. numOfPoints = len(X)
  268. nFeatures = len(X[0])
  269. tX = tf.convert_to_tensor(X)
  270. def distancesTo(x):
  271. w = np.repeat([x], numOfPoints, axis=0)
  272. d = tf.keras.layers.Subtract()([w,tX])
  273. t = tf.keras.layers.Dot(axes=(1,1))([d,d])
  274. return t.numpy()
  275. if nebSize == None:
  276. nebSize = self.nebSize
  277. isGreaterThan = lambda x, y: x[1] > y[1]
  278. self.neighbourhoods = [MaxHeap(nebSize, isGreaterThan, (i, 0.0)) for i in range(len(X))]
  279. for (i, x) in enumerate(X):
  280. distances = distancesTo(x)
  281. nbh = self.neighbourhoods[i]
  282. for (j, y) in enumerate(X[i+1:]):
  283. j += i + 1
  284. d = distances[j]
  285. nbh.insert((j,d))
  286. self.neighbourhoods[j].insert((i,d))
  287. self.neighbourhoods[i] = nbh.toArray(lambda v: v[0])
  288. def fit_chained(self, X, nebSize=None):
  289. if nebSize == None:
  290. nebSize = self.nebSize
  291. nPoints = len(X)
  292. nFeatures = len(X[0])
  293. neigh = NearestNeighbors(n_neighbors=nebSize)
  294. neigh.fit(X)
  295. self.neighbourhoods = [
  296. (neigh.kneighbors([x], nebSize, return_distance=False))[0]
  297. for (i, x) in enumerate(X)
  298. ]
  299. # ===============================================================
  300. # Heuristic search
  301. # ===============================================================
  302. def fit_heuristic(self, X, nebSize=None):
  303. if nebSize == None:
  304. nebSize = self.nebSize
  305. nPoints = len(X)
  306. def walkUp(nbh, ball, x, i):
  307. while ball.parent is not None:
  308. print(f"{i}: up (r: {nbh.getMax()})")
  309. oldBall = ball
  310. ball = ball.parent
  311. for c in ball.childs:
  312. if c != oldBall:
  313. walkDown(nbh, c, x)
  314. def walkDown(nbh, ball, x):
  315. if ball is None:
  316. return
  317. print(f"{i}: down (r: {nbh.getMax()})")
  318. if dist(x, ball.center) - ball.radius < nbh.getMax()[1]:
  319. if ball.childs == []:
  320. for (j, _) in ball.points:
  321. nbh.insert((j, dist(x, X[j])))
  322. else:
  323. for c in ball.childs:
  324. walkDown(nbh, c, x)
  325. def countBoles(b):
  326. if b is None:
  327. return 0
  328. return 1 + sum(map(countBoles, b.childs))
  329. root = Ball(X, range(len(X)))
  330. queue = [root]
  331. while queue != []:
  332. ball = queue[0]
  333. queue = queue[1:]
  334. if len(ball) <= nebSize:
  335. continue
  336. queue = ball.divideBall(X) + queue
  337. isGreaterThan = lambda x, y: x[1] > y[1]
  338. self.neighbourhoods = [MaxHeap(nPoints, isGreaterThan, (i, 0.0)) for i in range(len(X))]
  339. print("#B: " + str(countBoles(root)))
  340. exit()
  341. z = X[0]
  342. for (i, x) in enumerate(X):
  343. nbh = self.neighbourhoods[i]
  344. b = root.smalestBallFor(i)
  345. if b.parent is not None:
  346. b = b.parent
  347. for (j, _) in b.points:
  348. d = dist(x, X[j])
  349. nbh.insert((j, d))
  350. walkUp(nbh, b, x, i)