testers.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """
  2. This module contains test function for datasets using the logistic regression, the support vector
  3. machine and the k-next-neighbourhood algoritm. Additionally it contains a class for storing the
  4. results of the tests.
  5. """
  6. import sklearn
  7. # needed in function lr
  8. from sklearn.neighbors import KNeighborsClassifier
  9. from sklearn.linear_model import LogisticRegression
  10. from sklearn.metrics import confusion_matrix
  11. from sklearn.metrics import average_precision_score
  12. from sklearn.metrics import f1_score
  13. from sklearn.metrics import balanced_accuracy_score
  14. class TestResult:
  15. """
  16. This class represents the result of one test.
  17. It stores its *title*, a confusion matrix (*con_mat*), the balanced accuracy score (*bal_acc*)
  18. and the f1 score (*f1*). If given the average precision score is also stored (*aps*).
  19. """
  20. def __init__(self, title, labels, prediction, aps=None):
  21. """
  22. Creates an instance of this class. The stored data will be generated from the given values.
  23. *title* is a text to identify this result.
  24. *labels* is a /numpy.array/ containing the labels of the test-data-set.
  25. *prediction* is a /numpy.array/ containing the done prediction for the test-data-set.
  26. *aps* is a real number representing the average precision score.
  27. """
  28. self.title = title
  29. self.con_mat = confusion_matrix(labels, prediction)
  30. self.bal_acc = balanced_accuracy_score(labels, prediction)
  31. self.f1 = f1_score(labels, prediction)
  32. self.aps = aps
  33. def __str__(self):
  34. """
  35. Generates a text representing this result.
  36. """
  37. #tn, fp, fn, tp = con_mat.ravel()
  38. r = self.con_mat.ravel()
  39. text = f"tn, fp, fn, tp: {r}"
  40. if self.aps is not None:
  41. text += f"\naverage_pr_score: {self.aps}"
  42. text += f"\nf1 score_{self.title}: {self.f1}"
  43. text += f"\nbalanced accuracy_{self.title}: {self.bal_acc}"
  44. text += f"\nconfusion matrix_{self.title}\n {self.con_mat}"
  45. return text
  46. def lr(ttd):
  47. """
  48. Runs a test for a dataset with the logistic regression algorithm.
  49. It returns a /TestResult./
  50. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  51. """
  52. logreg = LogisticRegression(
  53. C=1e5,
  54. solver='lbfgs',
  55. multi_class='multinomial',
  56. class_weight={0: 1, 1: 1.3}
  57. )
  58. logreg.fit(ttd.train.data, ttd.train.labels)
  59. prediction = logreg.predict(ttd.test.data)
  60. prob_lr = logreg.predict_proba(ttd.test.data)
  61. aps_lr = average_precision_score(ttd.test.labels, prob_lr[:,1])
  62. return TestResult("LR", ttd.test.labels, prediction, aps_lr)
  63. def svm(ttd):
  64. """
  65. Runs a test for a dataset with the support vector machine algorithm.
  66. It returns a /TestResult./
  67. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  68. """
  69. svmTester = sklearn.svm.SVC(
  70. kernel='linear',
  71. decision_function_shape='ovo',
  72. class_weight={0: 1., 1: 1.},
  73. probability=True
  74. )
  75. svmTester.fit(ttd.train.data, ttd.train.labels)
  76. prediction = svmTester.predict(ttd.test.data)
  77. return TestResult("SVM", ttd.test.labels, prediction)
  78. def knn(ttd):
  79. """
  80. Runs a test for a dataset with the k-next neighbourhood algorithm.
  81. It returns a /TestResult./
  82. *ttd* is a /library.dataset.TrainTestData/ instance containing data to test.
  83. """
  84. knnTester = KNeighborsClassifier(n_neighbors=10)
  85. knnTester.fit(ttd.train.data, ttd.train.labels)
  86. prediction = knnTester.predict(ttd.test.data)
  87. return TestResult("KNN", ttd.test.labels, prediction)