Sfoglia il codice sorgente

Added ctgan generator.

Kristian Schultz 4 anni fa
parent
commit
1b0edfffda

+ 5 - 1
library/analysis.py

@@ -1,6 +1,6 @@
 from library.exercise import Exercise
 from library.dataset import DataSet, TrainTestData
-from library.generators import SimpleGan, Repeater, SpheredNoise, ConvGAN, StupidToyListGan
+from library.generators import SimpleGan, Repeater, SpheredNoise, ConvGAN, StupidToyListGan, CtGAN
 
 import pickle
 import numpy as np
@@ -192,6 +192,10 @@ def runExerciseForSpheredNoise(datasetName, resultList=None):
     runExercise(datasetName, resultList, "SpheredNoise", lambda _data: SpheredNoise())
 
 
+def runExerciseForCtGAN(datasetName, resultList=None, debug=False):
+    runExercise(datasetName, resultList, "ctGAN", lambda data: CtGAN(data.data0.shape[1], debug=debug))
+
+
 def runExerciseForConvGAN(datasetName, resultList=None, debug=False):
     runExercise(datasetName, resultList, "convGAN", lambda data: ConvGAN(data.data0.shape[1], debug=debug))
 

+ 34 - 0
library/distance.py

@@ -0,0 +1,34 @@
+import numpy as np
+
+
+def normSquared(v):
+    s = 0
+    for x in v:
+        s += x * x
+    return s
+
+def distSquared(u, v):
+    return normSquared(u - v)
+
+def distToCloud(v, cloud):
+    di = None
+    for p in cloud:
+        d = distSquared(v, p)
+        if di is None:
+            di = d 
+        else:
+            di = min(di, d)
+    return di
+
+def cloudDist(cloudA, cloudB):
+    di = None
+    dx = None
+    for v in cloudA:
+        d = distToCloud(v, cloudB)
+        if di is None:
+            di = d
+            dx = d
+        else:
+            di = min(di, d)
+            dx = max(dx, d)
+    return (di, dx)

+ 1 - 0
library/generators/__init__.py

@@ -5,3 +5,4 @@ from library.generators.LoRAS_ProWRAS import ProWRAS
 from library.generators.Repeater import Repeater
 from library.generators.SpheredNoise import SpheredNoise
 from library.generators.GanExamples import StupidToyListGan, StupidToyPointGan
+from library.generators.ctgan import CtGAN

+ 58 - 0
library/generators/ctgan.py

@@ -0,0 +1,58 @@
+import numpy as np
+import ctgan
+
+from library.interfaces import GanBaseClass
+from library.dataset import DataSet
+
+
+class CtGAN(GanBaseClass):
+    """
+    This is a toy example of a GAN.
+    It repeats the first point of the training-data-set.
+    """
+    def __init__(self, epochs=10, debug=False):
+        self.isTrained = False
+        self.epochs = epochs
+        self.debug = debug
+        self.ctgan = None
+
+    def reset(self):
+        """
+        Resets the trained GAN to an random state.
+        """
+        self.isTrained = False
+        ## instanciate generator network and visualize architecture
+        self.ctgan = ctgan.CTGANSynthesizer(epochs=self.epochs) 
+
+    def train(self, dataSet):
+        """
+        Trains the GAN.
+
+        It stores the data points in the training data set and mark as trained.
+
+        *dataSet* is a instance of /library.dataset.DataSet/. It contains the training dataset.
+        We are only interested in the first *maxListSize* points in class 1.
+        """
+        if dataSet.data1.shape[0] <= 0:
+            raise AttributeError("Train: Expected data class 1 to contain at least one point.")
+
+        self.ctgan.fit(dataSet.data1)
+        self.isTrained = True
+
+    def generateDataPoint(self):
+        """
+        Returns one synthetic data point by repeating the stored list.
+        """
+        return (self.generateData(1))[0]
+
+
+    def generateData(self, numOfSamples=1):
+        """
+        Generates a list of synthetic data-points.
+
+        *numOfSamples* is a integer > 0. It gives the number of new generated samples.
+        """
+        if not self.isTrained:
+            raise ValueError("Try to generate data with untrained Re.")
+
+        return self.ctgan.sample(numOfSamples)

+ 1 - 0
requirements.txt

@@ -5,3 +5,4 @@ pip:
 - sklearn
 - imblearn
 - matplotlib
+- ctgan

+ 20 - 11
run_all_exercises.ipynb

@@ -7,9 +7,22 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "from library.analysis import testSets\n",
-    "from library.analysis import runExerciseForSpheredNoise, runExerciseForRepeater\n",
-    "from library.analysis import runExerciseForSimpleGAN, runExerciseForConvGAN"
+    "from library.analysis import *"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ab26d06a",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fns = [ runExerciseForRepeater\n",
+    "      , runExerciseForSpheredNoise\n",
+    "      , runExerciseForSimpleGAN\n",
+    "      , runExerciseForConvGAN\n",
+    "      , runExerciseForCtGAN\n",
+    "      ]\n"
    ]
   },
   {
@@ -21,19 +34,15 @@
    },
    "outputs": [],
    "source": [
-    "fns = [runExerciseForRepeater, runExerciseForSpheredNoise, runExerciseForSimpleGAN, runExerciseForConvGAN]\n",
-    "    \n",
     "for dataset in testSets:\n",
-    "    runExerciseForRepeater(dataset)\n",
-    "    runExerciseForSpheredNoise(dataset)\n",
-    "    runExerciseForSimpleGAN(dataset)\n",
-    "    runExerciseForConvGAN(dataset)"
+    "    for f in fns:\n",
+    "        f(dataset)"
    ]
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3",
+   "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
   },
@@ -47,7 +56,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.5"
+   "version": "3.9.7"
   }
  },
  "nbformat": 4,