blob: adf38f782e34069cb61ce14ae19c20fcf3a925c7 [file] [edit]
############################################################################
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
############################################################################
from pathlib import Path
import atf
import pytest
oom_started = "Allocating"
oom_finished = "Done."
sleeper_survived = "SLEEPER SURVIVED OOM_KILL_STEP"
post_oom_success = "POST OOM STEP SUCCEEDED"
step_memory_mib = 32
job_memory_mib = 128
# Large enough to exceed step_memory_mib, but still fit inside job_memory_mib.
oom_allocation_mib = 92
@pytest.fixture(scope="module", autouse=True)
def setup():
atf.require_config_parameter("SelectType", ["select/cons_tres", "select/linear"])
atf.require_config_parameter_includes(
"SelectTypeParameters", ["CR_Core_Memory", "CR_Memory"]
)
atf.require_config_parameter_includes("TaskPlugin", "cgroup")
atf.require_config_parameter("ConstrainRAMSpace", "yes", source="cgroup")
atf.require_config_parameter("AllowedRAMSpace", 100, source="cgroup")
# Make OOM deterministic
atf.require_config_parameter("ConstrainSwapSpace", "yes", source="cgroup")
atf.require_config_parameter("AllowedSwapSpace", 0, source="cgroup")
# Ensure step_memory_mib is not clipped
atf.require_config_parameter("MinRAMSpace", 30, source="cgroup")
atf.require_config_parameter("CgroupPlugin", "autodetect", source="cgroup")
atf.require_nodes(2, [("CPUs", 2), ("RealMemory", job_memory_mib)])
atf.require_slurm_running()
@pytest.fixture(scope="module")
def step_script(use_memory_program):
path = Path("step.sh").absolute()
atf.make_bash_script(
path,
f"""
if [ "$SLURM_PROCID" -eq 0 ]; then
{use_memory_program} {oom_allocation_mib} 1
else
sleep 5
echo "{sleeper_survived} $SLURM_PROCID"
fi
""",
)
return str(path)
@pytest.mark.parametrize(
"num_nodes",
[
pytest.param(1, id="single-node"),
pytest.param(2, id="multi-node"),
],
)
@pytest.mark.parametrize(
"oom_kill_step",
[
pytest.param(0, id="oom-kill-step-disabled"),
pytest.param(1, id="oom-kill-step-enabled"),
],
)
def test_oom_kill_step(num_nodes, oom_kill_step, step_script):
"""Test that sbatch's --oom-kill-step flag controls OOM step cleanup."""
ntasks_per_node = 2 if num_nodes == 1 else 1
output_file = f"oom_kill_{num_nodes}_{oom_kill_step}.out"
batch_script = f"oom_kill_{num_nodes}_{oom_kill_step}.sh"
atf.make_bash_script(
batch_script,
f"""
#SBATCH --output={output_file}
#SBATCH --nodes={num_nodes}
#SBATCH --ntasks=2
#SBATCH --ntasks-per-node={ntasks_per_node}
#SBATCH --mem={job_memory_mib}M
srun --mem={step_memory_mib}M --kill-on-bad-exit=0 {step_script}
srun --nodes=1 --ntasks=1 --ntasks-per-node=1 echo "{post_oom_success}"
""",
)
job_id = atf.submit_job_sbatch(
f"--oom-kill-step={oom_kill_step} {batch_script}", fatal=True
)
atf.wait_for_job_state(job_id, "DONE", fatal=True)
for t in atf.timer():
output = atf.run_command_output(f"cat {output_file}")
if post_oom_success in output:
break
else:
assert False, f"Output file should contain the last line: {post_oom_success}"
assert oom_started in output, "OOM script task should start"
assert (
"step tasks have been OOM Killed" in output
), "Step OOM killer should be triggered"
if oom_kill_step:
assert (
sleeper_survived not in output
), "--oom-kill-step=1 should kill all step tasks after the OOM event"
else:
assert (
sleeper_survived in output
), "--oom-kill-step=0 should NOT kill healthy step tasks after the OOM event"
assert (
oom_finished not in output
), "OOM script finished allocation instead of being killed"