Kaynağa Gözat

Fixed error in convGAN if gen is to big.

Kristian Schultz 3 yıl önce
ebeveyn
işleme
9526cd5452
1 değiştirilmiş dosya ile 6 ekleme ve 3 silme
  1. 6 3
      library/generators/convGAN.py

+ 6 - 3
library/generators/convGAN.py

@@ -80,6 +80,8 @@ class ConvGAN(GanBaseClass):
         self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
 
         if self.debug:
+            print(f"neb={self.neb}, gen={self.gen}")
+
             print(self.conv_sample_generator.summary())
             print('\n')
             
@@ -302,6 +304,7 @@ class ConvGAN(GanBaseClass):
         minSetSize = len(data_min)
 
         labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
+        nLabels = 2 * self.gen
 
         for neb_epoch_count in range(self.neb_epochs):
             if discTrainCount > 0:
@@ -322,7 +325,7 @@ class ConvGAN(GanBaseClass):
                         ## switch on discriminator training
                         discriminator.trainable = True
                         ## train the discriminator with the concatenated samples and the one-hot encoded labels
-                        discriminator.fit(x=concat_sample, y=labels, verbose=0)
+                        discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=nLabels)
                         ## switch off the discriminator training again
                         discriminator.trainable = False
 
@@ -342,14 +345,14 @@ class ConvGAN(GanBaseClass):
                 ## switch on discriminator training
                 discriminator.trainable = True
                 ## train the discriminator with the concatenated samples and the one-hot encoded labels
-                discriminator.fit(x=concat_sample, y=labels, verbose=0)
+                discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=nLabels)
                 ## switch off the discriminator training again
                 discriminator.trainable = False
 
                 ## use the GAN to make the generator learn on the decisions
                 ## made by the previous discriminator training
                 ##- print(f"concat sample shape: {concat_sample.shape}/{labels.shape}")
-                gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0)
+                gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
 
                 ## store the loss for the step
                 loss_history.append(gan_loss_history.history['loss'])