train.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import pydicom
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import os
  5. import csv
  6. import math
  7. import wavelet
  8. import network
  9. import keras
  10. import tensorflow as tf
  11. import random
  12. maxPos = 4000 # 8740
  13. genData = []
  14. model = network.createModel()
  15. pos = 0
  16. first = True
  17. with open("mimx.csv") as f:
  18. for row in csv.reader(f, delimiter=","):
  19. if first or len(row) < 9:
  20. first = False
  21. continue
  22. n = f"{row[2]}"
  23. while len(n) < 4:
  24. n = f"0{n}"
  25. fileName = f"Proband {row[0]}/SE00000{row[1]}/{row[0]}_{n}.dcm"
  26. y = np.array([float(row[7]), float(row[8]), float(row[9]), float(row[10])])
  27. genData.append( (fileName, y) )
  28. def genDs():
  29. pos = 0
  30. for xy in genData:
  31. fileName = xy[0]
  32. #print(f"load '{fileName}' -> {pos}", end="\r")
  33. pos += 1
  34. img = pydicom.dcmread(fileName).pixel_array
  35. w = wavelet.refine(img)
  36. yield (1.0 * w.reshape((512 * 512,)) , xy[1])
  37. ds = tf.data.Dataset.from_generator(genDs, output_signature=(tf.TensorSpec(shape=(512*512,), dtype=float), tf.TensorSpec(shape=(4,), dtype=float)))
  38. ds = ds.batch(32)
  39. print()
  40. print("Start Train")
  41. checkpoint_filepath = 'checkpoint.weights.h5'
  42. model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
  43. filepath=checkpoint_filepath,
  44. save_weights_only=True,
  45. monitor='val_accuracy',
  46. mode='max',
  47. save_best_only=True)
  48. model.fit(ds, epochs=32, callbacks=[model_checkpoint_callback])
  49. network.save(model)
  50. print("done")