Преглед на файлове

Replaced loop by data stream.

Kristian Schultz преди 3 години
родител
ревизия
1116c9938d
променени са 1 файла, в които са добавени 16 реда и са изтрити 19 реда
  1. 16 19
      library/generators/NextConvGeN.py

+ 16 - 19
library/generators/NextConvGeN.py

@@ -408,27 +408,24 @@ class NextConvGeN(GanBaseClass):
                     yield x
         
         padd = np.zeros((self.gen - self.neb, self.n_feat))
-                
+        discTrainCount = 1 + max(0, discTrainCount)    
 
         for neb_epoch_count in range(self.neb_epochs):
-            shape = (self.gen, self.n_feat)
-
-            for n in range(1 + max(0,discTrainCount)):
-                self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, 0.5])
-
-                a = tf.data.Dataset.from_generator(genSamplesForDisc, output_types=tf.float32)
-                a = tf.data.Dataset.from_generator(unbatch(a), output_types=tf.float32)
-                b = tf.data.Dataset.from_tensor_slices(labels).repeat()
-                samples = tf.data.Dataset.zip((a, b)).batch(batchSize * 2 * self.gen)
-
-                self.timing["Fit"].start()
-                ## switch on discriminator training
-                discriminator.trainable = True
-                ## train the discriminator with the concatenated samples and the one-hot encoded labels
-                discriminator.fit(x=samples, verbose=0)
-                ## switch off the discriminator training again
-                discriminator.trainable = False
-                self.timing["Fit"].stop()
+            self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, 0.5, 0.5])
+
+            a = tf.data.Dataset.from_generator(genSamplesForDisc, output_types=tf.float32).repeat().take(discTrainCount * self.minSetSize)
+            a = tf.data.Dataset.from_generator(unbatch(a), output_types=tf.float32)
+            b = tf.data.Dataset.from_tensor_slices(labels).repeat()
+            samples = tf.data.Dataset.zip((a, b)).batch(batchSize * 2 * self.gen)
+
+            self.timing["Fit"].start()
+            ## switch on discriminator training
+            discriminator.trainable = True
+            ## train the discriminator with the concatenated samples and the one-hot encoded labels
+            discriminator.fit(x=samples, verbose=0)
+            ## switch off the discriminator training again
+            discriminator.trainable = False
+            self.timing["Fit"].stop()
 
             # <<<<<<<<<<
             src = tf.data.Dataset.from_generator(genSamplesForGeN, output_types=tf.float32)