Procházet zdrojové kódy

Transfered the optimization steps to the convGAN.py code.

Kristian Schultz před 4 roky
rodič
revize
a7e321213f
2 změnil soubory, kde provedl 21 přidání a 21 odebrání
  1. 19 20
      library/convGAN.py
  2. 2 1
      library/convGAN2.py

+ 19 - 20
library/convGAN.py

@@ -25,6 +25,8 @@ import tensorflow as tf
 from tensorflow.keras.optimizers import Adam
 from tensorflow.keras.layers import Lambda
 
+from library.NNSearch import NNSearch
+
 import warnings
 warnings.filterwarnings("ignore")
 
@@ -96,6 +98,7 @@ class ConvGAN(GanBaseClass):
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
         self.dataSet = dataSet
+        self.nmb = self._NMB_prepare(dataSet.data1)
         self._rough_learning(dataSet.data1, dataSet.data0)
         self.isTrained = True
 
@@ -123,11 +126,12 @@ class ConvGAN(GanBaseClass):
         ## generate synth_num synthetic samples from each minority neighbourhood
         synth_set=[]
         for i in range(len(data_min)):
-            synth_set.extend(self._generate_data_for_min_point(data_min, i, synth_num))
+            synth_set.extend(self._generate_data_for_min_point(i, synth_num))
 
-        synth_set = synth_set[:numOfSamples] ## extract the exact number of synthetic samples needed to exactly balance the two classes
+        ## extract the exact number of synthetic samples needed to exactly balance the two classes
+        synth_set = np.array(synth_set[:numOfSamples]) 
 
-        return np.array(synth_set)
+        return synth_set
 
     # ###############################################################
     # Hidden internal functions
@@ -249,7 +253,7 @@ class ConvGAN(GanBaseClass):
         return model
 
     # Create synthetic points
-    def _generate_data_for_min_point(self, data_min, index, synth_num):
+    def _generate_data_for_min_point(self, index, synth_num):
         """
         generate synth_num synthetic points for a particular minoity sample
         synth_num -> required number of data points that can be generated from a neighbourhood
@@ -261,10 +265,9 @@ class ConvGAN(GanBaseClass):
         runs = int(synth_num / self.neb) + 1
         synth_set = []
         for _run in range(runs):
-            batch = self._NMB_guided(data_min, index)
+            batch = self._NMB_guided(index)
             synth_batch = self.conv_sample_generator.predict(batch)
-            for x in synth_batch:
-                synth_set.append(x)
+            synth_set.extend(synth_batch)
 
         return synth_set[:synth_num]
 
@@ -283,7 +286,7 @@ class ConvGAN(GanBaseClass):
 
         for step in range(self.neb_epochs * len(data_min)):
             ## generate minority neighbourhood batch for every minority class sampls by index
-            min_batch = self._NMB_guided(data_min, min_idx)
+            min_batch = self._NMB_guided(min_idx)
             min_idx = min_idx + 1
             ## generate random proximal majority batch
             maj_batch = self._BMB(data_min, data_maj)
@@ -345,20 +348,17 @@ class ConvGAN(GanBaseClass):
         ## neb -> oversampling neighbourhood
         ## gen -> convex combinations generated from each neighbourhood
 
-        neigh = NearestNeighbors(self.neb)
-        neigh.fit(data_maj)
-        bmbi = [
-            neigh.kneighbors([data_min[i]], self.neb, return_distance=False)
-            for i in range(len(data_min))
-            ]
-        bmbi = np.unique(np.array(bmbi).flatten())
-        bmbi = shuffle(bmbi)
         return tf.convert_to_tensor(
             data_maj[np.random.randint(len(data_maj), size=self.gen)]
             )
 
+    def _NMB_prepare(self, data_min):
+        neigh = NNSearch(self.neb)
+        neigh.fit(data_min)
+        return (data_min, neigh)
+
 
-    def _NMB_guided(self, data_min, index):
+    def _NMB_guided(self, index):
 
         ## generate a minority neighbourhood batch for a particular minority sample
         ## we need this for minority data generation
@@ -366,10 +366,9 @@ class ConvGAN(GanBaseClass):
         ## index -> index of the minority sample in a training data whose neighbourhood we want to obtain
         ## data_min -> minority class data
         ## neb -> oversampling neighbourhood
+        (data_min, neigh) = self.nmb
 
-        neigh = NearestNeighbors(self.neb)
-        neigh.fit(data_min)
-        nmbi = neigh.kneighbors([data_min[index]], self.neb, return_distance=False)
+        nmbi = np.array([neigh.neighbourhoodOfItem(index)])
         nmbi = shuffle(nmbi)
         nmb = data_min[nmbi]
         nmb = tf.convert_to_tensor(nmb[0])

+ 2 - 1
library/convGAN2.py

@@ -141,7 +141,8 @@ class ConvGAN2(GanBaseClass):
         for i in range(len(data_min)):
             synth_set.extend(self._generate_data_for_min_point(i, synth_num))
 
-        synth_set = np.array(synth_set[:numOfSamples]) ## extract the exact number of synthetic samples needed to exactly balance the two classes
+        ## extract the exact number of synthetic samples needed to exactly balance the two classes
+        synth_set = np.array(synth_set[:numOfSamples]) 
         self.timing["create points"].stop()
 
         return synth_set