visualize.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import seaborn as sns
  2. import matplotlib.pyplot as plt
  3. def plotCluster(data, clusterName="cluster", xName="FDC_1", yName="FDC_2", stroke=20):
  4. colors_set = [
  5. 'lightcoral', 'cornflowerblue', 'orange','mediumorchid', 'lightseagreen'
  6. , 'olive', 'chocolate', 'steelblue', 'paleturquoise', 'lightgreen'
  7. , 'burlywood', 'lightsteelblue']
  8. customPalette_set = sns.set_palette(sns.color_palette(colors_set))
  9. sns.lmplot(
  10. x=xName
  11. , y=yName
  12. , data=data
  13. , fit_reg=False
  14. , legend=True
  15. , hue=clusterName
  16. , scatter_kws={"s": stroke}
  17. , palette=customPalette_set
  18. )
  19. plt.show()
  20. def plotMapping(data, xName="UMAP_0", yName="UMAP_1"):
  21. colors_set1 = [
  22. "lightcoral", "lightseagreen", "mediumorchid", "orange", "burlywood"
  23. , "cornflowerblue", "plum", "yellowgreen"]
  24. customPalette_set1 = sns.set_palette(sns.color_palette(colors_set1))
  25. sns.lmplot(x=xName
  26. , y=yName
  27. , data=data
  28. , fit_reg=False
  29. , legend=False
  30. , scatter_kws={"s": 3}
  31. , palette=customPalette_set1)
  32. plt.show()
  33. def vizx(feature_list, cluster_df_list, main_data, umap_data, cont_features):
  34. vizlimit = 15
  35. plt.rcParams["figure.figsize"] = (12, 6)
  36. col = sns.color_palette("Set2")
  37. rows = 3
  38. columns = 3
  39. for feature in feature_list:
  40. print('Feature name:', feature.upper())
  41. print('\n')
  42. if len(main_data[feature].value_counts()) <= vizlimit:
  43. for cluster_counter, cluster in enumerate(cluster_df_list):
  44. print('Cluster '+ str(cluster_counter + 1) + ' frequency distribution')
  45. if feature in list(rev_dict.keys()):
  46. feat_keys=rev_dict[feature]
  47. r = dict(zip(feat_keys.values(), feat_keys.keys()))
  48. print(cluster.replace({feature:r})[feature].value_counts())
  49. else:
  50. print(cluster[feature].value_counts())
  51. print('\n')
  52. print('\n')
  53. print('\n')
  54. cluster_bar = []
  55. for cluster in cluster_df_list:
  56. if feature in list(rev_dict.keys()):
  57. y = np.array(cluster.replace({feature:r})[feature].value_counts())
  58. x = np.array(cluster.replace({feature:r})[feature].value_counts().index)
  59. cluster_bar.append([x,y])
  60. else:
  61. y = np.array(cluster[feature].value_counts().sort_index())
  62. x = np.array(cluster[feature].value_counts().sort_index().index)
  63. cluster_bar.append([x,y])
  64. cluster_bar = np.array(cluster_bar)
  65. figx, ax = plt.subplots(rows, columns)
  66. figx.set_size_inches(10.5, 28.5)
  67. cluster_in_subplot_axis_dict = np.array([[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[2,0],[1,1],[2,2]])
  68. c = 0
  69. for i in range(rows):
  70. for j in range(columns):
  71. ax[i,j].bar(cluster_bar[c,0], cluster_bar[c,1], color=col)
  72. ax[i,j].tick_params(axis='x', which='major', labelsize=8, rotation=90)
  73. ax[i,j].set_title('Cluster: ' + str(c + 1))
  74. if c > len(cluster_df_list):
  75. break
  76. else:
  77. c += 1
  78. means = []
  79. sds = []
  80. cluster_labels = []
  81. for cluster_counter, cluster in enumerate(cluster_df_list):
  82. if feature in cont_features:
  83. print('Cluster '+ str(cluster_counter + 1) + ' summary statistics')
  84. print('\n')
  85. cm = cluster[feature].mean()
  86. cs = cluster[feature].std()
  87. print('feature mean:', cm)
  88. print('feature standard deviation:', cs)
  89. print('feature median:', cluster[feature].median())
  90. print('\n')
  91. means.append(cm)
  92. sds.append(cs)
  93. cluster_labels.append('C' + str(cluster_counter + 1))
  94. means = np.array(means)
  95. sds = np.array(sds)
  96. cluster_labels = np.array(cluster_labels)
  97. print('\n')
  98. print('Distribution of feature across clusters')
  99. if feature in cont_features:
  100. fig, ax7 = plt.subplots()
  101. ax7.bar(cluster_labels, means, yerr=sds, color=sns.color_palette("Set3"))
  102. ax7.tick_params(axis='both', which='major', labelsize=10)
  103. plt.xlabel(feature, fontsize=15)
  104. plt.show()
  105. print('\n')
  106. print('\n')
  107. colors_set = ['lightgray', 'lightcoral', 'cornflowerblue', 'orange', 'mediumorchid'
  108. , 'lightseagreen', 'olive', 'chocolate', 'steelblue', 'paleturquoise', 'lightgreen'
  109. , 'burlywood','lightsteelblue']
  110. customPalette_set = sns.set_palette(sns.color_palette(colors_set))
  111. if feature not in cont_features:
  112. print('Feature distribution in UMAP embedding')
  113. if feature in list(rev_dict.keys()):
  114. umap_data[feature] = np.array(main_data.replace({feature:r})[feature])
  115. else:
  116. umap_data[feature] = np.array(main_data[feature])
  117. sns.lmplot(x="FDC_1", y="FDC_2",
  118. data=umap_data,
  119. fit_reg=False,
  120. legend=True,
  121. hue=feature, # color by cluster
  122. scatter_kws={"s": 20},
  123. palette=customPalette_set) # specify the point size
  124. plt.show()
  125. print('\n')
  126. print('\n')