# mesalab/plotting/all_hrd_plotter.py
import numpy as np
import matplotlib.pyplot as plt
import math
import time
from matplotlib import colors
import os
import logging
import pandas as pd # Ensure pandas is imported as it's used for DataFrames
from mesalab.plotting.plot_config import DEFAULT_PLOT_CONFIG
from mpl_toolkits.axes_grid1 import make_axes_locatable
# Configure logging for better feedback during execution
logging.basicConfig(level=logging.WARNING, format='%(levelname)s: %(message:s)')
[docs]
def generate_all_hr_diagrams(all_history_data_flat: list, model_name: str, output_dir: str,
logT_blue_edge: list, logL_blue_edge: list,
logT_red_edge: list, logL_red_edge: list,
drop_zams: bool = False,
plot_cfg: dict = None):
"""
Generates Hertzsprung-Russell (HR) diagrams for pre-loaded MESA run data,
grouping plots by metallicity and saving each metallicity's plots
as a single image.
The pre-main sequence (pre-MS) phase can be excluded from the plots
if 'drop_zams' is True.
Plots are sorted by initial mass within each metallicity group; the subplot layout is
fixed to 4 columns.
If data is insufficient after trimming (and 'drop_zams' is True), the specific subplot
for that run is skipped, and a warning is logged.
Args:
all_history_data_flat (list): A flat list of full, untrimmed history DataFrames
for all MESA runs. Each DataFrame is expected
to have 'initial_Z' and 'initial_mass' columns.
(This is the 'flattened_full_history_data_for_plotting'
from mesa_analyzer).
model_name (str): The name of the MESA model, used for constructing
file paths and plot titles (e.g., 'nad_convos').
output_dir (str): The directory where the generated HR diagram images
will be saved.
logT_blue_edge (list): Logarithm of effective temperatures for the
blue edge of the instability strip.
logL_blue_edge (list): Logarithm of luminosities for the
blue edge of the instability strip.
logT_red_edge (list): Logarithm of effective temperatures for the
red edge of the instability strip.
logL_red_edge (list): Logarithm of luminosities for the
red edge of the instability strip.
drop_zams (bool, optional): If True, the pre-main sequence (pre-MS) phase
is trimmed from the beginning of the track using
the 'center_h1' drop criterion (or 'log_L' minimum as fallback).
Defaults to False (i.e., full track is plotted).
plot_cfg (dict, optional): A dictionary of plotting configurations. If None,
default configurations from 'plot_config.py' will be used.
Returns:
None
Example:
>>> import pandas as pd
>>> import os
>>> from mesalab.plotting import all_hrd_plotter
>>> import numpy as np
>>> # Define the output directory
>>> output_dir = 'output/plots'
>>> os.makedirs(output_dir, exist_ok=True)
>>>
>>> # Define dummy data for two different runs (M=1.0, Z=0.012 and M=1.5, Z=0.012)
>>> df1 = pd.DataFrame({
... 'initial_mass': [1.0] * 20,
... 'initial_Z': [0.012] * 20,
... 'log_Teff': np.linspace(3.7, 3.8, 20),
... 'log_L': np.linspace(1.0, 1.5, 20),
... 'model_number': np.arange(20),
... 'center_h1': np.linspace(0.7, 0.6, 20)
... })
>>> df2 = pd.DataFrame({
... 'initial_mass': [1.5] * 20,
... 'initial_Z': [0.012] * 20,
... 'log_Teff': np.linspace(3.8, 3.9, 20),
... 'log_L': np.linspace(1.5, 2.0, 20),
... 'model_number': np.arange(20),
... 'center_h1': np.linspace(0.7, 0.6, 20)
... })
>>> # The function expects a flat list of DataFrames
>>> all_data = [df1, df2]
>>>
>>> # Create dummy instability strip data
>>> logT_blue = [3.8, 3.75, 3.7]
>>> logL_blue = [1.5, 1.0, 0.5]
>>> logT_red = [3.7, 3.65, 3.6]
>>> logL_red = [1.5, 1.0, 0.5]
>>>
>>> # Call the function
>>> all_hrd_plotter.generate_all_hr_diagrams(
... all_history_data_flat=all_data,
... model_name='dummy_model',
... output_dir=output_dir,
... logT_blue_edge=logT_blue,
... logL_blue_edge=logL_blue,
... logT_red_edge=logT_red,
... logL_red_edge=logL_red,
... drop_zams=True
... )
"""
if plot_cfg is None:
from mesalab.plotting.plot_config import DEFAULT_PLOT_CONFIG
plot_cfg = DEFAULT_PLOT_CONFIG
logging.info(f"Starting HR diagram generation for model '{model_name}'.")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logging.info(f"Created output directory: {output_dir}")
# Group the flat list of DataFrames by metallicity
# This recreates the dict structure that the rest of the function expects
data_by_metallicity = {}
for df in all_history_data_flat:
if 'initial_Z' in df.columns and not df.empty:
z_value = df['initial_Z'].iloc[0]
if z_value not in data_by_metallicity:
data_by_metallicity[z_value] = []
data_by_metallicity[z_value].append(df)
else:
logging.warning("Skipping a DataFrame in HRD plotting due to missing 'initial_Z' or being empty.")
# The data is now grouped by metallicity, so we just sort the keys
sorted_metallicities = sorted(data_by_metallicity.keys())
for z_value in sorted_metallicities:
# current_z_dfs is a list of DataFrames for the current Z value
current_z_dfs = data_by_metallicity[z_value]
# Sort the DataFrames by initial_mass for consistent plotting order
# Assuming 'initial_mass' column is present and valid for sorting
try:
current_z_dfs.sort(key=lambda df: df['initial_mass'].iloc[0])
logging.info(f"Sorted {len(current_z_dfs)} runs by mass for Z={z_value:.4f}.")
except KeyError:
logging.warning(f" 'initial_mass' column not found for Z={z_value:.4f}. Skipping mass sort.")
except IndexError:
logging.warning(f"DataFrame for Z={z_value:.4f} is empty or 'initial_mass' column has no data. Skipping mass sort.")
logging.info(f'➡ Processing HR diagrams for Z={z_value:.4f} with {len(current_z_dfs)} masses...')
num_plots = len(current_z_dfs)
# --- Fixed Column Layout: Always 4 Columns ---
if num_plots == 0:
logging.warning(f"No runs found for Z={z_value:.4f}. Skipping HR diagram generation for this metallicity.")
continue
# cols = 4 # Fixed to 4 columns as requested
cols = min(num_plots, plot_cfg["all_hrd"]["max_cols"]) # Use max_cols from config, but don't exceed num_plots
rows = math.ceil(num_plots / cols)
# --- Figure Size with 1:2 (Height:Width) Aspect Ratio ---
# Define base dimensions for a single subplot to achieve 1:2 ratio
fig, axes = plt.subplots(rows, cols,
figsize=plot_cfg["all_hrd"]["figsize"],
facecolor=plot_cfg["all_hrd"]["facecolor"]
)
# Flatten axes array for easy iteration, even if it's a single subplot (rows=1, cols=1)
if num_plots == 1:
axes = np.array([axes]) # Ensure it's an array for consistent indexing
else:
axes = axes.flatten()
sc = None # Default scatter plot reference for colorbar
# Iterate directly over the DataFrames in current_z_dfs
for i, df_full_history in enumerate(current_z_dfs):
ax = axes[i] # Get the current subplot axis
# Extract mass and Z from the DataFrame itself
# Assuming 'initial_mass' and 'initial_Z' columns are present (added in mesa_analyzer)
mass = df_full_history['initial_mass'].iloc[0] # Mass for title
# --- Pre-MS Phase Trimming Logic (ZAMS detection) ---
# Now conditional based on 'drop_zams' parameter
df_post_prems = None # Initialize to None
# Check for missing required columns or empty DataFrame
if 'log_L' not in df_full_history.columns or df_full_history.empty:
logging.warning(f"Missing 'log_L' or empty DataFrame for M={mass:.1f} (Z={z_value:.4f}). Skipping plot.")
ax.set_visible(False)
continue # Skip to the next run_info
if drop_zams: # Csak akkor fut le a ZAMS levágás, ha drop_zams True
if 'center_h1' in df_full_history.columns:
try:
initial_h1_val = df_full_history['center_h1'].iloc[0]
H1_DROP_THRESHOLD = 1e-4 # Adjustable threshold for H1 drop
zams_candidates = df_full_history[df_full_history['center_h1'] < (initial_h1_val - H1_DROP_THRESHOLD)]
if not zams_candidates.empty:
zams_idx = zams_candidates.index[0]
if zams_idx < len(df_full_history) - 1:
df_post_prems = df_full_history.iloc[zams_idx:].copy()
logging.info(f"Trimmed pre-MS using 'center_h1' drop criterion (threshold={H1_DROP_THRESHOLD}) for M={mass:.1f} (Z={z_value:.4f}).")
else:
logging.warning(f" 'center_h1' drop index is too close to end of data for M={mass:.1f} (Z={z_value:.4f}). Falling back to log_L minimum trimming.")
min_log_L_idx = df_full_history['log_L'].idxmin()
df_post_prems = df_full_history.iloc[min_log_L_idx:].copy()
else:
logging.warning(f" 'center_h1' did not drop below threshold (>{H1_DROP_THRESHOLD}) for M={mass:.1f} (Z={z_value:.4f}). Falling back to log_L minimum trimming.")
min_log_L_idx = df_full_history['log_L'].idxmin()
df_post_prems = df_full_history.iloc[min_log_L_idx:].copy()
except Exception as e:
logging.warning(f"Error during 'center_h1' trimming for M={mass:.1f} (Z={z_value:.4f}): {e}. Falling back to log_L minimum trimming.")
min_log_L_idx = df_full_history['log_L'].idxmin()
df_post_prems = df_full_history.iloc[min_log_L_idx:].copy()
else:
logging.warning(f" 'center_h1' not found for M={mass:.1f} (Z={z_value:.4f}). Falling back to log_L minimum trimming.")
min_log_L_idx = df_full_history['log_L'].idxmin()
df_post_prems = df_full_history.iloc[min_log_L_idx:].copy()
else:
# If drop_zams is False, use the full history data
df_post_prems = df_full_history.copy()
logging.info(f"Pre-MS trimming skipped for M={mass:.1f} (Z={z_value:.4f}) as 'drop_zams' is False.")
# --- Common plotting logic after trimming (or not trimming) ---
log_Teff = df_post_prems['log_Teff'].values
log_L = df_post_prems['log_L'].values
model_number = np.array(df_post_prems['model_number'], dtype=float)
if len(log_Teff) < 2: # Check for insufficient data even AFTER (potential) trimming
logging.warning(f"Not enough data points after (potential) trimming for M={mass:.1f} (Z={z_value:.4f}) to plot HR diagram. Skipping plot.")
ax.set_visible(False)
continue # Skip to the next run_info
# If we reach here, data is sufficient for plotting
ax.set_title(f'{mass:.1f} M$_\odot$', fontsize=plot_cfg["all_hrd"]["title_size"])
norm = colors.Normalize(vmin=np.min(model_number), vmax=np.max(model_number))
# Plot the evolutionary track
sc = ax.scatter(log_Teff, log_L, c=model_number, cmap=plot_cfg["scatter"]["cmap"], norm=norm, s=plot_cfg["scatter"]["dot_size"], alpha=plot_cfg["scatter"]["alpha"], edgecolors='none', zorder=2)
# Plot instability strip edges
ax.plot(logT_blue_edge, logL_blue_edge, color='blue', linestyle='dashed', linewidth=1.5, zorder=1, label='Blue Edge')
ax.plot(logT_red_edge, logL_red_edge, color='red', linestyle='dashed', linewidth=1.5, zorder=1, label='Red Edge')
ax.invert_xaxis() # Standard HR diagram convention
# Hide any unused subplots (those beyond the last processed index 'i')
if num_plots > 0:
start_idx_to_hide = i + 1
else:
start_idx_to_hide = 0
for j in range(start_idx_to_hide, len(axes)):
fig.delaxes(axes[j])
# Set common Y-axis label only for the leftmost column
for idx, ax_item in enumerate(axes):
if ax_item.get_visible():
if idx % cols == 0:
ax_item.set_ylabel(r'$\log L/L_\odot$', fontsize=plot_cfg["all_hrd"]["label_size"])
else:
ax_item.set_yticklabels([])
# Set common X-axis label only for the bottom row
for idx, ax_item in enumerate(axes):
if ax_item.get_visible():
if idx >= (rows - 1) * cols:
ax_item.set_xlabel(r'$\log (T_{\rm eff}/K)$', fontsize=plot_cfg["all_hrd"]["label_size"])
else:
ax_item.set_xticklabels([])
fig.suptitle(f'Hertzsprung-Russell Diagram (Z = {z_value:.4f})', fontsize=plot_cfg["all_hrd"]["title_size"], y=1.02)
if sc is not None:
valid_axes = [ax for ax in axes if ax.get_visible()]
last_ax = valid_axes[-1]
divider = make_axes_locatable(last_ax)
cax = divider.append_axes(
"right",
size=plot_cfg["colorbar"]["size"],
pad=plot_cfg["colorbar"]["padding"]
)
cbar = fig.colorbar(sc, cax=cax)
cbar.set_label(
"Model Number (evolutionary stage)",
fontsize=plot_cfg["colorbar"]["label_size"]
)
fig.tight_layout()
filename = os.path.join(output_dir, f'HR_diagram_{model_name}_z{z_value:.4f}.png')
plt.savefig(filename, dpi=plot_cfg["all_hrd"]["dpi"], bbox_inches='tight')
plt.close(fig)
logging.info(f"✔ Saved HR diagram: {filename}")
logging.info("✔ All HR diagram generation complete.")