|
|
@@ -0,0 +1,73 @@
|
|
|
+"""
|
|
|
+This module contains some example Generative Adversarial Networks for testing.
|
|
|
+
|
|
|
+The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
|
|
|
+for testing the interface. Hope your actually GAN will perform better than this two.
|
|
|
+
|
|
|
+The class SimpleGan is a simple standard Generative Adversarial Network.
|
|
|
+"""
|
|
|
+
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+from library.interfaces import GanBaseClass
|
|
|
+
|
|
|
+
|
|
|
+class Repeater(GanBaseClass):
|
|
|
+ """
|
|
|
+ This is a toy example of a GAN.
|
|
|
+ It repeats the first point of the training-data-set.
|
|
|
+ """
|
|
|
+ def __init__(self):
|
|
|
+ self.isTrained = False
|
|
|
+ self.exampleItems = None
|
|
|
+ self.nextIndex = 0
|
|
|
+
|
|
|
+ def reset(self):
|
|
|
+ """
|
|
|
+ Resets the trained GAN to an random state.
|
|
|
+ """
|
|
|
+ self.isTrained = False
|
|
|
+ self.exampleItems = None
|
|
|
+
|
|
|
+ def train(self, dataSet):
|
|
|
+ """
|
|
|
+ Trains the GAN.
|
|
|
+
|
|
|
+ It stores the data points in the training data set and mark as trained.
|
|
|
+
|
|
|
+ *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
|
|
|
+ We are only interested in the first *maxListSize* points in class 1.
|
|
|
+ """
|
|
|
+ if dataSet.data1.shape[0] <= 0:
|
|
|
+ raise AttributeError("Train: Expected data class 1 to contain at least one point.")
|
|
|
+
|
|
|
+ self.isTrained = True
|
|
|
+ self.exampleItems = dataSet.data1.copy()
|
|
|
+
|
|
|
+ def generateDataPoint(self):
|
|
|
+ """
|
|
|
+ Returns one synthetic data point by repeating the stored list.
|
|
|
+ """
|
|
|
+ if not self.isTrained:
|
|
|
+ raise ValueError("Try to generate data with untrained Re.")
|
|
|
+
|
|
|
+ i = self.nextIndex
|
|
|
+ self.nextIndex += 1
|
|
|
+ if self.nextIndex >= self.exampleItems.shape[0]:
|
|
|
+ self.nextIndex = 0
|
|
|
+
|
|
|
+ return self.exampleItems[i]
|
|
|
+
|
|
|
+
|
|
|
+ def generateData(self, numOfSamples=1):
|
|
|
+ """
|
|
|
+ Generates a list of synthetic data-points.
|
|
|
+
|
|
|
+ *numOfSamples* is a integer > 0. It gives the number of new generated samples.
|
|
|
+ """
|
|
|
+ numOfSamples = int(numOfSamples)
|
|
|
+ if numOfSamples < 1:
|
|
|
+ raise AttributeError("Expected numOfSamples to be > 0")
|
|
|
+
|
|
|
+ return np.array([self.generateDataPoint() for _ in range(numOfSamples)])
|