analysis.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from library.exercise import Exercise
  2. from library.dataset import DataSet, TrainTestData
  3. from library.GanExamples import StupidToyListGan
  4. from library.SimpleGan import SimpleGan
  5. from library.Repeater import Repeater
  6. from library.SpheredNoise import SpheredNoise
  7. from library.convGAN import ConvGAN
  8. import pickle
  9. import numpy as np
  10. import random
  11. from imblearn.datasets import fetch_datasets
  12. def loadDataset(datasetName):
  13. def isSame(xs, ys):
  14. for (x, y) in zip(xs, ys):
  15. if x != y:
  16. return False
  17. return True
  18. def isIn(ys):
  19. def f(x):
  20. for y in ys:
  21. if isSame(x,y):
  22. return True
  23. return False
  24. return f
  25. def isNotIn(ys):
  26. def f(x):
  27. for y in ys:
  28. if isSame(x,y):
  29. return False
  30. return True
  31. return f
  32. pickle_in = open(f"{datasetName}.pickle", "rb")
  33. pickle_dict = pickle.load(pickle_in)
  34. myData = pickle_dict["folding"]
  35. k = myData[0]
  36. labels = np.concatenate((k[1], k[3]), axis=0).astype(float)
  37. features = np.concatenate((k[0], k[2]), axis=0).astype(float)
  38. label_1 = list(np.where(labels == 1)[0])
  39. label_0 = list(np.where(labels == 0)[0])
  40. features_1 = features[label_1]
  41. features_0 = features[label_0]
  42. cut = np.array(list(filter(isIn(features_0), features_1)))
  43. if len(cut) > 0:
  44. print(f"non empty cut in {datasetName}! ({len(cut)} points)")
  45. # print(f"{len(features_0)}/{len(features_1)} point before")
  46. # features_0 = np.array(list(filter(isNotIn(cut), features_0)))
  47. # features_1 = np.array(list(filter(isNotIn(cut), features_1)))
  48. # print(f"{len(features_0)}/{len(features_1)} points after")
  49. return DataSet(data0=features_0, data1=features_1)
  50. def getRandGen(initValue, incValue=257, multValue=101, modulus=65537):
  51. value = initValue
  52. while True:
  53. value = ((multValue * value) + incValue) % modulus
  54. yield value
  55. def genShuffler():
  56. randGen = getRandGen(2021)
  57. def shuffler(data):
  58. data = list(data)
  59. size = len(data)
  60. shuffled = []
  61. while size > 0:
  62. p = next(randGen) % size
  63. size -= 1
  64. shuffled.append(data[p])
  65. data = data[0:p] + data[(p + 1):]
  66. return np.array(shuffled)
  67. return shuffler
  68. def runExerciseForSimpleGAN(datasetName):
  69. ganName = "SimpleGAN"
  70. print()
  71. print()
  72. print("///////////////////////////////////////////")
  73. print(f"// Running {ganName} on {datasetName}")
  74. print("///////////////////////////////////////////")
  75. print()
  76. data = loadDataset(f"data_input/{datasetName}")
  77. gan = SimpleGan(numOfFeatures=data.data0.shape[1])
  78. random.seed(2021)
  79. shuffler = genShuffler()
  80. exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
  81. exercise.run(gan, data)
  82. exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
  83. exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
  84. def runExerciseForRepeater(datasetName):
  85. ganName = "Repeater"
  86. print()
  87. print()
  88. print("///////////////////////////////////////////")
  89. print(f"// Running {ganName} on {datasetName}")
  90. print("///////////////////////////////////////////")
  91. print()
  92. data = loadDataset(f"data_input/{datasetName}")
  93. gan = Repeater()
  94. random.seed(2021)
  95. shuffler = genShuffler()
  96. exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
  97. exercise.run(gan, data)
  98. exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
  99. exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
  100. def runExerciseForSpheredNoise(datasetName, resultList=None):
  101. ganName = "SpheredNoise"
  102. print()
  103. print()
  104. print("///////////////////////////////////////////")
  105. print(f"// Running {ganName} on {datasetName}")
  106. print("///////////////////////////////////////////")
  107. print()
  108. data = loadDataset(f"data_input/{datasetName}")
  109. gan = SpheredNoise()
  110. random.seed(2021)
  111. shuffler = genShuffler()
  112. exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
  113. exercise.run(gan, data)
  114. avg = exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
  115. exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
  116. if resultList is not None:
  117. resultList[datasetName] = avg
  118. def runExerciseForConvGAN(datasetName, resultList=None):
  119. ganName = "convGAN"
  120. print()
  121. print()
  122. print("///////////////////////////////////////////")
  123. print(f"// Running {ganName} on {datasetName}")
  124. print("///////////////////////////////////////////")
  125. print()
  126. data = loadDataset(f"data_input/{datasetName}")
  127. gan = ConvGAN(data.data0.shape[1])
  128. random.seed(2021)
  129. shuffler = genShuffler()
  130. exercise = Exercise(shuffleFunction=shuffler, numOfShuffles=5, numOfSlices=5)
  131. exercise.run(gan, data)
  132. avg = exercise.saveResultsTo(f"data_result/{datasetName}-{ganName}.csv")
  133. exercise.saveResultsTo(f"data_result/{ganName}-{datasetName}.csv")
  134. if resultList is not None:
  135. resultList[datasetName] = avg
  136. testSets = [
  137. "folding_abalone_17_vs_7_8_9_10",
  138. "folding_abalone9-18",
  139. "folding_car_good",
  140. "folding_car-vgood",
  141. "folding_flare-F",
  142. "folding_hypothyroid",
  143. "folding_kddcup-guess_passwd_vs_satan",
  144. "folding_kr-vs-k-three_vs_eleven",
  145. "folding_kr-vs-k-zero-one_vs_draw",
  146. "folding_shuttle-2_vs_5",
  147. "folding_winequality-red-4",
  148. "folding_yeast4",
  149. "folding_yeast5",
  150. "folding_yeast6"
  151. ]
  152. def runAllTestSets(dataSetList):
  153. for dsFileName in dataSetList:
  154. runExerciseForSimpleGAN(dataSetList)
  155. runExerciseForRepeater(dataSetList)