| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- """
- Class for testing the performance of Generative Adversarial Networks
- in generating synthetic samples for datasets with a minority class.
- """
- import numpy as np
- from sklearn.decomposition import PCA
- from sklearn.preprocessing import StandardScaler
- import matplotlib.pyplot as plt
- from library.dataset import DataSet, TrainTestData
- from library.testers import lr,knn, gb, TestResult
- class Exercise:
- """
- Exercising a test for a minority class extension class.
- """
- def __init__(self, testFunctions=None, shuffleFunction=None, numOfSlices=5, numOfShuffles=5):
- """
- Creates a instance of this class.
- *testFunctions* is a dictionary /(String : Function)/ of functions for testing
- a generated dataset. The functions have the signature:
- /(TrainTestData, TrainTestData) -> TestResult/
- *shuffleFunction* is either None or a function /numpy.array -> numpy.array/
- that shuffles a given array.
- *numOfSlices* is an integer > 0. The dataset given for the run function
- will be divided in such many slices.
- *numOfShuffles* is an integer > 0. It gives the number of exercised tests.
- The GAN will be trained and tested (numOfShuffles * numOfSlices) times.
- """
- self.numOfSlices = int(numOfSlices)
- self.numOfShuffles = int(numOfShuffles)
- self.shuffleFunction = shuffleFunction
- self.debug = print
- self.testFunctions = testFunctions
- if self.testFunctions is None:
- self.testFunctions = {
- "LR": lr,
- "GB": gb,
- "KNN": knn
- }
- self.results = { name: [] for name in self.testFunctions }
- # Check if the given values are in valid range.
- if self.numOfSlices < 0:
- raise AttributeError(f"Expected numOfSlices to be > 0 but got {self.numOfSlices}")
- if self.numOfShuffles < 0:
- raise AttributeError(f"Expected numOfShuffles to be > 0 but got {self.numOfShuffles}")
- def run(self, gan, dataset):
- """
- Exercise all tests for a given GAN.
- *gan* is a implemention of library.interfaces.GanBaseClass.
- It defines the GAN to test.
- *dataset* is a library.dataset.DataSet that contains the majority class
- (dataset.data0) and the minority class (dataset.data1) of data
- for training and testing.
- """
- # Check if the given values are in valid range.
- if len(dataset.data1) > len(dataset.data0):
- raise AttributeError(
- "Expected class 1 to be the minority class but class 1 is bigger than class 0.")
- # Reset results array.
- self.results = { name: [] for name in self.testFunctions }
- # If a shuffle function is given then shuffle the data before the
- # exercise starts.
- if self.shuffleFunction is not None:
- self.debug("-> Shuffling data")
- for _n in range(3):
- dataset.shuffleWith(self.shuffleFunction)
- # Repeat numOfShuffles times
- self.debug("### Start exercise for synthetic point generator")
- for shuffleStep in range(self.numOfShuffles):
- stepTitle = f"Step {shuffleStep + 1}/{self.numOfShuffles}"
- self.debug(f"\n====== {stepTitle} =======")
- # If a shuffle function is given then shuffle the data before the next
- # exercise starts.
- if self.shuffleFunction is not None:
- self.debug("-> Shuffling data")
- dataset.shuffleWith(self.shuffleFunction)
- # Split the (shuffled) data into numOfSlices slices.
- # dataSlices is a list of TrainTestData instances.
- #
- # If numOfSlices=3 then the data will be splited in D1, D2, D3.
- # dataSlices will contain:
- # [(train=D2+D3, test=D1), (train=D1+D3, test=D2), (train=D1+D2, test=D3)]
- self.debug("-> Spliting data to slices")
- dataSlices = TrainTestData.splitDataToSlices(dataset, self.numOfSlices)
- # Do a exercise for every slice.
- for (sliceNr, sliceData) in enumerate(dataSlices):
- sliceTitle = f"Slice {sliceNr + 1}/{self.numOfSlices}"
- self.debug(f"\n------ {stepTitle}: {sliceTitle} -------")
- self._exerciseWithDataSlice(gan, sliceData)
- self.debug("### Exercise is done.")
- for (n, name) in enumerate(self.results):
- stats = None
- for (m, result) in enumerate(self.results[name]):
- stats = result.addMinMaxAvg(stats)
-
- (mi, mx, avg) = TestResult.finishMinMaxAvg(stats)
- self.debug("")
- self.debug(f"-----[ {avg.title} ]-----")
- self.debug("maximum:")
- self.debug(str(mx))
- self.debug("")
- self.debug("average:")
- self.debug(str(avg))
- self.debug("")
- self.debug("minimum:")
- self.debug(str(mi))
- def _exerciseWithDataSlice(self, gan, dataSlice):
- """
- Runs one test for the given gan and dataSlice.
- *gan* is a implemention of library.interfaces.GanBaseClass.
- It defines the GAN to test.
- *dataSlice* is a library.dataset.TrainTestData instance that contains
- one data slice with training and testing data.
- """
- # Start over with a new GAN instance.
- self.debug("-> Reset the GAN")
- gan.reset()
- # Train the gan so it can produce synthetic samples.
- self.debug("-> Train generator for synthetic samples")
- gan.train(dataSlice.train)
- # Count how many syhthetic samples are needed.
- numOfNeededSamples = dataSlice.train.size0 - dataSlice.train.size1
- # Add synthetic samples (generated by the GAN) to the minority class.
- if numOfNeededSamples > 0:
- self.debug(f"-> create {numOfNeededSamples} synthetic samples")
- newSamples = gan.generateData(numOfNeededSamples)
- # Print out an overview of the new dataset.
- plotCloud(dataSlice.train.data0, dataSlice.train.data1, newSamples)
- dataSlice.train = DataSet(
- data0=dataSlice.train.data0,
- data1=np.concatenate((dataSlice.train.data1, newSamples))
- )
- # Test this dataset with every given test-function.
- # The results are printed out and stored to the results dictionary.
- for testerName in self.testFunctions:
- self.debug(f"-> test with '{testerName}'")
- testResult = (self.testFunctions[testerName])(dataSlice)
- self.debug(str(testResult))
- self.results[testerName].append(testResult)
- def saveResultsTo(self, fileName):
- avgResults = {}
- with open(fileName, "w") as f:
- for (n, name) in enumerate(self.results):
- if n > 0:
- f.write("---\n")
-
- f.write(name + "\n")
- isFirst = True
- stats = None
- for (m, result) in enumerate(self.results[name]):
- if isFirst:
- isFirst = False
- f.write("Nr.;" + result.csvHeading() + "\n")
- stats = result.addMinMaxAvg(stats)
- f.write(f"{m + 1};" + result.toCSV() + "\n")
-
- (mi, mx, avg) = TestResult.finishMinMaxAvg(stats)
- f.write(f"max;" + mx.toCSV() + "\n")
- f.write(f"avg;" + avg.toCSV() + "\n")
- f.write(f"min;" + mi.toCSV() + "\n")
- avgResults[name] = avg
- return avgResults
- def plotCloud(data0, data1, dataNew=None, outputFile=None, title=""):
- """
- Does a PCA analysis of the given data and plot the both important axis.
- """
- # Normalizes the data.
- if dataNew is None:
- data_t = StandardScaler().fit_transform(np.concatenate([data0, data1]))
- else:
- data_t = StandardScaler().fit_transform(np.concatenate([data0, data1, dataNew]))
- # Run the PCA analysis.
- pca = PCA(n_components=2)
- pc = pca.fit_transform(data_t)
- fig, ax = plt.subplots(sharex=True, sharey=True)
- fig.set_dpi(600)
- fig.set_figwidth(10)
- fig.set_figheight(10)
- fig.set_facecolor("white")
- ax.set_title(title)
- def doSubplot(m, n, c):
- pca0 = [x[0] for x in pc[m : m + n]]
- pca1 = [x[1] for x in pc[m : m + n]]
- s = ax.scatter(pca0, pca1, c=c)
- m = 0
- n = len(data0)
- doSubplot(m, n, "gray")
-
- m += n
- n = len(data1)
- doSubplot(m, n, "red")
- if dataNew is not None:
- m += n
- n = len(dataNew)
- doSubplot(m, n, "blue")
- ax.legend(title="", loc='upper left', labels=['majority', 'minority', 'synthetic minority'])
- ax.set_xlabel("PCA0")
- ax.set_ylabel("PCA1")
- plt.show()
- if outputFile is not None:
- fig.savefig(outputFile)
|