blob: f1165ae1fe7397766c6d88d74aff85662c908772 [file] [log] [blame] [edit]
############################################################################
# Copyright (C) SchedMD LLC.
############################################################################
import atf
import pytest
import pexpect
import re
import json
suser = atf.properties["slurm-user"]
# Setup
@pytest.fixture(scope="module", autouse=True)
def setup():
atf.require_slurm_running()
def test_env_variables(printenv):
"""Verify the appropriate job environment variables are set."""
env_vars = []
env_vars_g0 = [] # variables that need to be greater than zero
env_vars.append("SLURM_LAUNCH_NODE_IPADDR")
env_vars.append("SLURM_LOCALID")
env_vars.append("SLURM_NNODES")
env_vars.append("SLURM_NODEID")
env_vars.append("SLURM_NODELIST")
env_vars.append("SLURM_PROCID")
env_vars.append("SLURM_SRUN_COMM_HOST")
env_vars.append("SLURM_STEPID")
env_vars.append("SLURM_TOPOLOGY_ADDR")
env_vars.append("SLURM_TOPOLOGY_ADDR_PATTERN")
env_vars_g0.append("SLURM_CPUS_ON_NODE")
env_vars_g0.append("SLURM_CPUS_PER_TASK")
env_vars_g0.append("SLURM_JOB_ID")
env_vars_g0.append("SLURM_NTASKS")
env_vars_g0.append("SLURM_SRUN_COMM_PORT")
env_vars_g0.append("SLURM_TASKS_PER_NODE")
env_vars_g0.append("SLURM_TASK_PID")
output = atf.run_job_output(f"-N1 -n1 --cpus-per-task=1 {printenv}", user=suser)
env_dict = json.loads(output)
missing = [k for k in env_vars if k not in env_dict]
assert (
len(missing) == 0
), f"All environment variables should be printed, but these were missing: {missing}"
missing = [k for k in env_vars_g0 if k not in env_dict or int(env_dict[k]) <= 0]
assert (
len(missing) == 0
), f"All environment variables should be printed and greater than 0, but these were not: {missing}"
def test_user_env_variables():
"""Verify that user environment variables are propagated to the job."""
file_in = atf.module_tmp_path / "env_script"
env_var = "TEST_ENV_VAR"
env_val = "123"
atf.make_bash_script(file_in, f"""env | grep {env_var}; exit 0""")
output = atf.run_job_output(f"--export={env_var}={env_val} {file_in}").strip()
assert (
output == f"{env_var}={env_val}"
), "Environment variables not propagated with --export"
child = pexpect.spawn(f"env {env_var}={env_val}")
child.sendline(f"srun {file_in}")
assert (
child.expect([f"{env_var}={env_val}", pexpect.EOF]) == 0
), "Environment variables not propagated"
child.close()
output = atf.run_job_output(
f"--export=ALL {file_in}", env_vars=f"{env_var}={env_val}"
).strip()
assert (
output == f"{env_var}={env_val}"
), "Environment variables not propagated with export=ALL"
output = atf.run_job_output(
f"--export=NONE {file_in}", env_vars=f"{env_var}={env_val}"
).strip()
assert (
output != f"{env_var}={env_val}"
), "Environment variables were propagated with export=NONE"
def test_slurm_directed_env_variables():
"""Verify that Slurm directed environment variables are processed: SLURM_DEBUG, SLURM_NNODES, SLURN_NPROCS, SLURM_OVERCOMMIT, SLURM_STDOUTMODE."""
file_in = atf.module_tmp_path / "file_in"
file_out = str(atf.module_tmp_path / "file_out.output")
sorted_file_out = atf.module_tmp_path / "sorted_file_out"
min_nodes = 1
max_nodes = 2
slurm_debug = "SLURM_DEBUG"
slurm_debug_val = "1"
slurm_nnodes = "SLURM_NNODES"
slurm_nnodes_val = str(max_nodes - min_nodes)
slurm_nprocs = "SLURM_NPROCS"
slurm_nprocs_val = "5"
slurm_stdoutmode = "SLURM_STDOUTMODE"
slurm_stdoutmode_val = file_out
slurm_overcommit = "SLURM_OVERCOMMIT"
slurm_overcommit_val = "1"
atf.make_bash_script(file_in, "env | grep SLURM_")
atf.run_job(
file_in,
env_vars=f"{slurm_debug}={slurm_debug_val} {slurm_nnodes}={slurm_nnodes_val} {slurm_nprocs}={slurm_nprocs_val} {slurm_stdoutmode}={slurm_stdoutmode_val} {slurm_overcommit}={slurm_overcommit_val}",
)
atf.wait_for_file(file_out)
atf.run_command(f"sort {file_out} > {sorted_file_out}")
with open(sorted_file_out) as f:
output = f.read()
task_count = len(re.findall(rf"{slurm_nprocs}=(\d+)", output))
stale_count = len(re.findall(r"Stale file handle", output))
task_count += stale_count
assert task_count == int(
slurm_nprocs_val
), f"Did not process {slurm_nprocs} environment variable ({task_count} != {slurm_nprocs_val})"
match = re.findall(r"SLURM_NODEID=(\d+)", output)
# add one since nodeID starts at zero
max_node_val = int(max(match)) + 1
assert (
max_node_val >= min_nodes and max_node_val <= max_nodes
), f"Did not process {slurm_nnodes} environment variable max_node_val = {max_node_val}"
assert (
re.search(rf"{slurm_debug}={slurm_debug_val}", output) is not None
), f"Did not process {slurm_debug} environment variable"
assert (
re.search(rf"{slurm_overcommit}={slurm_overcommit_val}", output) is not None
), f"Did not process {slurm_overcommit} environment variable"