瀏覽代碼

Memory optimization.

Kristian Schultz 4 年之前
父節點
當前提交
ed67b93985
共有 1 個文件被更改,包括 13 次插入20 次删除
  1. 13 20
      library/generators/convGAN.py

+ 13 - 20
library/generators/convGAN.py

@@ -1,21 +1,9 @@
 import numpy as np
 import numpy as np
-from numpy.random import seed
-import pandas as pd
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 
 
 from library.interfaces import GanBaseClass
 from library.interfaces import GanBaseClass
 from library.dataset import DataSet
 from library.dataset import DataSet
 
 
-from sklearn.decomposition import PCA
-from sklearn.metrics import confusion_matrix
-from sklearn.metrics import f1_score
-from sklearn.metrics import cohen_kappa_score
-from sklearn.metrics import precision_score
-from sklearn.metrics import recall_score
-from sklearn.neighbors import NearestNeighbors
-from sklearn.utils import shuffle
-from imblearn.datasets import fetch_datasets
-
 from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
 from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
 from keras.models import Model
 from keras.models import Model
 from keras import backend as K
 from keras import backend as K
@@ -53,7 +41,7 @@ class ConvGAN(GanBaseClass):
         self.neb_epochs = 10
         self.neb_epochs = 10
         self.loss_history = None
         self.loss_history = None
         self.debug = debug
         self.debug = debug
-        self.dataSet = None
+        self.minSetSize = 0
         self.conv_sample_generator = None
         self.conv_sample_generator = None
         self.maj_min_discriminator = None
         self.maj_min_discriminator = None
         self.withMajorhoodNbSearch = withMajorhoodNbSearch
         self.withMajorhoodNbSearch = withMajorhoodNbSearch
@@ -98,13 +86,21 @@ class ConvGAN(GanBaseClass):
         if dataSet.data1.shape[0] <= 0:
         if dataSet.data1.shape[0] <= 0:
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
 
-        self.dataSet = dataSet
+        # Store size of minority class. This is needed during point generation.
+        self.minSetSize = dataSet.data1.shape[0]
+
+        # Precalculate neighborhoods
         self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
         self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
         if self.withMajorhoodNbSearch:
         if self.withMajorhoodNbSearch:
             self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
             self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
         else:
         else:
             self.nmbMaj = None
             self.nmbMaj = None
+
+        # Do the training.
         self._rough_learning(dataSet.data1, dataSet.data0)
         self._rough_learning(dataSet.data1, dataSet.data0)
+        
+        # Neighborhood in majority class is no longer needed. So save memory.
+        self.nmbMaj = None
         self.isTrained = True
         self.isTrained = True
 
 
     def generateDataPoint(self):
     def generateDataPoint(self):
@@ -123,14 +119,12 @@ class ConvGAN(GanBaseClass):
         if not self.isTrained:
         if not self.isTrained:
             raise ValueError("Try to generate data with untrained Re.")
             raise ValueError("Try to generate data with untrained Re.")
 
 
-        data_min = self.dataSet.data1
-
         ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
         ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
-        synth_num = (numOfSamples // len(data_min)) + 1
+        synth_num = (numOfSamples // self.minSetSize) + 1
 
 
         ## generate synth_num synthetic samples from each minority neighbourhood
         ## generate synth_num synthetic samples from each minority neighbourhood
         synth_set=[]
         synth_set=[]
-        for i in range(len(data_min)):
+        for i in range(self.minSetSize):
             synth_set.extend(self._generate_data_for_min_point(i, synth_num))
             synth_set.extend(self._generate_data_for_min_point(i, synth_num))
 
 
         ## extract the exact number of synthetic samples needed to exactly balance the two classes
         ## extract the exact number of synthetic samples needed to exactly balance the two classes
@@ -349,9 +343,8 @@ class ConvGAN(GanBaseClass):
     def _BMB(self, data_maj, min_idxs):
     def _BMB(self, data_maj, min_idxs):
 
 
         ## Generate a borderline majority batch
         ## Generate a borderline majority batch
-        ## data_min -> minority class data
         ## data_maj -> majority class data
         ## data_maj -> majority class data
-        ## neb -> oversampling neighbourhood
+        ## min_idxs -> indices of points in minority class
         ## gen -> convex combinations generated from each neighbourhood
         ## gen -> convex combinations generated from each neighbourhood
 
 
         if self.nmbMaj is not None:
         if self.nmbMaj is not None: