|
|
@@ -1,7 +1,8 @@
|
|
|
from library.interfaces import GanBaseClass
|
|
|
from library.dataset import DataSet
|
|
|
|
|
|
-from model.ctabgan import CTABGAN
|
|
|
+from model.synthesizer.ctabgan_synthesizer import CTABGANSynthesizer
|
|
|
+import pandas as pd
|
|
|
|
|
|
import warnings
|
|
|
warnings.filterwarnings("ignore")
|
|
|
@@ -13,16 +14,16 @@ class CtabGan(GanBaseClass):
|
|
|
This is a toy example of a GAN.
|
|
|
It repeats the first point of the training-data-set.
|
|
|
"""
|
|
|
- def __init__(self, n_feat, epochs=10, debug=True):
|
|
|
+ def __init__(self, epochs=10, debug=True):
|
|
|
self.isTrained = False
|
|
|
- self.epochs = 10
|
|
|
+ self.epochs = epochs
|
|
|
|
|
|
def reset(self):
|
|
|
"""
|
|
|
Resets the trained GAN to an random state.
|
|
|
"""
|
|
|
self.isTrained = False
|
|
|
- self.synthesizer = CTABGAN(epochs = self.epochs)
|
|
|
+ self.synthesizer = CTABGANSynthesizer(epochs = self.epochs)
|
|
|
|
|
|
def train(self, dataSet):
|
|
|
"""
|
|
|
@@ -36,7 +37,7 @@ class CtabGan(GanBaseClass):
|
|
|
if dataSet.data1.shape[0] <= 0:
|
|
|
raise AttributeError("Train: Expected data class 1 to contain at least one point.")
|
|
|
|
|
|
- self.synthesizer.fit(train_data=dataSet.data1)
|
|
|
+ self.synthesizer.fit(train_data=pd.DataFrame(dataSet.data1))
|
|
|
self.isTrained = True
|
|
|
|
|
|
def generateDataPoint(self):
|
|
|
@@ -55,4 +56,4 @@ class CtabGan(GanBaseClass):
|
|
|
if not self.isTrained:
|
|
|
raise ValueError("Try to generate data with untrained Re.")
|
|
|
|
|
|
- return self.synthesizer.generate_samples(numOfSamples)
|
|
|
+ return self.synthesizer.sample(numOfSamples)
|