Преглед изворни кода

The fiiting for the nearest neighbourhood search is only once needed.

Kristian Schultz пре 4 година
родитељ
комит
5be80037b6
1 измењених фајлова са 2 додато и 7 уклоњено
  1. 2 7
      library/convGAN2.py

+ 2 - 7
library/convGAN2.py

@@ -109,6 +109,7 @@ class ConvGAN2(GanBaseClass):
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
         self.dataSet = dataSet
+        self.nmb = self._NMB_prepare(dataSet.data1)
         self._rough_learning(dataSet.data1, dataSet.data0)
         self.isTrained = True
         self.timing["train"].stop()
@@ -137,9 +138,8 @@ class ConvGAN2(GanBaseClass):
 
         ## generate synth_num synthetic samples from each minority neighbourhood
         synth_set=[]
-        nmb = self._NMB_prepare(data_min)
         for i in range(len(data_min)):
-            synth_set.extend(self._generate_data_for_min_point(nmb, i, synth_num))
+            synth_set.extend(self._generate_data_for_min_point(self.nmb, i, synth_num))
 
         synth_set = np.array(synth_set[:numOfSamples]) ## extract the exact number of synthetic samples needed to exactly balance the two classes
         self.timing["create points"].stop()
@@ -284,8 +284,6 @@ class ConvGAN2(GanBaseClass):
             synth_batch = self.conv_sample_generator.predict(batch)
             self.timing["predict"].stop()
             synth_set.extend(synth_batch)
-            #for x in synth_batch:
-            #    synth_set.append(x)
 
         self.timing["_generate_data_for_min_point"].stop()
 
@@ -304,7 +302,6 @@ class ConvGAN2(GanBaseClass):
 
         labels = tf.convert_to_tensor(create01Labels(2 * self.gen, self.gen))
 
-        nmb = self._NMB_prepare(data_min)
         for step in range(self.neb_epochs * len(data_min)):
             ## generate minority neighbourhood batch for every minority class sampls by index
             min_batch = self._NMB_guided(nmb, min_idx)
@@ -380,7 +377,6 @@ class ConvGAN2(GanBaseClass):
         self.timing["NMB"].start()
         t = time.time()
         neigh = NNSearch(self.neb)
-        #neigh = NearestNeighbors(self.neb)
         neigh.fit(data_min)
         self.tNbhFit += (time.time() - t)
         self.nNbhFit += 1
@@ -400,7 +396,6 @@ class ConvGAN2(GanBaseClass):
         (data_min, neigh) = nmb
 
         t = time.time()
-        #nmbi = neigh.kneighbors([data_min[index]], self.neb, return_distance=False)
         nmbi = np.array([neigh.neighbourhoodOfItem(index)])
         self.tNbhSearch += (time.time() - t)
         self.nNbhSearch += 1