Kristian Schultz пре 4 година
родитељ
комит
41399bb0ed
1 измењених фајлова са 33 додато и 9 уклоњено
  1. 33 9
      library/convGAN.py

+ 33 - 9
library/convGAN.py

@@ -60,16 +60,27 @@ class ConvGAN(GanBaseClass):
     This is a toy example of a GAN.
     It repeats the first point of the training-data-set.
     """
-    def __init__(self):
+    def __init__(self, neb, gen):
         self.isTrained = False
+        self.neb = neb
+        self.gen = gen
+        self.loss_history = None
 
     def reset(self):
         """
         Resets the trained GAN to an random state.
         """
         self.isTrained = False
+        ## instanciate generator network and visualize architecture
+        self.conv_sample_generator = conv_sample_gen()
 
-    def train(self, dataSet):
+        ## instanciate discriminator network and visualize architecture
+        self.maj_min_discriminator = maj_min_disc()
+
+        ## instanciate network and visualize architecture
+        self.cg = convGAN(self.conv_sample_generator, self.maj_min_discriminator)
+
+    def train(self, dataSet, neb_epochs=5):
         """
         Trains the GAN.
 
@@ -82,6 +93,15 @@ class ConvGAN(GanBaseClass):
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
         # TODO: do actually training
+        self.conv_sample_generator, self.maj_min_discriminator_r , self.cg , self.loss_history = rough_learning(
+                neb_epochs,
+                dataSet.data1,
+                dataSet.data0,
+                self.neb,
+                self.gen,
+                self.conv_sample_generator,
+                self.maj_min_discriminator,
+                self.cg)
 
         self.isTrained = True
 
@@ -423,19 +443,23 @@ def convGAN_train_end_to_end(training_data,training_labels,test_data,test_labels
     ##majority class
     data_maj=training_data[np.where(training_labels == 0)[0]]
 
+    dataSet = DataSet(data0=data_maj, data1=data_min)
+
+    gan = ConvGAN(neb, gen)
+    gan.reset()
    
     ## instanciate generator network and visualize architecture
-    conv_sample_generator=conv_sample_gen()
+    conv_sample_generator = gan.conv_sample_generator 
     print(conv_sample_generator.summary())
     print('\n')
 
     ## instanciate discriminator network and visualize architecture
-    maj_min_discriminator=maj_min_disc()
+    maj_min_discriminator = self.maj_min_discriminator 
     print(maj_min_discriminator.summary())
     print('\n')
 
     ## instanciate network and visualize architecture
-    cg=convGAN(conv_sample_generator, maj_min_discriminator)
+    cg = self.cg
     print(cg.summary())
     print('\n')
     
@@ -443,19 +467,19 @@ def convGAN_train_end_to_end(training_data,training_labels,test_data,test_labels
     print('\n')
 
     ## train gan generator ## rough_train_discriminator
-    conv_sample_generator, maj_min_discriminator_r ,cg , loss_history=rough_learning(neb_epochs, data_min,data_maj, neb, gen, conv_sample_generator, maj_min_discriminator, cg)
+    gan.train(dataSet, neb_epochs)
     print('\n')
     
     ## rough learning results
-    c_r,f_r,pr_r,rc_r,k_r=rough_learning_predictions(maj_min_discriminator_r, test_data,test_labels)
+    c_r,f_r,pr_r,rc_r,k_r=rough_learning_predictions(gan.maj_min_discriminator_r, test_data, test_labels)
     print('\n')
     
     ## generate synthetic data
-    ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh=generate_synthetic_data(data_min, data_maj, neb, conv_sample_generator)
+    ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh=generate_synthetic_data(data_min, data_maj, gan.neb, gan.conv_sample_generator)
     print('\n')
     
     ## final training results
-    c,f,pr,rc,k=final_learning(maj_min_discriminator, ovs_training_dataset, ovs_training_labels_oh, test_data, test_labels, epochs_retrain_disc)
+    c,f,pr,rc,k=final_learning(gan.maj_min_discriminator, ovs_training_dataset, ovs_training_labels_oh, test_data, test_labels, epochs_retrain_disc)
     
     return ((c_r,f_r,pr_r,rc_r,k_r),(c,f,pr,rc,k))