Jelajahi Sumber

Added automatic guessing of discrete columns.

Kristian Schultz 4 tahun lalu
induk
melakukan
20d0ad081b
2 mengubah file dengan 23 tambahan dan 2 penghapusan
  1. 22 1
      library/generators/ctgan.py
  2. 1 1
      run_all_exercises.ipynb

+ 22 - 1
library/generators/ctgan.py

@@ -1,5 +1,6 @@
 import numpy as np
 import ctgan
+import math
 
 from library.interfaces import GanBaseClass
 from library.dataset import DataSet
@@ -36,7 +37,12 @@ class CtGAN(GanBaseClass):
         if dataSet.data1.shape[0] <= 0:
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
-        self.ctgan.fit(dataSet.data1)
+        discreteColumns = self.findDiscreteColumns(dataSet.data1)
+
+        if discreteColumns != []:
+            self.ctgan.fit(dataSet.data1, discreteColumns)
+        else:
+            self.ctgan.fit(dataSet.data1)
         self.isTrained = True
 
     def generateDataPoint(self):
@@ -56,3 +62,18 @@ class CtGAN(GanBaseClass):
             raise ValueError("Try to generate data with untrained Re.")
 
         return self.ctgan.sample(numOfSamples)
+
+
+    def findDiscreteColumns(self, data):
+        columns = set(range(data.shape[1]))
+
+        for row in data:
+            for c in list(columns):
+                x = row[c]
+                if float(math.floor(x)) != x:
+                    columns.remove(c)
+
+            if len(columns) == 0:
+                break
+
+        return columns

+ 1 - 1
run_all_exercises.ipynb

@@ -13,7 +13,7 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "id": "ab26d06a",
+   "id": "f1f637ca",
    "metadata": {},
    "outputs": [],
    "source": [