Kristian Schultz 3 лет назад
Родитель
Сommit
487b8da57f
1 измененных файлов с 32 добавлено и 3 удалено
  1. 32 3
      library/generators/NextConvGeN.py

+ 32 - 3
library/generators/NextConvGeN.py

@@ -52,9 +52,10 @@ class NextConvGeN(GanBaseClass):
         self.cg = None
         self.canPredict = True
         self.fdc = fdc
+        self.lastProgress = (-1,-1,-1)
         
         self.timing = { n: timing(n) for n in [
-            "Train", "BMB", "NbhSearch"
+            "Train", "BMB", "NbhSearch", "NBH", "GenSamples", "Fit"
             ] }
 
         if self.neb is not None and self.gen is not None and self.neb > self.gen:
@@ -90,6 +91,7 @@ class NextConvGeN(GanBaseClass):
         ## instanciate network and visualize architecture
         self.cg = self._convGeN(self.conv_sample_generator, self.maj_min_discriminator)
 
+        self.lastProgress = (-1,-1,-1)
         if self.debug:
             print(f"neb={self.neb}, gen={self.gen}")
 
@@ -120,6 +122,9 @@ class NextConvGeN(GanBaseClass):
         normalizedData = data
         if self.fdc is not None:
             normalizedData = self.fdc.normalize(data)
+            
+        print(f"|N| = {normalizedData.shape}")
+        print(f"|D| = {data.shape}")
         
         self.timing["NbhSearch"].start()
         # Precalculate neighborhoods
@@ -324,31 +329,37 @@ class NextConvGeN(GanBaseClass):
         nLabels = 2 * self.gen
 
         for neb_epoch_count in range(self.neb_epochs):
-            print(f"NEB EPOCH #{neb_epoch_count + 1} / {self.neb_epochs}")
             if discTrainCount > 0:
                 for n in range(discTrainCount):
-                    print(f"discTrain #{n + 1} / {discTrainCount}")
                     for min_idx in range(minSetSize):
+                        self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, (min_idx + 1) / minSetSize])
+                        self.timing["NBH"].start()
                         ## generate minority neighbourhood batch for every minority class sampls by index
                         min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
                         min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
                         ## generate random proximal majority batch
                         maj_batch = self._BMB(min_batch_indices)
+                        self.timing["NBH"].stop()
 
+                        self.timing["GenSamples"].start()
                         ## generate synthetic samples from convex space
                         ## of minority neighbourhood batch using generator
                         conv_samples = generator.predict(min_batch, batch_size=self.neb)
                         ## concatenate them with the majority batch
                         concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
+                        self.timing["GenSamples"].stop()
 
+                        self.timing["Fit"].start()
                         ## switch on discriminator training
                         discriminator.trainable = True
                         ## train the discriminator with the concatenated samples and the one-hot encoded labels
                         discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=20)
                         ## switch off the discriminator training again
                         discriminator.trainable = False
+                        self.timing["Fit"].stop()
 
             for min_idx in range(minSetSize):
+                self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, 1.0, (min_idx + 1) / minSetSize])
                 ## generate minority neighbourhood batch for every minority class sampls by index
                 min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
                 min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
@@ -419,3 +430,21 @@ class NextConvGeN(GanBaseClass):
         labels = np.array([ [x, 1 - x] for x in labels])
         self.maj_min_discriminator.fit(x=data, y=labels, batch_size=20, epochs=self.neb_epochs)
         self.maj_min_discriminator.trainable = False
+
+    def progressBar(self, x):
+        x = [int(v * 10) for v in x]
+        if True not in [self.lastProgress[i] != x[i] for i in range(len(self.lastProgress))]:
+            return
+        
+        def bar(v):   
+            r = ""
+            for n in range(10):
+                if n > v:
+                    r += " "
+                else:
+                    r += "="
+            return r
+        
+        s = [bar(v) for v in x]
+        print(f"[{s[0]}] [{s[1]}] [{s[2]}]", end="\r")
+