Bläddra i källkod

Fixed class 0/1 confusion.

Kristian Schultz 4 år sedan
förälder
incheckning
bf875e9a0e
3 ändrade filer med 68 tillägg och 43 borttagningar
  1. 52 33
      Example Toy Exercise.ipynb
  2. 5 5
      library/exercise.py
  3. 11 5
      library/interfaces.py

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 52 - 33
Example Toy Exercise.ipynb


+ 5 - 5
library/exercise.py

@@ -33,8 +33,8 @@ class Exercise:
         self.debug = print
 
     def run(self, gan, dataset):
-        if len(dataset.data0) > len(dataset.data1):
-            raise AttributeError("Expected class 0 to be the minority class but class 0 is bigger than class 1.")
+        if len(dataset.data1) > len(dataset.data0):
+            raise AttributeError("Expected class 1 to be the minority class but class 1 is bigger than class 0.")
 
         self.debug("### Start exercise for synthetic point generator")
         for shuffleStep in range(self.numOfShuffles):
@@ -58,14 +58,14 @@ class Exercise:
         self.debug("-> Train generator for synthetic samples")
         gan.train(dataSlice.train)
 
-        numOfNeededSamples = dataSlice.train.size1 - dataSlice.train.size0
+        numOfNeededSamples = dataSlice.train.size0 - dataSlice.train.size1
 
         if numOfNeededSamples > 0:
             self.debug(f"-> create {numOfNeededSamples} synthetic samples")
             newSamples = np.asarray([gan.generateData() for _ in range(numOfNeededSamples)])
             train = DataSet(
-                data0=np.concatenate((dataSlice.train.data0, newSamples)),
-                data1=dataSlice.train.data1
+                data0=dataSlice.train.data0,
+                data1=np.concatenate((dataSlice.train.data1, newSamples))
                 )
         else:
             train = dataSlice.train

+ 11 - 5
library/interfaces.py

@@ -3,25 +3,31 @@ import numpy as np
 class GanBaseClass:
     def __init__(self):
         self.isTrained = False
-        self.exampleItem = None
+        self.exampleItems = None
+        self.nextIndex = 0
         pass
 
     def train(self, dataSet):
-        if dataSet.data0.shape[0] <= 0:
-            raise AttributeError("Train GAN: Expected data class 0 to contain at least one point.")
+        if dataSet.data1.shape[0] <= 0:
+            raise AttributeError("Train GAN: Expected data class 1 to contain at least one point.")
 
         print(
             "Train GAN with |class 0|=%d, |class 1|=%d"
             % (dataSet.data0.shape[0], dataSet.data1.shape[0])
             )
         self.isTrained = True
-        self.exampleItem = dataSet.data0[0].copy()
+        self.exampleItems = dataSet.data1.copy()
 
     def generateData(self):
         if not self.isTrained:
             raise ValueError("Try to generate data with untrained GAN.")
 
-        return self.exampleItem
+        i = self.nextIndex
+        self.nextIndex += 1
+        if self.nextIndex >= self.exampleItems.shape[0]:
+            self.nextIndex = 0
+
+        return self.exampleItems[i]
 
 
 class TesterNetworkBaseClass:

Vissa filer visades inte eftersom för många filer har ändrats