Repeater.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. from library.interfaces import GanBaseClass
  9. class Repeater(GanBaseClass):
  10. """
  11. This is a toy example of a GAN.
  12. It repeats the first point of the training-data-set.
  13. """
  14. def __init__(self):
  15. self.canPredict = False
  16. self.isTrained = False
  17. self.exampleItems = None
  18. self.nextIndex = 0
  19. def reset(self, _dataSet):
  20. """
  21. Resets the trained GAN to an random state.
  22. """
  23. self.isTrained = False
  24. self.exampleItems = None
  25. def train(self, dataSet):
  26. """
  27. Trains the GAN.
  28. It stores the data points in the training data set and mark as trained.
  29. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  30. We are only interested in the first *maxListSize* points in class 1.
  31. """
  32. if dataSet.data1.shape[0] <= 0:
  33. raise AttributeError("Train: Expected data class 1 to contain at least one point.")
  34. self.isTrained = True
  35. self.exampleItems = dataSet.data1.copy()
  36. def generateDataPoint(self):
  37. """
  38. Returns one synthetic data point by repeating the stored list.
  39. """
  40. if not self.isTrained:
  41. raise ValueError("Try to generate data with untrained Re.")
  42. if self.nextIndex >= self.exampleItems.shape[0]:
  43. self.nextIndex = 0
  44. i = self.nextIndex
  45. self.nextIndex += 1
  46. return self.exampleItems[i]
  47. def generateData(self, numOfSamples=1):
  48. """
  49. Generates a list of synthetic data-points.
  50. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  51. """
  52. numOfSamples = int(numOfSamples)
  53. if numOfSamples < 1:
  54. raise AttributeError("Expected numOfSamples to be > 0")
  55. return np.array([self.generateDataPoint() for _ in range(numOfSamples)])