train.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import pydicom
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import os
  5. import csv
  6. import math
  7. import wavelet
  8. import network
  9. import keras
  10. import tensorflow as tf
  11. import random
  12. genData = []
  13. print("Create Model")
  14. ae, model = network.createModelAe()
  15. pos = 0
  16. first = True
  17. print("Load Data")
  18. with open("mimx.csv") as f:
  19. for row in csv.reader(f, delimiter=","):
  20. if first or len(row) < 9:
  21. first = False
  22. continue
  23. n = f"{row[2]}"
  24. while len(n) < 4:
  25. n = f"0{n}"
  26. fileName = f"../Proband {row[0]}/SE00000{row[1]}/{row[0]}_{n}.dcm"
  27. y = np.array(network.toOneHot([1.0,2.0,3.0,4.0,5.0], float(row[8])))
  28. genData.append( (fileName, y) )
  29. def genDsAe():
  30. for xy in genData:
  31. fileName = xy[0]
  32. images = np.load(f"{fileName}_images.npy", allow_pickle=False)
  33. for img in images:
  34. yield (img, img)
  35. def genDsDisc():
  36. for xy in genData:
  37. fileName = xy[0]
  38. images = np.load(f"{fileName}_images.npy", allow_pickle=False)
  39. for img in images:
  40. yield (img, xy[1])
  41. with open("histo.log", "wt") as fHisto:
  42. print()
  43. print("Start Train AE")
  44. print("Start Train AE", file=fHisto)
  45. ds = tf.data.Dataset.from_generator(genDsAe, output_signature=(tf.TensorSpec(shape=(512*512,), dtype=float), tf.TensorSpec(shape=(512*512,), dtype=float)))
  46. ds = ds.batch(32)
  47. ae.trainable = True
  48. bestLoss = None
  49. for epoch in range(128):
  50. print(f"Epoch {epoch + 1}")
  51. h = ae.fit(ds, epochs=1)
  52. h = h.history
  53. loss = h['loss']
  54. if bestLoss is None or loss < bestLoss:
  55. bestLoss = loss
  56. print(f"Update best loss to {bestLoss} and save model")
  57. print(f"Update best loss to {bestLoss} and save model", file=fHisto)
  58. network.save(model)
  59. print()
  60. print("Start Train Disc")
  61. print("Start Train Disc", file=fHisto)
  62. ae.trainable = False
  63. ds = tf.data.Dataset.from_generator(genDsDisc, output_signature=(tf.TensorSpec(shape=(512*512,), dtype=float), tf.TensorSpec(shape=(5,), dtype=float)))
  64. ds = ds.batch(32)
  65. bestLoss = None
  66. for epoch in range(128):
  67. print(f"Epoch {epoch + 1}")
  68. h = model.fit(ds, epochs=1)
  69. h = h.history
  70. loss = h['loss']
  71. if bestLoss is None or loss < bestLoss:
  72. bestLoss = loss
  73. print(f"Update best loss to {bestLoss} and save model")
  74. print(f"Update best loss to {bestLoss} and save model", file=fHisto)
  75. network.save(model)
  76. print("done")