|
@@ -408,27 +408,24 @@ class NextConvGeN(GanBaseClass):
|
|
|
yield x
|
|
yield x
|
|
|
|
|
|
|
|
padd = np.zeros((self.gen - self.neb, self.n_feat))
|
|
padd = np.zeros((self.gen - self.neb, self.n_feat))
|
|
|
-
|
|
|
|
|
|
|
+ discTrainCount = 1 + max(0, discTrainCount)
|
|
|
|
|
|
|
|
for neb_epoch_count in range(self.neb_epochs):
|
|
for neb_epoch_count in range(self.neb_epochs):
|
|
|
- shape = (self.gen, self.n_feat)
|
|
|
|
|
-
|
|
|
|
|
- for n in range(1 + max(0,discTrainCount)):
|
|
|
|
|
- self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, 0.5])
|
|
|
|
|
-
|
|
|
|
|
- a = tf.data.Dataset.from_generator(genSamplesForDisc, output_types=tf.float32)
|
|
|
|
|
- a = tf.data.Dataset.from_generator(unbatch(a), output_types=tf.float32)
|
|
|
|
|
- b = tf.data.Dataset.from_tensor_slices(labels).repeat()
|
|
|
|
|
- samples = tf.data.Dataset.zip((a, b)).batch(batchSize * 2 * self.gen)
|
|
|
|
|
-
|
|
|
|
|
- self.timing["Fit"].start()
|
|
|
|
|
- ## switch on discriminator training
|
|
|
|
|
- discriminator.trainable = True
|
|
|
|
|
- ## train the discriminator with the concatenated samples and the one-hot encoded labels
|
|
|
|
|
- discriminator.fit(x=samples, verbose=0)
|
|
|
|
|
- ## switch off the discriminator training again
|
|
|
|
|
- discriminator.trainable = False
|
|
|
|
|
- self.timing["Fit"].stop()
|
|
|
|
|
|
|
+ self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, 0.5, 0.5])
|
|
|
|
|
+
|
|
|
|
|
+ a = tf.data.Dataset.from_generator(genSamplesForDisc, output_types=tf.float32).repeat().take(discTrainCount * self.minSetSize)
|
|
|
|
|
+ a = tf.data.Dataset.from_generator(unbatch(a), output_types=tf.float32)
|
|
|
|
|
+ b = tf.data.Dataset.from_tensor_slices(labels).repeat()
|
|
|
|
|
+ samples = tf.data.Dataset.zip((a, b)).batch(batchSize * 2 * self.gen)
|
|
|
|
|
+
|
|
|
|
|
+ self.timing["Fit"].start()
|
|
|
|
|
+ ## switch on discriminator training
|
|
|
|
|
+ discriminator.trainable = True
|
|
|
|
|
+ ## train the discriminator with the concatenated samples and the one-hot encoded labels
|
|
|
|
|
+ discriminator.fit(x=samples, verbose=0)
|
|
|
|
|
+ ## switch off the discriminator training again
|
|
|
|
|
+ discriminator.trainable = False
|
|
|
|
|
+ self.timing["Fit"].stop()
|
|
|
|
|
|
|
|
# <<<<<<<<<<
|
|
# <<<<<<<<<<
|
|
|
src = tf.data.Dataset.from_generator(genSamplesForGeN, output_types=tf.float32)
|
|
src = tf.data.Dataset.from_generator(genSamplesForGeN, output_types=tf.float32)
|