test.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import pydicom
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import os
  5. import csv
  6. import math
  7. import wavelet
  8. import network
  9. import keras
  10. import tensorflow as tf
  11. maxPos = 4000 # 8740
  12. genData = []
  13. hSize = 1000
  14. model = network.load("model.keras")
  15. pos = 0
  16. first = True
  17. with open("prediction.csv", "wt") as fout:
  18. wtr = csv.writer(fout)
  19. with open("mimx.csv") as f:
  20. for row in csv.reader(f, delimiter=","):
  21. if first or len(row) < 9:
  22. first = False
  23. wtr.writerow(["image", row[8]] + [1,2,3,4,5,] + [x + "_predicted" for x in ["v", "1", "2", "3", "4", "5"]])
  24. continue
  25. n = f"{row[2]}"
  26. while len(n) < 4:
  27. n = f"0{n}"
  28. fileNameImg = f"../Proband {row[0]}/SE00000{row[1]}/{row[0]}_{n}.dcm"
  29. fileName = f"../Proband {row[0]}/SE00000{row[1]}/{row[0]}_{n}.dcm_blured_histo_diff.npy"
  30. print(f"load '{fileName}' -> {pos}", end="\r")
  31. pos += 1
  32. y = network.toOneHot([1.0, 2.0, 3.0, 4.0, 5.0], float(row[8]))
  33. w = np.load(fileName, allow_pickle=False)[:hSize]
  34. prediction = model.predict(np.array([w]), verbose=0)
  35. p = prediction[0]
  36. s = sum(prediction[0])
  37. if s > 0.0:
  38. p = (1.0 / s) * p
  39. v = sum(np.array([1.0,2.0,3.0,4.0,5.0]) * p)
  40. wtr.writerow([fileNameImg, row[7]] + y + [v] + list(prediction[0]))
  41. print()
  42. print("done")