testers.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. """
  2. This module contains test function for datasets using the logistic regression, the support vector
  3. machine and the k-next-neighbourhood algoritm. Additionally it contains a class for storing the
  4. results of the tests.
  5. """
  6. import sklearn
  7. # needed in function lr
  8. from sklearn.ensemble import RandomForestClassifier
  9. from sklearn.neighbors import KNeighborsClassifier
  10. from sklearn.linear_model import LogisticRegression
  11. from sklearn.metrics import confusion_matrix
  12. from sklearn.metrics import average_precision_score
  13. from sklearn.metrics import f1_score
  14. from sklearn.metrics import cohen_kappa_score
  15. from sklearn.metrics import RocCurveDisplay
  16. from sklearn.metrics import PrecisionRecallDisplay
  17. from sklearn.ensemble import GradientBoostingClassifier
  18. from imblearn.metrics import geometric_mean_score
  19. from library.cache import dataCache
  20. _tF1 = "f1 score"
  21. _tTN = "TN"
  22. _tTP = "TP"
  23. _tFN = "FN"
  24. _tFP = "FP"
  25. _tFP = "RF"
  26. _tAps = "average precision score"
  27. _tCks = "cohens kappa score"
  28. _tGMean = "G-Mean score"
  29. class TestResult:
  30. """
  31. This class represents the result of one test.
  32. It stores its *title*, a confusion matrix (*con_mat*), the balanced accuracy score (*bal_acc*)
  33. and the f1 score (*f1*). If given the average precision score is also stored (*aps*).
  34. """
  35. def __init__(self, title, labels=None, prediction=None, aps=None):
  36. """
  37. Creates an instance of this class. The stored data will be generated from the given values.
  38. *title* is a text to identify this result.
  39. *labels* is a /numpy.array/ containing the labels of the test-data-set.
  40. *prediction* is a /numpy.array/ containing the done prediction for the test-data-set.
  41. *aps* is a real number representing the average precision score.
  42. """
  43. self.title = title
  44. self.heading = [_tTN, _tTP, _tFN, _tFP, _tF1, _tCks, _tAps, _tGMean]
  45. self.data = { n: 0.0 for n in self.heading }
  46. self.labels = labels
  47. self.prediction = prediction
  48. if labels is not None and prediction is not None:
  49. self.data[_tF1] = f1_score(labels, prediction)
  50. self.data[_tCks] = cohen_kappa_score(labels, prediction)
  51. conMat = self._enshureConfusionMatrix(confusion_matrix(labels, prediction))
  52. [[tn, fp], [fn, tp]] = conMat
  53. self.data[_tTN] = tn
  54. self.data[_tTP] = tp
  55. self.data[_tFN] = fn
  56. self.data[_tFP] = fp
  57. self.data[_tGMean] = geometric_mean_score(labels, prediction)
  58. if aps is None:
  59. self.data[_tAps] = average_precision_score(labels, prediction)
  60. if aps is not None:
  61. self.data[_tAps] = aps
  62. def __str__(self):
  63. """
  64. Generates a text representing this result.
  65. """
  66. text = ""
  67. tn = self.data[_tTN]
  68. tp = self.data[_tTP]
  69. fn = self.data[_tFN]
  70. fp = self.data[_tFP]
  71. text += f"{self.title} tn, fp: {tn}, {fp}\n"
  72. text += f"{self.title} fn, tp: {fn}, {tp}\n"
  73. for k in self.heading:
  74. if k not in [_tTP, _tTN, _tFP, _tFN]:
  75. text += f"{self.title} {k}: {self.data[k]:.3f}\n"
  76. return text
  77. def csvHeading(self):
  78. return ";".join(self.heading)
  79. def toCSV(self):
  80. return ";".join(map(lambda k: f"{self.data[k]:0.3f}", self.heading))
  81. @staticmethod
  82. def _enshureConfusionMatrix(c):
  83. c0 = [0.0, 0.0]
  84. c1 = [0.0, 0.0]
  85. if len(c) > 0:
  86. if len(c[0]) > 0:
  87. c0[0] = c[0][0]
  88. if len(c[0]) > 1:
  89. c0[1] = c[0][1]
  90. if len(c) > 1 and len(c[1]) > 1:
  91. c1[0] = c[1][0]
  92. c1[1] = c[1][1]
  93. return [c0, c1]
  94. def copy(self):
  95. r = TestResult(self.title)
  96. r.data = self.data.copy()
  97. r.heading = self.heading.copy()
  98. return r
  99. def addMinMaxAvg(self, mma=None):
  100. if mma is None:
  101. return (1, self.copy(), self.copy(), self.copy())
  102. (n, mi, mx, a) = mma
  103. for k in a.heading:
  104. if k in self.heading:
  105. a.data[k] += self.data[k]
  106. for k in mi.heading:
  107. if k in self.heading:
  108. mi.data[k] = min(mi.data[k], self.data[k])
  109. for k in mx.heading:
  110. if k in self.heading:
  111. mx.data[k] = max(mx.data[k], self.data[k])
  112. return (n + 1, mi, mx, a)
  113. @staticmethod
  114. def finishMinMaxAvg(mma):
  115. if mma is None:
  116. return (TestResult("?"), TestResult("?"), TestResult("?"))
  117. else:
  118. (n, mi, ma, a) = mma
  119. for k in a.heading:
  120. if n > 0:
  121. a.data[k] = a.data[k] / n
  122. else:
  123. a.data[k] = 0.0
  124. return (mi, ma, a)
  125. def plotPR(self, ax):
  126. PrecisionRecallDisplay.from_predictions(self.labels, self.prediction, name=self.title, ax=ax)
  127. def plotROC(self, ax):
  128. RocCurveDisplay.from_predictions(self.labels, self.prediction, name=self.title, ax=ax)
  129. def lr(ttd, jsonFileName=None):
  130. """
  131. Runs a test for a dataset with the logistic regression algorithm.
  132. It returns a /TestResult./
  133. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  134. """
  135. def g(nothing):
  136. checkType(ttd)
  137. logreg = LogisticRegression(
  138. C=1e5,
  139. solver='lbfgs',
  140. max_iter=10000,
  141. multi_class='multinomial',
  142. class_weight={0: 1, 1: 1.3}
  143. )
  144. logreg.fit(ttd.train.data, ttd.train.labels)
  145. prediction = logreg.predict(ttd.test.data)
  146. prob_lr = logreg.predict_proba(ttd.test.data)
  147. aps_lr = average_precision_score(ttd.test.labels, prob_lr[:,1])
  148. return {
  149. "labels": ttd.test.labels,
  150. "prediction": prediction,
  151. "aps_lr": aps_lr
  152. }
  153. d = dataCache(jsonFileName, g)
  154. return TestResult("LR", d["labels"], d["prediction"], d["aps_lr"])
  155. def knn(ttd, jsonFileName=None):
  156. """
  157. Runs a test for a dataset with the k-next neighbourhood algorithm.
  158. It returns a /TestResult./
  159. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  160. """
  161. knnTester = KNeighborsClassifier(n_neighbors=10)
  162. return runTester(ttd, knnTester, "KNN", jsonFileName)
  163. def gb(ttd, jsonFileName=None):
  164. """
  165. Runs a test for a dataset with the gradient boosting algorithm.
  166. It returns a /TestResult./
  167. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  168. """
  169. tester = GradientBoostingClassifier()
  170. return runTester(ttd, tester, "GB", jsonFileName)
  171. def rf(ttd, jsonFileName=None):
  172. """
  173. Runs a test for a dataset with the random forest algorithm.
  174. It returns a /TestResult./
  175. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  176. """
  177. tester = RandomForestClassifier()
  178. return runTester(ttd, tester, "RF", jsonFileName)
  179. def runTester(ttd, tester, name="GAN", jsonFileName=None):
  180. def g(nothing):
  181. checkType(ttd)
  182. tester.fit(ttd.train.data, ttd.train.labels)
  183. return {
  184. "labels": ttd.test.labels,
  185. "prediction": tester.predict(ttd.test.data)
  186. }
  187. d = dataCache(jsonFileName, g)
  188. return TestResult(name, d["labels"], d["prediction"])
  189. def checkType(t):
  190. if str(type(t)) == "<class 'numpy.ndarray'>":
  191. return t.shape[0] > 0 and all(map(checkType, t))
  192. elif str(type(t)) == "<class 'list'>":
  193. return len(t) > 0 and all(map(checkType, t))
  194. elif str(type(t)) in ["<class 'int'>", "<class 'float'>", "<class 'numpy.float64'>"]:
  195. return True
  196. elif str(type(t)) == "<class 'library.dataset.DataSet'>":
  197. return checkType(t.data0) and checkType(t.data1)
  198. elif str(type(t)) == "<class 'library.dataset.TrainTestData'>":
  199. return checkType(t.train) and checkType(t.test)
  200. else:
  201. raise ValueError("expected int, float, or list, dataset of int, float but got " + str(type(t)))
  202. return False