|
|
@@ -0,0 +1,111 @@
|
|
|
+"""
|
|
|
+This module contains some example Generative Adversarial Networks for testing.
|
|
|
+
|
|
|
+The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
|
|
|
+for testing the interface. Hope your actually GAN will perform better than this two.
|
|
|
+
|
|
|
+The class SimpleGan is a simple standard Generative Adversarial Network.
|
|
|
+"""
|
|
|
+
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import tensorflow as tf
|
|
|
+
|
|
|
+from library.interfaces import GanBaseClass
|
|
|
+
|
|
|
+
|
|
|
+def dist(x, y):
|
|
|
+ return tf.sqrt(tf.reduce_sum(tf.square(x - y)))
|
|
|
+
|
|
|
+def minDistPointToSet(x, setB):
|
|
|
+ m = None
|
|
|
+ for y in setB:
|
|
|
+ d = dist(x,y)
|
|
|
+ if m is None or m > d:
|
|
|
+ m = d
|
|
|
+ return m
|
|
|
+
|
|
|
+def minDistSetToSet(setA, setB):
|
|
|
+ m = None
|
|
|
+ for x in setA:
|
|
|
+ d = minDistPointToSet(x,setB)
|
|
|
+ if m is None or m > d:
|
|
|
+ m = d
|
|
|
+ return m
|
|
|
+
|
|
|
+def createSquare(pointCount, noiseSize):
|
|
|
+ noise = []
|
|
|
+ while len(noise) < pointCount:
|
|
|
+ nPointsToAdd = max(100, pointCount - len(noise))
|
|
|
+ noiseDimension = [nPointsToAdd, noiseSize]
|
|
|
+ noise.extend(list(filter(
|
|
|
+ lambda x: tf.reduce_max(tf.square(x)) < 1,
|
|
|
+ np.random.normal(0, 1, noiseDimension))))
|
|
|
+ return np.array(noise[0:pointCount])
|
|
|
+
|
|
|
+def createDisc(pointCount, noiseSize):
|
|
|
+ noise = []
|
|
|
+ while len(noise) < pointCount:
|
|
|
+ nPointsToAdd = max(100, pointCount - len(noise))
|
|
|
+ noiseDimension = [nPointsToAdd, noiseSize]
|
|
|
+ noise.extend(list(filter(
|
|
|
+ lambda x: tf.reduce_sum(tf.square(x)) < 1,
|
|
|
+ np.random.normal(0, 1, noiseDimension))))
|
|
|
+ return np.array(noise[0:pointCount])
|
|
|
+
|
|
|
+class SpheredNoise(GanBaseClass):
|
|
|
+ """
|
|
|
+ A class for a simple GAN.
|
|
|
+ """
|
|
|
+ def __init__(self, noiseSize=101):
|
|
|
+ self.isTrained = False
|
|
|
+ self.noiseSize = noiseSize
|
|
|
+ self.disc = []
|
|
|
+ self.reset()
|
|
|
+
|
|
|
+ def reset(self):
|
|
|
+ """
|
|
|
+ Resets the trained GAN to an random state.
|
|
|
+ """
|
|
|
+ self.pointDists = []
|
|
|
+ self.nextId = 0
|
|
|
+ self.numPoints = 0
|
|
|
+ self.nextDiscPoint = 0
|
|
|
+
|
|
|
+ def train(self, dataset):
|
|
|
+ majoritySet = dataset.data0
|
|
|
+ minoritySet = dataset.data1
|
|
|
+ trainDataSize = minoritySet.shape[0]
|
|
|
+ numOfFeatures = minoritySet.shape[1]
|
|
|
+
|
|
|
+ if minoritySet.shape[0] <= 0 or majoritySet.shape[0] <= 0:
|
|
|
+ raise AttributeError("Train: Expected each data class to contain at least one point.")
|
|
|
+
|
|
|
+ if numOfFeatures <= 0:
|
|
|
+ raise AttributeError("Train: Expected at least one feature.")
|
|
|
+
|
|
|
+ self.disc = createDisc(self.noiseSize, minoritySet.shape[1])
|
|
|
+ self.pointDists = [(x, minDistPointToSet(x, majoritySet)) for x in minoritySet]
|
|
|
+ self.nextId = 0
|
|
|
+ self.numPoints = len(self.pointDists)
|
|
|
+ self.isTrained = True
|
|
|
+ minD = None
|
|
|
+ maxD = None
|
|
|
+ for (x, d) in self.pointDists:
|
|
|
+ if minD is None or minD > d:
|
|
|
+ minD = d
|
|
|
+ if maxD is None or maxD < d:
|
|
|
+ maxD = d
|
|
|
+ print(f"trained {trainDataSize} points min:{minD} max:{maxD}")
|
|
|
+
|
|
|
+ def generateDataPoint(self):
|
|
|
+ (x, d) = self.pointDists[self.nextId]
|
|
|
+ self.nextId = (self.nextId + 1) % self.numPoints
|
|
|
+ disc = (0.5 * d) * self.disc
|
|
|
+ p = disc[self.nextDiscPoint]
|
|
|
+ self.nextDiscPoint = (self.nextDiscPoint + 1) % disc.shape[0]
|
|
|
+ return p
|
|
|
+
|
|
|
+
|
|
|
+ def generateData(self, numOfSamples=1):
|
|
|
+ return np.array([self.generateDataPoint() for n in range(numOfSamples)])
|