SpheredNoise.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. import tensorflow as tf
  9. from library.interfaces import GanBaseClass
  10. def dist(x, y):
  11. return tf.sqrt(tf.reduce_sum(tf.square(x - y)))
  12. def minDistPointToSet(x, setB):
  13. m = None
  14. for y in setB:
  15. d = dist(x,y)
  16. if m is None or m > d:
  17. m = d
  18. return m
  19. def minDistSetToSet(setA, setB):
  20. m = None
  21. for x in setA:
  22. d = minDistPointToSet(x,setB)
  23. if m is None or m > d:
  24. m = d
  25. return m
  26. def createSquare(pointCount, noiseSize):
  27. noise = []
  28. while len(noise) < pointCount:
  29. nPointsToAdd = max(100, pointCount - len(noise))
  30. noiseDimension = [nPointsToAdd, noiseSize]
  31. noise.extend(list(filter(
  32. lambda x: tf.reduce_max(tf.square(x)) < 1,
  33. np.random.normal(0, 1, noiseDimension))))
  34. return np.array(noise[0:pointCount])
  35. def createDisc(pointCount, noiseSize):
  36. noise = []
  37. while len(noise) < pointCount:
  38. nPointsToAdd = max(100, pointCount - len(noise))
  39. noiseDimension = [nPointsToAdd, noiseSize]
  40. noise.extend(list(filter(
  41. lambda x: tf.reduce_sum(tf.square(x)) < 1,
  42. np.random.normal(0, 1, noiseDimension))))
  43. return np.array(noise[0:pointCount])
  44. class SpheredNoise(GanBaseClass):
  45. """
  46. A class for a simple GAN.
  47. """
  48. def __init__(self, noiseSize=101):
  49. self.isTrained = False
  50. self.noiseSize = noiseSize
  51. self.disc = []
  52. self.reset()
  53. def reset(self):
  54. """
  55. Resets the trained GAN to an random state.
  56. """
  57. self.pointDists = []
  58. self.nextId = 0
  59. self.numPoints = 0
  60. self.nextDiscPoint = 0
  61. def train(self, dataset):
  62. majoritySet = dataset.data0
  63. minoritySet = dataset.data1
  64. trainDataSize = minoritySet.shape[0]
  65. numOfFeatures = minoritySet.shape[1]
  66. if minoritySet.shape[0] <= 0 or majoritySet.shape[0] <= 0:
  67. raise AttributeError("Train: Expected each data class to contain at least one point.")
  68. if numOfFeatures <= 0:
  69. raise AttributeError("Train: Expected at least one feature.")
  70. self.disc = createDisc(self.noiseSize, minoritySet.shape[1])
  71. self.pointDists = [(x, minDistPointToSet(x, majoritySet)) for x in minoritySet]
  72. self.nextId = 0
  73. self.numPoints = len(self.pointDists)
  74. self.isTrained = True
  75. minD = None
  76. maxD = None
  77. for (x, d) in self.pointDists:
  78. if minD is None or minD > d:
  79. minD = d
  80. if maxD is None or maxD < d:
  81. maxD = d
  82. print(f"trained {trainDataSize} points min:{minD} max:{maxD}")
  83. def generateDataPoint(self):
  84. (x, d) = self.pointDists[self.nextId]
  85. self.nextId = (self.nextId + 1) % self.numPoints
  86. disc = (0.5 * d) * self.disc
  87. p = disc[self.nextDiscPoint]
  88. self.nextDiscPoint = (self.nextDiscPoint + 1) % disc.shape[0]
  89. return p
  90. def generateData(self, numOfSamples=1):
  91. return np.array([self.generateDataPoint() for n in range(numOfSamples)])