Ver Fonte

First try of ctab-gan generator class.

Kristian Schultz há 4 anos atrás
pai
commit
cf0c619c5f
1 ficheiros alterados com 58 adições e 0 exclusões
  1. 58 0
      library/generators/ctab.py

+ 58 - 0
library/generators/ctab.py

@@ -0,0 +1,58 @@
+from library.interfaces import GanBaseClass
+from library.dataset import DataSet
+
+from model.ctabgan import CTABGAN
+
+import warnings
+warnings.filterwarnings("ignore")
+
+
+
+class CtabGan(GanBaseClass):
+    """
+    This is a toy example of a GAN.
+    It repeats the first point of the training-data-set.
+    """
+    def __init__(self, n_feat, epochs=10, debug=True):
+        self.isTrained = False
+        self.epochs = 10
+
+    def reset(self):
+        """
+        Resets the trained GAN to an random state.
+        """
+        self.isTrained = False
+        self.synthesizer = CTABGAN(epochs = self.epochs) 
+
+    def train(self, dataSet):
+        """
+        Trains the GAN.
+
+        It stores the data points in the training data set and mark as trained.
+
+        *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
+        We are only interested in the first *maxListSize* points in class 1.
+        """
+        if dataSet.data1.shape[0] <= 0:
+            raise AttributeError("Train: Expected data class 1 to contain at least one point.")
+
+        self.synthesizer.fit(train_data=dataSet.data1)
+        self.isTrained = True
+
+    def generateDataPoint(self):
+        """
+        Returns one synthetic data point by repeating the stored list.
+        """
+        return (self.generateData(1))[0]
+
+
+    def generateData(self, numOfSamples=1):
+        """
+        Generates a list of synthetic data-points.
+
+        *numOfSamples* is a integer > 0. It gives the number of new generated samples.
+        """
+        if not self.isTrained:
+            raise ValueError("Try to generate data with untrained Re.")
+
+        return self.synthesizer.generate_samples(numOfSamples)