Ver Fonte

Seperated loop.

Kristian Schultz há 4 anos atrás
pai
commit
e8c23c68bf
1 ficheiros alterados com 37 adições e 39 exclusões
  1. 37 39
      library/generators/convGAN.py

+ 37 - 39
library/generators/convGAN.py

@@ -282,48 +282,46 @@ class ConvGAN(GanBaseClass):
         discriminator = self.maj_min_discriminator
         GAN = self.cg
         loss_history = [] ## this is for stroring the loss for every run
-        min_idx = 0
-        neb_epoch_count = 1
+        step = 0
+        minSetSize = len(data_min)
 
         labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
 
-        for step in range(self.neb_epochs * len(data_min)):
-            ## generate minority neighbourhood batch for every minority class sampls by index
-            min_batch_indices = self.nmbMin.neighbourhoodOfItem(min_idx)
-            min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
-            min_idx = min_idx + 1
-            ## generate random proximal majority batch
-            maj_batch = self._BMB(data_maj, min_batch_indices)
-
-            ## generate synthetic samples from convex space
-            ## of minority neighbourhood batch using generator
-            conv_samples = generator.predict(min_batch)
-            ## concatenate them with the majority batch
-            concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
-
-            ## switch on discriminator training
-            discriminator.trainable = True
-            ## train the discriminator with the concatenated samples and the one-hot encoded labels
-            discriminator.fit(x=concat_sample, y=labels, verbose=0)
-            ## switch off the discriminator training again
-            discriminator.trainable = False
-
-            ## use the GAN to make the generator learn on the decisions
-            ## made by the previous discriminator training
-            ##- print(f"concat sample shape: {concat_sample.shape}/{labels.shape}")
-            gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0)
-
-            ## store the loss for the step
-            loss_history.append(gan_loss_history.history['loss'])
-
-            if self.debug and ((step + 1) % 10 == 0):
-                print(f"{step + 1} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
-
-            if min_idx == len(data_min) - 1:
-                if self.debug:
-                    print(f"Neighbourhood epoch {neb_epoch_count} complete")
-                neb_epoch_count = neb_epoch_count + 1
-                min_idx = 0
+        for neb_epoch_count in range(self.neb_epochs):
+            for min_idx in range(minSetSize):
+                ## generate minority neighbourhood batch for every minority class sampls by index
+                min_batch_indices = self.nmbMin.neighbourhoodOfItem(min_idx)
+                min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
+                ## generate random proximal majority batch
+                maj_batch = self._BMB(data_maj, min_batch_indices)
+
+                ## generate synthetic samples from convex space
+                ## of minority neighbourhood batch using generator
+                conv_samples = generator.predict(min_batch)
+                ## concatenate them with the majority batch
+                concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
+
+                ## switch on discriminator training
+                discriminator.trainable = True
+                ## train the discriminator with the concatenated samples and the one-hot encoded labels
+                discriminator.fit(x=concat_sample, y=labels, verbose=0)
+                ## switch off the discriminator training again
+                discriminator.trainable = False
+
+                ## use the GAN to make the generator learn on the decisions
+                ## made by the previous discriminator training
+                ##- print(f"concat sample shape: {concat_sample.shape}/{labels.shape}")
+                gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0)
+
+                ## store the loss for the step
+                loss_history.append(gan_loss_history.history['loss'])
+
+                step += 1
+                if self.debug and (step % 10 == 0):
+                    print(f"{step} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
+
+            if self.debug:
+                print(f"Neighbourhood epoch {neb_epoch_count + 1} complete")
 
         if self.debug:
             run_range = range(1, len(loss_history) + 1)