|
|
@@ -60,16 +60,27 @@ class ConvGAN(GanBaseClass):
|
|
|
This is a toy example of a GAN.
|
|
|
It repeats the first point of the training-data-set.
|
|
|
"""
|
|
|
- def __init__(self):
|
|
|
+ def __init__(self, neb, gen):
|
|
|
self.isTrained = False
|
|
|
+ self.neb = neb
|
|
|
+ self.gen = gen
|
|
|
+ self.loss_history = None
|
|
|
|
|
|
def reset(self):
|
|
|
"""
|
|
|
Resets the trained GAN to an random state.
|
|
|
"""
|
|
|
self.isTrained = False
|
|
|
+ ## instanciate generator network and visualize architecture
|
|
|
+ self.conv_sample_generator = conv_sample_gen()
|
|
|
|
|
|
- def train(self, dataSet):
|
|
|
+ ## instanciate discriminator network and visualize architecture
|
|
|
+ self.maj_min_discriminator = maj_min_disc()
|
|
|
+
|
|
|
+ ## instanciate network and visualize architecture
|
|
|
+ self.cg = convGAN(self.conv_sample_generator, self.maj_min_discriminator)
|
|
|
+
|
|
|
+ def train(self, dataSet, neb_epochs=5):
|
|
|
"""
|
|
|
Trains the GAN.
|
|
|
|
|
|
@@ -82,6 +93,15 @@ class ConvGAN(GanBaseClass):
|
|
|
raise AttributeError("Train: Expected data class 1 to contain at least one point.")
|
|
|
|
|
|
# TODO: do actually training
|
|
|
+ self.conv_sample_generator, self.maj_min_discriminator_r , self.cg , self.loss_history = rough_learning(
|
|
|
+ neb_epochs,
|
|
|
+ dataSet.data1,
|
|
|
+ dataSet.data0,
|
|
|
+ self.neb,
|
|
|
+ self.gen,
|
|
|
+ self.conv_sample_generator,
|
|
|
+ self.maj_min_discriminator,
|
|
|
+ self.cg)
|
|
|
|
|
|
self.isTrained = True
|
|
|
|
|
|
@@ -423,19 +443,23 @@ def convGAN_train_end_to_end(training_data,training_labels,test_data,test_labels
|
|
|
##majority class
|
|
|
data_maj=training_data[np.where(training_labels == 0)[0]]
|
|
|
|
|
|
+ dataSet = DataSet(data0=data_maj, data1=data_min)
|
|
|
+
|
|
|
+ gan = ConvGAN(neb, gen)
|
|
|
+ gan.reset()
|
|
|
|
|
|
## instanciate generator network and visualize architecture
|
|
|
- conv_sample_generator=conv_sample_gen()
|
|
|
+ conv_sample_generator = gan.conv_sample_generator
|
|
|
print(conv_sample_generator.summary())
|
|
|
print('\n')
|
|
|
|
|
|
## instanciate discriminator network and visualize architecture
|
|
|
- maj_min_discriminator=maj_min_disc()
|
|
|
+ maj_min_discriminator = self.maj_min_discriminator
|
|
|
print(maj_min_discriminator.summary())
|
|
|
print('\n')
|
|
|
|
|
|
## instanciate network and visualize architecture
|
|
|
- cg=convGAN(conv_sample_generator, maj_min_discriminator)
|
|
|
+ cg = self.cg
|
|
|
print(cg.summary())
|
|
|
print('\n')
|
|
|
|
|
|
@@ -443,19 +467,19 @@ def convGAN_train_end_to_end(training_data,training_labels,test_data,test_labels
|
|
|
print('\n')
|
|
|
|
|
|
## train gan generator ## rough_train_discriminator
|
|
|
- conv_sample_generator, maj_min_discriminator_r ,cg , loss_history=rough_learning(neb_epochs, data_min,data_maj, neb, gen, conv_sample_generator, maj_min_discriminator, cg)
|
|
|
+ gan.train(dataSet, neb_epochs)
|
|
|
print('\n')
|
|
|
|
|
|
## rough learning results
|
|
|
- c_r,f_r,pr_r,rc_r,k_r=rough_learning_predictions(maj_min_discriminator_r, test_data,test_labels)
|
|
|
+ c_r,f_r,pr_r,rc_r,k_r=rough_learning_predictions(gan.maj_min_discriminator_r, test_data, test_labels)
|
|
|
print('\n')
|
|
|
|
|
|
## generate synthetic data
|
|
|
- ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh=generate_synthetic_data(data_min, data_maj, neb, conv_sample_generator)
|
|
|
+ ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh=generate_synthetic_data(data_min, data_maj, gan.neb, gan.conv_sample_generator)
|
|
|
print('\n')
|
|
|
|
|
|
## final training results
|
|
|
- c,f,pr,rc,k=final_learning(maj_min_discriminator, ovs_training_dataset, ovs_training_labels_oh, test_data, test_labels, epochs_retrain_disc)
|
|
|
+ c,f,pr,rc,k=final_learning(gan.maj_min_discriminator, ovs_training_dataset, ovs_training_labels_oh, test_data, test_labels, epochs_retrain_disc)
|
|
|
|
|
|
return ((c_r,f_r,pr_r,rc_r,k_r),(c,f,pr,rc,k))
|
|
|
|