|
@@ -35,7 +35,7 @@ class ConvGeN(GanBaseClass):
|
|
|
This is a toy example of a GAN.
|
|
This is a toy example of a GAN.
|
|
|
It repeats the first point of the training-data-set.
|
|
It repeats the first point of the training-data-set.
|
|
|
"""
|
|
"""
|
|
|
- def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, withMajorhoodNbSearch=False, debug=False):
|
|
|
|
|
|
|
+ def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, maj_proximal=False, debug=False):
|
|
|
self.isTrained = False
|
|
self.isTrained = False
|
|
|
self.n_feat = n_feat
|
|
self.n_feat = n_feat
|
|
|
self.neb = neb
|
|
self.neb = neb
|
|
@@ -48,7 +48,7 @@ class ConvGeN(GanBaseClass):
|
|
|
self.minSetSize = 0
|
|
self.minSetSize = 0
|
|
|
self.conv_sample_generator = None
|
|
self.conv_sample_generator = None
|
|
|
self.maj_min_discriminator = None
|
|
self.maj_min_discriminator = None
|
|
|
- self.withMajorhoodNbSearch = withMajorhoodNbSearch
|
|
|
|
|
|
|
+ self.maj_proximal = maj_proximal
|
|
|
self.cg = None
|
|
self.cg = None
|
|
|
self.canPredict = True
|
|
self.canPredict = True
|
|
|
|
|
|
|
@@ -110,7 +110,7 @@ class ConvGeN(GanBaseClass):
|
|
|
|
|
|
|
|
# Precalculate neighborhoods
|
|
# Precalculate neighborhoods
|
|
|
self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
|
|
self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
|
|
|
- if self.withMajorhoodNbSearch:
|
|
|
|
|
|
|
+ if self.maj_proximal:
|
|
|
self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
|
|
self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
|
|
|
else:
|
|
else:
|
|
|
self.nmbMaj = None
|
|
self.nmbMaj = None
|