NNSearch.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import math
  2. import tensorflow as tf
  3. import numpy as np
  4. from sklearn.neighbors import NearestNeighbors
  5. from library.timing import timing
  6. class NNSearch:
  7. def __init__(self, nebSize=5, timingDict=None):
  8. self.nebSize = nebSize
  9. self.neighbourhoods = []
  10. self.timingDict = timingDict
  11. def timerStart(self, name):
  12. if self.timingDict is not None:
  13. if name not in self.timingDict:
  14. self.timingDict[name] = timing(name)
  15. self.timingDict[name].start()
  16. def timerStop(self, name):
  17. if self.timingDict is not None:
  18. if name in self.timingDict:
  19. self.timingDict[name].stop()
  20. def neighbourhoodOfItem(self, i):
  21. return self.neighbourhoods[i]
  22. def fit(self, X, nebSize=None):
  23. self.timerStart("NN_fit_chained_init")
  24. if nebSize == None:
  25. nebSize = self.nebSize
  26. nPoints = len(X)
  27. nFeatures = len(X[0])
  28. neigh = NearestNeighbors(n_neighbors=nebSize)
  29. neigh.fit(X)
  30. self.timerStop("NN_fit_chained_init")
  31. self.timerStart("NN_fit_chained_toList")
  32. self.neighbourhoods = [
  33. (neigh.kneighbors([x], nebSize, return_distance=False))[0]
  34. for (i, x) in enumerate(X)
  35. ]
  36. self.timerStop("NN_fit_chained_toList")