|
|
@@ -104,7 +104,7 @@ class NextConvGeN(GanBaseClass):
|
|
|
print(self.cg.summary())
|
|
|
print('\n')
|
|
|
|
|
|
- def train(self, data, discTrainCount=5):
|
|
|
+ def train(self, data, discTrainCount=5, batchSize=8):
|
|
|
"""
|
|
|
Trains the Network.
|
|
|
|
|
|
@@ -129,11 +129,11 @@ class NextConvGeN(GanBaseClass):
|
|
|
self.timing["NbhSearch"].start()
|
|
|
# Precalculate neighborhoods
|
|
|
self.nmbMin = NNSearch(self.neb).fit(haystack=normalizedData)
|
|
|
- self.nmbMin.basePoints = data
|
|
|
+ self.nmbMin.basePoints = np.array([ [x.astype(np.float32) for x in p] for p in data])
|
|
|
self.timing["NbhSearch"].stop()
|
|
|
|
|
|
# Do the training.
|
|
|
- self._rough_learning(data, discTrainCount)
|
|
|
+ self._rough_learning(data, discTrainCount, batchSize=batchSize)
|
|
|
|
|
|
# Neighborhood in majority class is no longer needed. So save memory.
|
|
|
self.isTrained = True
|
|
|
@@ -318,7 +318,7 @@ class NextConvGeN(GanBaseClass):
|
|
|
|
|
|
|
|
|
# Training
|
|
|
- def _rough_learning(self, data, discTrainCount):
|
|
|
+ def _rough_learning(self, data, discTrainCount, batchSize=8):
|
|
|
generator = self.conv_sample_generator
|
|
|
discriminator = self.maj_min_discriminator
|
|
|
convGeN = self.cg
|
|
|
@@ -367,14 +367,39 @@ class NextConvGeN(GanBaseClass):
|
|
|
discriminator.trainable = False
|
|
|
self.timing["Fit"].stop()
|
|
|
|
|
|
-
|
|
|
+ def genSamples():
|
|
|
+ for min_idx in range(minSetSize):
|
|
|
+ samples = createSamples(min_idx)
|
|
|
+ for x in samples[0]:
|
|
|
+ yield x
|
|
|
+
|
|
|
+ for x in samples[1]:
|
|
|
+ yield x
|
|
|
+
|
|
|
+ def genLabels():
|
|
|
+ for min_idx in range(minSetSize):
|
|
|
+ for x in labels:
|
|
|
+ yield x
|
|
|
+
|
|
|
+
|
|
|
|
|
|
for neb_epoch_count in range(self.neb_epochs):
|
|
|
- if discTrainCount > 0:
|
|
|
- for n in range(discTrainCount):
|
|
|
- for min_idx in range(minSetSize):
|
|
|
- self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, (min_idx + 1) / minSetSize])
|
|
|
- trainDiscriminator(createSamples(min_idx))
|
|
|
+ for n in range(max(0,discTrainCount)):
|
|
|
+ self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, n / discTrainCount, 0.5])
|
|
|
+ samples = genSamples()
|
|
|
+
|
|
|
+ a = tf.data.Dataset.from_generator(genSamples, output_types=tf.float32)
|
|
|
+ b = tf.data.Dataset.from_generator(genLabels, output_types=tf.float32)
|
|
|
+ samples = tf.data.Dataset.zip((a, b)).batch(batchSize * 2 * self.gen)
|
|
|
+
|
|
|
+ self.timing["Fit"].start()
|
|
|
+ ## switch on discriminator training
|
|
|
+ discriminator.trainable = True
|
|
|
+ ## train the discriminator with the concatenated samples and the one-hot encoded labels
|
|
|
+ discriminator.fit(x=samples, verbose=0)
|
|
|
+ ## switch off the discriminator training again
|
|
|
+ discriminator.trainable = False
|
|
|
+ self.timing["Fit"].stop()
|
|
|
|
|
|
for min_idx in range(minSetSize):
|
|
|
self.progressBar([(neb_epoch_count + 1) / self.neb_epochs, 1.0, (min_idx + 1) / minSetSize])
|
|
|
@@ -415,7 +440,7 @@ class NextConvGeN(GanBaseClass):
|
|
|
## gen -> convex combinations generated from each neighbourhood
|
|
|
self.timing["BMB"].start()
|
|
|
indices = [i for i in range(self.minSetSize) if i not in min_idxs]
|
|
|
- r = np.array([ [x.astype(np.float32) for x in self.nmbMin.basePoints[i]] for i in shuffle(indices)[0:self.gen]])
|
|
|
+ r = self.nmbMin.basePoints[shuffle(indices)[0:self.gen]]
|
|
|
self.timing["BMB"].stop()
|
|
|
return r
|
|
|
|