visualize.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import seaborn as sns
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. def plotCluster(data, clusterName="cluster", xName="FDC_1", yName="FDC_2", stroke=20):
  5. colors_set = [
  6. 'lightgray', 'lightcoral', 'cornflowerblue', 'orange','mediumorchid', 'lightseagreen'
  7. , 'olive', 'chocolate', 'steelblue', 'paleturquoise', 'lightgreen'
  8. , 'burlywood', 'lightsteelblue']
  9. customPalette_set = sns.set_palette(sns.color_palette(colors_set))
  10. sns.lmplot(
  11. x=xName
  12. , y=yName
  13. , data=data
  14. , fit_reg=False
  15. , legend=True
  16. , hue=clusterName
  17. , scatter_kws={"s": stroke}
  18. , palette=customPalette_set
  19. )
  20. plt.show()
  21. def plotMapping(data, xName="UMAP_0", yName="UMAP_1"):
  22. colors_set1 = [
  23. "lightcoral", "lightseagreen", "mediumorchid", "orange", "burlywood"
  24. , "cornflowerblue", "plum", "yellowgreen"]
  25. customPalette_set1 = sns.set_palette(sns.color_palette(colors_set1))
  26. sns.lmplot(x=xName
  27. , y=yName
  28. , data=data
  29. , fit_reg=False
  30. , legend=False
  31. , scatter_kws={"s": 3}
  32. , palette=customPalette_set1)
  33. plt.show()
  34. def vizx(feature_list, cluster_df_list, main_data, umap_data, cont_features, rev_dict, xName="FDC_1", yName="FDC_2"):
  35. vizlimit = 15
  36. plt.rcParams["figure.figsize"] = (12, 6)
  37. col = sns.color_palette("Set2")
  38. rows = 3
  39. columns = 3
  40. for feature in feature_list:
  41. print('Feature name:', feature.upper())
  42. print('\n')
  43. if len(main_data[feature].value_counts()) <= vizlimit:
  44. for cluster_counter, cluster in enumerate(cluster_df_list):
  45. print('Cluster '+ str(cluster_counter + 1) + ' frequency distribution')
  46. if feature in list(rev_dict.keys()):
  47. feat_keys=rev_dict[feature]
  48. r = dict(zip(feat_keys.values(), feat_keys.keys()))
  49. print(cluster.replace({feature:r})[feature].value_counts())
  50. else:
  51. print(cluster[feature].value_counts())
  52. print('\n')
  53. print('\n')
  54. print('\n')
  55. cluster_bar = []
  56. for cluster in cluster_df_list:
  57. if feature in list(rev_dict.keys()):
  58. y = np.array(cluster.replace({feature:r})[feature].value_counts())
  59. x = np.array(cluster.replace({feature:r})[feature].value_counts().index)
  60. cluster_bar.append([x,y])
  61. else:
  62. y = np.array(cluster[feature].value_counts().sort_index())
  63. x = np.array(cluster[feature].value_counts().sort_index().index)
  64. cluster_bar.append([x,y])
  65. cluster_bar = np.array(cluster_bar)
  66. figx, ax = plt.subplots(rows, columns)
  67. figx.set_size_inches(10.5, 28.5)
  68. cluster_in_subplot_axis_dict = np.array(
  69. [[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[2,0],[1,1],[2,2]])
  70. c = 0
  71. for i in range(rows):
  72. for j in range(columns):
  73. if c >= len(cluster_df_list):
  74. break
  75. ax[i,j].bar(cluster_bar[c,0], cluster_bar[c,1], color=col)
  76. ax[i,j].tick_params(axis='x', which='major', labelsize=8, rotation=90)
  77. ax[i,j].set_title('Cluster: ' + str(c + 1))
  78. c += 1
  79. means = []
  80. sds = []
  81. cluster_labels = []
  82. for cluster_counter, cluster in enumerate(cluster_df_list):
  83. if feature in cont_features:
  84. print('Cluster '+ str(cluster_counter + 1) + ' summary statistics')
  85. print('\n')
  86. cm = cluster[feature].mean()
  87. cs = cluster[feature].std()
  88. print('feature mean:', cm)
  89. print('feature standard deviation:', cs)
  90. print('feature median:', cluster[feature].median())
  91. print('\n')
  92. means.append(cm)
  93. sds.append(cs)
  94. cluster_labels.append('C' + str(cluster_counter + 1))
  95. means = np.array(means)
  96. sds = np.array(sds)
  97. cluster_labels = np.array(cluster_labels)
  98. print('\n')
  99. print('Distribution of feature across clusters')
  100. if feature in cont_features:
  101. fig, ax7 = plt.subplots()
  102. ax7.bar(cluster_labels, means, yerr=sds, color=sns.color_palette("Set3"))
  103. ax7.tick_params(axis='both', which='major', labelsize=10)
  104. plt.xlabel(feature, fontsize=15)
  105. plt.show()
  106. print('\n')
  107. print('\n')
  108. customPalette_set = sns.set_palette(sns.color_palette(
  109. [ 'lightgray', 'lightcoral', 'cornflowerblue', 'orange', 'mediumorchid'
  110. , 'lightseagreen', 'olive', 'chocolate', 'steelblue', 'paleturquoise'
  111. , 'lightgreen', 'burlywood','lightsteelblue'
  112. ]))
  113. if feature not in cont_features:
  114. print('Feature distribution in UMAP embedding')
  115. if feature in list(rev_dict.keys()):
  116. feat_keys=rev_dict[feature]
  117. r = dict(zip(feat_keys.values(), feat_keys.keys()))
  118. umap_data[feature] = np.array(main_data.replace({feature:r})[feature])
  119. else:
  120. umap_data[feature] = np.array(main_data[feature])
  121. sns.lmplot(x=xName, y=yName,
  122. data=umap_data,
  123. fit_reg=False,
  124. legend=True,
  125. hue=feature, # color by cluster
  126. scatter_kws={"s": 20},
  127. palette=customPalette_set) # specify the point size
  128. plt.show()
  129. print('\n')
  130. print('\n')