import pydicom import numpy as np import matplotlib.pyplot as plt import os import csv import math import wavelet import network import keras import tensorflow as tf import random genData = [] print("Create Model") ae, model = network.createModelAe() pos = 0 first = True print("Load Data") with open("mimx.csv") as f: for row in csv.reader(f, delimiter=","): if first or len(row) < 9: first = False continue n = f"{row[2]}" while len(n) < 4: n = f"0{n}" fileName = f"../Proband {row[0]}/SE00000{row[1]}/{row[0]}_{n}.dcm" y = np.array(network.toOneHot([1.0,2.0,3.0,4.0,5.0], float(row[8]))) genData.append( (fileName, y) ) def genDsAe(): for xy in genData: fileName = xy[0] images = np.load(f"{fileName}_images.npy", allow_pickle=False) for img in images: yield (img, img) def genDsDisc(): for xy in genData: fileName = xy[0] images = np.load(f"{fileName}_images.npy", allow_pickle=False) for img in images: yield (img, xy[1]) with open("histo.log", "wt") as fHisto: print() print("Start Train AE") print("Start Train AE", file=fHisto) ds = tf.data.Dataset.from_generator(genDsAe, output_signature=(tf.TensorSpec(shape=(512*512,), dtype=float), tf.TensorSpec(shape=(512*512,), dtype=float))) ds = ds.batch(32) ae.trainable = True bestLoss = None for epoch in range(128): print(f"Epoch {epoch + 1}") h = ae.fit(ds, epochs=1) h = h.history loss = h['loss'] if bestLoss is None or loss < bestLoss: bestLoss = loss print(f"Update best loss to {bestLoss} and save model") print(f"Update best loss to {bestLoss} and save model", file=fHisto) network.save(model) print() print("Start Train Disc") print("Start Train Disc", file=fHisto) ae.trainable = False ds = tf.data.Dataset.from_generator(genDsDisc, output_signature=(tf.TensorSpec(shape=(512*512,), dtype=float), tf.TensorSpec(shape=(5,), dtype=float))) ds = ds.batch(32) bestLoss = None for epoch in range(128): print(f"Epoch {epoch + 1}") h = model.fit(ds, epochs=1) h = h.history loss = h['loss'] if bestLoss is None or loss < bestLoss: bestLoss = loss print(f"Update best loss to {bestLoss} and save model") print(f"Update best loss to {bestLoss} and save model", file=fHisto) network.save(model) print("done")