Преглед изворни кода

Added comments to training.

Kristian Schultz пре 3 година
родитељ
комит
53b8eb1cab
1 измењених фајлова са 21 додато и 11 уклоњено
  1. 21 11
      library/generators/NextConvGeN.py

+ 21 - 11
library/generators/NextConvGeN.py

@@ -340,6 +340,7 @@ class NextConvGeN(GanBaseClass):
         loss_history = [] ## this is for stroring the loss for every run
         minSetSize = len(data)
 
+        ## Create labels for one neighborhood training.
         nLabels = 2 * self.gen
         labels = np.array(create01Labels(nLabels, self.gen))
         labelsGeN = np.array([labels])
@@ -402,34 +403,45 @@ class NextConvGeN(GanBaseClass):
         for neb_epoch_count in range(self.neb_epochs):
             self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, 0.5, 0.5])
 
+            ## Training of the discriminator.
+            #
+            # Get all neighborhoods and synthetic points as data stream.
             a = tf.data.Dataset.from_generator(genSamplesForDisc, output_types=tf.float32).repeat().take(discTrainCount * self.minSetSize)
             a = tf.data.Dataset.from_generator(unbatch(a), output_types=tf.float32)
+
+            # Get all labels as data stream.
             b = tf.data.Dataset.from_tensor_slices(labels).repeat()
+
+            # Zip data and matching labels together for training. 
             samples = tf.data.Dataset.zip((a, b)).batch(batchSize * 2 * self.gen)
 
+            # train the discriminator with the concatenated samples and the one-hot encoded labels
             self.timing["Fit"].start()
-            ## switch on discriminator training
             discriminator.trainable = True
-            ## train the discriminator with the concatenated samples and the one-hot encoded labels
             discriminator.fit(x=samples, verbose=0)
-            ## switch off the discriminator training again
             discriminator.trainable = False
             self.timing["Fit"].stop()
 
-            # <<<<<<<<<<
-            src = tf.data.Dataset.from_generator(genSamplesForGeN, output_types=tf.float32)
-
             ## use the complete network to make the generator learn on the decisions
             ## made by the previous discriminator training
-            a = src.map(lambda x: [[tf.concat([x[0], padd], axis=0), x[1]]])
+            #
+            # Get all neighborhoods as data stream.
+            a = (tf.data.Dataset
+                .from_generator(genSamplesForGeN, output_types=tf.float32)
+                .map(lambda x: [[tf.concat([x[0], padd], axis=0), x[1]]]))
+
+            # Get all labels as data stream.
             b = tf.data.Dataset.from_tensor_slices(labelsGeN).repeat()
+
+            # Zip data and matching labels together for training. 
             samples = tf.data.Dataset.zip((a, b)).batch(batchSize)
 
+            # Train with the data stream. Store the loss for later usage.
             gen_loss_history = convGeN.fit(samples, verbose=0, batch_size=batchSize)
             loss_history.append(gen_loss_history.history['loss'])
-            # >>>>>>>>>>
 
 
+        ## When done: print some statistics.
         if self.debug:
             run_range = range(1, len(loss_history) + 1)
             plt.rcParams["figure.figsize"] = (16,10)
@@ -441,9 +453,7 @@ class NextConvGeN(GanBaseClass):
             plt.plot(run_range, loss_history)
             plt.show()
 
-        self.conv_sample_generator = generator
-        self.maj_min_discriminator = discriminator
-        self.cg = convGeN
+        ## When done: print some statistics.
         self.loss_history = loss_history