SpheredNoise.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. """
  2. This module contains some example Generative Adversarial Networks for testing.
  3. The classes StupidToyPointGan and StupidToyListGan are not really Networks. This classes are used
  4. for testing the interface. Hope your actually GAN will perform better than this two.
  5. The class SimpleGan is a simple standard Generative Adversarial Network.
  6. """
  7. import numpy as np
  8. import math
  9. from library.interfaces import GanBaseClass
  10. def square(x):
  11. return (x*x)
  12. def fold0(f, xs):
  13. if xs == []:
  14. return None
  15. s = xs[0]
  16. for x in xs[1:]:
  17. s = f(s, x)
  18. return s
  19. def fold1(f, s0, xs):
  20. if xs == []:
  21. return None
  22. s = s0
  23. for x in xs:
  24. s = f(s, x)
  25. return s
  26. def dist(x, y):
  27. return math.sqrt(fold1(lambda s, a: s + square(a[0] - a[1]), 0, list(zip(x, y))))
  28. def minDistPointToSet(x, setB):
  29. return fold0(lambda m,y: min(m,y), [dist(x,y) for y in setB])
  30. def minDistSetToSet(setA, setB):
  31. return fold0(lambda m,x: min(m, minDistPointToSet(x,setB)), setA)
  32. def normInf(xs):
  33. return fold0(lambda m, x: max(m, abs(x)), xs)
  34. def norm2Sq(xs):
  35. return fold0(lambda s, x: s + (x*x), xs)
  36. def norm2(xs):
  37. return math.sqrt(norm2Sq(xs))
  38. def minmax(xs):
  39. if xs == []:
  40. return None
  41. (mi, mx) = (xs[0][1], xs[0][1])
  42. for x in xs[1:]:
  43. mi = min(mi, x[1])
  44. mx = max(mx, x[1])
  45. return (mi, mx)
  46. def createSquare(pointCount, noiseSize):
  47. noise = [
  48. [np.random.uniform(-1.0, 1.0) for n in range(noiseSize)]
  49. for m in range(pointCount)
  50. ]
  51. return np.array(noise)
  52. def createDisc(pointCount, noiseSize):
  53. noise = []
  54. for n in range(pointCount):
  55. p = [np.random.uniform(-1.0, 1.0)]
  56. for m in range(noiseSize - 1):
  57. d = norm2Sq(p)
  58. d = math.sqrt(1.0 - d)
  59. p.append(np.random.uniform(0.0 - d, d))
  60. noise.append(p)
  61. return np.array(noise)
  62. class SpheredNoise(GanBaseClass):
  63. """
  64. A class for a simple GAN.
  65. """
  66. def __init__(self, noiseSize=101):
  67. self.isTrained = False
  68. self.noiseSize = noiseSize
  69. self.disc = []
  70. self.reset()
  71. def reset(self):
  72. """
  73. Resets the trained GAN to an random state.
  74. """
  75. self.pointDists = []
  76. self.nextId = 0
  77. self.numPoints = 0
  78. self.nextDiscPoint = 0
  79. self.minDist = 0.0
  80. def train(self, dataset):
  81. majoritySet = dataset.data0
  82. minoritySet = dataset.data1
  83. trainDataSize = minoritySet.shape[0]
  84. numOfFeatures = minoritySet.shape[1]
  85. print(f"Train {majoritySet.shape[0]}/{trainDataSize} points")
  86. if minoritySet.shape[0] <= 0 or majoritySet.shape[0] <= 0:
  87. raise AttributeError("Train: Expected each data class to contain at least one point.")
  88. if numOfFeatures <= 0:
  89. raise AttributeError("Train: Expected at least one feature.")
  90. print("-> new disc")
  91. self.disc = createDisc(self.noiseSize, minoritySet.shape[1])
  92. print("-> calc distances")
  93. self.pointDists = list(filter(lambda x: x[1] > 0.0, [(x, minDistPointToSet(x, majoritySet)) for x in minoritySet]))
  94. print("-> statistics")
  95. self.nextId = 0
  96. self.numPoints = len(self.pointDists)
  97. self.isTrained = True
  98. (minD, maxD) = minmax(self.pointDists)
  99. self.minDist = minD
  100. print(f"trained {trainDataSize} points min:{minD} max:{maxD}")
  101. def generateDataPoint(self):
  102. (x, d) = self.pointDists[self.nextId]
  103. self.nextId = (self.nextId + 1) % self.numPoints
  104. disc = (0.5 * self.minDist) * self.disc
  105. p = disc[self.nextDiscPoint]
  106. self.nextDiscPoint = (self.nextDiscPoint + 1) % disc.shape[0]
  107. return p
  108. def generateData(self, numOfSamples=1):
  109. return np.array([self.generateDataPoint() for n in range(numOfSamples)])