|
|
@@ -45,7 +45,7 @@ class ConvGAN(GanBaseClass):
|
|
|
This is a toy example of a GAN.
|
|
|
It repeats the first point of the training-data-set.
|
|
|
"""
|
|
|
- def __init__(self, n_feat, neb=5, gen=5, neb_epochs=10, debug=True):
|
|
|
+ def __init__(self, n_feat, neb=5, gen=5, neb_epochs=10, withMajorhoodNbSearch=False, debug=False):
|
|
|
self.isTrained = False
|
|
|
self.n_feat = n_feat
|
|
|
self.neb = neb
|
|
|
@@ -56,6 +56,7 @@ class ConvGAN(GanBaseClass):
|
|
|
self.dataSet = None
|
|
|
self.conv_sample_generator = None
|
|
|
self.maj_min_discriminator = None
|
|
|
+ self.withMajorhoodNbSearch = withMajorhoodNbSearch
|
|
|
self.cg = None
|
|
|
|
|
|
if neb > gen:
|
|
|
@@ -98,7 +99,11 @@ class ConvGAN(GanBaseClass):
|
|
|
raise AttributeError("Train: Expected data class 1 to contain at least one point.")
|
|
|
|
|
|
self.dataSet = dataSet
|
|
|
- self.nmb = self._NMB_prepare(dataSet.data1)
|
|
|
+ self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
|
|
|
+ if self.withMajorhoodNbSearch:
|
|
|
+ self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
|
|
|
+ else:
|
|
|
+ self.nmbMaj = None
|
|
|
self._rough_learning(dataSet.data1, dataSet.data0)
|
|
|
self.isTrained = True
|
|
|
|
|
|
@@ -265,7 +270,7 @@ class ConvGAN(GanBaseClass):
|
|
|
runs = int(synth_num / self.neb) + 1
|
|
|
synth_set = []
|
|
|
for _run in range(runs):
|
|
|
- batch = self._NMB_guided(index)
|
|
|
+ batch = self.nmbMin.getNbhPointsOfItem(index)
|
|
|
synth_batch = self.conv_sample_generator.predict(batch)
|
|
|
synth_set.extend(synth_batch)
|
|
|
|
|
|
@@ -286,10 +291,11 @@ class ConvGAN(GanBaseClass):
|
|
|
|
|
|
for step in range(self.neb_epochs * len(data_min)):
|
|
|
## generate minority neighbourhood batch for every minority class sampls by index
|
|
|
- min_batch = self._NMB_guided(min_idx)
|
|
|
+ min_batch_indices = self.nmbMin.neighbourhoodOfItem(min_idx)
|
|
|
+ min_batch = self.nmbMin.getPointsFromIndices(min_batch_indices)
|
|
|
min_idx = min_idx + 1
|
|
|
## generate random proximal majority batch
|
|
|
- maj_batch = self._BMB(data_min, data_maj)
|
|
|
+ maj_batch = self._BMB(data_maj, min_batch_indices)
|
|
|
|
|
|
## generate synthetic samples from convex space
|
|
|
## of minority neighbourhood batch using generator
|
|
|
@@ -340,7 +346,7 @@ class ConvGAN(GanBaseClass):
|
|
|
|
|
|
|
|
|
## convGAN
|
|
|
- def _BMB(self, data_min, data_maj):
|
|
|
+ def _BMB(self, data_maj, min_idxs):
|
|
|
|
|
|
## Generate a borderline majority batch
|
|
|
## data_min -> minority class data
|
|
|
@@ -348,29 +354,9 @@ class ConvGAN(GanBaseClass):
|
|
|
## neb -> oversampling neighbourhood
|
|
|
## gen -> convex combinations generated from each neighbourhood
|
|
|
|
|
|
- return tf.convert_to_tensor(
|
|
|
- data_maj[np.random.randint(len(data_maj), size=self.gen)]
|
|
|
- )
|
|
|
-
|
|
|
- def _NMB_prepare(self, data_min):
|
|
|
- neigh = NNSearch(self.neb)
|
|
|
- neigh.fit(data_min)
|
|
|
- return (data_min, neigh)
|
|
|
-
|
|
|
-
|
|
|
- def _NMB_guided(self, index):
|
|
|
-
|
|
|
- ## generate a minority neighbourhood batch for a particular minority sample
|
|
|
- ## we need this for minority data generation
|
|
|
- ## we will generate synthetic samples for each training data neighbourhood
|
|
|
- ## index -> index of the minority sample in a training data whose neighbourhood we want to obtain
|
|
|
- ## data_min -> minority class data
|
|
|
- ## neb -> oversampling neighbourhood
|
|
|
- (data_min, neigh) = self.nmb
|
|
|
-
|
|
|
- nmbi = np.array([neigh.neighbourhoodOfItem(index)])
|
|
|
- nmbi = shuffle(nmbi)
|
|
|
- nmb = data_min[nmbi]
|
|
|
- nmb = tf.convert_to_tensor(nmb[0])
|
|
|
- return nmb
|
|
|
-
|
|
|
+ if self.nmbMaj is not None:
|
|
|
+ return self.nmbMaj.neighbourhoodOfItemList(min_idxs, maxCount=self.gen)
|
|
|
+ else:
|
|
|
+ return tf.convert_to_tensor(
|
|
|
+ data_maj[np.random.randint(len(data_maj), size=self.gen)]
|
|
|
+ )
|