Prechádzať zdrojové kódy

Renamed parameter withMajorhoodNbSearch to maj_proximal

Kristian Schultz 3 rokov pred
rodič
commit
9998296e23

+ 1 - 1
convGeN-predict.ipynb

@@ -55,7 +55,7 @@
     "    print(f\"======[ {descTrainCount} ]======\")\n",
     "    t = timing(f\"train with {descTrainCount} extra rounds\")\n",
     "    t.start()\n",
-    "    g = ConvGeN(data.data1.shape[1], neb_epochs=10, withMajorhoodNbSearch=True)\n",
+    "    g = ConvGeN(data.data1.shape[1], neb_epochs=10, maj_proximal=True)\n",
     "    g.reset(data)\n",
     "    g.train(data, descTrainCount)\n",
     "    t.stop()\n",

+ 2 - 2
library/analysis.py

@@ -210,6 +210,6 @@ generators = { "Repeater":                lambda _data: Repeater()
              , "CTAB-GAN":                lambda _data: CtabGan()
              , "ConvGeN-majority-5":      lambda data: ConvGeN(data.data0.shape[1], neb=5, gen=5)
              , "ConvGeN-majority-full":   lambda data: ConvGeN(data.data0.shape[1], neb=None)
-             , "ConvGeN-proximity-5":     lambda data: ConvGeN(data.data0.shape[1], neb=5, gen=5, withMajorhoodNbSearch=True)
-             , "ConvGeN-proximity-full":  lambda data: ConvGeN(data.data0.shape[1], neb=None, withMajorhoodNbSearch=True)
+             , "ConvGeN-proximity-5":     lambda data: ConvGeN(data.data0.shape[1], neb=5, gen=5, maj_proximal=True)
+             , "ConvGeN-proximity-full":  lambda data: ConvGeN(data.data0.shape[1], neb=None, maj_proximal=True)
              }

+ 3 - 3
library/generators/ConvGeN.py

@@ -35,7 +35,7 @@ class ConvGeN(GanBaseClass):
     This is a toy example of a GAN.
     It repeats the first point of the training-data-set.
     """
-    def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, withMajorhoodNbSearch=False, debug=False):
+    def __init__(self, n_feat, neb=5, gen=None, neb_epochs=10, maj_proximal=False, debug=False):
         self.isTrained = False
         self.n_feat = n_feat
         self.neb = neb
@@ -48,7 +48,7 @@ class ConvGeN(GanBaseClass):
         self.minSetSize = 0
         self.conv_sample_generator = None
         self.maj_min_discriminator = None
-        self.withMajorhoodNbSearch = withMajorhoodNbSearch
+        self.maj_proximal = maj_proximal
         self.cg = None
         self.canPredict = True
 
@@ -110,7 +110,7 @@ class ConvGeN(GanBaseClass):
 
         # Precalculate neighborhoods
         self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
-        if self.withMajorhoodNbSearch:
+        if self.maj_proximal:
             self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
         else:
             self.nmbMaj = None