GanExamples.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. from library.interfaces import GanBaseClass
  9. class StupidToyPointGan(GanBaseClass):
  10. """
  11. This is a toy example of a GAN.
  12. It repeats the first point of the training-data-set.
  13. """
  14. def __init__(self):
  15. """
  16. Initializes the class and mark it as untrained.
  17. """
  18. self.isTrained = False
  19. self.exampleItem = None
  20. def train(self, dataSet):
  21. """
  22. Trains the GAN.
  23. It stores the first data-point in the training data-set and mark the GAN as trained.
  24. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  25. We are only interested in the class 1.
  26. """
  27. if dataSet.data1.shape[0] <= 0:
  28. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  29. self.isTrained = True
  30. self.exampleItem = dataSet.data1[0].copy()
  31. def generateDataPoint(self):
  32. """
  33. Generates one synthetic data-point by copying the stored data point.
  34. """
  35. if not self.isTrained:
  36. raise ValueError("Try to generate data with untrained GAN.")
  37. return self.exampleItem
  38. def generateData(self, numOfSamples=1):
  39. """
  40. Generates a list of synthetic data-points.
  41. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  42. """
  43. numOfSamples = int(numOfSamples)
  44. if numOfSamples < 1:
  45. raise AttributeError("Expected numOfSamples to be > 0")
  46. return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
  47. class StupidToyListGan(GanBaseClass):
  48. """
  49. This is a toy example of a GAN.
  50. It repeats the first point of the training-data-set.
  51. """
  52. def __init__(self, maxListSize=100):
  53. self.isTrained = False
  54. self.exampleItems = None
  55. self.nextIndex = 0
  56. self.maxListSize = int(maxListSize)
  57. if self.maxListSize < 1:
  58. raise AttributeError("Expected maxListSize to be > 0 but got " + str(self.maxListSize))
  59. def train(self, dataSet):
  60. """
  61. Trains the GAN.
  62. It stores the first data-point in the training data-set and mark the GAN as trained.
  63. *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
  64. We are only interested in the first *maxListSize* points in class 1.
  65. """
  66. if dataSet.data1.shape[0] <= 0:
  67. raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  68. self.isTrained = True
  69. self.exampleItems = dataSet.data1[: self.maxListSize].copy()
  70. def generateDataPoint(self):
  71. """
  72. Returns one synthetic data point by repeating the stored list.
  73. """
  74. if not self.isTrained:
  75. raise ValueError("Try to generate data with untrained GAN.")
  76. i = self.nextIndex
  77. self.nextIndex += 1
  78. if self.nextIndex >= self.exampleItems.shape[0]:
  79. self.nextIndex = 0
  80. return self.exampleItems[i]
  81. def generateData(self, numOfSamples=1):
  82. """
  83. Generates a list of synthetic data-points.
  84. *numOfSamples* is a integer > 0. It gives the number of new generated samples.
  85. """
  86. numOfSamples = int(numOfSamples)
  87. if numOfSamples < 1:
  88. raise AttributeError("Expected numOfSamples to be > 0")
  89. return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
  90. # class SimpleGan(GanBaseClass):
  91. # def __init__(self, maxListSize=100):
  92. # self.isTrained = False
  93. # self.exampleItems = None
  94. # self.nextIndex = 0
  95. # self.maxListSize = int(maxListSize)
  96. # if self.maxListSize < 1:
  97. # raise AttributeError(f"Expected maxListSize to be > 0 but got {self.maxListSize}")
  98. #
  99. #
  100. # def train(self, dataSet):
  101. # if dataSet.data1.shape[0] <= 0:
  102. # raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
  103. #
  104. # self.isTrained = True
  105. # self.exampleItems = dataSet.data1[: self.maxListSize].copy()
  106. #
  107. # def generateDataPoint(self, numOfSamples=1):
  108. # if not self.isTrained:
  109. # raise ValueError("Try to generate data with untrained GAN.")
  110. #
  111. # i = self.nextIndex
  112. # self.nextIndex += 1
  113. # if self.nextIndex >= self.exampleItems.shape[0]:
  114. # self.nextIndex = 0
  115. #
  116. # return self.exampleItems[i]
  117. #
  118. #
  119. # def generateData(self, numOfSamples=1):
  120. # numOfSamples = int(numOfSamples)
  121. # if numOfSamples < 1:
  122. # raise AttributeError("Expected numOfSamples to be > 0")
  123. #
  124. # return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
  125. #