NNSearch.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. from sklearn.neighbors import NearestNeighbors
  5. from sklearn.utils import shuffle
  6. from library.timing import timing
  7. class NNSearch:
  8. def __init__(self, nebSize=5, timingDict=None):
  9. self.nebSize = nebSize
  10. self.neighbourhoods = []
  11. self.timingDict = timingDict
  12. self.basePoints = []
  13. def timerStart(self, name):
  14. if self.timingDict is not None:
  15. if name not in self.timingDict:
  16. self.timingDict[name] = timing(name)
  17. self.timingDict[name].start()
  18. def timerStop(self, name):
  19. if self.timingDict is not None:
  20. if name in self.timingDict:
  21. self.timingDict[name].stop()
  22. def neighbourhoodOfItem(self, i):
  23. return self.neighbourhoods[i]
  24. def getNbhPointsOfItem(self, index):
  25. return self.getPointsFromIndices(self.neighbourhoodOfItem(index))
  26. def getPointsFromIndices(self, indices):
  27. nmbi = shuffle(np.array([indices]))
  28. nmb = self.basePoints[nmbi]
  29. return tf.convert_to_tensor(nmb[0])
  30. def neighbourhoodOfItemList(self, items, maxCount=None):
  31. nbhIndices = set()
  32. duplicates = []
  33. for i in items:
  34. for x in self.neighbourhoodOfItem(i):
  35. if x in nbhIndices:
  36. duplicates.append(x)
  37. else:
  38. nbhIndices.add(x)
  39. nbhIndices = list(nbhIndices)
  40. if maxCount is not None:
  41. if len(nbhIndices) < maxCount:
  42. nbhIndices.extend(duplicates)
  43. nbhIndices = nbhIndices[0:maxCount]
  44. return self.getPointsFromIndices(nbhIndices)
  45. def fit(self, haystack, needles=None, nebSize=None):
  46. self.timerStart("NN_fit_chained_init")
  47. if nebSize == None:
  48. nebSize = self.nebSize
  49. if needles is None:
  50. needles = haystack
  51. self.basePoints = haystack
  52. neigh = NearestNeighbors(n_neighbors=nebSize)
  53. neigh.fit(haystack)
  54. self.timerStop("NN_fit_chained_init")
  55. self.timerStart("NN_fit_chained_toList")
  56. self.neighbourhoods = [
  57. (neigh.kneighbors([x], nebSize, return_distance=False))[0]
  58. for (i, x) in enumerate(needles)
  59. ]
  60. self.timerStop("NN_fit_chained_toList")
  61. return self