GanExamples.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. from library.interfaces import GanBaseClass
  9. class StupidToyPointGan(GanBaseClass):
  10. """
  11. This is a toy example of a GAN.
  12. It repeats the first point of the training-data-set.
  13. """
  14. def __init__(self):
  15. """
  16. Initializes the class and mark it as untrained.
  17. """
  18. self.isTrained = False
  19. self.exampleItem = None
  20. def reset(self, _dataSet):
  21. """
  22. Resets the trained GAN to an random state.
  23. """
  24. self.isTrained = False
  25. self.exampleItem = None
  26. def train(self, dataSet):
  27. """
  28. Trains the GAN.
  29. It stores the first data-point in the training data-set and mark the GAN as trained.
  30. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  31. We are only interested in the class 1.
  32. """
  33. if dataSet.data1.shape[0] <= 0:
  34. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  35. self.isTrained = True
  36. self.exampleItem = dataSet.data1[0].copy()
  37. def generateDataPoint(self):
  38. """
  39. Generates one synthetic data-point by copying the stored data point.
  40. """
  41. if not self.isTrained:
  42. raise ValueError("Try to generate data with untrained GAN.")
  43. return self.exampleItem
  44. def generateData(self, numOfSamples=1):
  45. """
  46. Generates a list of synthetic data-points.
  47. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  48. """
  49. numOfSamples = int(numOfSamples)
  50. if numOfSamples < 1:
  51. raise AttributeError("Expected numOfSamples to be > 0")
  52. return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
  53. class StupidToyListGan(GanBaseClass):
  54. """
  55. This is a toy example of a GAN.
  56. It repeats the first point of the training-data-set.
  57. """
  58. def __init__(self, maxListSize=100):
  59. self.isTrained = False
  60. self.exampleItems = None
  61. self.nextIndex = 0
  62. self.maxListSize = int(maxListSize)
  63. if self.maxListSize < 1:
  64. raise AttributeError("Expected maxListSize to be > 0 but got " + str(self.maxListSize))
  65. def reset(self, _dataSet):
  66. """
  67. Resets the trained GAN to an random state.
  68. """
  69. self.isTrained = False
  70. self.exampleItems = None
  71. def train(self, dataSet):
  72. """
  73. Trains the GAN.
  74. It stores the first data-point in the training data-set and mark the GAN as trained.
  75. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  76. We are only interested in the first *maxListSize* points in class 1.
  77. """
  78. if dataSet.data1.shape[0] <= 0:
  79. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  80. self.isTrained = True
  81. self.exampleItems = dataSet.data1[: self.maxListSize].copy()
  82. def generateDataPoint(self):
  83. """
  84. Returns one synthetic data point by repeating the stored list.
  85. """
  86. if not self.isTrained:
  87. raise ValueError("Try to generate data with untrained GAN.")
  88. i = self.nextIndex
  89. self.nextIndex += 1
  90. if self.nextIndex >= self.exampleItems.shape[0]:
  91. self.nextIndex = 0
  92. return self.exampleItems[i]
  93. def generateData(self, numOfSamples=1):
  94. """
  95. Generates a list of synthetic data-points.
  96. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  97. """
  98. numOfSamples = int(numOfSamples)
  99. if numOfSamples < 1:
  100. raise AttributeError("Expected numOfSamples to be > 0")
  101. return np.array([self.generateDataPoint() for _ in range(numOfSamples)])