Adding a 3D plot to the k-means clustering algorithm (#12372)

* Adding a 3D plot to the k-means clustering algorithm

* Update k_means_clust.py

* Update k_means_clust.py

---------

Co-authored-by: Maxim Smolskiy <mithridatus@mail.ru>
This commit is contained in:
lorenzo30salgado
2025-08-30 22:58:54 +02:00
committed by GitHub
parent 501576f90e
commit e3a263c1ed

View File

@@ -37,7 +37,13 @@ Usage:
heterogeneity,
k
)
5. Transfers Dataframe into excel format it must have feature called
5. Plot the labeled 3D data points with centroids.
plot_kmeans(
X,
centroids,
cluster_assignment
)
6. Transfers Dataframe into excel format it must have feature called
'Clust' with k means clustering numbers in it.
"""
@@ -126,6 +132,19 @@ def plot_heterogeneity(heterogeneity, k):
plt.show()
def plot_kmeans(data, centroids, cluster_assignment):
ax = plt.axes(projection="3d")
ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=cluster_assignment, cmap="viridis")
ax.scatter(
centroids[:, 0], centroids[:, 1], centroids[:, 2], c="red", s=100, marker="x"
)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.set_title("3D K-Means Clustering Visualization")
plt.show()
def kmeans(
data, k, initial_centroids, maxiter=500, record_heterogeneity=None, verbose=False
):
@@ -193,6 +212,7 @@ if False: # change to true to run this test case.
verbose=True,
)
plot_heterogeneity(heterogeneity, k)
plot_kmeans(dataset["data"], centroids, cluster_assignment)
def report_generator(