ソースを参照

Memory optimization.

Kristian Schultz 4 年 前
コミット
ed67b93985
1 ファイル変更13 行追加20 行削除
  1. 13 20
      library/generators/convGAN.py

+ 13 - 20
library/generators/convGAN.py

@@ -1,21 +1,9 @@
 import numpy as np
-from numpy.random import seed
-import pandas as pd
 import matplotlib.pyplot as plt
 
 from library.interfaces import GanBaseClass
 from library.dataset import DataSet
 
-from sklearn.decomposition import PCA
-from sklearn.metrics import confusion_matrix
-from sklearn.metrics import f1_score
-from sklearn.metrics import cohen_kappa_score
-from sklearn.metrics import precision_score
-from sklearn.metrics import recall_score
-from sklearn.neighbors import NearestNeighbors
-from sklearn.utils import shuffle
-from imblearn.datasets import fetch_datasets
-
 from keras.layers import Dense, Input, Multiply, Flatten, Conv1D, Reshape
 from keras.models import Model
 from keras import backend as K
@@ -53,7 +41,7 @@ class ConvGAN(GanBaseClass):
         self.neb_epochs = 10
         self.loss_history = None
         self.debug = debug
-        self.dataSet = None
+        self.minSetSize = 0
         self.conv_sample_generator = None
         self.maj_min_discriminator = None
         self.withMajorhoodNbSearch = withMajorhoodNbSearch
@@ -98,13 +86,21 @@ class ConvGAN(GanBaseClass):
         if dataSet.data1.shape[0] <= 0:
             raise AttributeError("Train: Expected data class 1 to contain at least one point.")
 
-        self.dataSet = dataSet
+        # Store size of minority class. This is needed during point generation.
+        self.minSetSize = dataSet.data1.shape[0]
+
+        # Precalculate neighborhoods
         self.nmbMin = NNSearch(self.neb).fit(haystack=dataSet.data1)
         if self.withMajorhoodNbSearch:
             self.nmbMaj = NNSearch(self.neb).fit(haystack=dataSet.data0, needles=dataSet.data1)
         else:
             self.nmbMaj = None
+
+        # Do the training.
         self._rough_learning(dataSet.data1, dataSet.data0)
+        
+        # Neighborhood in majority class is no longer needed. So save memory.
+        self.nmbMaj = None
         self.isTrained = True
 
     def generateDataPoint(self):
@@ -123,14 +119,12 @@ class ConvGAN(GanBaseClass):
         if not self.isTrained:
             raise ValueError("Try to generate data with untrained Re.")
 
-        data_min = self.dataSet.data1
-
         ## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
-        synth_num = (numOfSamples // len(data_min)) + 1
+        synth_num = (numOfSamples // self.minSetSize) + 1
 
         ## generate synth_num synthetic samples from each minority neighbourhood
         synth_set=[]
-        for i in range(len(data_min)):
+        for i in range(self.minSetSize):
             synth_set.extend(self._generate_data_for_min_point(i, synth_num))
 
         ## extract the exact number of synthetic samples needed to exactly balance the two classes
@@ -349,9 +343,8 @@ class ConvGAN(GanBaseClass):
     def _BMB(self, data_maj, min_idxs):
 
         ## Generate a borderline majority batch
-        ## data_min -> minority class data
         ## data_maj -> majority class data
-        ## neb -> oversampling neighbourhood
+        ## min_idxs -> indices of points in minority class
         ## gen -> convex combinations generated from each neighbourhood
 
         if self.nmbMaj is not None: