createGraphics.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import pydicom
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import os
  5. import csv
  6. import math
  7. import wavelet
  8. pos = 0
  9. first = True
  10. firstImage = True
  11. labels = []
  12. prediction = []
  13. titles = []
  14. with open("prediction.csv") as f:
  15. for row in csv.reader(f, delimiter=","):
  16. if first:
  17. first = False
  18. titles = row[8:12]
  19. continue
  20. if len(row) < 9:
  21. continue
  22. y = [float(row[1])]
  23. z = [float(row[n]) for n in [7,8,9,10,11]]
  24. z = [float( min([z[0]] + [n for n in [1,2,3,4,5] if z[n] >= 0.2]))] + z
  25. labels.append(y)
  26. prediction.append(z)
  27. pos += 1
  28. plt.show()
  29. n = 0
  30. #for n in range(4):
  31. keys = set([x[n] for x in labels])
  32. boxes = [[] for _ in range(5)]
  33. for y, z in zip(labels, prediction):
  34. k = int(y[n]) - 1
  35. boxes[k].append(z[n])
  36. for y in [1.5,2.5,3.5,4.5]:
  37. plt.plot([1,5], [y, y], color="#cccccc")
  38. plt.boxplot(boxes)
  39. plt.savefig(f"graphics/prediction_{n}.pdf")
  40. plt.close()
  41. diffs = [abs(y[n] - z[n]) < 0.5 for y, z in zip(labels, prediction)]
  42. good = 0
  43. bad = 0
  44. for x in diffs:
  45. if x:
  46. good += 1
  47. else:
  48. bad += 1
  49. print(f"{n}: {good} <-> {bad}")
  50. for n in [1,2,3,4,5]:
  51. plt.scatter([x[0] for x in labels], [x[n] for x in prediction], c=[int(x[0] + 0.5) for x in labels])
  52. plt.savefig(f"graphics/one_hot_{n}.pdf")
  53. plt.close()
  54. for y, z in zip(labels, prediction):
  55. if y[0] < 4 and abs(y[0] - z[0]) >= 0.5:
  56. plt.plot([1,2,3,4,5], [z[n] for n in [1,2,3,4,5]])
  57. plt.savefig(f"graphics/one_hot_all.pdf")
  58. plt.close()
  59. barY = []
  60. barN = []
  61. titles = [1.0,2.0,3.0,4.0,5.0]
  62. for n in titles:
  63. yes = 0
  64. no = 0
  65. for y, d in zip(labels, diffs):
  66. if y[0] == n:
  67. if d:
  68. yes += 1
  69. else:
  70. no += 1
  71. s = yes + no
  72. if s > 0:
  73. yes = yes * 100 / s
  74. no = no * 100 / s
  75. barY.append(yes)
  76. barN.append(no)
  77. width = 0.4
  78. for y in [20, 40, 60, 80, 100]:
  79. plt.plot([1,5.8], [y, y], color="#cccccc")
  80. plt.plot([1,5.8], [y - 10, y - 10], color="#eeeeee")
  81. plt.bar(titles, barN, width, label="Mismatch")
  82. plt.bar([x + width for x in titles], barY, width, label="Match")
  83. plt.legend()
  84. plt.savefig(f"graphics/bar.pdf")
  85. plt.close()
  86. print(f"Yes: {barY}")
  87. print(f"No : {barN}")