Ver Fonte

During a reset it is now possible to get the training dataset sizes.

Kristian Schultz há 3 anos atrás
pai
commit
201bbd9785

+ 3 - 3
library/analysis.py

@@ -1,6 +1,6 @@
 from library.exercise import Exercise
 from library.dataset import DataSet, TrainTestData
-from library.generators import ProWRAS, SimpleGan, Repeater, SpheredNoise, ConvGAN, StupidToyListGan, CtGAN
+from library.generators import ProWRAS, SimpleGan, Repeater, SpheredNoise, ConvGAN, StupidToyListGan, CtGAN, CtabGan
 
 import pickle
 import numpy as np
@@ -269,7 +269,7 @@ generators = { "ProWRAS":                 lambda _data: ProWRAS()
              , "convGAN-old-5":      lambda data: ConvGAN(data.data0.shape[1], neb=5, gen=5)
              , "convGAN-old-full":   lambda data: ConvGAN(data.data0.shape[1], neb=data.data0.shape[1], gen=data.data0.shape[1])
              , "convGAN-majority-5":      lambda data: ConvGAN(data.data0.shape[1], neb=5, gen=5)
-             , "convGAN-majority-full":   lambda data: ConvGAN(data.data0.shape[1], neb=data.data0.shape[1], gen=data.data0.shape[1])
+             , "convGAN-majority-full":   lambda data: ConvGAN(data.data0.shape[1], neb=None)
              , "convGAN-proximary-5":     lambda data: ConvGAN(data.data0.shape[1], neb=5, gen=5, withMajorhoodNbSearch=True)
-             , "convGAN-proxymary-full":  lambda data: ConvGAN(data.data0.shape[1], neb=data.data0.shape[1], gen=data.data0.shape[1], withMajorhoodNbSearch=True)
+             , "convGAN-proxymary-full":  lambda data: ConvGAN(data.data0.shape[1], neb=None, withMajorhoodNbSearch=True)
              }

+ 1 - 1
library/exercise.py

@@ -163,7 +163,7 @@ class Exercise:
 
         # Start over with a new GAN instance.
         self.debug("-> Reset the GAN")
-        gan.reset()
+        gan.reset(dataSlice.train)
 
         # Train the gan so it can produce synthetic samples.
         self.debug("-> Train generator for synthetic samples")

+ 2 - 2
library/generators/GanExamples.py

@@ -26,7 +26,7 @@ class StupidToyPointGan(GanBaseClass):
         self.isTrained = False
         self.exampleItem = None
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """
@@ -83,7 +83,7 @@ class StupidToyListGan(GanBaseClass):
         if self.maxListSize < 1:
             raise AttributeError("Expected maxListSize to be > 0 but got " + str(self.maxListSize))
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """

+ 1 - 1
library/generators/LoRAS_ProWRAS.py

@@ -33,7 +33,7 @@ class ProWRAS(GanBaseClass):
         self.n_jobs = n_jobs
         self.debug = debug
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """

+ 1 - 1
library/generators/Repeater.py

@@ -23,7 +23,7 @@ class Repeater(GanBaseClass):
         self.exampleItems = None
         self.nextIndex = 0
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """

+ 1 - 1
library/generators/SimpleGan.py

@@ -34,7 +34,7 @@ class SimpleGan(GanBaseClass):
         self.dLayers = dLayers if dLayers is not None else [1024, 512, 256]
         self.gLayers = gLayers if gLayers is not None else [256, 512, 1024]
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """

+ 1 - 1
library/generators/SpheredNoise.py

@@ -92,7 +92,7 @@ class SpheredNoise(GanBaseClass):
         self.disc = []
         self.reset()
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """

+ 1 - 1
library/generators/autoencoder.py

@@ -63,7 +63,7 @@ class Autoencoder(GanBaseClass):
         self.lossFn = lossFunction #"mse"
         self.lossFn = "mean_squared_logarithmic_error"
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """

+ 18 - 4
library/generators/convGAN.py

@@ -33,11 +33,13 @@ class ConvGAN(GanBaseClass):
     This is a toy example of a GAN.
     It repeats the first point of the training-data-set.
     """
-    def __init__(self, n_feat, neb=5, gen=5, neb_epochs=10, withMajorhoodNbSearch=False, debug=False):
+    def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, withMajorhoodNbSearch=False, debug=False):
         self.isTrained = False
         self.n_feat = n_feat
         self.neb = neb
-        self.gen = gen
+        self.nebInitial = neb
+        self.genInitial = gen
+        self.gen = gen if gen is not None else self.neb
         self.neb_epochs = 10
         self.loss_history = None
         self.debug = debug
@@ -48,14 +50,26 @@ class ConvGAN(GanBaseClass):
         self.cg = None
         self.canPredict = True
 
-        if neb > gen:
+        if self.neb is not None and self.gen is not None and self.neb > self.gen:
             raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
 
-    def reset(self):
+    def reset(self, dataSet):
         """
         Resets the trained GAN to an random state.
         """
         self.isTrained = False
+
+        if dataSet is not None:
+            nMinoryPoints = dataSet.data1.shape[0]
+            if self.nebInitial is None:
+                self.neb = nMinoryPoints
+            else
+                self.neb = min(self.nebInitial, nMinoryPoints)
+        else:
+            self.neb = self.nebInitial
+
+        self.gen = self.genInitial if self.genInitial is not None else self.neb
+
         ## instanciate generator network and visualize architecture
         self.conv_sample_generator = self._conv_sample_gen()
 

+ 1 - 1
library/generators/convGAN_experimental.py

@@ -69,7 +69,7 @@ class ConvGAN_experimental(GanBaseClass):
         if neb > gen:
             raise ValueError(f"Expected neb <= gen but got neb={neb} and gen={gen}.")
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """

+ 1 - 1
library/generators/ctab.py

@@ -18,7 +18,7 @@ class CtabGan(GanBaseClass):
         self.isTrained = False
         self.epochs = epochs
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """

+ 1 - 1
library/generators/ctgan.py

@@ -17,7 +17,7 @@ class CtGAN(GanBaseClass):
         self.debug = debug
         self.ctgan = None
 
-    def reset(self):
+    def reset(self, _dataSet):
         """
         Resets the trained GAN to an random state.
         """

+ 1 - 1
library/interfaces.py

@@ -16,7 +16,7 @@ class GanBaseClass:
         """
         self.canPredict = False
 
-    def reset(self):
+    def reset(self, dataSet):
         """
         Resets the trained GAN to an random state.
         """