瀏覽代碼

Fixed indexing bug in Repeater.

Kristian Schultz 4 年之前
父節點
當前提交
d7668cfb1e
共有 2 個文件被更改,包括 9 次插入8 次删除
  1. 3 2
      library/Repeater.py
  2. 6 6
      library/analysis.py

+ 3 - 2
library/Repeater.py

@@ -52,11 +52,12 @@ class Repeater(GanBaseClass):
         if not self.isTrained:
             raise ValueError("Try to generate data with untrained Re.")
 
-        i = self.nextIndex
-        self.nextIndex += 1
         if self.nextIndex >= self.exampleItems.shape[0]:
             self.nextIndex = 0
 
+        i = self.nextIndex
+        self.nextIndex += 1
+
         return self.exampleItems[i]
 
 

+ 6 - 6
library/analysis.py

@@ -57,14 +57,14 @@ def runExerciseForSimpleGAN(datasetName):
     print(f"// Running {ganName} on {datasetName}")
     print("///////////////////////////////////////////")
     print()
-    data = loadDataset(datasetName)
+    data = loadDataset(f"data_input/{datasetName}")
     gan = SimpleGan(numOfFeatures=data.data0.shape[1])
     random.seed(2021)
     shuffler = genShuffler()
     exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
     exercise.run(gan, data)
-    exercise.saveResultsTo(f"{datasetName}-{ganName}.csv")
-    exercise.saveResultsTo(f"{ganName}-{datasetName}.csv")
+    exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
+    exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
     
     
 def runExerciseForRepeater(datasetName):
@@ -75,14 +75,14 @@ def runExerciseForRepeater(datasetName):
     print(f"// Running {ganName} on {datasetName}")
     print("///////////////////////////////////////////")
     print()
-    data = loadDataset(datasetName)
+    data = loadDataset(f"data_input/{datasetName}")
     gan = Repeater()
     random.seed(2021)
     shuffler = genShuffler()
     exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
     exercise.run(gan, data)
-    exercise.saveResultsTo(f"{datasetName}-{ganName}.csv")
-    exercise.saveResultsTo(f"{ganName}-{datasetName}.csv")
+    exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
+    exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
     
 testSets = [
     "folding_abalone_17_vs_7_8_9_10",