Forráskód Böngészése

Merge branch 'master' of git.dg0nks.de/fyrr/LoGAN

Kristian Schultz 3 éve
szülő
commit
4da12be094
4 módosított fájl, 26 hozzáadás és 17 törlés
  1. 1 1
      library/analysis.py
  2. 10 7
      library/generators/convGAN.py
  3. 1 0
      library/generators/ctab.py
  4. 14 9
      run_all_exercises.py

+ 1 - 1
library/analysis.py

@@ -271,5 +271,5 @@ generators = { "ProWRAS":                 lambda _data: ProWRAS()
              , "convGAN-majority-5":      lambda data: ConvGAN(data.data0.shape[1], neb=5, gen=5)
              , "convGAN-majority-full":   lambda data: ConvGAN(data.data0.shape[1], neb=None)
              , "convGAN-proximary-5":     lambda data: ConvGAN(data.data0.shape[1], neb=5, gen=5, withMajorhoodNbSearch=True)
-             , "convGAN-proxymary-full":  lambda data: ConvGAN(data.data0.shape[1], neb=None, withMajorhoodNbSearch=True)
+             , "convGAN-proximary-full":  lambda data: ConvGAN(data.data0.shape[1], neb=None, withMajorhoodNbSearch=True)
              }

+ 10 - 7
library/generators/convGAN.py

@@ -63,7 +63,7 @@ class ConvGAN(GanBaseClass):
             nMinoryPoints = dataSet.data1.shape[0]
             if self.nebInitial is None:
                 self.neb = nMinoryPoints
-            else
+            else:
                 self.neb = min(self.nebInitial, nMinoryPoints)
         else:
             self.neb = self.nebInitial
@@ -80,6 +80,8 @@ class ConvGAN(GanBaseClass):
         self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
 
         if self.debug:
+            print(f"neb={self.neb}, gen={self.gen}")
+
             print(self.conv_sample_generator.summary())
             print('\n')
             
@@ -285,7 +287,7 @@ class ConvGAN(GanBaseClass):
         synth_set = []
         for _run in range(runs):
             batch = self.nmbMin.getNbhPointsOfItem(index)
-            synth_batch = self.conv_sample_generator.predict(batch)
+            synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
             synth_set.extend(synth_batch)
 
         return synth_set[:synth_num]
@@ -302,6 +304,7 @@ class ConvGAN(GanBaseClass):
         minSetSize = len(data_min)
 
         labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
+        nLabels = 2 * self.gen
 
         for neb_epoch_count in range(self.neb_epochs):
             if discTrainCount > 0:
@@ -315,14 +318,14 @@ class ConvGAN(GanBaseClass):
 
                         ## generate synthetic samples from convex space
                         ## of minority neighbourhood batch using generator
-                        conv_samples = generator.predict(min_batch)
+                        conv_samples = generator.predict(min_batch, batch_size=self.neb)
                         ## concatenate them with the majority batch
                         concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
 
                         ## switch on discriminator training
                         discriminator.trainable = True
                         ## train the discriminator with the concatenated samples and the one-hot encoded labels
-                        discriminator.fit(x=concat_sample, y=labels, verbose=0)
+                        discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=nLabels)
                         ## switch off the discriminator training again
                         discriminator.trainable = False
 
@@ -335,21 +338,21 @@ class ConvGAN(GanBaseClass):
 
                 ## generate synthetic samples from convex space
                 ## of minority neighbourhood batch using generator
-                conv_samples = generator.predict(min_batch)
+                conv_samples = generator.predict(min_batch, batch_size=self.neb)
                 ## concatenate them with the majority batch
                 concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
 
                 ## switch on discriminator training
                 discriminator.trainable = True
                 ## train the discriminator with the concatenated samples and the one-hot encoded labels
-                discriminator.fit(x=concat_sample, y=labels, verbose=0)
+                discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=nLabels)
                 ## switch off the discriminator training again
                 discriminator.trainable = False
 
                 ## use the GAN to make the generator learn on the decisions
                 ## made by the previous discriminator training
                 ##- print(f"concat sample shape: {concat_sample.shape}/{labels.shape}")
-                gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0)
+                gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
 
                 ## store the loss for the step
                 loss_history.append(gan_loss_history.history['loss'])

+ 1 - 0
library/generators/ctab.py

@@ -17,6 +17,7 @@ class CtabGan(GanBaseClass):
     def __init__(self, epochs=10, debug=True):
         self.isTrained = False
         self.epochs = epochs
+        self.canPredict = False
 
     def reset(self, _dataSet):
         """

+ 14 - 9
run_all_exercises.py

@@ -2,21 +2,26 @@ from library.analysis import testSets, generators, runExercise
 import os
 import threading
 
+maxWorkers = 6
+doMultitask = False
 
 nWorker = 0
 
 for dataset in testSets:
     for name in generators.keys():
-        nWorker += 1
-        if 0 == os.fork():
-            print(f"#{nWorker}: start: {name}({dataset})")
-            runExercise(dataset, None, name, generators[name])
-            print(f"#{nWorker}: end.")
-            exit()
+        if doMultitask:
+            nWorker += 1
+            if 0 == os.fork():
+                print(f"#{nWorker}: start: {name}({dataset})")
+                runExercise(dataset, None, name, generators[name])
+                print(f"#{nWorker}: end.")
+                exit()
+            else:
+                if nWorker >= 6:
+                    os.wait()
+                    nWorker -= 1
         else:
-            if nWorker >= 2:
-                os.wait()
-                nWorker -= 1
+            runExercise(dataset, None, name, generators[name])
 
 while nWorker > 0:
     os.wait()