Forráskód Böngészése

Merge branch 'ctab' (Added CTAB-GAN)

Kristian Schultz 4 éve
szülő
commit
0e80636ed8
4 módosított fájl, 63 hozzáadás és 0 törlés
  1. 3 0
      .gitmodules
  2. 1 0
      external/CTAB-GAN
  3. 58 0
      library/generators/ctab.py
  4. 1 0
      model

+ 3 - 0
.gitmodules

@@ -4,3 +4,6 @@
 [submodule "external/LoRAS"]
 	path = external/LoRAS
 	url = https://github.com/COSPOV/LoRAS.git
+[submodule "external/CTAB-GAN"]
+	path = external/CTAB-GAN
+	url = https://github.com/Team-TUD/CTAB-GAN.git

+ 1 - 0
external/CTAB-GAN

@@ -0,0 +1 @@
+Subproject commit 3acfd9cb8f9ae795ab0a5d701df3b0c2ec52a4ed

+ 58 - 0
library/generators/ctab.py

@@ -0,0 +1,58 @@
+from library.interfaces import GanBaseClass
+from library.dataset import DataSet
+
+from model.ctabgan import CTABGAN
+
+import warnings
+warnings.filterwarnings("ignore")
+
+
+
+class CtabGan(GanBaseClass):
+    """
+    This is a toy example of a GAN.
+    It repeats the first point of the training-data-set.
+    """
+    def __init__(self, n_feat, epochs=10, debug=True):
+        self.isTrained = False
+        self.epochs = 10
+
+    def reset(self):
+        """
+        Resets the trained GAN to an random state.
+        """
+        self.isTrained = False
+        self.synthesizer = CTABGAN(epochs = self.epochs) 
+
+    def train(self, dataSet):
+        """
+        Trains the GAN.
+
+        It stores the data points in the training data set and mark as trained.
+
+        *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
+        We are only interested in the first *maxListSize* points in class 1.
+        """
+        if dataSet.data1.shape[0] <= 0:
+            raise AttributeError("Train: Expected data class 1 to contain at least one point.")
+
+        self.synthesizer.fit(train_data=dataSet.data1)
+        self.isTrained = True
+
+    def generateDataPoint(self):
+        """
+        Returns one synthetic data point by repeating the stored list.
+        """
+        return (self.generateData(1))[0]
+
+
+    def generateData(self, numOfSamples=1):
+        """
+        Generates a list of synthetic data-points.
+
+        *numOfSamples* is a integer > 0. It gives the number of new generated samples.
+        """
+        if not self.isTrained:
+            raise ValueError("Try to generate data with untrained Re.")
+
+        return self.synthesizer.generate_samples(numOfSamples)

+ 1 - 0
model

@@ -0,0 +1 @@
+external/CTAB-GAN/model