Explorar o código

Some cleanup.

Kristian Schultz %!s(int64=3) %!d(string=hai) anos
pai
achega
d82b45654a
Modificáronse 2 ficheiros con 43 adicións e 44 borrados
  1. 29 28
      XConvGeN-Example.ipynb
  2. 14 16
      library/generators/XConvGeN.py

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 29 - 28
XConvGeN-Example.ipynb


+ 14 - 16
library/generators/XConvGeN.py

@@ -118,7 +118,7 @@ class XConvGeN(GanBaseClass):
         self.cg = None
         self.canPredict = True
         self.fdc = fdc
-        self.lastProgress = (-1,-1,-1)
+        self.lastProgress = -1
         
         self.timing = { n: timing(n) for n in [
             "Train", "BMB", "NbhSearch", "NBH", "GenSamples", "Fit", "FixType"
@@ -464,11 +464,11 @@ class XConvGeN(GanBaseClass):
 
             return (min_batch, maj_batch)
 
-        def unbatch(rows):
+        def unbatch(parts):
             def fn():
-                for row in rows:
-                    for part in row:
-                        for x in part:
+                for part in parts:
+                    for neighborhood in part:
+                        for x in neighborhood:
                             yield x
             return fn
 
@@ -481,7 +481,7 @@ class XConvGeN(GanBaseClass):
         discTrainCount = 1 + max(0, discTrainCount)    
 
         for neb_epoch_count in range(self.config.neb_epochs):
-            self.progressBar([(neb_epoch_count + 1) / self.config.neb_epochs, 0.5, 0.5])
+            self.progressBar(neb_epoch_count / self.config.neb_epochs)
 
             ## Training of the discriminator.
             #
@@ -531,6 +531,7 @@ class XConvGeN(GanBaseClass):
             gen_loss_history = convGeN.fit(samples, verbose=0, batch_size=batchSize)
             loss_history.append(gen_loss_history.history['loss'])
 
+        self.progressBar(1.0)
 
         ## When done: print some statistics.
         if self.debug:
@@ -568,21 +569,18 @@ class XConvGeN(GanBaseClass):
         self.maj_min_discriminator.trainable = False
 
     def progressBar(self, x):
-        x = [int(v * 10) for v in x]
-        if True not in [self.lastProgress[i] != x[i] for i in range(len(self.lastProgress))]:
+        barWidth = 40
+
+        x = int(x * barWidth)
+        if self.lastProgress == x:
             return
         
         def bar(v):   
-            r = ""
-            for n in range(10):
-                if n > v:
-                    r += " "
-                else:
-                    r += "="
+            v = min(v, barWidth)
+            r = ("=" * v) + (" " * (barWidth - v))
             return r
         
-        s = [bar(v) for v in x]
-        print(f"[{s[0]}] [{s[1]}] [{s[2]}]", end="\r")
+        print(f"[{bar(x)}]", end="\r")
         
     def correct_feature_types(self):
         # batch[0] = original points (gen x n_feat)

Algúns arquivos non se mostraron porque demasiados arquivos cambiaron neste cambio