import pydicom import numpy as np import matplotlib.pyplot as plt import os import csv import math import wavelet import network import keras import tensorflow as tf import random actions = [ lambda img: wavelet.rotate(img, 1) , lambda img: wavelet.rotate(img, 2) , lambda img: wavelet.rotate(img, 3) , lambda img: wavelet.rotate(img, 4) , lambda img: wavelet.rotate(img, 5) , lambda img: wavelet.rotate(img, 6) , lambda img: wavelet.rotate(img, 7) ] #while len(actions) > 4: # p = random.randint(0, len(actions) - 1) # del actions[p] genData = [] model = network.createModel() pos = 0 first = True with open("mimx.csv") as f: for row in csv.reader(f, delimiter=","): if first or len(row) < 9: first = False continue n = f"{row[2]}" while len(n) < 4: n = f"0{n}" fileName = f"Proband {row[0]}/SE00000{row[1]}/{row[0]}_{n}.dcm" y = np.array([float(row[7]), float(row[8]), float(row[9]), float(row[10])]) genData.append( (fileName, y) ) def genDs(): pos = 0 for xy in genData: fileName = xy[0] #print(f"load '{fileName}' -> {pos}", end="\r") pos += 1 images = np.load(f"{fileName}_images.npy", allow_pickle=False) #img = pydicom.dcmread(fileName).pixel_array #for action in actions: # w = wavelet.refine(action(img)) # yield (1.0 * w.reshape((512 * 512,)) , xy[1]) for img in images: yield (img, xy[1]) ds = tf.data.Dataset.from_generator(genDs, output_signature=(tf.TensorSpec(shape=(512*512,), dtype=float), tf.TensorSpec(shape=(4,), dtype=float))) ds = ds.batch(32) print() print("Start Train") checkpoint_filepath = '/home/kristian/Dokumente/SBI/2025/UmrMri/checkpoint.weights.h5' model_checkpoint_callback = keras.callbacks.ModelCheckpoint( filepath=checkpoint_filepath, save_weights_only=True, monitor='val_accuracy', mode='max', save_best_only=True) bestLoss = None for epoch in range(128): print(f"Epoch {epoch + 1}") h = model.fit(ds, epochs=1, callbacks=[model_checkpoint_callback]) h = h.history loss = h['loss'] if bestLoss is None or loss < bestLoss: bestLoss = loss print(f"Update best loss to {bestLoss} and save model") network.save(model) print("done")