blob: 23286e0963d25dc56e56d7da316ecc66d43b39fe [file] [edit]
############################################################################
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
############################################################################
"""Test PMI2 graceful termination on MPI_Abort (requires MPICH).
Open MPI dropped PMI2 client support, so MPICH is required to exercise the
PMI2 server-side abort path in src/plugins/mpi/pmi2/. PMI2 has no broken-
socket _errhandler() equivalent, so only the explicit MPI_Abort() path is
covered here.
"""
import atf
import os
import pytest
@pytest.fixture(scope="module", autouse=True)
def setup():
atf.require_lmod()
atf.module_load("mpich")
atf.require_mpi("pmi2", "mpicc")
atf.require_auto_config("needs MPI/PMI2 configuration")
atf.require_config_parameter("MpiDefault", "pmi2")
atf.require_config_parameter("KillWait", "5")
atf.require_nodes(1, [("CPUs", 4)])
atf.require_slurm_running()
@pytest.fixture(scope="module")
def mpich_program():
"""Compile mpi_signal_test.c with MPICH's mpicc."""
src = atf.properties["testsuite_scripts_dir"] + "/mpi_signal_test.c"
bin_path = os.getcwd() + "/mpi_signal_test_pmi2"
atf.run_command(f"mpicc -o {bin_path} {src}", fatal=True)
return bin_path
@pytest.fixture(autouse=True)
def shm_cleanup():
"""
MPICH's only unlinks it's shared memory on a clean MPI_Finalize().
If MPI_Abort() is used, or SIGKILL/SIGTERM, segfault, OOM kill...
then the shared memory is leaked and may interfere with future
MPI_Init() calls.
So we need to remove them in general, but specially in this case.
"""
yield
atf.run_command("rm -f /dev/shm/psm* /dev/shm/mpich*", user="root", fatal=False)
@pytest.mark.xfail(
atf.get_version("sbin/slurmd") < (26, 5),
reason="Ticket 24022: PMI2 graceful termination via SIG_TERM_KILL added in 26.05",
)
@pytest.mark.parametrize("trap", [False, True])
def test_pmi2_abort_kills_failed_step(mpich_program, trap):
"""Surviving ranks are killed when an MPICH/PMI2 rank calls MPI_Abort().
- trap=False: SIGTERM (default action) terminates ranks immediately.
- trap=True: ranks catch SIGTERM and continue; the final SIGKILL after
KillWait must still kill them.
"""
tag = "trap" if trap else "notrap"
file_out = f"out_{tag}"
file_err = f"err_{tag}"
script = f"job_{tag}.sh"
args = "abort trap" if trap else "abort"
atf.make_bash_script(script, f"srun --mpi=pmi2 -n4 {mpich_program} {args}")
job_id = atf.submit_job_sbatch(
f"-N1 -n4 --output={file_out} --error={file_err} {script}",
fatal=True,
)
atf.wait_for_job_state(job_id, "DONE", fatal=True)
atf.wait_for_file(file_out, fatal=True)
stdout = atf.run_command_output(f"cat {file_out}", fatal=True)
stderr = atf.run_command_output(f"cat {file_err}", fatal=False)
# Stepd's "DUE TO TASK FAILURE" message proves SIG_TERM_KILL was used.
assert (
"DUE TO TASK FAILURE" in stderr
), f"Expected SIG_TERM_KILL termination. stderr:\n{stderr}"
if trap:
assert (
"rank_got_sigterm" in stdout
), f"Surviving ranks should have caught SIGTERM. stdout:\n{stdout}"