interfaces.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. """
  2. This module contains used interfaces for testing the Generative Adversarial Networks.
  3. """
  4. class GanBaseClass:
  5. """
  6. Base class for the Generative Adversarial Network.
  7. It defines the interface used by the Exercise class.
  8. """
  9. def __init__(self):
  10. """
  11. Initializes the class.
  12. """
  13. def reset(self):
  14. """
  15. Resets the trained GAN to an random state.
  16. """
  17. raise NotImplementedError
  18. def train(self, dataSet):
  19. """
  20. Trains the GAN.
  21. """
  22. raise NotImplementedError
  23. def generateDataPoint(self):
  24. """
  25. Generates one synthetic data-point.
  26. """
  27. return self.generateData(1)[0]
  28. def generateData(self, numOfSamples=1):
  29. """
  30. Generates a list of synthetic data-points.
  31. *numOfSamples* is an integer > 0. It gives the number of generated samples.
  32. """
  33. raise NotImplementedError
  34. def predict(self, data):
  35. """
  36. Takes a list (numpy array) of data points.
  37. Returns a list with real values in [0,1] for the propapility
  38. that a point is in the minority dataset. With:
  39. 0.0: point is in majority set
  40. 1.0: point is in minority set
  41. """
  42. raise NotImplemented