LoRAS_ProWRAS.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from library.ext_prowras import ProWRAS_gen
  2. from library.interfaces import GanBaseClass
  3. class ProWRAS(GanBaseClass):
  4. """
  5. This is a toy example of a GAN.
  6. It repeats the first point of the training-data-set.
  7. """
  8. def __init__(self
  9. , max_levels = 5
  10. , convex_nbd = 5
  11. , n_neighbors = 5
  12. , max_concov = None
  13. , theta = 1.0
  14. , shadow = 100
  15. , sigma = 0.000001
  16. , n_jobs = 1
  17. , debug = False
  18. ):
  19. """
  20. Initializes the class and mark it as untrained.
  21. """
  22. self.data = None
  23. self.max_levels = max_levels
  24. self.convex_nbd = convex_nbd
  25. self.n_neighbors = n_neighbors
  26. self.max_concov = max_concov
  27. self.theta = theta
  28. self.shadow = shadow
  29. self.sigma = sigma
  30. self.n_jobs = n_jobs
  31. self.debug = debug
  32. def reset(self):
  33. """
  34. Resets the trained GAN to an random state.
  35. """
  36. pass
  37. def train(self, dataSet):
  38. """
  39. Trains the GAN.
  40. It stores the first data-point in the training data-set and mark the GAN as trained.
  41. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  42. We are only interested in the class 1.
  43. """
  44. self.data = dataSet
  45. def generateDataPoint(self):
  46. """
  47. Generates one synthetic data-point by copying the stored data point.
  48. """
  49. return self.generateData(1)[0]
  50. def generateData(self, numOfSamples=1):
  51. """
  52. Generates a list of synthetic data-points.
  53. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  54. """
  55. if self.max_concov is not None:
  56. max_concov = self.max_concov
  57. else:
  58. max_concov = self.data.data.shape[0]
  59. return ProWRAS_gen(
  60. data = self.data.data,
  61. labels = self.data.labels,
  62. max_levels = self.max_levels,
  63. convex_nbd = self.convex_nbd,
  64. n_neighbors = self.n_neighbors,
  65. max_concov = max_concov,
  66. num_samples_to_generate = numOfSamples,
  67. theta = self.theta,
  68. shadow = self.shadow,
  69. sigma = self.sigma,
  70. n_jobs = self.n_jobs,
  71. enableDebug = self.debug)[0][:numOfSamples]