Files
mt/tools/anomaly_scatter_plot.py
2025-03-14 18:02:23 +01:00

94 lines
3.0 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
# Fix the random seed for reproducibility
np.random.seed(0)
# 1. Generate NORMAL DATA (e.g. points roughly around the origin)
# We'll keep them relatively close together so that a circle can enclose them easily.
normal_data = np.random.randn(50, 2) * 0.75
# 2. Generate ANOMALOUS DATA
# - Cluster 1: 3 points close together
anomaly_cluster_1 = np.array([[3.0, 3.0], [3.2, 3.1], [2.8, 2.9], [0.4, 4.0]])
# - Cluster 2: A single point
# 3. Compute the center and radius for a boundary circle around normal data
center = normal_data.mean(axis=0)
distances = np.linalg.norm(normal_data - center, axis=1)
radius = (
np.max(distances) + 0.2
) # Add a small margin to ensure all normal points are inside
# Create coordinates for plotting the circular boundary
theta = np.linspace(0, 2 * np.pi, 200)
circle_x = center[0] + radius * np.cos(theta)
circle_y = center[1] + radius * np.sin(theta)
# 4. Plot the data
plt.figure(figsize=(7, 7))
# Scatter normal points with 'o'
plt.scatter(
normal_data[:, 0], normal_data[:, 1], marker="o", color="blue", label="Normal Data"
)
# Scatter anomalous points with 'x', but separate them by cluster for clarity
plt.scatter(
anomaly_cluster_1[:, 0],
anomaly_cluster_1[:, 1],
marker="x",
color="red",
label="Anomalies",
)
# Plot the boundary (circle) around the normal data
plt.plot(circle_x, circle_y, linestyle="--", color="black", label="Boundary")
# 5. Annotate/label the clusters
# Label the normal cluster near its center
# plt.text(center[0], center[1],
# 'Normal Cluster',
# horizontalalignment='center',
# verticalalignment='center',
# fontsize=9,
# bbox=dict(facecolor='white', alpha=0.7))
#
# # Label anomaly cluster 1 near its centroid
# ac1_center = anomaly_cluster_1.mean(axis=0)
# plt.text(ac1_center[0], ac1_center[1],
# 'Anomaly Cluster 1',
# horizontalalignment='center',
# verticalalignment='center',
# fontsize=9,
# bbox=dict(facecolor='white', alpha=0.7))
#
# # Label anomaly cluster 2
# ac2_point = anomaly_cluster_2[0]
# plt.text(ac2_point[0]+0.2, ac2_point[1],
# 'Anomaly Cluster 2',
# horizontalalignment='left',
# verticalalignment='center',
# fontsize=9,
# bbox=dict(facecolor='white', alpha=0.7))
# Add legend and make plot look nice
plt.legend(loc="upper left")
# plt.title('2D Scatter Plot Showing Normal and Anomalous Clusters')
plt.xlabel("x")
plt.ylabel("y")
plt.tick_params(
axis="both", # changes apply to the x-axis
which="both", # both major and minor ticks are affected
bottom=False, # ticks along the bottom edge are off
top=False, # ticks along the top edge are off
left=False, # ticks along the top edge are off
right=False, # ticks along the top edge are off
labelbottom=False,
labelleft=False,
) #
# plt.grid(True)
plt.axis("equal") # Makes circles look circular rather than elliptical
plt.savefig("scatter_plot.png")