exercise.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """
  2. Class for testing the performance of Generative Adversarial Networks
  3. in generating synthetic samples for datasets with a minority class.
  4. """
  5. import numpy as np
  6. import pandas as pd
  7. import seaborn as sns
  8. from sklearn.decomposition import PCA
  9. from sklearn.preprocessing import StandardScaler
  10. import matplotlib.pyplot as plt
  11. from library.dataset import DataSet, TrainTestData
  12. from library.testers import lr, svm, knn
  13. class Exercise:
  14. """
  15. Exercising a test for a minority class extension class.
  16. """
  17. def __init__(self, testFunctions=None, shuffleFunction=None, numOfSlices=5, numOfShuffles=5):
  18. """
  19. Creates a instance of this class.
  20. *testFunctions* is a dictionary /(String : Function)/ of functions for testing
  21. a generated dataset. The functions have the signature:
  22. /(TrainTestData, TrainTestData) -> TestResult/
  23. *shuffleFunction* is either None or a function /numpy.array -> numpy.array/
  24. that shuffles a given array.
  25. *numOfSlices* is an integer > 0. The dataset given for the run function
  26. will be divided in such many slices.
  27. *numOfShuffles* is an integer > 0. It gives the number of exercised tests.
  28. The GAN will be trained and tested (numOfShuffles * numOfSlices) times.
  29. """
  30. self.numOfSlices = int(numOfSlices)
  31. self.numOfShuffles = int(numOfShuffles)
  32. self.shuffleFunction = shuffleFunction
  33. self.debug = print
  34. self.testFunctions = testFunctions
  35. if self.testFunctions is None:
  36. self.testFunctions = {
  37. "LR": lr,
  38. "SVM": svm,
  39. "KNN": knn
  40. }
  41. self.results = { name: [] for name in self.testFunctions }
  42. # Check if the given values are in valid range.
  43. if self.numOfSlices < 0:
  44. raise AttributeError(f"Expected numOfSlices to be > 0 but got {self.numOfSlices}")
  45. if self.numOfShuffles < 0:
  46. raise AttributeError(f"Expected numOfShuffles to be > 0 but got {self.numOfShuffles}")
  47. def run(self, gan, dataset):
  48. """
  49. Exercise all tests for a given GAN.
  50. *gan* is a implemention of library.interfaces.GanBaseClass.
  51. It defines the GAN to test.
  52. *dataset* is a library.dataset.DataSet that contains the majority class
  53. (dataset.data0) and the minority class (dataset.data1) of data
  54. for training and testing.
  55. """
  56. # Check if the given values are in valid range.
  57. if len(dataset.data1) > len(dataset.data0):
  58. raise AttributeError(
  59. "Expected class 1 to be the minority class but class 1 is bigger than class 0.")
  60. # Reset results array.
  61. self.results = { name: [] for name in self.testFunctions }
  62. # Repeat numOfShuffles times
  63. self.debug("### Start exercise for synthetic point generator")
  64. for shuffleStep in range(self.numOfShuffles):
  65. stepTitle = f"Step {shuffleStep + 1}/{self.numOfShuffles}"
  66. self.debug(f"\n====== {stepTitle} =======")
  67. # If a shuffle fuction is given then shuffle the data before the next
  68. # exercise starts.
  69. if self.shuffleFunction is not None:
  70. self.debug("-> Shuffling data")
  71. dataset.shuffleWith(self.shuffleFunction)
  72. # Split the (shuffled) data into numOfSlices slices.
  73. # dataSlices is a list of TrainTestData instances.
  74. #
  75. # If numOfSlices=3 then the data will be splited in D1, D2, D3.
  76. # dataSlices will contain:
  77. # [(train=D2+D3, test=D1), (train=D1+D3, test=D2), (train=D1+D2, test=D3)]
  78. self.debug("-> Spliting data to slices")
  79. dataSlices = TrainTestData.splitDataToSlices(dataset, self.numOfSlices)
  80. # Do a exercise for every slice.
  81. for (sliceNr, sliceData) in enumerate(dataSlices):
  82. sliceTitle = f"Slice {sliceNr + 1}/{self.numOfSlices}"
  83. self.debug(f"\n------ {stepTitle}: {sliceTitle} -------")
  84. self._exerciseWithDataSlice(gan, sliceData)
  85. self.debug("### Exercise is done.")
  86. def _exerciseWithDataSlice(self, gan, dataSlice):
  87. """
  88. Runs one test for the given gan and dataSlice.
  89. *gan* is a implemention of library.interfaces.GanBaseClass.
  90. It defines the GAN to test.
  91. *dataSlice* is a library.dataset.TrainTestData instance that contains
  92. one data slice with training and testing data.
  93. """
  94. # Train the gan so it can produce synthetic samples.
  95. self.debug("-> Train generator for synthetic samples")
  96. gan.train(dataSlice.train)
  97. # Count how many syhthetic samples are needed.
  98. numOfNeededSamples = dataSlice.train.size0 - dataSlice.train.size1
  99. # Add synthetic samples (generated by the GAN) to the minority class.
  100. if numOfNeededSamples > 0:
  101. self.debug(f"-> create {numOfNeededSamples} synthetic samples")
  102. newSamples = gan.generateData(numOfNeededSamples)
  103. dataSlice.train = DataSet(
  104. data0=dataSlice.train.data0,
  105. data1=np.concatenate((dataSlice.train.data1, newSamples))
  106. )
  107. # Print out an overview of the new dataset.
  108. plotCloud(dataSlice.train)
  109. # Test this dataset with every given test-function.
  110. # The results are printed out and stored to the results dictionary.
  111. for testerName in self.testFunctions:
  112. self.debug(f"-> test with '{testerName}'")
  113. testResult = (self.testFunctions[testerName])(dataSlice)
  114. self.debug(str(testResult))
  115. self.results[testerName].append(testResult)
  116. def saveResultsTo(self, fileName):
  117. with open(fileName, "w") as f:
  118. for name in self.results:
  119. f.write(name + "\n")
  120. isFirst = True
  121. for result in self.results[name]:
  122. if isFirst:
  123. isFirst = False
  124. f.write(result.csvHeading() + "\n")
  125. f.write(result.toCSV() + "\n")
  126. def plotCloud(dataset):
  127. """
  128. Does a PCA analysis of the given data and plot the both important axis.
  129. """
  130. # Normalizes the data.
  131. data_t = StandardScaler().fit_transform(dataset.data)
  132. # Run the PCA analysis.
  133. pca = PCA(n_components=2)
  134. pc = pca.fit_transform(data_t)
  135. # Create a DataFrame for plotting.
  136. result = pd.DataFrame(data=pc, columns=['PCA0', 'PCA1'])
  137. result['Cluster'] = dataset.labels
  138. # Plot the analysis results.
  139. sns.set( font_scale=1.2)
  140. sns.lmplot( x="PCA0", y="PCA1",
  141. data=result,
  142. fit_reg=False,
  143. hue='Cluster', # color by cluster
  144. legend=False,
  145. scatter_kws={"s": 3}, palette="Set1") # specify the point size
  146. plt.legend(title='', loc='upper left', labels=['0', '1'])
  147. plt.show()