visualize.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import seaborn as sns
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. def plotCluster(data, clusterName="cluster", xName="FDC_1", yName="FDC_2", stroke=20):
  5. colors_set = [
  6. 'lightgray', 'lightcoral', 'cornflowerblue', 'orange','mediumorchid', 'lightseagreen'
  7. , 'olive', 'chocolate', 'steelblue', 'paleturquoise', 'lightgreen'
  8. , 'burlywood', 'lightsteelblue']
  9. customPalette_set = sns.set_palette(sns.color_palette(colors_set))
  10. sns.lmplot(
  11. x=xName
  12. , y=yName
  13. , data=data
  14. , fit_reg=False
  15. , legend=True
  16. , hue=clusterName
  17. , scatter_kws={"s": stroke}
  18. , palette=customPalette_set
  19. )
  20. plt.show()
  21. def plotMapping(data, xName="UMAP_0", yName="UMAP_1"):
  22. colors_set1 = [
  23. "lightcoral", "lightseagreen", "mediumorchid", "orange", "burlywood"
  24. , "cornflowerblue", "plum", "yellowgreen"]
  25. customPalette_set1 = sns.set_palette(sns.color_palette(colors_set1))
  26. sns.lmplot(x=xName
  27. , y=yName
  28. , data=data
  29. , fit_reg=False
  30. , legend=False
  31. , scatter_kws={"s": 3}
  32. , palette=customPalette_set1)
  33. plt.show()
  34. def vizx(feature_list, cluster_df_list, main_data, umap_data, cont_features, rev_dict, xName="FDC_1", yName="FDC_2"):
  35. vizlimit = 15
  36. plt.rcParams["figure.figsize"] = (12, 6)
  37. col = sns.color_palette("Set2")
  38. rows = 3
  39. columns = 3
  40. for feature in feature_list:
  41. print('Feature name:', feature.upper())
  42. print('\n')
  43. if len(main_data[feature].value_counts()) <= vizlimit:
  44. for cluster_counter, cluster in enumerate(cluster_df_list):
  45. print('Cluster '+ str(cluster_counter + 1) + ' frequency distribution')
  46. if feature in rev_dict:
  47. r = rev_dict[feature]
  48. print(cluster.replace({feature:r})[feature].value_counts())
  49. else:
  50. print(cluster[feature].value_counts())
  51. print('\n')
  52. print('\n')
  53. print('\n')
  54. cluster_bar = []
  55. for cluster in cluster_df_list:
  56. if feature in rev_dict:
  57. y = np.array(cluster.replace({feature:r})[feature].value_counts())
  58. x = np.array(cluster.replace({feature:r})[feature].value_counts().index)
  59. cluster_bar.append([x,y])
  60. else:
  61. y = np.array(cluster[feature].value_counts().sort_index())
  62. x = np.array(cluster[feature].value_counts().sort_index().index)
  63. cluster_bar.append([x,y])
  64. cluster_bar = np.array(cluster_bar)
  65. figx, ax = plt.subplots(rows, columns)
  66. figx.set_size_inches(10.5, 28.5)
  67. cluster_in_subplot_axis_dict = np.array(
  68. [[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[2,0],[1,1],[2,2]])
  69. c = 0
  70. for i in range(rows):
  71. for j in range(columns):
  72. if c >= len(cluster_df_list):
  73. break
  74. ax[i,j].bar(cluster_bar[c,0], cluster_bar[c,1], color=col)
  75. ax[i,j].tick_params(axis='x', which='major', labelsize=8, rotation=90)
  76. ax[i,j].set_title('Cluster: ' + str(c + 1))
  77. c += 1
  78. means = []
  79. sds = []
  80. cluster_labels = []
  81. for cluster_counter, cluster in enumerate(cluster_df_list):
  82. if feature in cont_features:
  83. print('Cluster '+ str(cluster_counter + 1) + ' summary statistics')
  84. print('\n')
  85. cm = cluster[feature].mean()
  86. cs = cluster[feature].std()
  87. print('feature mean:', cm)
  88. print('feature standard deviation:', cs)
  89. print('feature median:', cluster[feature].median())
  90. print('\n')
  91. means.append(cm)
  92. sds.append(cs)
  93. cluster_labels.append('C' + str(cluster_counter + 1))
  94. means = np.array(means)
  95. sds = np.array(sds)
  96. cluster_labels = np.array(cluster_labels)
  97. print('\n')
  98. print('Distribution of feature across clusters')
  99. if feature in cont_features:
  100. fig, ax7 = plt.subplots()
  101. ax7.bar(cluster_labels, means, yerr=sds, color=sns.color_palette("Set3"))
  102. ax7.tick_params(axis='both', which='major', labelsize=10)
  103. plt.xlabel(feature, fontsize=15)
  104. plt.show()
  105. print('\n')
  106. print('\n')
  107. customPalette_set = sns.set_palette(sns.color_palette(
  108. [ 'lightgray', 'lightcoral', 'cornflowerblue', 'orange', 'mediumorchid'
  109. , 'lightseagreen', 'olive', 'chocolate', 'steelblue', 'paleturquoise'
  110. , 'lightgreen', 'burlywood','lightsteelblue'
  111. ]))
  112. if feature not in cont_features:
  113. print('Feature distribution in UMAP embedding')
  114. if feature in rev_dict:
  115. r = rev_dict[feature]
  116. umap_data[feature] = np.array(main_data.replace({feature:r})[feature])
  117. else:
  118. umap_data[feature] = np.array(main_data[feature])
  119. sns.lmplot(x=xName, y=yName,
  120. data=umap_data,
  121. fit_reg=False,
  122. legend=True,
  123. hue=feature, # color by cluster
  124. scatter_kws={"s": 20},
  125. palette=customPalette_set) # specify the point size
  126. plt.show()
  127. print('\n')
  128. print('\n')