|
@@ -118,7 +118,7 @@ class XConvGeN(GanBaseClass):
|
|
|
self.cg = None
|
|
self.cg = None
|
|
|
self.canPredict = True
|
|
self.canPredict = True
|
|
|
self.fdc = fdc
|
|
self.fdc = fdc
|
|
|
- self.lastProgress = (-1,-1,-1)
|
|
|
|
|
|
|
+ self.lastProgress = -1
|
|
|
|
|
|
|
|
self.timing = { n: timing(n) for n in [
|
|
self.timing = { n: timing(n) for n in [
|
|
|
"Train", "BMB", "NbhSearch", "NBH", "GenSamples", "Fit", "FixType"
|
|
"Train", "BMB", "NbhSearch", "NBH", "GenSamples", "Fit", "FixType"
|
|
@@ -464,11 +464,11 @@ class XConvGeN(GanBaseClass):
|
|
|
|
|
|
|
|
return (min_batch, maj_batch)
|
|
return (min_batch, maj_batch)
|
|
|
|
|
|
|
|
- def unbatch(rows):
|
|
|
|
|
|
|
+ def unbatch(parts):
|
|
|
def fn():
|
|
def fn():
|
|
|
- for row in rows:
|
|
|
|
|
- for part in row:
|
|
|
|
|
- for x in part:
|
|
|
|
|
|
|
+ for part in parts:
|
|
|
|
|
+ for neighborhood in part:
|
|
|
|
|
+ for x in neighborhood:
|
|
|
yield x
|
|
yield x
|
|
|
return fn
|
|
return fn
|
|
|
|
|
|
|
@@ -481,7 +481,7 @@ class XConvGeN(GanBaseClass):
|
|
|
discTrainCount = 1 + max(0, discTrainCount)
|
|
discTrainCount = 1 + max(0, discTrainCount)
|
|
|
|
|
|
|
|
for neb_epoch_count in range(self.config.neb_epochs):
|
|
for neb_epoch_count in range(self.config.neb_epochs):
|
|
|
- self.progressBar([(neb_epoch_count + 1) / self.config.neb_epochs, 0.5, 0.5])
|
|
|
|
|
|
|
+ self.progressBar(neb_epoch_count / self.config.neb_epochs)
|
|
|
|
|
|
|
|
## Training of the discriminator.
|
|
## Training of the discriminator.
|
|
|
#
|
|
#
|
|
@@ -531,6 +531,7 @@ class XConvGeN(GanBaseClass):
|
|
|
gen_loss_history = convGeN.fit(samples, verbose=0, batch_size=batchSize)
|
|
gen_loss_history = convGeN.fit(samples, verbose=0, batch_size=batchSize)
|
|
|
loss_history.append(gen_loss_history.history['loss'])
|
|
loss_history.append(gen_loss_history.history['loss'])
|
|
|
|
|
|
|
|
|
|
+ self.progressBar(1.0)
|
|
|
|
|
|
|
|
## When done: print some statistics.
|
|
## When done: print some statistics.
|
|
|
if self.debug:
|
|
if self.debug:
|
|
@@ -568,21 +569,18 @@ class XConvGeN(GanBaseClass):
|
|
|
self.maj_min_discriminator.trainable = False
|
|
self.maj_min_discriminator.trainable = False
|
|
|
|
|
|
|
|
def progressBar(self, x):
|
|
def progressBar(self, x):
|
|
|
- x = [int(v * 10) for v in x]
|
|
|
|
|
- if True not in [self.lastProgress[i] != x[i] for i in range(len(self.lastProgress))]:
|
|
|
|
|
|
|
+ barWidth = 40
|
|
|
|
|
+
|
|
|
|
|
+ x = int(x * barWidth)
|
|
|
|
|
+ if self.lastProgress == x:
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
def bar(v):
|
|
def bar(v):
|
|
|
- r = ""
|
|
|
|
|
- for n in range(10):
|
|
|
|
|
- if n > v:
|
|
|
|
|
- r += " "
|
|
|
|
|
- else:
|
|
|
|
|
- r += "="
|
|
|
|
|
|
|
+ v = min(v, barWidth)
|
|
|
|
|
+ r = ("=" * v) + (" " * (barWidth - v))
|
|
|
return r
|
|
return r
|
|
|
|
|
|
|
|
- s = [bar(v) for v in x]
|
|
|
|
|
- print(f"[{s[0]}] [{s[1]}] [{s[2]}]", end="\r")
|
|
|
|
|
|
|
+ print(f"[{bar(x)}]", end="\r")
|
|
|
|
|
|
|
|
def correct_feature_types(self):
|
|
def correct_feature_types(self):
|
|
|
# batch[0] = original points (gen x n_feat)
|
|
# batch[0] = original points (gen x n_feat)
|