Ver código fonte

Collected generators in list.

Kristian Schultz 4 anos atrás
pai
commit
589a8aa3ca
3 arquivos alterados com 19 adições e 32 exclusões
  1. 12 2
      library/analysis.py
  2. 4 19
      run_all_exercises.ipynb
  3. 3 11
      run_all_exercises.py

+ 12 - 2
library/analysis.py

@@ -196,8 +196,8 @@ def runExerciseForCtGAN(datasetName, resultList=None, debug=False):
     runExercise(datasetName, resultList, "ctGAN", lambda data: CtGAN(data.data0.shape[1], debug=debug))
 
 
-def runExerciseForConvGAN(datasetName, resultList=None, debug=False):
-    runExercise(datasetName, resultList, "convGAN", lambda data: ConvGAN(data.data0.shape[1], debug=debug))
+def runExerciseForConvGAN(datasetName, resultList=None, neb=5, debug=False):
+    runExercise(datasetName, resultList, "convGAN", lambda data: ConvGAN(data.data0.shape[1], neb=neb, gen=neb, debug=debug))
 
 def runExerciseForConvGANfull(datasetName, resultList=None, debug=False):
     runExercise(datasetName, resultList, "convGAN-full", lambda data: ConvGAN(data.data0.shape[1], neb=data.data0.shape[1], gen=data.data0.shape[1], debug=debug))
@@ -257,3 +257,13 @@ def runAllTestSets(dataSetList):
         runExerciseForSimpleGAN(dataset)
         runExerciseForConvGAN(dataset)
         runExerciseForConvGANfull(dataset)
+
+
+
+generators = [ ("Repeater",      lambda _data: Repeater())
+             #, ("SpheredNoise",  lambda _data: SpheredNoise())
+             , ("SimpleGAN",     lambda data: SimpleGan(numOfFeatures=data.data0.shape[1]))
+             , ("convGAN",       lambda data: ConvGAN(data.data0.shape[1], neb=5, gen=5))
+             , ("convGAN-full",  lambda data: ConvGAN(data.data0.shape[1], neb=data.data0.shape[1], gen=data.data0.shape[1]))
+             , ("ctGAN",         lambda data: CtGAN(data.data0.shape[1]))
+             ]

+ 4 - 19
run_all_exercises.ipynb

@@ -10,21 +10,6 @@
     "from library.analysis import *"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "8da890b0",
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "fns = [# runExerciseForRepeater\n",
-    "      #, runExerciseForSpheredNoise\n",
-    "      #, runExerciseForSimpleGAN\n",
-    "      #, runExerciseForConvGAN\n",
-    "      runExerciseForCtGAN\n",
-    "      ]\n"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -35,14 +20,14 @@
    "outputs": [],
    "source": [
     "for dataset in testSets:\n",
-    "    for f in fns:\n",
-    "        f(dataset)"
+    "    for f in generators:\n",
+    "        runExercise(dataset, None, name, f)"
    ]
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
+   "display_name": "Python 3",
    "language": "python",
    "name": "python3"
   },
@@ -56,7 +41,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.9.7"
+   "version": "3.8.5"
   }
  },
  "nbformat": 4,

+ 3 - 11
run_all_exercises.py

@@ -1,24 +1,16 @@
-from library.analysis import testSets
-from library.analysis import runExerciseForSpheredNoise, runExerciseForRepeater
-from library.analysis import runExerciseForSimpleGAN, runExerciseForConvGAN
+from library.analysis import testSets, generators, runExercise
 import os
 import threading
 
 
 nWorker = 0
 
-fns = [ ("Repeater",     runExerciseForRepeater)
-      , ("SpheredNoise", runExerciseForSpheredNoise)
-      , ("SimpleGAN",    runExerciseForSimpleGAN)
-      , ("convGAN",      runExerciseForConvGAN)
-      ]
-
 for dataset in testSets:
-    for (name, f) in fns:
+    for (name, f) in generators:
         nWorker += 1
         if 0 == os.fork():
             print(f"#{nWorker}: start: {name}({dataset})")
-            f(dataset)
+            runExercise(dataset, None, name, f)
             print(f"#{nWorker}: end.")
             exit()
         else: