| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import pydicom
- import numpy as np
- import matplotlib.pyplot as plt
- import os
- import csv
- import math
- import wavelet
- pos = 0
- first = True
- firstImage = True
- labels = []
- prediction = []
- titles = []
- with open("prediction.csv") as f:
- for row in csv.reader(f, delimiter=","):
- if first:
- first = False
- titles = row[8:12]
- continue
- if len(row) < 9:
- continue
- y = [float(row[1])]
- z = [float(row[n]) for n in [7,8,9,10,11,12]]
- #z[0] = float( min([z[0]] + [n for n in [1,2,3,4,5] if z[n] >= 0.2]))
-
- labels.append(y)
- prediction.append(z)
- pos += 1
- plt.show()
- n = 0
- #for n in range(4):
- keys = set([x[n] for x in labels])
- boxes = [[] for _ in range(5)]
- for y, z in zip(labels, prediction):
- k = int(y[n]) - 1
- boxes[k].append(z[n])
- for y in [1.5,2.5,3.5,4.5]:
- plt.plot([1,5], [y, y], color="#cccccc")
- plt.boxplot(boxes)
- plt.savefig(f"graphics/prediction_{n}.pdf")
- plt.close()
- diffs = [abs(y[n] - z[n]) < 0.5 for y, z in zip(labels, prediction)]
- good = 0
- bad = 0
- for x in diffs:
- if x:
- good += 1
- else:
- bad += 1
- print(f"{n}: {good} <-> {bad}")
- for n in [1,2,3,4,5]:
- plt.scatter([x[0] for x in labels], [x[n] for x in prediction], c=[int(x[0] + 0.5) for x in labels])
- plt.savefig(f"graphics/one_hot_{n}.pdf")
- plt.close()
- for y, z in zip(labels, prediction):
- if y[0] < 4 and abs(y[0] - z[0]) >= 0.5:
- plt.plot([1,2,3,4,5], [z[n] for n in [1,2,3,4,5]])
- plt.savefig(f"graphics/one_hot_all.pdf")
- plt.close()
- barY = []
- barN = []
- titles = [1.0,2.0,3.0,4.0,5.0]
- for n in titles:
- yes = 0
- no = 0
- for y, d in zip(labels, diffs):
- if y[0] == n:
- if d:
- yes += 1
- else:
- no += 1
- s = yes + no
- if s > 0:
- yes = yes * 100 / s
- no = no * 100 / s
- barY.append(yes)
- barN.append(no)
- width = 0.4
- for y in [20, 40, 60, 80, 100]:
- plt.plot([1,5.8], [y, y], color="#cccccc")
- plt.plot([1,5.8], [y - 10, y - 10], color="#eeeeee")
- plt.bar(titles, barN, width, label="Mismatch")
- plt.bar([x + width for x in titles], barY, width, label="Match")
- plt.legend()
- plt.savefig(f"graphics/bar.pdf")
- plt.close()
- print(f"Yes: {barY}")
- print(f"No : {barN}")
|