Kristian Schultz 3 лет назад
Родитель
Сommit
1dee8f0aaf
2 измененных файлов с 214 добавлено и 86 удалено
  1. 133 19
      XConvGeN-Example.ipynb
  2. 81 67
      library/generators/XConvGeN.py

Разница между файлами не показана из-за своего большого размера
+ 133 - 19
XConvGeN-Example.ipynb


+ 81 - 67
library/generators/XConvGeN.py

@@ -172,20 +172,21 @@ class XConvGeN(GanBaseClass):
             .from_generator(neighborhoodGenerator, output_types=tf.float32)
             .from_generator(neighborhoodGenerator, output_types=tf.float32)
             .repeat()
             .repeat()
             )
             )
-        batch = neighborhoods.take(runs * self.minSetSize).batch(32)
+        batch = neighborhoods.take(runs * self.minSetSize)
 
 
-        synth_batch = self.conv_sample_generator.predict(batch)
+        synth_batch = self.conv_sample_generator.predict(batch.batch(32), verbose=0)
 
 
-        n = 0
-        synth_set = []
-        for (x,y) in zip(neighborhoods, synth_batch):
-            synth_set.extend(self.correct_feature_types(x.numpy(), y))
-            n += len(y)
-            if n >= numOfSamples:
-                break
+        pairs = tf.data.Dataset.zip(
+            ( batch
+            , tf.data.Dataset.from_tensor_slices(synth_batch)
+            ))
+
+        corrected = pairs.map(self.correct_feature_types())
 
 
         ## extract the exact number of synthetic samples needed to exactly balance the two classes
         ## extract the exact number of synthetic samples needed to exactly balance the two classes
-        return np.array(synth_set[:numOfSamples])
+        r = np.concatenate(np.array(list(corrected.take(1 + (numOfSamples // self.gen)))), axis=0)[:numOfSamples]
+
+        return r
 
 
     def predictReal(self, data):
     def predictReal(self, data):
         """
         """
@@ -370,7 +371,11 @@ class XConvGeN(GanBaseClass):
         nLabels = 2 * self.gen
         nLabels = 2 * self.gen
         labels = np.array(create01Labels(nLabels, self.gen))
         labels = np.array(create01Labels(nLabels, self.gen))
         labelsGeN = np.array([labels])
         labelsGeN = np.array([labels])
-        
+
+        def getNeighborhoods():
+            for index in range(self.minSetSize):
+                yield indexToBatches(index)
+
         def indexToBatches(min_idx):
         def indexToBatches(min_idx):
             self.timing["NBH"].start()
             self.timing["NBH"].start()
             ## generate minority neighbourhood batch for every minority class sampls by index
             ## generate minority neighbourhood batch for every minority class sampls by index
@@ -383,33 +388,6 @@ class XConvGeN(GanBaseClass):
 
 
             return (min_batch, maj_batch)
             return (min_batch, maj_batch)
 
 
-        def createSamples(min_idx):
-            min_batch, maj_batch = indexToBatches(min_idx)
-
-            self.timing["GenSamples"].start()
-            ## generate synthetic samples from convex space
-            ## of minority neighbourhood batch using generator
-            conv_samples = generator.predict(np.array([min_batch]), batch_size=self.neb, verbose=0)
-            conv_samples = tf.reshape(conv_samples, shape=(self.gen, self.n_feat))
-            self.timing["GenSamples"].stop()
-
-            self.timing["FixType"].start()
-            ## Fix feature types
-            conv_samples = self.correct_feature_types(min_batch.numpy(), conv_samples)
-            self.timing["FixType"].stop()
-
-            ## concatenate them with the majority batch
-            conv_samples = [conv_samples, maj_batch]
-            return conv_samples
-
-        def genSamplesForDisc():
-            for min_idx in range(minSetSize):
-                yield createSamples(min_idx)
-
-        def genSamplesForGeN():
-            for min_idx in range(minSetSize):
-                yield indexToBatches(min_idx)
-
         def unbatch(rows):
         def unbatch(rows):
             def fn():
             def fn():
                 for row in rows:
                 for row in rows:
@@ -432,8 +410,19 @@ class XConvGeN(GanBaseClass):
             ## Training of the discriminator.
             ## Training of the discriminator.
             #
             #
             # Get all neighborhoods and synthetic points as data stream.
             # Get all neighborhoods and synthetic points as data stream.
-            a = tf.data.Dataset.from_generator(genSamplesForDisc, output_types=tf.float32).repeat().take(discTrainCount * self.minSetSize)
-            a = tf.data.Dataset.from_generator(unbatch(a), output_types=tf.float32)
+            nbhPairs = tf.data.Dataset.from_generator(getNeighborhoods, output_types=tf.float32).repeat().take(discTrainCount * self.minSetSize)
+            nbhMin = nbhPairs.map(lambda x: x[0])
+            batchMaj = nbhPairs.map(lambda x: x[1])
+
+            fnCt = self.correct_feature_types()
+            synth_batch = self.conv_sample_generator.predict(nbhMin.batch(32), verbose=0)
+            pairMinMaj = tf.data.Dataset.zip(
+                ( nbhMin
+                , tf.data.Dataset.from_tensor_slices(synth_batch)
+                , batchMaj
+                )).map(lambda x, y, z: [fnCt(x,y), z])
+            
+            a = tf.data.Dataset.from_generator(unbatch(pairMinMaj), output_types=tf.float32)
 
 
             # Get all labels as data stream.
             # Get all labels as data stream.
             b = tf.data.Dataset.from_tensor_slices(labels).repeat()
             b = tf.data.Dataset.from_tensor_slices(labels).repeat()
@@ -453,7 +442,7 @@ class XConvGeN(GanBaseClass):
             #
             #
             # Get all neighborhoods as data stream.
             # Get all neighborhoods as data stream.
             a = (tf.data.Dataset
             a = (tf.data.Dataset
-                .from_generator(genSamplesForGeN, output_types=tf.float32)
+                .from_generator(getNeighborhoods, output_types=tf.float32)
                 .map(lambda x: [[tf.concat([x[0], padd], axis=0), x[1]]]))
                 .map(lambda x: [[tf.concat([x[0], padd], axis=0), x[1]]]))
 
 
             # Get all labels as data stream.
             # Get all labels as data stream.
@@ -519,33 +508,58 @@ class XConvGeN(GanBaseClass):
         s = [bar(v) for v in x]
         s = [bar(v) for v in x]
         print(f"[{s[0]}] [{s[1]}] [{s[2]}]", end="\r")
         print(f"[{s[0]}] [{s[1]}] [{s[2]}]", end="\r")
         
         
-    def correct_feature_types(self, batch, synth_batch):
+    def correct_feature_types(self):
+        # batch[0] = original points (gen x n_feat)
+        # batch[1] = synthetic points (gen x n_feat)
+        
+        @tf.function
+        def voidFunction(reference, synth):
+            return synth
+    
         if self.fdc is None:
         if self.fdc is None:
-            return synth_batch
+            return voidFunction
         
         
-        def bestMatchOf(referenceValues, value):
-            if referenceValues is not None:
-                best = referenceValues[0]
-                d = abs(best - value)
-                for x in referenceValues:
-                    dx = abs(x - value)
-                    if dx < d:
-                        best = x
-                        d = dx
-                return best
-            else:
-                return value
+        columns = set(self.fdc.nom_list or [])
+        for y in (self.fdc.ord_list or []):
+            columns.add(y)
+        columns = list(columns)
+        
+        if len(columns) == 0:
+            return voidFunction
         
         
-        def correctVector(referenceLists, v):
-            return np.array([bestMatchOf(referenceLists[i], v[i]) for i in range(len(v))])
+        neb = self.neb
+        n_feat = self.n_feat
+        nn = tf.constant([(1.0 if x in columns else 0.0) for x in range(n_feat)])
+        if n_feat is None:
+            print("ERRROR n_feat is None")
+
+        if nn is None:
+            print("ERRROR nn is None")
+
+        @tf.function
+        def bestMatchOf(vi):
+            value = vi[0]
+            c = vi[1][0]
+            r = vi[2]
+            if c != 0.0:
+                d = tf.abs(value - r)
+                return r[tf.math.argmin(d)]
+            else:
+                return value[0]
             
             
-        referenceLists = [None for _ in range(self.n_feat)]
-        for i in (self.fdc.nom_list or []):
-            referenceLists[i] = list(set(list(batch[:, i])))
-
-        for i in (self.fdc.ord_list or []):
-            referenceLists[i] = list(set(list(batch[:, i])))
-
-        # print(batch.shape, synth_batch.shape)
-
-        return Lambda(lambda x: np.array([correctVector(referenceLists, y) for y in x]))(synth_batch)
+        @tf.function
+        def indexted(v, rt):
+            vv = tf.reshape(tf.repeat([v], neb, axis=1), (n_feat, neb))
+            vn = tf.reshape(tf.repeat([nn], neb, axis=1), (n_feat, neb))
+            return tf.stack((vv, vn, rt), axis=1)
+        
+        @tf.function
+        def correctVector(v, rt):
+            return tf.map_fn(lambda x: bestMatchOf(x), indexted(v, rt))
+
+        @tf.function
+        def fn(reference, synth):
+            rt = tf.transpose(reference)
+            return tf.map_fn(lambda x: correctVector(x, rt), synth)
+        
+        return fn

Некоторые файлы не были показаны из-за большого количества измененных файлов