瀏覽代碼

Reorganized training function.

Kristian Schultz 4 年之前
父節點
當前提交
35bb41b681
共有 1 個文件被更改,包括 66 次插入66 次删除
  1. 66 66
      library/convGAN.py

+ 66 - 66
library/convGAN.py

@@ -60,11 +60,12 @@ class ConvGAN(GanBaseClass):
     This is a toy example of a GAN.
     It repeats the first point of the training-data-set.
     """
-    def __init__(self, neb, gen):
+    def __init__(self, neb, gen, debug=False):
         self.isTrained = False
         self.neb = neb
         self.gen = gen
         self.loss_history = None
+        self.debug = debug
 
     def reset(self):
         """
@@ -93,15 +94,7 @@ class ConvGAN(GanBaseClass):
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
         # TODO: do actually training
-        self.conv_sample_generator, self.maj_min_discriminator_r , self.cg , self.loss_history = rough_learning(
-                neb_epochs,
-                dataSet.data1,
-                dataSet.data0,
-                self.neb,
-                self.gen,
-                self.conv_sample_generator,
-                self.maj_min_discriminator,
-                self.cg)
+        self.rough_learning(neb_epochs, dataSet.data1, dataSet.data0)
 
         self.isTrained = True
 
@@ -127,6 +120,69 @@ class ConvGAN(GanBaseClass):
         return np.array(syntheticPoints)
 
 
+    # Hidden internal functions
+
+    # Training
+    def _rough_learning(self, neb_epochs, data_min):
+        generator = self.conv_sample_generator
+        discriminator = self.maj_min_discriminator
+        GAN = self.cg
+        loss_history=[] ## this is for stroring the loss for every run
+        min_idx = 0
+        neb_epoch_count = 1
+        
+        labels = []
+        for i in range(2 * self.gen):
+            if i < gen:
+                labels.append(np.array([1,0]))
+            else:
+                labels.append(np.array([0,1]))
+        labels = np.array(labels)
+        labels = tf.convert_to_tensor(labels)
+        
+        
+        for step in range(neb_epochs * len(data_min)):
+            min_batch = NMB_guided(data_min, self.neb, min_idx) ## generate minority neighbourhood batch for every minority class sampls by index
+            min_idx = min_idx + 1 
+            maj_batch = BMB(data_min,data_maj, self.neb, self.gen) ## generate random proximal majority batch 
+
+            conv_samples = generator.predict(min_batch) ## generate synthetic samples from convex space of minority neighbourhood batch using generator
+            concat_sample = tf.concat([conv_samples, maj_batch], axis=0) ## concatenate them with the majority batch
+
+            discriminator.trainable = True ## switch on discriminator training
+            discriminator.fit(x=concat_sample, y=labels, verbose=0) ## train the discriminator with the concatenated samples and the one-hot encoded labels 
+            discriminator.trainable = False ## switch off the discriminator training again
+
+            gan_loss_history = GAN.fit(concat_sample, y=labels, verbose=0) ## use the GAN to make the generator learn on the decisions made by the previous discriminator training
+
+            loss_history.append(gan_loss_history.history['loss']) ## store the loss for the step
+
+            if self.debug and ((step + 1) % 10 == 0):
+                print(f"{step + 1} neighbourhood batches trained; running neighbourhood epoch {neb_epoch_count}")
+
+            if min_idx == len(data_min) - 1:
+                if self.debug:
+                    print(f"Neighbourhood epoch {neb_epoch_count} complete")
+                neb_epoch_count = neb_epoch_count + 1
+                min_idx = 0
+
+        if self.debug:
+            run_range = range(1, len(loss_history) + 1)
+            plt.rcParams["figure.figsize"] = (16,10)
+            plt.xticks(fontsize=20)
+            plt.yticks(fontsize=20)
+            plt.xlabel('runs', fontsize=25)
+            plt.ylabel('loss', fontsize=25)
+            plt.title('Rough learning loss for discriminator', fontsize=25)
+            plt.plot(run_range, loss_history)
+            plt.show()
+
+        self.conv_sample_generator = generator
+        self.maj_min_discriminator = discriminator
+        self.cg = GAN
+        self.loss_history = loss_history
+
+
 
 ## convGAN
 def unison_shuffled_copies(a, b,seed_perm):
@@ -244,62 +300,6 @@ def convGAN(generator,discriminator):
 ## this is the first training phase for the discriminator and the only training phase for the generator.
 
 
-def rough_learning(neb_epochs,data_min,data_maj,neb,gen,generator, discriminator,GAN):
-
-    
-    step=1
-    loss_history=[] ## this is for stroring the loss for every run
-    min_idx=0
-    neb_epoch_count=1
-    
-    labels=[]
-    for i in range(2*gen):
-        if i<gen:
-            labels.append(np.array([1,0]))
-        else:
-            labels.append(np.array([0,1]))
-    labels=np.array(labels)
-    labels=tf.convert_to_tensor(labels)
-    
-    
-    while step<(neb_epochs*len(data_min)):
-
-        
-        min_batch=NMB_guided(data_min, neb, min_idx) ## generate minority neighbourhood batch for every minority class sampls by index
-        min_idx=min_idx+1 
-        maj_batch=BMB(data_min,data_maj,neb,gen) ## generate random proximal majority batch 
-
-        conv_samples=generator.predict(min_batch) ## generate synthetic samples from convex space of minority neighbourhood batch using generator
-        concat_sample=tf.concat([conv_samples,maj_batch],axis=0) ## concatenate them with the majority batch
-
-        discriminator.trainable=True ## switch on discriminator training
-        discriminator.fit(x=concat_sample,y=labels,verbose=0) ## train the discriminator with the concatenated samples and the one-hot encoded labels 
-        discriminator.trainable=False ## switch off the discriminator training again
-
-        gan_loss_history=GAN.fit(concat_sample,y=labels,verbose=0) ## use the GAN to make the generator learn on the decisions made by the previous discriminator training
-
-        loss_history.append(gan_loss_history.history['loss']) ## store the loss for the step
-
-        if step%10 == 0:
-            print(str(step)+' neighbourhood batches trained; running neighbourhood epoch ' + str(neb_epoch_count))
-
-        if min_idx==len(data_min)-1:
-            print(str('Neighbourhood epoch '+ str(neb_epoch_count) +' complete'))
-            neb_epoch_count=neb_epoch_count+1
-            min_idx=0
-
-
-        step=step+1
-    run_range=range(1,len(loss_history)+1)
-    plt.rcParams["figure.figsize"] = (16,10)
-    plt.xticks(fontsize=20)
-    plt.yticks(fontsize=20)
-    plt.xlabel('runs',fontsize=25)
-    plt.ylabel('loss', fontsize=25)
-    plt.title('Rough learning loss for discriminator', fontsize=25)
-    plt.plot(run_range, loss_history)
-    plt.show()
-    return generator, discriminator, GAN, loss_history