Repeater.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. from library.interfaces import GanBaseClass
  9. class Repeater(GanBaseClass):
  10. """
  11. This is a toy example of a GAN.
  12. It repeats the first point of the training-data-set.
  13. """
  14. def __init__(self):
  15. self.isTrained = False
  16. self.exampleItems = None
  17. self.nextIndex = 0
  18. def reset(self, _dataSet):
  19. """
  20. Resets the trained GAN to an random state.
  21. """
  22. self.isTrained = False
  23. self.exampleItems = None
  24. def train(self, dataSet):
  25. """
  26. Trains the GAN.
  27. It stores the data points in the training data set and mark as trained.
  28. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  29. We are only interested in the first *maxListSize* points in class 1.
  30. """
  31. if dataSet.data1.shape[0] <= 0:
  32. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  33. self.isTrained = True
  34. self.exampleItems = dataSet.data1.copy()
  35. def generateDataPoint(self):
  36. """
  37. Returns one synthetic data point by repeating the stored list.
  38. """
  39. if not self.isTrained:
  40. raise ValueError("Try to generate data with untrained Re.")
  41. if self.nextIndex >= self.exampleItems.shape[0]:
  42. self.nextIndex = 0
  43. i = self.nextIndex
  44. self.nextIndex += 1
  45. return self.exampleItems[i]
  46. def generateData(self, numOfSamples=1):
  47. """
  48. Generates a list of synthetic data-points.
  49. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  50. """
  51. numOfSamples = int(numOfSamples)
  52. if numOfSamples < 1:
  53. raise AttributeError("Expected numOfSamples to be > 0")
  54. return np.array([self.generateDataPoint() for _ in range(numOfSamples)])