LoRAS_ProWRAS.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from library.ext_prowras import ProWRAS_gen
  2. from library.interfaces import GanBaseClass
  3. class ProWRAS(GanBaseClass):
  4. """
  5. This is a toy example of a GAN.
  6. It repeats the first point of the training-data-set.
  7. """
  8. def __init__(self
  9. , max_levels = 5
  10. , convex_nbd = 5
  11. , n_neighbors = 5
  12. , max_concov = None
  13. , theta = 1.0
  14. , shadow = 100
  15. , sigma = 0.000001
  16. , n_jobs = 1
  17. , debug = False
  18. ):
  19. """
  20. Initializes the class and mark it as untrained.
  21. """
  22. self.data = None
  23. self.max_levels = max_levels
  24. self.convex_nbd = convex_nbd
  25. self.n_neighbors = n_neighbors
  26. self.max_concov = max_concov
  27. self.theta = theta
  28. self.shadow = shadow
  29. self.sigma = sigma
  30. self.n_jobs = n_jobs
  31. self.debug = debug
  32. self.canPredict = False
  33. def reset(self, _dataSet):
  34. """
  35. Resets the trained GAN to an random state.
  36. """
  37. pass
  38. def train(self, dataSet):
  39. """
  40. Trains the GAN.
  41. It stores the first data-point in the training data-set and mark the GAN as trained.
  42. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  43. We are only interested in the class 1.
  44. """
  45. self.data = dataSet
  46. def generateDataPoint(self):
  47. """
  48. Generates one synthetic data-point by copying the stored data point.
  49. """
  50. return self.generateData(1)[0]
  51. def generateData(self, numOfSamples=1):
  52. """
  53. Generates a list of synthetic data-points.
  54. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  55. """
  56. if self.max_concov is not None:
  57. max_concov = self.max_concov
  58. else:
  59. max_concov = self.data.data.shape[0]
  60. return ProWRAS_gen(
  61. data = self.data.data,
  62. labels = self.data.labels,
  63. max_levels = self.max_levels,
  64. convex_nbd = self.convex_nbd,
  65. n_neighbors = self.n_neighbors,
  66. max_concov = max_concov,
  67. num_samples_to_generate = numOfSamples,
  68. theta = self.theta,
  69. shadow = self.shadow,
  70. sigma = self.sigma,
  71. n_jobs = self.n_jobs,
  72. enableDebug = self.debug)[0][:numOfSamples]