Source code for synkit.Vis.chemical_space

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

plt.rc("text", usetex=True)  # Enable LaTeX rendering
plt.rc("font", family="serif")  # Optional: use serif font


[docs] def scatter_plot( data_train, data_test, size_train=10, size_test=10, title=None, ax=None, xlabel="Coordinate 1", ylabel="Coordinate 2", ): # Check if data is empty if data_train.empty or data_test.empty: raise ValueError("Input data frames cannot be empty.") # Check for necessary columns if data_train.columns.size < 3 or data_test.columns.size < 3: raise ValueError("Data frames must have at least three columns.") # Adding 'Type' column to differentiate between train and test data data_train["Type"] = "Train" data_test["Type"] = "Test" # Combine the datasets data_combined = pd.concat([data_train, data_test]) # If no axes object is passed, create one if ax is None: fig, ax = plt.subplots(figsize=(12, 8)) # Define a more distinct color palette pastel_palette = { "Train": "deepskyblue", "Test": "magenta", } # Using deepskyblue and magenta for better distinction # Create scatter plots with specified sizes for dtype, color in pastel_palette.items(): subset = data_combined[data_combined["Type"] == dtype] ax.scatter( subset[subset.columns[1]], subset[subset.columns[2]], color=color, label=dtype, s=size_train if dtype == "Train" else size_test, alpha=0.1, edgecolor="none", ) # Set the title if provided if title: ax.set_title(rf"{title}", fontsize=24, fontweight="bold") # Set labels ax.set_xlabel(xlabel, fontsize=18) ax.set_ylabel(ylabel, fontsize=18) # Enhance grid and layout ax.grid(True, which="both", linestyle="--", linewidth=0.5) ax.set_axisbelow(True) # Get legend handles and labels for external usage handles, labels = ax.get_legend_handles_labels() # Return the axes, handles, and labels for further customization outside the function return ax, handles, labels
# Define a function that modifies the legend handles to full opacity for better visibility in the legend
[docs] def adjust_legend_handles(handles, colors): new_handles = [] for handle, color in zip(handles, colors): # Create a new handle with the same properties but with full alpha for the legend new_handle = mpatches.Patch(color=color, label=handle.get_label()) new_handles.append(new_handle) return new_handles