Kristian Schultz 3 роки тому
батько
коміт
c4ab2275d5
1 змінених файлів з 6 додано та 33 видалено
  1. 6 33
      library/generators/NextConvGeN.py

+ 6 - 33
library/generators/NextConvGeN.py

@@ -372,7 +372,7 @@ class NextConvGeN(GanBaseClass):
             self.timing["FixType"].stop()
 
             ## concatenate them with the majority batch
-            conv_samples = [conv_samples, maj_batch, min_batch]
+            conv_samples = [conv_samples, maj_batch]
             return conv_samples
 
         def trainDiscriminator(samples):
@@ -394,11 +394,11 @@ class NextConvGeN(GanBaseClass):
             for min_idx in range(minSetSize):
                 yield indexToBatches(min_idx)
 
-        def unbatch(indices, rows):
+        def unbatch(rows):
             def fn():
-                for arr in rows:
-                    for i in indices:
-                        for x in arr[i]:
+                for row in rows:
+                    for part in row:
+                        for x in part:
                             yield x
             return fn
 
@@ -417,7 +417,7 @@ class NextConvGeN(GanBaseClass):
                 self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, 0.5])
 
                 a = tf.data.Dataset.from_generator(genSamplesForDisc, output_types=tf.float32)
-                a = tf.data.Dataset.from_generator(unbatch([0,1], a), output_types=tf.float32)
+                a = tf.data.Dataset.from_generator(unbatch(a), output_types=tf.float32)
                 b = tf.data.Dataset.from_tensor_slices(labels).repeat()
                 samples = tf.data.Dataset.zip((a, b)).batch(batchSize * 2 * self.gen)
 
@@ -430,35 +430,8 @@ class NextConvGeN(GanBaseClass):
                 discriminator.trainable = False
                 self.timing["Fit"].stop()
 
-            # for min_idx in range(minSetSize):
-            #     self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, 1.0, (min_idx + 1) / minSetSize])
-
-            #     samples = createSamples(min_idx)
-            #     trainDiscriminator(samples)
-
-            #     ## use the complete network to make the generator learn on the decisions
-            #     ## made by the previous discriminator training
-            #     samples = np.array([[samples[2], samples[1]]])
-            #     gen_loss_history = convGeN.fit(samples, y=labelsGeN, verbose=0, batch_size=nLabels)
-
-            #     ## store the loss for the step
-            #     loss_history.append(gen_loss_history.history['loss'])
-
-
             # <<<<<<<<<<
             src = tf.data.Dataset.from_generator(genSamplesForGeN, output_types=tf.float32)
-            #a = tf.data.Dataset.from_generator(unbatch([0,1], src), output_types=tf.float32)
-            #b = tf.data.Dataset.from_tensor_slices(labels).repeat()
-            #samples = tf.data.Dataset.zip((a, b)).batch(batchSize * 2 * self.gen)
-
-            #self.timing["Fit"].start()
-            ### switch on discriminator training
-            #discriminator.trainable = True
-            ### train the discriminator with the concatenated samples and the one-hot encoded labels
-            #discriminator.fit(x=samples, verbose=0)
-            ### switch off the discriminator training again
-            #discriminator.trainable = False
-            #self.timing["Fit"].stop()
 
             ## use the complete network to make the generator learn on the decisions
             ## made by the previous discriminator training