NNSearch.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import math
  2. import random
  3. import tensorflow as tf
  4. import numpy as np
  5. from sklearn.neighbors import NearestNeighbors
  6. from sklearn.utils import shuffle
  7. from library.timing import timing
  8. def randomIndices(size, outputSize=None, indicesToIgnore=None):
  9. indices = list(range(size))
  10. if indicesToIgnore is not None:
  11. for x in indicesToIgnore:
  12. indices.remove(x)
  13. size = len(indices)
  14. if outputSize is None or outputSize > size:
  15. outputSize = size
  16. r = []
  17. for _ in range(outputSize):
  18. size -= 1
  19. if size < 0:
  20. break
  21. if size == 0:
  22. r.append(indices[0])
  23. else:
  24. p = random.randint(0, size)
  25. x = indices[p]
  26. r.append(x)
  27. indices.remove(x)
  28. return r
  29. class NNSearch:
  30. def __init__(self, nebSize=5, timingDict=None):
  31. self.nebSize = nebSize
  32. self.neighbourhoods = []
  33. self.timingDict = timingDict
  34. self.basePoints = []
  35. def timerStart(self, name):
  36. if self.timingDict is not None:
  37. if name not in self.timingDict:
  38. self.timingDict[name] = timing(name)
  39. self.timingDict[name].start()
  40. def timerStop(self, name):
  41. if self.timingDict is not None:
  42. if name in self.timingDict:
  43. self.timingDict[name].stop()
  44. def neighbourhoodOfItem(self, i):
  45. return self.neighbourhoods[i]
  46. def getNbhPointsOfItem(self, index):
  47. return self.getPointsFromIndices(self.neighbourhoodOfItem(index))
  48. def getPointsFromIndices(self, indices):
  49. permutation = randomIndices(len(indices))
  50. nmbi = np.array(indices)[permutation]
  51. nmb = self.basePoints[nmbi]
  52. return tf.convert_to_tensor(nmb)
  53. def neighbourhoodOfItemList(self, items, maxCount=None):
  54. nbhIndices = set()
  55. duplicates = []
  56. for i in items:
  57. for x in self.neighbourhoodOfItem(i):
  58. if x in nbhIndices:
  59. duplicates.append(x)
  60. else:
  61. nbhIndices.add(x)
  62. nbhIndices = list(nbhIndices)
  63. if maxCount is not None:
  64. if len(nbhIndices) < maxCount:
  65. nbhIndices.extend(duplicates)
  66. nbhIndices = nbhIndices[0:maxCount]
  67. return self.getPointsFromIndices(nbhIndices)
  68. def fit(self, haystack, needles=None, nebSize=None):
  69. self.timerStart("NN_fit_chained_init")
  70. if nebSize == None:
  71. nebSize = self.nebSize
  72. if needles is None:
  73. needles = haystack
  74. self.basePoints = haystack
  75. neigh = NearestNeighbors(n_neighbors=nebSize)
  76. neigh.fit(haystack)
  77. self.timerStop("NN_fit_chained_init")
  78. self.timerStart("NN_fit_chained_toList")
  79. self.neighbourhoods = [
  80. (neigh.kneighbors([x], nebSize, return_distance=False))[0]
  81. for (i, x) in enumerate(needles)
  82. ]
  83. self.timerStop("NN_fit_chained_toList")
  84. return self