Przeglądaj źródła

Fixed calls with locally neb variable.

Kristian Schultz 4 lat temu
rodzic
commit
f73b89e22a
1 zmienionych plików z 7 dodań i 7 usunięć
  1. 7 7
      library/convGAN2.py

+ 7 - 7
library/convGAN2.py

@@ -139,7 +139,7 @@ class ConvGAN2(GanBaseClass):
         ## generate synth_num synthetic samples from each minority neighbourhood
         synth_set=[]
         for i in range(len(data_min)):
-            synth_set.extend(self._generate_data_for_min_point(self.nmb, i, synth_num))
+            synth_set.extend(self._generate_data_for_min_point(i, synth_num))
 
         synth_set = np.array(synth_set[:numOfSamples]) ## extract the exact number of synthetic samples needed to exactly balance the two classes
         self.timing["create points"].stop()
@@ -266,7 +266,7 @@ class ConvGAN2(GanBaseClass):
         return model
 
     # Create synthetic points
-    def _generate_data_for_min_point(self, nmb, index, synth_num):
+    def _generate_data_for_min_point(self, index, synth_num):
         """
         generate synth_num synthetic points for a particular minoity sample
         synth_num -> required number of data points that can be generated from a neighbourhood
@@ -279,7 +279,7 @@ class ConvGAN2(GanBaseClass):
         runs = int(synth_num / self.neb) + 1
         synth_set = []
         for _run in range(runs):
-            batch = self._NMB_guided(nmb, index)
+            batch = self._NMB_guided(index)
             self.timing["predict"].start()
             synth_batch = self.conv_sample_generator.predict(batch)
             self.timing["predict"].stop()
@@ -304,7 +304,7 @@ class ConvGAN2(GanBaseClass):
 
         for step in range(self.neb_epochs * len(data_min)):
             ## generate minority neighbourhood batch for every minority class sampls by index
-            min_batch = self._NMB_guided(nmb, min_idx)
+            min_batch = self._NMB_guided(min_idx)
             min_idx = min_idx + 1
             ## generate random proximal majority batch
             maj_batch = self._BMB(data_min, data_maj)
@@ -376,7 +376,7 @@ class ConvGAN2(GanBaseClass):
     def _NMB_prepare(self, data_min):
         self.timing["NMB"].start()
         t = time.time()
-        neigh = NNSearch(self.neb)
+        neigh = NNSearch(self.neb, timingDict=self.timing)
         neigh.fit(data_min)
         self.tNbhFit += (time.time() - t)
         self.nNbhFit += 1
@@ -384,7 +384,7 @@ class ConvGAN2(GanBaseClass):
         return (data_min, neigh)
 
 
-    def _NMB_guided(self, nmb, index):
+    def _NMB_guided(self, index):
 
         ## generate a minority neighbourhood batch for a particular minority sample
         ## we need this for minority data generation
@@ -393,7 +393,7 @@ class ConvGAN2(GanBaseClass):
         ## data_min -> minority class data
         ## neb -> oversampling neighbourhood
         self.timing["NMB"].start()
-        (data_min, neigh) = nmb
+        (data_min, neigh) = self.nmb
 
         t = time.time()
         nmbi = np.array([neigh.neighbourhoodOfItem(index)])