|
|
@@ -1,5 +1,6 @@
|
|
|
import numpy as np
|
|
|
import ctgan
|
|
|
+import math
|
|
|
|
|
|
from library.interfaces import GanBaseClass
|
|
|
from library.dataset import DataSet
|
|
|
@@ -36,7 +37,12 @@ class CtGAN(GanBaseClass):
|
|
|
if dataSet.data1.shape[0] <= 0:
|
|
|
raise AttributeError("Train: Expected data class 1 to contain at least one point.")
|
|
|
|
|
|
- self.ctgan.fit(dataSet.data1)
|
|
|
+ discreteColumns = self.findDiscreteColumns(dataSet.data1)
|
|
|
+
|
|
|
+ if discreteColumns != []:
|
|
|
+ self.ctgan.fit(dataSet.data1, discreteColumns)
|
|
|
+ else:
|
|
|
+ self.ctgan.fit(dataSet.data1)
|
|
|
self.isTrained = True
|
|
|
|
|
|
def generateDataPoint(self):
|
|
|
@@ -56,3 +62,18 @@ class CtGAN(GanBaseClass):
|
|
|
raise ValueError("Try to generate data with untrained Re.")
|
|
|
|
|
|
return self.ctgan.sample(numOfSamples)
|
|
|
+
|
|
|
+
|
|
|
+ def findDiscreteColumns(self, data):
|
|
|
+ columns = set(range(data.shape[1]))
|
|
|
+
|
|
|
+ for row in data:
|
|
|
+ for c in list(columns):
|
|
|
+ x = row[c]
|
|
|
+ if float(math.floor(x)) != x:
|
|
|
+ columns.remove(c)
|
|
|
+
|
|
|
+ if len(columns) == 0:
|
|
|
+ break
|
|
|
+
|
|
|
+ return columns
|