exercise.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. """
  2. Class for testing the performance of Generative Adversarial Networks
  3. in generating synthetic samples for datasets with a minority class.
  4. """
  5. import os
  6. import os.path
  7. import numpy as np
  8. from sklearn.decomposition import PCA
  9. from sklearn.preprocessing import StandardScaler
  10. from sklearn.utils import shuffle
  11. import matplotlib.pyplot as plt
  12. from library.dataset import DataSet, TrainTestData
  13. from library.testers import lr, knn, gb, rf, TestResult, runTester
  14. from library.cache import dataCache
  15. import json
  16. class Exercise:
  17. """
  18. Exercising a test for a minority class extension class.
  19. """
  20. def __init__(self, testFunctions=None, shuffleFunction=None, numOfSlices=5, numOfShuffles=5):
  21. """
  22. Creates a instance of this class.
  23. *testFunctions* is a dictionary /(String : Function)/ of functions for testing
  24. a generated dataset. The functions have the signature:
  25. /(TrainTestData, TrainTestData) -> TestResult/
  26. *shuffleFunction* is either None or a function /numpy.array -> numpy.array/
  27. that shuffles a given array.
  28. *numOfSlices* is an integer > 0. The dataset given for the run function
  29. will be divided in such many slices.
  30. *numOfShuffles* is an integer > 0. It gives the number of exercised tests.
  31. The GAN will be trained and tested (numOfShuffles * numOfSlices) times.
  32. """
  33. self.numOfSlices = int(numOfSlices)
  34. self.numOfShuffles = int(numOfShuffles)
  35. self.shuffleFunction = shuffleFunction
  36. self.debug = print
  37. self.testFunctions = testFunctions
  38. if self.testFunctions is None:
  39. self.testFunctions = {
  40. "LR": lr,
  41. "RF": rf,
  42. "GB": gb,
  43. "KNN": knn
  44. }
  45. self.results = { name: [] for name in self.testFunctions }
  46. # Check if the given values are in valid range.
  47. if self.numOfSlices < 0:
  48. raise AttributeError(f"Expected numOfSlices to be > 0 but got {self.numOfSlices}")
  49. if self.numOfShuffles < 0:
  50. raise AttributeError(f"Expected numOfShuffles to be > 0 but got {self.numOfShuffles}")
  51. def run(self, gan, dataset, resultsFileName=None):
  52. """
  53. Exercise all tests for a given GAN.
  54. *gan* is a implemention of library.interfaces.GanBaseClass.
  55. It defines the GAN to test.
  56. *dataset* is a library.dataset.DataSet that contains the majority class
  57. (dataset.data0) and the minority class (dataset.data1) of data
  58. for training and testing.
  59. """
  60. # Check if the given values are in valid range.
  61. if len(dataset.data1) > len(dataset.data0):
  62. raise AttributeError(
  63. "Expected class 1 to be the minority class but class 1 is bigger than class 0.")
  64. # Prepare Folder for Images
  65. if resultsFileName is not None:
  66. try:
  67. os.mkdir(resultsFileName)
  68. except FileExistsError as e:
  69. pass
  70. # Reset results array.
  71. self.results = { name: [] for name in self.testFunctions }
  72. if gan.canPredict and "GAN" not in self.testFunctions.keys():
  73. self.results["GAN"] = []
  74. # If a shuffle function is given then shuffle the data before the
  75. # exercise starts.
  76. if self.shuffleFunction is not None:
  77. self.debug("-> Shuffling data")
  78. for _n in range(3):
  79. dataset.shuffleWith(self.shuffleFunction)
  80. # Repeat numOfShuffles times
  81. self.debug("### Start exercise for synthetic point generator")
  82. for shuffleStep in range(self.numOfShuffles):
  83. stepTitle = f"Step {shuffleStep + 1}/{self.numOfShuffles}"
  84. self.debug(f"\n====== {stepTitle} =======")
  85. # If a shuffle function is given then shuffle the data before the next
  86. # exercise starts.
  87. if self.shuffleFunction is not None:
  88. self.debug("-> Shuffling data")
  89. dataset.shuffleWith(self.shuffleFunction)
  90. # Split the (shuffled) data into numOfSlices slices.
  91. # dataSlices is a list of TrainTestData instances.
  92. #
  93. # If numOfSlices=3 then the data will be splited in D1, D2, D3.
  94. # dataSlices will contain:
  95. # [(train=D2+D3, test=D1), (train=D1+D3, test=D2), (train=D1+D2, test=D3)]
  96. self.debug("-> Spliting data to slices")
  97. dataSlices = TrainTestData.splitDataToSlices(dataset, self.numOfSlices)
  98. # Do a exercise for every slice.
  99. for (sliceNr, sliceData) in enumerate(dataSlices):
  100. sliceTitle = f"Slice {sliceNr + 1}/{self.numOfSlices}"
  101. self.debug(f"\n------ {stepTitle}: {sliceTitle} -------")
  102. imageFileName = None
  103. jsonFileName = None
  104. if resultsFileName is not None:
  105. imageFileName = f"{resultsFileName}/Step{shuffleStep + 1}_Slice{sliceNr + 1}"
  106. self._exerciseWithDataSlice(gan, sliceData, imageFileName)
  107. self.debug("### Exercise is done.")
  108. for (n, name) in enumerate(self.results):
  109. stats = None
  110. for (m, result) in enumerate(self.results[name]):
  111. stats = result.addMinMaxAvg(stats)
  112. (mi, mx, avg) = TestResult.finishMinMaxAvg(stats)
  113. self.debug("")
  114. self.debug(f"-----[ {avg.title} ]-----")
  115. self.debug("maximum:")
  116. self.debug(str(mx))
  117. self.debug("")
  118. self.debug("average:")
  119. self.debug(str(avg))
  120. self.debug("")
  121. self.debug("minimum:")
  122. self.debug(str(mi))
  123. if resultsFileName is not None:
  124. return self.saveResultsTo(resultsFileName + ".csv")
  125. return {}
  126. def _exerciseWithDataSlice(self, gan, dataSlice, imageFileName=None):
  127. """
  128. Runs one test for the given gan and dataSlice.
  129. *gan* is a implemention of library.interfaces.GanBaseClass.
  130. It defines the GAN to test.
  131. *dataSlice* is a library.dataset.TrainTestData instance that contains
  132. one data slice with training and testing data.
  133. """
  134. jsonFileName = f"{imageFileName}.json"
  135. # Count how many syhthetic samples are needed.
  136. numOfNeededSamples = dataSlice.train.size0 - dataSlice.train.size1
  137. # Start over with a new GAN instance.
  138. self.debug("-> Reset the GAN")
  139. gan.reset(dataSlice.train)
  140. # Add synthetic samples (generated by the GAN) to the minority class.
  141. if numOfNeededSamples > 0:
  142. def synth(params):
  143. me = params["self"]
  144. train = params["train"]
  145. # Train the gan so it can produce synthetic samples.
  146. me.debug("-> Train generator for synthetic samples")
  147. gan.train(train)
  148. me.debug(f"-> create {numOfNeededSamples} synthetic samples")
  149. newSamples = gan.generateData(numOfNeededSamples)
  150. # Print out an overview of the new dataset.
  151. plotCloud(train.data0, train.data1, newSamples, outputFile=imageFileName, doShow=False)
  152. return {
  153. "majority": train.data0,
  154. "minority": train.data1,
  155. "synthetic": newSamples
  156. }
  157. j = dataCache(jsonFileName, synth, {"self": self, "train":dataSlice.train})
  158. dataSlice.train = DataSet(
  159. data0=j["majority"],
  160. data1=np.concatenate((j["minority"], j["synthetic"]))
  161. )
  162. j = None
  163. if imageFileName is not None:
  164. fig_pr, ax_pr = plt.subplots()
  165. fig_roc, ax_roc = plt.subplots()
  166. # Test this dataset with every given test-function.
  167. # The results are printed out and stored to the results dictionary.
  168. if gan.canPredict and "GAN" not in self.testFunctions.keys():
  169. self.debug(f"-> test with 'GAN'")
  170. testResult = runTester(dataSlice, gan, "GAN", f"{imageFileName}-GAN.json")
  171. self.debug(str(testResult))
  172. self.results["GAN"].append(testResult)
  173. if imageFileName is not None:
  174. testResult.plotPR(ax_pr)
  175. testResult.plotROC(ax_roc)
  176. for testerName in self.testFunctions:
  177. self.debug(f"-> test with '{testerName}'")
  178. testResult = (self.testFunctions[testerName])(dataSlice, f"{imageFileName}-{testerName}.json")
  179. self.debug(str(testResult))
  180. self.results[testerName].append(testResult)
  181. if imageFileName is not None:
  182. testResult.plotPR(ax_pr)
  183. testResult.plotROC(ax_roc)
  184. if imageFileName is not None:
  185. fig_pr.savefig(imageFileName + "_PR.pdf")
  186. fig_roc.savefig(imageFileName + "_ROC.pdf")
  187. def saveResultsTo(self, fileName):
  188. avgResults = {}
  189. with open(fileName, "w") as f:
  190. for (n, name) in enumerate(self.results):
  191. if n > 0:
  192. f.write("---\n")
  193. f.write(name + "\n")
  194. isFirst = True
  195. stats = None
  196. for (m, result) in enumerate(self.results[name]):
  197. if isFirst:
  198. isFirst = False
  199. f.write("Nr.;" + result.csvHeading() + "\n")
  200. stats = result.addMinMaxAvg(stats)
  201. f.write(f"{m + 1};" + result.toCSV() + "\n")
  202. (mi, mx, avg) = TestResult.finishMinMaxAvg(stats)
  203. f.write(f"max;" + mx.toCSV() + "\n")
  204. f.write(f"avg;" + avg.toCSV() + "\n")
  205. f.write(f"min;" + mi.toCSV() + "\n")
  206. avgResults[name] = avg
  207. return avgResults
  208. def plotCloud(data0, data1, dataNew=None, outputFile=None, title="", doShow=True):
  209. """
  210. Does a PCA analysis of the given data and plot the both important axis.
  211. """
  212. if data0.shape[0] > 0:
  213. if data1.shape[0] > 0:
  214. data = np.concatenate([data0, data1])
  215. else:
  216. data = data0
  217. else:
  218. data = data1
  219. # Normalizes the data.
  220. if dataNew is None:
  221. data_t = StandardScaler().fit_transform(data)
  222. else:
  223. data_t = StandardScaler().fit_transform(np.concatenate([data, dataNew]))
  224. # Run the PCA analysis.
  225. pca = PCA(n_components=2)
  226. pc = pca.fit_transform(data_t)
  227. fig, ax = plt.subplots(sharex=True, sharey=True)
  228. fig.set_dpi(600)
  229. fig.set_figwidth(10)
  230. fig.set_figheight(10)
  231. fig.set_facecolor("white")
  232. ax.set_title(title)
  233. def doSubplot(m, n, c):
  234. pca0 = [x[0] for x in pc[m : m + n]]
  235. pca1 = [x[1] for x in pc[m : m + n]]
  236. s = ax.scatter(pca0, pca1, c=c)
  237. m = 0
  238. n = len(data0)
  239. labels = []
  240. if n > 0:
  241. labels = ["majority", "minority"]
  242. doSubplot(m, n, "gray")
  243. else:
  244. labels = ["data"]
  245. m += n
  246. n = len(data1)
  247. doSubplot(m, n, "red")
  248. if dataNew is not None:
  249. m += n
  250. n = len(dataNew)
  251. labels.append("synthetic")
  252. doSubplot(m, n, "blue")
  253. ax.legend(title="", loc='upper left', labels=labels)
  254. ax.set_xlabel("PCA0")
  255. ax.set_ylabel("PCA1")
  256. if doShow:
  257. plt.show()
  258. if outputFile is not None:
  259. fig.savefig(outputFile + ".pdf")