Bladeren bron

Added synthetic point generation.

Kristian Schultz 4 jaren geleden
bovenliggende
commit
a738349d46
1 gewijzigde bestanden met toevoegingen van 47 en 37 verwijderingen
  1. 47 37
      library/convGAN.py

+ 47 - 37
library/convGAN.py

@@ -68,6 +68,7 @@ class ConvGAN(GanBaseClass):
         self.gen = gen
         self.loss_history = None
         self.debug = debug
+        self.dataSet = None
 
     def reset(self):
         """
@@ -95,9 +96,8 @@ class ConvGAN(GanBaseClass):
         if dataSet.data1.shape[0] <= 0:
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
-        # TODO: do actually training
+        self.dataSet = dataSet
         self._rough_learning(neb_epochs, dataSet.data1, dataSet.data0)
-
         self.isTrained = True
 
     def generateDataPoint(self):
@@ -116,13 +116,48 @@ class ConvGAN(GanBaseClass):
         if not self.isTrained:
             raise ValueError("Try to generate data with untrained Re.")
 
+        data_min = self.dataSet.data1
+        data_maj = self.dataSet.data0
+        neb = self.neb
+
+        # ---
 
-        syntheticPoints = [] # TODO
+        ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
+        synth_num = (numOfSamples // len(data_min)) + 1
 
-        return np.array(syntheticPoints)
+        ## generate synth_num synthetic samples from each minority neighbourhood
+        synth_set=[]
+        for i in range(len(data_min)):
+            synth_set.extend(self.generate_data_for_min_point(data_min, i, synth_num))
+    
+        synth_set = synth_set[:numOfSamples] ## extract the exact number of synthetic samples needed to exactly balance the two classes
 
+        return np.array(synth_set)
 
+    # ###############################################################
     # Hidden internal functions
+    # ###############################################################
+
+    def _generate_data_for_min_point(self, data_min, index, synth_num, generator):
+        """
+        generate synth_num synthetic points for a particular minoity sample 
+        synth_num -> required number of data points that can be generated from a neighbourhood
+        data_min -> minority class data
+        neb -> oversampling neighbourhood
+        index -> index of the minority sample in a training data whose neighbourhood we want to obtain
+        """
+
+        runs = int(synth_num / self.neb) + 1
+        synth_set = []
+        for run in range(runs):
+            batch = self._NMB_guided(data_min, index)
+            synth_batch = self.conv_sample_generator.predict(batch)
+            for x in synth_batch:
+                synth_set.append(x)
+        
+        return synth_set[:synth_num]
+
+
 
     # Training
     def _rough_learning(self, neb_epochs, data_min, data_maj):
@@ -220,8 +255,7 @@ class ConvGAN(GanBaseClass):
         
         neigh = NearestNeighbors(self.neb)
         neigh.fit(data_min)
-        ind = index
-        nmbi = neigh.kneighbors([data_min[ind]], self.neb, return_distance=False)
+        nmbi = neigh.kneighbors([data_min[index]], self.neb, return_distance=False)
         nmbi = shuffle(nmbi)
         nmb = data_min[nmbi]
         nmb = tf.convert_to_tensor(nmb[0])
@@ -314,39 +348,15 @@ def rough_learning_predictions(discriminator,test_data_numpy,test_labels_numpy):
     return c,f,pr,rc,k
 
 
-def generate_data_for_min_point(data_min,neb,index,synth_num,generator):
-    
-    ## generate synth_num synthetic points for a particular minoity sample 
-    ## synth_num -> required number of data points that can be generated from a neighbourhood
-    ## data_min -> minority class data
-    ## neb -> oversampling neighbourhood
-    ## index -> index of the minority sample in a training data whose neighbourhood we want to obtain
-    
-    runs=int(synth_num/neb)+1
-    synth_set=[]
-    for run in range(runs):
-        batch=NMB_guided(data_min, neb, index)
-        synth_batch=generator.predict(batch)
-        for i in range(len(synth_batch)):
-            synth_set.append(synth_batch[i])
-    synth_set=synth_set[:synth_num]
-    synth_set=np.array(synth_set)
-    return(synth_set)
-
-
-def generate_synthetic_data(data_min,data_maj,neb,generator):
-    
+
+
+def generate_synthetic_data(gan, data_min, data_maj):
     ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
     synth_num=((len(data_maj)-len(data_min))//len(data_min))+1
 
     ## generate synth_num synthetic samples from each minority neighbourhood
-    synth_set=[]
-    for i in range(len(data_min)):
-        synth_i=generate_data_for_min_point(data_min,neb,i,synth_num,generator)
-        for k in range(len(synth_i)):
-            synth_set.append(synth_i[k])
-    synth_set=synth_set[:(len(data_maj)-len(data_min))] ## extract the exact number of synthetic samples needed to exactly balance the two classes
-    synth_set=np.array(synth_set)
+    synth_set = gan.generateData(synth_num)
+
     ovs_min_class=np.concatenate((data_min,synth_set),axis=0)
     ovs_training_dataset=np.concatenate((ovs_min_class,data_maj),axis=0)
     ovs_pca_labels=np.concatenate((np.zeros(len(data_min)),np.zeros(len(synth_set))+1,np.zeros(len(data_maj))+2))
@@ -461,11 +471,11 @@ def convGAN_train_end_to_end(training_data,training_labels,test_data,test_labels
     print('\n')
     
     ## rough learning results
-    c_r,f_r,pr_r,rc_r,k_r=rough_learning_predictions(gan.maj_min_discriminator_r, test_data, test_labels)
+    c_r,f_r,pr_r,rc_r,k_r = rough_learning_predictions(gan.maj_min_discriminator_r, test_data, test_labels)
     print('\n')
     
     ## generate synthetic data
-    ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh=generate_synthetic_data(data_min, data_maj, gan.neb, gan.conv_sample_generator)
+    ovs_training_dataset, ovs_pca_labels, ovs_training_labels_oh = generate_synthetic_data(gan, data_min, data_maj)
     print('\n')
     
     ## final training results