浏览代码

Added notebook for running tests with convGAN.

Kristian Schultz 4 年之前
父节点
当前提交
a49a41f7c1
共有 3 个文件被更改,包括 158 次插入7 次删除
  1. 22 1
      library/analysis.py
  2. 7 6
      library/convGAN.py
  3. 129 0
      run_all_with_convGan.ipynb

+ 22 - 1
library/analysis.py

@@ -4,6 +4,7 @@ from library.GanExamples import StupidToyListGan
 from library.SimpleGan import SimpleGan
 from library.Repeater import Repeater
 from library.SpheredNoise import SpheredNoise
+from library.convGAN import ConvGAN
 
 import pickle
 import numpy as np
@@ -132,6 +133,26 @@ def runExerciseForSpheredNoise(datasetName, resultList=None):
     exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
     if resultList is not None:
         resultList[datasetName] = avg
+
+
+def runExerciseForConvGAN(datasetName, resultList=None):
+    ganName = "convGAN"
+    print()
+    print()
+    print("///////////////////////////////////////////")
+    print(f"// Running {ganName} on {datasetName}")
+    print("///////////////////////////////////////////")
+    print()
+    data = loadDataset(f"data_input/{datasetName}")
+    gan = ConvGAN(data.data0.shape[1])
+    random.seed(2021)
+    shuffler = genShuffler()
+    exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
+    exercise.run(gan, data)
+    avg = exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
+    exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
+    if resultList is not None:
+        resultList[datasetName] = avg
     
 testSets = [
     "folding_abalone_17_vs_7_8_9_10",
@@ -153,4 +174,4 @@ testSets = [
 def runAllTestSets(dataSetList):
     for dsFileName in dataSetList:
         runExerciseForSimpleGAN(dataSetList)
-        runExerciseForRepeater(dataSetList)
+        runExerciseForRepeater(dataSetList)

+ 7 - 6
library/convGAN.py

@@ -43,11 +43,12 @@ class ConvGAN(GanBaseClass):
     This is a toy example of a GAN.
     It repeats the first point of the training-data-set.
     """
-    def __init__(self, n_feat, neb, gen, debug=True):
+    def __init__(self, n_feat, neb=5, gen=5, neb_epochs=10, debug=True):
         self.isTrained = False
         self.n_feat = n_feat
         self.neb = neb
         self.gen = gen
+        self.neb_epochs = 10
         self.loss_history = None
         self.debug = debug
         self.dataSet = None
@@ -69,7 +70,7 @@ class ConvGAN(GanBaseClass):
         ## instanciate network and visualize architecture
         self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
 
-    def train(self, dataSet, neb_epochs=5):
+    def train(self, dataSet):
         """
         Trains the GAN.
 
@@ -82,7 +83,7 @@ class ConvGAN(GanBaseClass):
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
         self.dataSet = dataSet
-        self._rough_learning(neb_epochs, dataSet.data1, dataSet.data0)
+        self._rough_learning(dataSet.data1, dataSet.data0)
         self.isTrained = True
 
     def generateDataPoint(self):
@@ -109,7 +110,7 @@ class ConvGAN(GanBaseClass):
         ## generate synth_num synthetic samples from each minority neighbourhood
         synth_set=[]
         for i in range(len(data_min)):
-            synth_set.extend(self.generate_data_for_min_point(data_min, i, synth_num))
+            synth_set.extend(self._generate_data_for_min_point(data_min, i, synth_num))
 
         synth_set = synth_set[:numOfSamples] ## extract the exact number of synthetic samples needed to exactly balance the two classes
 
@@ -253,7 +254,7 @@ class ConvGAN(GanBaseClass):
 
 
     # Training
-    def _rough_learning(self, neb_epochs, data_min, data_maj):
+    def _rough_learning(self, data_min, data_maj):
         generator = self.conv_sample_generator
         discriminator = self.maj_min_discriminator
         GAN = self.cg
@@ -263,7 +264,7 @@ class ConvGAN(GanBaseClass):
 
         labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
 
-        for step in range(neb_epochs * len(data_min)):
+        for step in range(self.neb_epochs * len(data_min)):
             ## generate minority neighbourhood batch for every minority class sampls by index
             min_batch = self._NMB_guided(data_min, min_idx)
             min_idx = min_idx + 1

文件差异内容过多而无法显示
+ 129 - 0
run_all_with_convGan.ipynb


部分文件因为文件数量过多而无法显示