Ver código fonte

Added mapping to nomial/ordinal data.

Kristian Schultz 3 anos atrás
pai
commit
cd676be0d9
1 arquivos alterados com 30 adições e 1 exclusões
  1. 30 1
      library/generators/NextConvGeN.py

+ 30 - 1
library/generators/NextConvGeN.py

@@ -310,6 +310,7 @@ class NextConvGeN(GanBaseClass):
         for _run in range(runs):
             batch = self.nmbMin.getNbhPointsOfItem(index)
             synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
+            synth_batch = self.correct_feature_types(batch, synth_batch)
             synth_set.extend(synth_batch)
 
         return synth_set[:synth_num]
@@ -447,4 +448,32 @@ class NextConvGeN(GanBaseClass):
         
         s = [bar(v) for v in x]
         print(f"[{s[0]}] [{s[1]}] [{s[2]}]", end="\r")
-        
+        
+    def correct_feature_types(self, batch, synth_batch):
+
+        if self.fdc is None:
+            return synth_batch
+
+        def bestMatchOf(referenceValues, value):
+            best = referenceValues[0]
+            d = abs(best - value)
+            for x in referenceValues:
+                dx = abs(x - value)
+                if dx < d:
+                    best = x
+                    d = dx
+            return best
+
+        synth_batch = list(synth_batch)
+        for i in (self.fdc.nom_list or []):
+            referenceValues = list(set(list(batch[:, i].numpy())))
+            for x in synth_batch:
+                y = x[i]
+                x[i] = bestMatchOf(referenceValues, y)
+
+        for i in (self.fdc.ord_list or []):
+            referenceValues = list(set(list(batch[:, i].numpy())))
+            for x in synth_batch:
+                x[i] = bestMatchOf(referenceValues, x[i])
+
+        return np.array(synth_batch)