|
@@ -63,7 +63,7 @@ class ConvGAN(GanBaseClass):
|
|
|
nMinoryPoints = dataSet.data1.shape[0]
|
|
nMinoryPoints = dataSet.data1.shape[0]
|
|
|
if self.nebInitial is None:
|
|
if self.nebInitial is None:
|
|
|
self.neb = nMinoryPoints
|
|
self.neb = nMinoryPoints
|
|
|
- else
|
|
|
|
|
|
|
+ else:
|
|
|
self.neb = min(self.nebInitial, nMinoryPoints)
|
|
self.neb = min(self.nebInitial, nMinoryPoints)
|
|
|
else:
|
|
else:
|
|
|
self.neb = self.nebInitial
|
|
self.neb = self.nebInitial
|
|
@@ -80,6 +80,8 @@ class ConvGAN(GanBaseClass):
|
|
|
self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
|
|
self.cg = self._convGAN(self.conv_sample_generator, self.maj_min_discriminator)
|
|
|
|
|
|
|
|
if self.debug:
|
|
if self.debug:
|
|
|
|
|
+ print(f"neb={self.neb}, gen={self.gen}")
|
|
|
|
|
+
|
|
|
print(self.conv_sample_generator.summary())
|
|
print(self.conv_sample_generator.summary())
|
|
|
print('\n')
|
|
print('\n')
|
|
|
|
|
|
|
@@ -285,7 +287,7 @@ class ConvGAN(GanBaseClass):
|
|
|
synth_set = []
|
|
synth_set = []
|
|
|
for _run in range(runs):
|
|
for _run in range(runs):
|
|
|
batch = self.nmbMin.getNbhPointsOfItem(index)
|
|
batch = self.nmbMin.getNbhPointsOfItem(index)
|
|
|
- synth_batch = self.conv_sample_generator.predict(batch)
|
|
|
|
|
|
|
+ synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
|
|
|
synth_set.extend(synth_batch)
|
|
synth_set.extend(synth_batch)
|
|
|
|
|
|
|
|
return synth_set[:synth_num]
|
|
return synth_set[:synth_num]
|
|
@@ -302,6 +304,7 @@ class ConvGAN(GanBaseClass):
|
|
|
minSetSize = len(data_min)
|
|
minSetSize = len(data_min)
|
|
|
|
|
|
|
|
labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
|
|
labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
|
|
|
|
|
+ nLabels = 2 * self.gen
|
|
|
|
|
|
|
|
for neb_epoch_count in range(self.neb_epochs):
|
|
for neb_epoch_count in range(self.neb_epochs):
|
|
|
if discTrainCount > 0:
|
|
if discTrainCount > 0:
|
|
@@ -315,14 +318,14 @@ class ConvGAN(GanBaseClass):
|
|
|
|
|
|
|
|
## generate synthetic samples from convex space
|
|
## generate synthetic samples from convex space
|
|
|
## of minority neighbourhood batch using generator
|
|
## of minority neighbourhood batch using generator
|
|
|
- conv_samples = generator.predict(min_batch)
|
|
|
|
|
|
|
+ conv_samples = generator.predict(min_batch, batch_size=self.neb)
|
|
|
## concatenate them with the majority batch
|
|
## concatenate them with the majority batch
|
|
|
concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
|
|
concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
|
|
|
|
|
|
|
|
## switch on discriminator training
|
|
## switch on discriminator training
|
|
|
discriminator.trainable = True
|
|
discriminator.trainable = True
|
|
|
## train the discriminator with the concatenated samples and the one-hot encoded labels
|
|
## train the discriminator with the concatenated samples and the one-hot encoded labels
|
|
|
- discriminator.fit(x=concat_sample, y=labels, verbose=0)
|
|
|
|
|
|
|
+ discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=nLabels)
|
|
|
## switch off the discriminator training again
|
|
## switch off the discriminator training again
|
|
|
discriminator.trainable = False
|
|
discriminator.trainable = False
|
|
|
|
|
|
|
@@ -335,21 +338,21 @@ class ConvGAN(GanBaseClass):
|
|
|
|
|
|
|
|
## generate synthetic samples from convex space
|
|
## generate synthetic samples from convex space
|
|
|
## of minority neighbourhood batch using generator
|
|
## of minority neighbourhood batch using generator
|
|
|
- conv_samples = generator.predict(min_batch)
|
|
|
|
|
|
|
+ conv_samples = generator.predict(min_batch, batch_size=self.neb)
|
|
|
## concatenate them with the majority batch
|
|
## concatenate them with the majority batch
|
|
|
concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
|
|
concat_sample = tf.concat([conv_samples, maj_batch], axis=0)
|
|
|
|
|
|
|
|
## switch on discriminator training
|
|
## switch on discriminator training
|
|
|
discriminator.trainable = True
|
|
discriminator.trainable = True
|
|
|
## train the discriminator with the concatenated samples and the one-hot encoded labels
|
|
## train the discriminator with the concatenated samples and the one-hot encoded labels
|
|
|
- discriminator.fit(x=concat_sample, y=labels, verbose=0)
|
|
|
|
|
|
|
+ discriminator.fit(x=concat_sample, y=labels, verbose=0, batch_size=nLabels)
|
|
|
## switch off the discriminator training again
|
|
## switch off the discriminator training again
|
|
|
discriminator.trainable = False
|
|
discriminator.trainable = False
|
|
|
|
|
|
|
|
## use the GAN to make the generator learn on the decisions
|
|
## use the GAN to make the generator learn on the decisions
|
|
|
## made by the previous discriminator training
|
|
## made by the previous discriminator training
|
|
|
##- print(f"concat sample shape: {concat_sample.shape}/{labels.shape}")
|
|
##- print(f"concat sample shape: {concat_sample.shape}/{labels.shape}")
|
|
|
- gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0)
|
|
|
|
|
|
|
+ gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0, batch_size=nLabels)
|
|
|
|
|
|
|
|
## store the loss for the step
|
|
## store the loss for the step
|
|
|
loss_history.append(gan_loss_history.history['loss'])
|
|
loss_history.append(gan_loss_history.history['loss'])
|