train.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import pydicom
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import os
  5. import csv
  6. import math
  7. import wavelet
  8. import network
  9. import keras
  10. import tensorflow as tf
  11. import random
  12. actions = [ lambda img: wavelet.rotate(img, 1)
  13. , lambda img: wavelet.rotate(img, 2)
  14. , lambda img: wavelet.rotate(img, 3)
  15. , lambda img: wavelet.rotate(img, 4)
  16. , lambda img: wavelet.rotate(img, 5)
  17. , lambda img: wavelet.rotate(img, 6)
  18. , lambda img: wavelet.rotate(img, 7)
  19. ]
  20. #while len(actions) > 4:
  21. # p = random.randint(0, len(actions) - 1)
  22. # del actions[p]
  23. genData = []
  24. model = network.createModel()
  25. pos = 0
  26. first = True
  27. with open("mimx.csv") as f:
  28. for row in csv.reader(f, delimiter=","):
  29. if first or len(row) < 9:
  30. first = False
  31. continue
  32. n = f"{row[2]}"
  33. while len(n) < 4:
  34. n = f"0{n}"
  35. fileName = f"Proband {row[0]}/SE00000{row[1]}/{row[0]}_{n}.dcm"
  36. y = np.array([float(row[7]), float(row[8]), float(row[9]), float(row[10])])
  37. genData.append( (fileName, y) )
  38. def genDs():
  39. pos = 0
  40. for xy in genData:
  41. fileName = xy[0]
  42. #print(f"load '{fileName}' -> {pos}", end="\r")
  43. pos += 1
  44. images = np.load(f"{fileName}_images.npy", allow_pickle=False)
  45. #img = pydicom.dcmread(fileName).pixel_array
  46. #for action in actions:
  47. # w = wavelet.refine(action(img))
  48. # yield (1.0 * w.reshape((512 * 512,)) , xy[1])
  49. for img in images:
  50. yield (img, xy[1])
  51. ds = tf.data.Dataset.from_generator(genDs, output_signature=(tf.TensorSpec(shape=(512*512,), dtype=float), tf.TensorSpec(shape=(4,), dtype=float)))
  52. ds = ds.batch(32)
  53. print()
  54. print("Start Train")
  55. checkpoint_filepath = '/home/kristian/Dokumente/SBI/2025/UmrMri/checkpoint.weights.h5'
  56. model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
  57. filepath=checkpoint_filepath,
  58. save_weights_only=True,
  59. monitor='val_accuracy',
  60. mode='max',
  61. save_best_only=True)
  62. bestLoss = None
  63. for epoch in range(128):
  64. print(f"Epoch {epoch + 1}")
  65. h = model.fit(ds, epochs=1, callbacks=[model_checkpoint_callback])
  66. h = h.history
  67. loss = h['loss']
  68. if bestLoss is None or loss < bestLoss:
  69. bestLoss = loss
  70. print(f"Update best loss to {bestLoss} and save model")
  71. network.save(model)
  72. print("done")