Ver Fonte

Optimized SimpleGAN parameters.

Kristian Schultz há 3 anos atrás
pai
commit
d664198ed0
1 ficheiros alterados com 8 adições e 6 exclusões
  1. 8 6
      library/generators/SimpleGan.py

+ 8 - 6
library/generators/SimpleGan.py

@@ -23,16 +23,16 @@ class SimpleGan(GanBaseClass):
     """
     A class for a simple GAN.
     """
-    def __init__(self, numOfFeatures=786, noiseSize=None, epochs=3, batchSize=128, withTanh=False, gLayers=None, dLayers=None):
+    def __init__(self, numOfFeatures=786, noiseSize=None, epochs=10, batchSize=128, withTanh=False, gLayers=None, dLayers=None):
         self.isTrained = False
-        self.noiseSize = noiseSize if noiseSize is not None else numOfFeatures
+        self.noiseSize = noiseSize if noiseSize is not None else (numOfFeatures * 16)
         self.numOfFeatures = numOfFeatures
         self.epochs = epochs
         self.batchSize = batchSize
         self.scaler = 1.0
         self.withTanh = withTanh
-        self.dLayers = dLayers if dLayers is not None else [1024, 512, 256]
-        self.gLayers = gLayers if gLayers is not None else [256, 512, 1024]
+        self.dLayers = dLayers if dLayers is not None else [numOfFeatures * 40, numOfFeatures * 20, numOfFeatures * 10]
+        self.gLayers = gLayers if gLayers is not None else [noiseSize * 2, numOfFeatures * 4, numOfFeatures * 2]
 
     def reset(self, _dataSet):
         """
@@ -106,6 +106,8 @@ class SimpleGan(GanBaseClass):
             self.scaler = max(1.0, 1.1 * tf.reduce_max(tf.abs(trainData)).numpy())
             scaleDown = 1.0 / self.scaler
 
+        trainData = scaleDown * trainData
+
         for e in range(self.epochs):
             print(f"Epoch {e + 1}/{self.epochs}")
             for _ in range(self.batchSize):
@@ -116,12 +118,12 @@ class SimpleGan(GanBaseClass):
                 syntheticBatch = self.generator.predict(noise)
 
                 # Get a random set of  real images
-                realBatch = dataset.data1[
+                realBatch = trainData[
                     np.random.randint(low=0, high=trainDataSize, size=self.batchSize)
                     ]
 
                 #Construct different batches of  real and fake data
-                X = np.concatenate([scaleDown * realBatch, syntheticBatch])
+                X = np.concatenate([realBatch, syntheticBatch])
 
                 # Labels for generated and real data
                 y_dis=np.zeros(2 * self.batchSize)