소스 검색

Replaced shuffle by faster version.

Kristian Schultz 3 년 전
부모
커밋
948e38c7b4
2개의 변경된 파일39개의 추가작업 그리고 7개의 파일을 삭제
  1. 32 3
      library/NNSearch.py
  2. 7 4
      library/generators/NextConvGeN.py

+ 32 - 3
library/NNSearch.py

@@ -1,5 +1,5 @@
 import math
-
+import random
 import tensorflow as tf
 import numpy as np
 from sklearn.neighbors import NearestNeighbors
@@ -7,6 +7,34 @@ from sklearn.utils import shuffle
 from library.timing import timing
 
 
+
+def randomIndices(size, outputSize=None, indicesToIgnore=None):
+    indices = list(range(size))
+
+    if indicesToIgnore is not None:
+        for x in indicesToIgnore:
+            indices.remove(x)
+
+    size = len(indices)
+    if outputSize is None or outputSize > size:
+        outputSize = size
+
+    r = []
+    for _ in range(outputSize):
+        size -= 1
+        if size < 0:
+            break
+        if size == 0:
+            r.append(indices[0])
+        else:
+            p = random.randint(0, size)
+            x = indices[p]
+            r.append(x)
+            indices.remove(x)
+    
+    return r
+
+
 class NNSearch:
     def __init__(self, nebSize=5, timingDict=None):
         self.nebSize = nebSize
@@ -34,9 +62,10 @@ class NNSearch:
         return self.getPointsFromIndices(self.neighbourhoodOfItem(index))
 
     def getPointsFromIndices(self, indices):
-        nmbi = shuffle(np.array([indices]))
+        permutation = randomIndices(len(indices))
+        nmbi = np.array(indices)[permutation]
         nmb = self.basePoints[nmbi]
-        return tf.convert_to_tensor(nmb[0])
+        return tf.convert_to_tensor(nmb)
 
     def neighbourhoodOfItemList(self, items, maxCount=None):
         nbhIndices = set()

+ 7 - 4
library/generators/NextConvGeN.py

@@ -16,7 +16,7 @@ from tensorflow.keras.layers import Lambda
 
 from sklearn.utils import shuffle
 
-from library.NNSearch import NNSearch
+from library.NNSearch import NNSearch, randomIndices
 
 import warnings
 warnings.filterwarnings("ignore")
@@ -347,7 +347,7 @@ class NextConvGeN(GanBaseClass):
         def indexToBatches(min_idx):
             self.timing["NBH"].start()
             ## generate minority neighbourhood batch for every minority class sampls by index
-            min_batch_indices = shuffle(self.nmbMin.neighbourhoodOfItem(min_idx))
+            min_batch_indices = self.nmbMin.neighbourhoodOfItem(min_idx)
             min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
 
             ## generate random proximal majority batch
@@ -468,8 +468,11 @@ class NextConvGeN(GanBaseClass):
         ## min_idxs -> indices of points in minority class
         ## gen -> convex combinations generated from each neighbourhood
         self.timing["BMB"].start()
-        indices = [i for i in range(self.minSetSize) if i not in min_idxs]
-        r = self.nmbMin.basePoints[shuffle(indices)[0:self.gen]]
+        # indices = [i for i in range(self.minSetSize) if i not in min_idxs]
+        # r = self.nmbMin.basePoints[shuffle(indices)[0:self.gen]]
+
+        indices = randomIndices(self.minSetSize, outputSize=self.gen, indicesToIgnore=min_idxs)
+        r = self.nmbMin.basePoints[indices]
         self.timing["BMB"].stop()
         return r