|
|
@@ -310,6 +310,7 @@ class NextConvGeN(GanBaseClass):
|
|
|
for _run in range(runs):
|
|
|
batch = self.nmbMin.getNbhPointsOfItem(index)
|
|
|
synth_batch = self.conv_sample_generator.predict(batch, batch_size=self.neb)
|
|
|
+ synth_batch = self.correct_feature_types(batch, synth_batch)
|
|
|
synth_set.extend(synth_batch)
|
|
|
|
|
|
return synth_set[:synth_num]
|
|
|
@@ -447,4 +448,32 @@ class NextConvGeN(GanBaseClass):
|
|
|
|
|
|
s = [bar(v) for v in x]
|
|
|
print(f"[{s[0]}] [{s[1]}] [{s[2]}]", end="\r")
|
|
|
-
|
|
|
+
|
|
|
+ def correct_feature_types(self, batch, synth_batch):
|
|
|
+
|
|
|
+ if self.fdc is None:
|
|
|
+ return synth_batch
|
|
|
+
|
|
|
+ def bestMatchOf(referenceValues, value):
|
|
|
+ best = referenceValues[0]
|
|
|
+ d = abs(best - value)
|
|
|
+ for x in referenceValues:
|
|
|
+ dx = abs(x - value)
|
|
|
+ if dx < d:
|
|
|
+ best = x
|
|
|
+ d = dx
|
|
|
+ return best
|
|
|
+
|
|
|
+ synth_batch = list(synth_batch)
|
|
|
+ for i in (self.fdc.nom_list or []):
|
|
|
+ referenceValues = list(set(list(batch[:, i].numpy())))
|
|
|
+ for x in synth_batch:
|
|
|
+ y = x[i]
|
|
|
+ x[i] = bestMatchOf(referenceValues, y)
|
|
|
+
|
|
|
+ for i in (self.fdc.ord_list or []):
|
|
|
+ referenceValues = list(set(list(batch[:, i].numpy())))
|
|
|
+ for x in synth_batch:
|
|
|
+ x[i] = bestMatchOf(referenceValues, x[i])
|
|
|
+
|
|
|
+ return np.array(synth_batch)
|