from adjustText import adjust_text
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import warnings
from .plotting import BaseVisualizer
from .clustering import create_clusters
from .dim_reduction import reduce_dim
from .similarity import n_most_similar
[docs]
class Visualizer3D(BaseVisualizer):
def __init__(self, loader):
super().__init__(loader)
self.reduced = None
self.reduced_subset = None
def _setup_3d(self, reduced_emb, theme, grid, tokens, title, def_title, labels=None):
'''base private function to config plotly 3d plot'''
df = pd.DataFrame(reduced_emb, columns=['x', 'y', 'z'])
style = self.get_theme(theme)
kwargs = {}
if labels is not None:
colors, legend_labels = self.map_colors(labels, theme)
df['label'] = labels
df['name'] = [f'Cluster {label+1}' for label in labels]
color_discrete_map = {
legend_labels[label_num][1]: f'rgb({int(color_tuple[0]*255)}, {int(color_tuple[1]*255)}, {int(color_tuple[2]*255)})'
for label_num, (color_tuple, _) in legend_labels.items()
}
kwargs['color'] = 'name'
kwargs['color_discrete_map'] = color_discrete_map
else:
kwargs['color_discrete_sequence'] = [style['points']]
fig = px.scatter_3d(df, x='x', y='y', z='z',**kwargs)
fig.update_traces(
text=tokens,
hovertemplate='%{text}<extra></extra>',
hoverlabel=dict(
bgcolor=style['bg'],
font=dict(color=style['text'])),
marker=dict(size=5, opacity=0.6, line=dict(width=0))
)
fig.update_layout(
height=500,
title=title if title else def_title,
title_x=0.5,
title_xanchor='center',
scene=dict(bgcolor=style['bg'],
xaxis=dict(
backgroundcolor=style['bg'],
showticklabels=False, showgrid=grid,
gridcolor=style['grid_color'],
zeroline=False, title=None
),
yaxis=dict(
backgroundcolor=style['bg'],
showticklabels=False, showgrid=grid,
gridcolor=style['grid_color'],
zeroline=False, title=None
),
zaxis=dict(
backgroundcolor=style['bg'],
showticklabels=False, showgrid=grid,
gridcolor=style['grid_color'],
zeroline=False, title=None
)
),
paper_bgcolor=style['bg'],
font=dict(color=style['text']),
)
return fig
[docs]
def plot_static(self, red_method: str = 'auto', grid: bool = True, theme: str = 'light1', title: str = None, nlabels: int = 0, use_subset: bool = False):
'''
Creates a simple static 3D scatterplot of the embeddings.
Parameters
-----------
red_method : str, default='auto'
Dimensionality reduction method to apply ('pca', 'tsne', 'umap', etc.). If 'auto' searches for cached reduction, if None runs pca.
grid : bool, default=True
If True, displays a background grid on the plot.
theme : str, default='light1'
Color theme to apply.
title : str, optional
Title to display on the plot.
nlabels : int, default=0
Number of word labels to display. If 0, no labels are shown.
use_subset : bool, default=False
If True, uses the embedding subset instead of the full embeddings.
Returns
--------
fig : matplotlib.figure.Figure
ax : matplotlib.axes.Axes
'''
reduced_emb, tokens = self._set_embeddings(use_subset=use_subset, red_method=red_method, dims=3)
fig, ax, colors = self._setup_plot(theme, grid, title, dims=3)
ax.scatter(reduced_emb[:, 0], reduced_emb[:, 1], reduced_emb[:, 2], c=colors['points'], alpha=0.5, s=14, marker='o')
texts = []
if nlabels > 0:
sparse_indices = self.select_sparse_labels(reduced_emb, nlabels)
for i in sparse_indices:
texts.append(ax.text(reduced_emb[i, 0], reduced_emb[i, 1], reduced_emb[i, 2], tokens[i],
color=colors['text'], fontsize=7, alpha=1, ha='center', va='bottom'))
plt.rcParams['figure.dpi'] = 600
plt.show()
return fig, ax
[docs]
def plot_embeddings(self, red_method='auto', grid=True, theme='light1', title=None, use_subset=False):
'''
Creates an interactive 3D scatterplot of embeddings using Plotly.
Parameters:
-----------
red_method : str, default='auto'
Dimensionality reduction method to apply ('pca', 'tsne', 'umap', etc.). If 'auto' searches for cached reduction, if None runs pca.
grid : bool, default=True
Whether to display grid lines.
theme : str, default='light1'
Plot color theme.
title : str, optional
Title of the plot. If None, no title is shown.
use_subset : bool, default=False
If True, uses the embedding subset instead of the full embeddings.
Returns:
--------
fig : plotly.graph_objects.Figure
Notes:
------
In 3D plotting Plotly.py tends to use GPU to visualize an high number of elements and label, so it is possible that this function does not work properly with a whole embedding set.
'''
warnings.warn(
"Without a suitable GPU, full 3D visualization may be slow or unstable. "
"It is recommended to use a subset of the data for optimal performance and user experience."
)
reduced_emb, tokens = self._set_embeddings(use_subset=use_subset, red_method=red_method, dims=3)
fig = self._setup_3d(reduced_emb=reduced_emb, theme=theme, grid=grid, tokens=tokens, title=title, def_title="Word Embedding 3D Plot")
return fig
[docs]
def plot_similarity(self, target_word: str, dist: str = 'cosine', n: int = 10, red_method: str = 'pca', grid: bool = True, theme: str = 'light1', title: str = None):
'''
Creates a dynamic 3D scatterplot showing the most similar words to a target word.
Parameters
-----------
target_word : str
The word for which to find and plot the most similar words.
dist : str, default='cosine'
Distance metric to use when computing word similarity.
n : int, default=10
Number of similar words to display.
red_method : str, default='pca'
Dimensionality reduction method to apply ('pca', 'tsne', 'umap', etc.).
grid : bool, default=True
If True, displays a background grid on the plot.
theme : str, default='light1'
Color theme to apply to the plot.
title : str, optional
Title to display. If None, a default title will be generated.
Returns
--------
fig : plotly.graph_objects.Figure
'''
warnings.warn(
"The parameter names target_word will be renamed to target in a future release. "
"Please update your code accordingly.",
FutureWarning
)
similar_words, similar_vecs, _ = n_most_similar(self.loader, target_word, dist, n)
target_vec = self.loader.get_embedding(target_word)
vectors = np.vstack([target_vec.reshape(1, -1), similar_vecs])
words = [target_word] + similar_words
reduced_emb = reduce_dim(vectors, method=red_method, n_dimensions=3)
fig = self._setup_3d(reduced_emb=reduced_emb, theme=theme, grid=grid, tokens=words, title=title, def_title=f"Top {n} words similar to '{target_word}'")
style = self.get_theme(theme)
fig.add_trace(go.Scatter3d(
x=[target_vec[0]],
y=[target_vec[1]],
z=[target_vec[2]],
mode='markers',
marker=dict(
size=5,
color=style['target'],
symbol='circle'
),
text=[target_word],
hovertemplate='%{text}<extra></extra>',
showlegend=False
))
return fig
[docs]
def plot_clusters(self, n_clusters=5, method='kmeans', red_method='auto', show_centers=False, grid=True, theme='light1', title=None, nlabels=0, use_subset=False):
'''
Creates a 3D scatterplot of clustered embeddings using a clustering algorithm.
Parameters:
-----------
n_clusters : int, default=5
Number of clusters to generate.
method : str, default='kmeans'
Clustering method to use ('kmeans' or others supported by create_clusters).
red_method : str, default='auto'
Dimensionality reduction method to apply ('pca', 'tsne', 'umap', etc.). If 'auto' searches for cached reduction, if None runs pca.
show_centers : bool, default=False
If True, displays cluster centers on the plot.
grid : bool, default=True
Whether to display grid lines.
theme : str, default='light1'
Plot color theme.
title : str, optional
Title of the plot. If None, no title is shown.
nlabels : int, default=0
Number of token labels to display on the plot.
use_subset : bool, default=False
If True, uses the embedding subset instead of the full embeddings.
Returns:
--------
fig : plotly.graph_objects.Figure
Notes:
------
In 3D plotting Plotly.py tends to use GPU to visualize an high number of elements and label, so it is possible that this function does not work properly with a whole embedding set.
'''
warnings.warn(
"Without a suitable GPU, full 3D visualization may be slow or unstable. "
"It is recommended to use a subset of the data for optimal performance and user experience."
)
reduced_emb, tokens = self._set_embeddings(use_subset=use_subset, red_method=red_method, dims=3)
clusters, centers, reduced_emb = create_clusters(reduced_emb, n_clusters=n_clusters, method=method)
clusters_colors, legend_labels = self.map_colors(clusters, theme=theme)
fig = self._setup_3d(reduced_emb=reduced_emb, theme=theme, grid=grid, tokens=tokens, title=title, def_title=f"3D Clustering Scatterplot", labels=clusters)
color = 'white' if 'dark' in theme else 'black'
if show_centers and centers is not None:
fig.add_trace(go.Scatter3d(
x=centers[:, 0],
y=centers[:, 1],
z=centers[:, 2],
mode='markers',
marker=dict(
size=7,
color=color,
symbol='circle'
),
name='Centers'
))
return fig