train.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import pydicom
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import os
  5. import csv
  6. import math
  7. import wavelet
  8. import network
  9. import keras
  10. import tensorflow as tf
  11. import random
  12. actions = [ lambda img: wavelet.rotate(img, 1)
  13. , lambda img: wavelet.rotate(img, 2)
  14. , lambda img: wavelet.rotate(img, 3)
  15. , lambda img: wavelet.rotate(img, 4)
  16. , lambda img: wavelet.rotate(img, 5)
  17. , lambda img: wavelet.rotate(img, 6)
  18. , lambda img: wavelet.rotate(img, 7)
  19. ]
  20. #while len(actions) > 4:
  21. # p = random.randint(0, len(actions) - 1)
  22. # del actions[p]
  23. genData = []
  24. labels = []
  25. data = []
  26. hSize = 1000
  27. model = network.createModelHistogram(1000)
  28. pos = 0
  29. first = True
  30. with open("mimx.csv") as f:
  31. for row in csv.reader(f, delimiter=","):
  32. if first or len(row) < 9:
  33. first = False
  34. continue
  35. n = f"{row[2]}"
  36. while len(n) < 4:
  37. n = f"0{n}"
  38. fileName = f"../Proband {row[0]}/SE00000{row[1]}/{row[0]}_{n}.dcm_blured_histo_diff.npy"
  39. print(pos, end="\r")
  40. pos += 1
  41. x = np.load(fileName, allow_pickle=False)[:hSize]
  42. y = network.toOneHot([1.0,2.0,3.0,4.0,5.0], float(row[8]))
  43. if int(row[7]) == 4:
  44. n = 30
  45. elif int(row[7]) < 4:
  46. n = 60
  47. else:
  48. n = 10
  49. for _ in range(n):
  50. noise = np.array([float(random.randint(0,4)) for _ in x])
  51. labels.append(y)
  52. data.append(x + noise)
  53. print()
  54. print(f"Start Train with {len(data)} items.")
  55. data = np.array(data)
  56. labels = np.array(labels)
  57. with open("histogram.log", "wt") as f:
  58. bestLoss = None
  59. for epoch in range(128):
  60. print(f"Epoch {epoch + 1}")
  61. h = model.fit(data, labels, epochs=1)
  62. h = h.history
  63. loss = h['loss']
  64. if bestLoss is None or loss < bestLoss:
  65. bestLoss = loss
  66. print(f"Update best loss to {bestLoss} and save model")
  67. f.write(f"Update best loss to {bestLoss} and save model\n")
  68. network.save(model, "model.keras")
  69. print("done")