interfaces.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. """
  2. This module contains used interfaces for testing the Generative Adversarial Networks.
  3. """
  4. import numpy as np
  5. class GanBaseClass:
  6. """
  7. Base class for the Generative Adversarial Network.
  8. It defines the interface used by the Exercise class.
  9. """
  10. def __init__(self):
  11. """
  12. Initializes the class.
  13. """
  14. self.canPredict = False
  15. def reset(self, dataSet):
  16. """
  17. Resets the trained GAN to an random state.
  18. """
  19. raise NotImplementedError
  20. def train(self, dataSet):
  21. """
  22. Trains the GAN.
  23. """
  24. raise NotImplementedError
  25. def generateDataPoint(self):
  26. """
  27. Generates one synthetic data-point.
  28. """
  29. return self.generateData(1)[0]
  30. def generateData(self, numOfSamples=1):
  31. """
  32. Generates a list of synthetic data-points.
  33. *numOfSamples* is an integer > 0. It gives the number of generated samples.
  34. """
  35. raise NotImplementedError
  36. def predict(self, data, limit=0.5):
  37. """
  38. Takes a list (numpy array) of data points.
  39. Returns a list with real values in [0,1] for the propapility
  40. that a point is in the minority dataset. With:
  41. 0.0: point is in majority set
  42. 1.0: point is in minority set
  43. """
  44. return np.array([max(0, min(1, int(x + 1.0 - limit))) for x in self.predictReal(data)])
  45. def predictReal(self, data):
  46. raise NotImplemented