| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import pydicom
- import numpy as np
- import matplotlib.pyplot as plt
- import os
- import csv
- import math
- import wavelet
- import network
- import keras
- import tensorflow as tf
- import random
- maxPos = 4000 # 8740
- genData = []
- model = network.createModel()
- pos = 0
- first = True
- with open("mimx.csv") as f:
- for row in csv.reader(f, delimiter=","):
- if first or len(row) < 9:
- first = False
- continue
- n = f"{row[2]}"
- while len(n) < 4:
- n = f"0{n}"
- fileName = f"Proband {row[0]}/SE00000{row[1]}/{row[0]}_{n}.dcm"
- y = np.array([float(row[7]), float(row[8]), float(row[9]), float(row[10])])
- genData.append( (fileName, y) )
- def genDs():
- pos = 0
- for xy in genData:
- fileName = xy[0]
- #print(f"load '{fileName}' -> {pos}", end="\r")
- pos += 1
- img = pydicom.dcmread(fileName).pixel_array
- w = wavelet.refine(img)
- yield (1.0 * w.reshape((512 * 512,)) , xy[1])
- ds = tf.data.Dataset.from_generator(genDs, output_signature=(tf.TensorSpec(shape=(512*512,), dtype=float), tf.TensorSpec(shape=(4,), dtype=float)))
- ds = ds.batch(32)
- print()
- print("Start Train")
- checkpoint_filepath = 'checkpoint.weights.h5'
- model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
- filepath=checkpoint_filepath,
- save_weights_only=True,
- monitor='val_accuracy',
- mode='max',
- save_best_only=True)
- model.fit(ds, epochs=32, callbacks=[model_checkpoint_callback])
- network.save(model)
- print("done")
|