Jelajahi Sumber

Merge branch 'master' with server

Kristian Schultz 4 tahun lalu
induk
melakukan
d97b81f922
3 mengubah file dengan 9 tambahan dan 6 penghapusan
  1. 1 0
      library/analysis.py
  2. 1 0
      library/generators/__init__.py
  3. 7 6
      library/generators/ctab.py

+ 1 - 0
library/analysis.py

@@ -265,6 +265,7 @@ generators = [ ("ProWRAS",       lambda _data: ProWRAS())
              #, ("SpheredNoise",  lambda _data: SpheredNoise())
              , ("SimpleGAN",     lambda data: SimpleGan(numOfFeatures=data.data0.shape[1]))
              , ("ctGAN",         lambda data: CtGAN(data.data0.shape[1]))
+             , ("CTAB-GAN",      lambda _data: CtabGan())
              , ("convGAN",       lambda data: ConvGAN(data.data0.shape[1], neb=5, gen=5))
              , ("convGAN-full",  lambda data: ConvGAN(data.data0.shape[1], neb=data.data0.shape[1], gen=data.data0.shape[1]))
              ]

+ 1 - 0
library/generators/__init__.py

@@ -6,3 +6,4 @@ from library.generators.Repeater import Repeater
 from library.generators.SpheredNoise import SpheredNoise
 from library.generators.GanExamples import StupidToyListGan, StupidToyPointGan
 from library.generators.ctgan import CtGAN
+from library.generators.ctab import CtabGan

+ 7 - 6
library/generators/ctab.py

@@ -1,7 +1,8 @@
 from library.interfaces import GanBaseClass
 from library.dataset import DataSet
 
-from model.ctabgan import CTABGAN
+from model.synthesizer.ctabgan_synthesizer import CTABGANSynthesizer 
+import pandas as pd
 
 import warnings
 warnings.filterwarnings("ignore")
@@ -13,16 +14,16 @@ class CtabGan(GanBaseClass):
     This is a toy example of a GAN.
     It repeats the first point of the training-data-set.
     """
-    def __init__(self, n_feat, epochs=10, debug=True):
+    def __init__(self, epochs=10, debug=True):
         self.isTrained = False
-        self.epochs = 10
+        self.epochs = epochs
 
     def reset(self):
         """
         Resets the trained GAN to an random state.
         """
         self.isTrained = False
-        self.synthesizer = CTABGAN(epochs = self.epochs) 
+        self.synthesizer = CTABGANSynthesizer(epochs = self.epochs) 
 
     def train(self, dataSet):
         """
@@ -36,7 +37,7 @@ class CtabGan(GanBaseClass):
         if dataSet.data1.shape[0] <= 0:
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
-        self.synthesizer.fit(train_data=dataSet.data1)
+        self.synthesizer.fit(train_data=pd.DataFrame(dataSet.data1))
         self.isTrained = True
 
     def generateDataPoint(self):
@@ -55,4 +56,4 @@ class CtabGan(GanBaseClass):
         if not self.isTrained:
             raise ValueError("Try to generate data with untrained Re.")
 
-        return self.synthesizer.generate_samples(numOfSamples)
+        return self.synthesizer.sample(numOfSamples)