testers.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """
  2. This module contains test function for datasets using the logistic regression, the support vector
  3. machine and the k-next-neighbourhood algoritm. Additionally it contains a class for storing the
  4. results of the tests.
  5. """
  6. import sklearn
  7. # needed in function lr
  8. from sklearn.neighbors import KNeighborsClassifier
  9. from sklearn.linear_model import LogisticRegression
  10. from sklearn.metrics import confusion_matrix
  11. from sklearn.metrics import average_precision_score
  12. from sklearn.metrics import f1_score
  13. from sklearn.metrics import balanced_accuracy_score
  14. class TestResult:
  15. """
  16. This class represents the result of one test.
  17. It stores its *title*, a confusion matrix (*con_mat*), the balanced accuracy score (*bal_acc*)
  18. and the f1 score (*f1*). If given the average precision score is also stored (*aps*).
  19. """
  20. def __init__(self, title, labels, prediction, aps=None):
  21. """
  22. Creates an instance of this class. The stored data will be generated from the given values.
  23. *title* is a text to identify this result.
  24. *labels* is a /numpy.array/ containing the labels of the test-data-set.
  25. *prediction* is a /numpy.array/ containing the done prediction for the test-data-set.
  26. *aps* is a real number representing the average precision score.
  27. """
  28. self.title = title
  29. self.con_mat = confusion_matrix(labels, prediction)
  30. self.bal_acc = balanced_accuracy_score(labels, prediction)
  31. self.f1 = f1_score(labels, prediction)
  32. self.aps = aps
  33. def __str__(self):
  34. """
  35. Generates a text representing this result.
  36. """
  37. #tn, fp, fn, tp = con_mat.ravel()
  38. r = self.con_mat.ravel()
  39. text = f"tn, fp, fn, tp: {r}"
  40. if self.aps is not None:
  41. text += f"\naverage_pr_score: {self.aps}"
  42. text += f"\nf1 score_{self.title}: {self.f1}"
  43. text += f"\nbalanced accuracy_{self.title}: {self.bal_acc}"
  44. text += f"\nconfusion matrix_{self.title}\n {self.con_mat}"
  45. return text
  46. def csvHeading():
  47. r = [
  48. "F1 score",
  49. "balanced accuracy",
  50. "TN",
  51. "FP",
  52. "FN",
  53. "TP"
  54. ]
  55. if self.aps is not None:
  56. r.append(self.aps)
  57. return ";".join(r)
  58. def toCSV():
  59. r = [
  60. self.f1,
  61. self.bal_acc,
  62. self.con_mat[0] if len(self.con_mat) > 0 else float(self.con_mat),
  63. self.con_mat[1] if len(self.con_mat) > 1 else 0,
  64. self.con_mat[2] if len(self.con_mat) > 2 else 0,
  65. self.con_mat[3] if len(self.con_mat) > 3 else 0
  66. ]
  67. if self.aps is not None:
  68. r.append(self.aps)
  69. return ";".join(r)
  70. def lr(ttd):
  71. """
  72. Runs a test for a dataset with the logistic regression algorithm.
  73. It returns a /TestResult./
  74. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  75. """
  76. checkType(ttd)
  77. logreg = LogisticRegression(
  78. C=1e5,
  79. solver='lbfgs',
  80. multi_class='multinomial',
  81. class_weight={0: 1, 1: 1.3}
  82. )
  83. logreg.fit(ttd.train.data, ttd.train.labels)
  84. prediction = logreg.predict(ttd.test.data)
  85. prob_lr = logreg.predict_proba(ttd.test.data)
  86. aps_lr = average_precision_score(ttd.test.labels, prob_lr[:,1])
  87. return TestResult("LR", ttd.test.labels, prediction, aps_lr)
  88. def svm(ttd):
  89. """
  90. Runs a test for a dataset with the support vector machine algorithm.
  91. It returns a /TestResult./
  92. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  93. """
  94. checkType(ttd)
  95. svmTester = sklearn.svm.SVC(
  96. kernel='linear',
  97. decision_function_shape='ovo',
  98. class_weight={0: 1., 1: 1.},
  99. probability=True
  100. )
  101. svmTester.fit(ttd.train.data, ttd.train.labels)
  102. prediction = svmTester.predict(ttd.test.data)
  103. return TestResult("SVM", ttd.test.labels, prediction)
  104. def knn(ttd):
  105. """
  106. Runs a test for a dataset with the k-next neighbourhood algorithm.
  107. It returns a /TestResult./
  108. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  109. """
  110. checkType(ttd)
  111. knnTester = KNeighborsClassifier(n_neighbors=10)
  112. knnTester.fit(ttd.train.data, ttd.train.labels)
  113. prediction = knnTester.predict(ttd.test.data)
  114. return TestResult("KNN", ttd.test.labels, prediction)
  115. def checkType(t):
  116. if str(type(t)) == "<class 'numpy.ndarray'>":
  117. return t.shape[0] > 0 and all(map(checkType, t))
  118. elif str(type(t)) == "<class 'list'>":
  119. return len(t) > 0 and all(map(checkType, t))
  120. elif str(type(t)) in ["<class 'int'>", "<class 'float'>", "<class 'numpy.float64'>"]:
  121. return True
  122. elif str(type(t)) == "<class 'library.dataset.DataSet'>":
  123. return checkType(t.data0) and checkType(t.data1)
  124. elif str(type(t)) == "<class 'library.dataset.TrainTestData'>":
  125. return checkType(t.train) and checkType(t.test)
  126. else:
  127. raise ValueError("expected int, float, or list, dataset of int, float but got " + str(type(t)))
  128. return False