NNSearch.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. from sklearn.neighbors import NearestNeighbors
  5. from library.timing import timing
  6. def dist(x,y):
  7. return sum(map(lambda z: (z[0] - z[1])*(z[0] - z[1]), zip(x, y)))
  8. def maxby(data, fn, startValue=0.0):
  9. m = startValue
  10. for v in data:
  11. m = max(m, fn(v))
  12. return m
  13. def distancesToPoint(p, points):
  14. w = np.array(np.repeat([p], len(points), axis=0))
  15. d = tf.keras.layers.Subtract()([w, np.array(points)])
  16. t = tf.keras.layers.Dot(axes=(1,1))([d,d])
  17. # As the concrete distance is not needed and sqrt(x) is strict monotone
  18. # we avoid here unneccessary calculating of expensive roots.
  19. return t.numpy()
  20. def calculateCenter(points):
  21. if points.shape[0] == 1:
  22. return points[0]
  23. return tf.keras.layers.Average()(list(points)).numpy()
  24. class MaxHeap:
  25. def __init__(self, maxSize=None, isGreaterThan=None, smalestValue=(-1,0.0)):
  26. self.heap = []
  27. self.size = 0
  28. self.maxSize = maxSize
  29. self.isGreaterThan = isGreaterThan if isGreaterThan is not None else (lambda a, b: a > b)
  30. self.smalestValue = smalestValue
  31. self.indices = set()
  32. self.wasChanged = False
  33. self.insert(smalestValue)
  34. def insert(self, v):
  35. if self.maxSize is not None and self.size >= self.maxSize:
  36. return self.replaceMax(v)
  37. if v[0] in self.indices:
  38. return False
  39. self.indices.add(v[0])
  40. pos = self.size
  41. self.size += 1
  42. self.heap.append(v)
  43. while pos > 0:
  44. w = self.heap[pos // 2]
  45. if not self.isGreaterThan(v, w):
  46. break
  47. self.heap[pos] = w
  48. pos = pos // 2
  49. self.heap[pos] = v
  50. self.wasChanged = True
  51. return True
  52. def childPos(self, pos):
  53. c = (pos + 1) * 2
  54. return (c - 1, c)
  55. def removeMax(self):
  56. if self.heap == []:
  57. self.size = 0
  58. return False
  59. x = self.heap[0]
  60. self.indices.remove(x[0])
  61. self.heap[0] = self.heap[-1]
  62. self.heap = self.heap[:-1]
  63. self.size -= 1
  64. x = self.heap[0]
  65. pos = 0
  66. size = self.size
  67. while pos < size:
  68. (left, right) = self.childPos(pos)
  69. if left >= size:
  70. break
  71. y = self.heap[left]
  72. if right >= size:
  73. if self.isGreaterThan(y, x):
  74. self.heap[pos] = y
  75. self.heap[left] = x
  76. break
  77. z = self.heap[right]
  78. (best, v) = (left, y) if self.isGreaterThan(y, z) else (right, z)
  79. if not self.isGreaterThan(v, x):
  80. break
  81. self.heap[pos] = v
  82. self.heap[best] = x
  83. pos = best
  84. self.wasChanged = True
  85. return True
  86. def replaceMax(self, x):
  87. if self.heap == []:
  88. self.heap = [x]
  89. self.size = 1
  90. self.indices.add(x[0])
  91. self.wasChanged = True
  92. return True
  93. if x[0] in self.indices:
  94. return False
  95. if self.isGreaterThan(x, self.heap[0]):
  96. return False
  97. self.indices.remove((self.heap[0])[0])
  98. self.indices.add(x[0])
  99. self.heap[0] = x
  100. pos = 0
  101. size = len(self.heap)
  102. while pos < size:
  103. (left, right) = self.childPos(pos)
  104. if left >= size:
  105. break
  106. y = self.heap[left]
  107. if right >= size:
  108. if self.isGreaterThan(y, x):
  109. self.heap[pos] = y
  110. self.heap[left] = x
  111. break
  112. z = self.heap[right]
  113. (best, v) = (left, y) if self.isGreaterThan(y, z) else (right, z)
  114. if not self.isGreaterThan(v, x):
  115. break
  116. self.heap[pos] = v
  117. self.heap[best] = x
  118. pos = best
  119. self.wasChanged = True
  120. return True
  121. def getMax(self):
  122. if self.heap == []:
  123. return self.smalestValue
  124. return self.heap[0]
  125. def setMaxSize(self, maxSize):
  126. self.maxSize = maxSize
  127. while self.size > maxSize:
  128. self.removeMax()
  129. def toArray(self, mapFn=None):
  130. return list(self.indices)
  131. def length(self):
  132. return self.size
  133. class Ball:
  134. def __init__(self, points=None, indices=None, parent=None, center=None):
  135. if center is not None:
  136. self.center = center
  137. elif points is not None:
  138. self.center = calculateCenter(np.array(points))
  139. else:
  140. raise ParameterError("Missing points or center")
  141. self.radius = 0
  142. self.points = []
  143. self.indices = set()
  144. self.childs = []
  145. self.parent = parent
  146. if points is not None and indices is not None:
  147. for (i, r) in zip(indices, distancesToPoint(self.center, points)):
  148. self.add(i, r)
  149. elif points is not None:
  150. raise ParameterError("Missing indices")
  151. def findIt(self, r):
  152. if r == self.points[0][1]:
  153. return 0
  154. upper = len(self.points) - 1
  155. lower = 0
  156. while upper > lower + 1:
  157. h = (upper + lower) // 2
  158. if self.points[h][1] >= r:
  159. upper = h
  160. else:
  161. lower = h
  162. return upper
  163. def add(self, xi, r):
  164. if xi in self.indices:
  165. return False
  166. # Here we know, that x is not in points.
  167. newEntry = (xi, r)
  168. self.indices.add(xi)
  169. # Special case: empty list or new element will extend the radius:
  170. # Here we can avoid the search
  171. if self.points == [] or r >= self.radius:
  172. self.points.append(newEntry)
  173. self.radius = r
  174. return True
  175. # Special case: r <= min radius
  176. # Here we can avoid the search
  177. if self.points[0][1] >= r:
  178. self.points = [newEntry] + self.points
  179. return True
  180. # Here we know that min radius < r < max radius.
  181. # So len(points) >= 2.
  182. pos = self.findIt(r)
  183. # here shoul be: r(pos) >= r > r(pos - 1)
  184. self.points = self.points[:pos] + [newEntry] + self.points[pos: ]
  185. return True
  186. def remove(self, xi, r):
  187. if xi not in self.indices:
  188. return False
  189. # special case: remove the element with the heighest radius:
  190. if self.points[-1][0] == xi:
  191. self.indices.remove(xi)
  192. self.points = self.points[: -1]
  193. if self.points == []:
  194. self.radius = 0.0
  195. else:
  196. self.radius = self.points[-1][1]
  197. return True
  198. pos = self.findIt(r)
  199. nPoints = len(self.points)
  200. # pos is the smalest position with:
  201. # r(pos - 1) < r <= r(pos)
  202. # Just in case: check that we remove the correct index.
  203. while pos < nPoints and r == self.points[pos]:
  204. if self.points[pos][0] == xi:
  205. self.indices.remove(xi)
  206. self.points = self.points[:pos] + self.points[(pos + 1): ]
  207. return True
  208. pos += 1
  209. return False
  210. def divideBall(self, X):
  211. indices = [i for (i, _r) in self.points]
  212. points = np.array([X[i] for i in indices])
  213. distances = tf.keras.layers.Dot(axes=(1,1))([points, points]).numpy()
  214. ball = Ball(points, indices, center=np.zeros(len(points[0])))
  215. h = len(ball) // 2
  216. indicesA = [i for (i, _) in ball.points[:h]]
  217. indicesB = [i for (i, _) in ball.points[h:]]
  218. pointsA = np.array([X[i] for i in indicesA])
  219. pointsB = np.array([X[i] for i in indicesB])
  220. self.childs = []
  221. print(f"{len(points)} -> <{len(pointsA)}|{len(pointsB)}>")
  222. self.childs.append(Ball(pointsA, indicesA, self))
  223. self.childs.append(Ball(pointsB, indicesB, self))
  224. return self.childs
  225. def __len__(self):
  226. return len(self.points)
  227. def smalestBallFor(self, i):
  228. if i not in self.indices:
  229. return None
  230. for c in self.childs:
  231. b = c.smalestBallFor(i)
  232. if b is not None:
  233. return b
  234. return self
  235. class NNSearch:
  236. def __init__(self, nebSize=5, timingDict=None):
  237. self.nebSize = nebSize
  238. self.neighbourhoods = []
  239. self.timingDict = timingDict
  240. def timerStart(self, name):
  241. if self.timingDict is not None:
  242. if name not in self.timingDict:
  243. self.timingDict[name] = timing(name)
  244. self.timingDict[name].start()
  245. def timerStop(self, name):
  246. if self.timingDict is not None:
  247. if name in self.timingDict:
  248. self.timingDict[name].stop()
  249. def neighbourhoodOfItem(self, i):
  250. return self.neighbourhoods[i]
  251. def fit(self, X, nebSize=None):
  252. self.timerStart("NN_fit_all")
  253. self.fit_chained(X, nebSize)
  254. #self.fit_bruteForce_np(X, nebSize)
  255. self.timerStop("NN_fit_all")
  256. def fit_optimized(self, X, nebSize=None):
  257. self.timerStart("NN_fit_all")
  258. if nebSize == None:
  259. nebSize = self.nebSize
  260. nPoints = len(X)
  261. nFeatures = len(X[0])
  262. if nFeatures > 15 or nebSize >= (nPoints // 2):
  263. print("Using brute force")
  264. self.fit_bruteForce_np(X, nebSize)
  265. else:
  266. print("Using chained")
  267. self.fit_chained(X, nebSize)
  268. self.timerStop("NN_fit_all")
  269. def fit_bruteForce(self, X, nebSize=None):
  270. if nebSize == None:
  271. nebSize = self.nebSize
  272. self.timerStart("NN_fit_bf_init")
  273. isGreaterThan = lambda x, y: x[1] > y[1]
  274. self.neighbourhoods = [MaxHeap(nebSize, isGreaterThan, (i, 0.0)) for i in range(len(X))]
  275. self.timerStop("NN_fit_bf_init")
  276. self.timerStart("NN_fit_bf_loop")
  277. for (i, x) in enumerate(X):
  278. nbh = self.neighbourhoods[i]
  279. for (j, y) in enumerate(X[i+1:]):
  280. j += i + 1
  281. self.timerStart("NN_fit_bf_dist")
  282. d = dist(x,y)
  283. self.timerStop("NN_fit_bf_dist")
  284. self.timerStart("NN_fit_bf_insert")
  285. nbh.insert((j,d))
  286. self.neighbourhoods[j].insert((i,d))
  287. self.timerStop("NN_fit_bf_insert")
  288. self.timerStart("NN_fit_bf_toList")
  289. self.neighbourhoods[i] = nbh.toArray(lambda v: v[0])
  290. self.timerStop("NN_fit_bf_toList")
  291. self.timerStop("NN_fit_bf_loop")
  292. def fit_bruteForce_np(self, X, nebSize=None):
  293. self.timerStart("NN_fit_bfnp_init")
  294. numOfPoints = len(X)
  295. nFeatures = len(X[0])
  296. tX = tf.convert_to_tensor(X)
  297. def distancesTo(x):
  298. w = np.repeat([x], numOfPoints, axis=0)
  299. d = tf.keras.layers.Subtract()([w,tX])
  300. t = tf.keras.layers.Dot(axes=(1,1))([d,d])
  301. return t.numpy()
  302. if nebSize == None:
  303. nebSize = self.nebSize
  304. isGreaterThan = lambda x, y: x[1] > y[1]
  305. self.neighbourhoods = [MaxHeap(nebSize, isGreaterThan, (i, 0.0)) for i in range(len(X))]
  306. self.timerStop("NN_fit_bfnp_init")
  307. self.timerStart("NN_fit_bfnp_loop")
  308. for (i, x) in enumerate(X):
  309. self.timerStart("NN_fit_bfnp_dist")
  310. distances = distancesTo(x)
  311. self.timerStop("NN_fit_bfnp_dist")
  312. nbh = self.neighbourhoods[i]
  313. for (j, y) in enumerate(X[i+1:]):
  314. j += i + 1
  315. d = distances[j]
  316. self.timerStart("NN_fit_bfnp_insert")
  317. nbh.insert((j,d))
  318. self.neighbourhoods[j].insert((i,d))
  319. self.timerStop("NN_fit_bfnp_insert")
  320. self.timerStart("NN_fit_bfnp_toList")
  321. self.neighbourhoods[i] = nbh.toArray(lambda v: v[0])
  322. self.timerStop("NN_fit_bfnp_toList")
  323. self.timerStop("NN_fit_bfnp_loop")
  324. def fit_chained(self, X, nebSize=None):
  325. self.timerStart("NN_fit_chained_init")
  326. if nebSize == None:
  327. nebSize = self.nebSize
  328. nPoints = len(X)
  329. nFeatures = len(X[0])
  330. neigh = NearestNeighbors(n_neighbors=nebSize)
  331. neigh.fit(X)
  332. self.timerStop("NN_fit_chained_init")
  333. self.timerStart("NN_fit_chained_toList")
  334. self.neighbourhoods = [
  335. (neigh.kneighbors([x], nebSize, return_distance=False))[0]
  336. for (i, x) in enumerate(X)
  337. ]
  338. self.timerStop("NN_fit_chained_toList")
  339. # ===============================================================
  340. # Heuristic search
  341. # ===============================================================
  342. def fit_heuristic(self, X, nebSize=None):
  343. if nebSize == None:
  344. nebSize = self.nebSize
  345. nPoints = len(X)
  346. def walkUp(nbh, ball, x, i):
  347. while ball.parent is not None:
  348. print(f"{i}: up (r: {nbh.getMax()})")
  349. oldBall = ball
  350. ball = ball.parent
  351. for c in ball.childs:
  352. if c != oldBall:
  353. walkDown(nbh, c, x)
  354. def walkDown(nbh, ball, x):
  355. if ball is None:
  356. return
  357. print(f"{i}: down (r: {nbh.getMax()})")
  358. if dist(x, ball.center) - ball.radius < nbh.getMax()[1]:
  359. if ball.childs == []:
  360. for (j, _) in ball.points:
  361. nbh.insert((j, dist(x, X[j])))
  362. else:
  363. for c in ball.childs:
  364. walkDown(nbh, c, x)
  365. def countBoles(b):
  366. if b is None:
  367. return 0
  368. return 1 + sum(map(countBoles, b.childs))
  369. root = Ball(X, range(len(X)))
  370. queue = [root]
  371. while queue != []:
  372. ball = queue[0]
  373. queue = queue[1:]
  374. if len(ball) <= nebSize:
  375. continue
  376. queue = ball.divideBall(X) + queue
  377. isGreaterThan = lambda x, y: x[1] > y[1]
  378. self.neighbourhoods = [MaxHeap(nPoints, isGreaterThan, (i, 0.0)) for i in range(len(X))]
  379. print("#B: " + str(countBoles(root)))
  380. exit()
  381. z = X[0]
  382. for (i, x) in enumerate(X):
  383. nbh = self.neighbourhoods[i]
  384. b = root.smalestBallFor(i)
  385. if b.parent is not None:
  386. b = b.parent
  387. for (j, _) in b.points:
  388. d = dist(x, X[j])
  389. nbh.insert((j, d))
  390. walkUp(nbh, b, x, i)