Przeglądaj źródła

Added Repeater-dummy GAN.

Kristian Schultz 4 lat temu
rodzic
commit
1010f737d4
1 zmienionych plików z 73 dodań i 0 usunięć
  1. 73 0
      library/Repeater.py

+ 73 - 0
library/Repeater.py

@@ -0,0 +1,73 @@
+"""
+This module contains some example Generative Adversarial Networks for testing.
+
+The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
+for testing the interface. Hope your actually GAN will perform better than this two.
+
+The class SimpleGan is a simple standard Generative Adversarial Network.
+"""
+
+
+import numpy as np
+
+from library.interfaces import GanBaseClass
+
+
+class Repeater(GanBaseClass):
+    """
+    This is a toy example of a GAN.
+    It repeats the first point of the training-data-set.
+    """
+    def __init__(self):
+        self.isTrained = False
+        self.exampleItems = None
+        self.nextIndex = 0
+
+    def reset(self):
+        """
+        Resets the trained GAN to an random state.
+        """
+        self.isTrained = False
+        self.exampleItems = None
+
+    def train(self, dataSet):
+        """
+        Trains the GAN.
+
+        It stores the data points in the training data set and mark as trained.
+
+        *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
+        We are only interested in the first *maxListSize* points in class 1.
+        """
+        if dataSet.data1.shape[0] <= 0:
+            raise AttributeError("Train: Expected data class 1 to contain at least one point.")
+
+        self.isTrained = True
+        self.exampleItems = dataSet.data1.copy()
+
+    def generateDataPoint(self):
+        """
+        Returns one synthetic data point by repeating the stored list.
+        """
+        if not self.isTrained:
+            raise ValueError("Try to generate data with untrained Re.")
+
+        i = self.nextIndex
+        self.nextIndex += 1
+        if self.nextIndex >= self.exampleItems.shape[0]:
+            self.nextIndex = 0
+
+        return self.exampleItems[i]
+
+
+    def generateData(self, numOfSamples=1):
+        """
+        Generates a list of synthetic data-points.
+
+        *numOfSamples* is a integer > 0. It gives the number of new generated samples.
+        """
+        numOfSamples = int(numOfSamples)
+        if numOfSamples < 1:
+            raise AttributeError("Expected numOfSamples to be > 0")
+
+        return np.array([self.generateDataPoint() for _ in range(numOfSamples)])