cache.py 692 B

12345678910111213141516171819202122232425
  1. import os.path
  2. import json
  3. def dataCache(fileName, dataGenerator, x=None):
  4. def flatten(z):
  5. if str(type(z)) == "<class 'numpy.ndarray'>":
  6. return [flatten(x) for x in z]
  7. else:
  8. return float(z)
  9. if fileName is not None and os.path.exists(fileName):
  10. print(f"load data from previous session '{fileName}'")
  11. with open(fileName) as f:
  12. return json.load(f)
  13. else:
  14. d = dataGenerator(x)
  15. if fileName is not None:
  16. print(f"save data for '{fileName}'")
  17. with open(fileName, 'w') as f:
  18. json.dump({k: flatten(d[k]) for k in d.keys() }, f)
  19. return d