blob: 9791289717990b807238641faa43fa70cff9b2fe [file] [edit]
############################################################################
# Copyright (C) SchedMD LLC.
############################################################################
from pathlib import Path
import atf
import pytest
memory_tolerance = 0.05
low_mem_mib = 64
high_mem_mib = 192
job_mem_mib = 320
@pytest.fixture(scope="module", autouse=True)
def setup():
atf.require_config_parameter("CgroupPlugin", "autodetect", source="cgroup")
atf.require_config_parameter("JobAcctGatherType", "jobacct_gather/cgroup")
atf.require_config_parameter_excludes("JobAcctGatherParams", "no_file_cache")
atf.require_config_parameter_excludes("AccountingStorageEnforce", "nojobs")
atf.require_config_parameter_excludes("AccountingStorageEnforce", "nosteps")
atf.require_accounting()
# Assumes task cgroups expose memory.peak (cgroup/v2) or
# memory.max_usage_in_bytes (cgroup/v1).
# Need 2 nodes to have a meaningful max memory node.
atf.require_nodes(2, [("RealMemory", job_mem_mib)])
# Force termination-only sampling, so short peaks come from cgroup peak data.
atf.require_config_parameter("JobAcctGatherFrequency", 0)
atf.require_slurm_running()
@pytest.fixture(scope="module")
def task_script(use_memory_program):
"""Per-task script that allocates a short memory peak and writes the record
`<task_id> <node> <requested_bytes> <peak_kb> <final_kb>`.
"""
path = Path("task.sh").absolute()
atf.make_bash_script(
path,
f"""
output_dir="$1"
# Rank 1 holds the high peak so MaxRSSTask attribution is non-zero.
# A single task uses the high peak too for tolerance headroom.
if [[ $SLURM_NTASKS -eq 1 || $SLURM_PROCID -eq 1 ]]; then
mib={high_mem_mib}
else
mib={low_mem_mib}
fi
# Use exec so that use_memory_program is the only process in the cgroup
exec {use_memory_program} "$mib" 0 --time-mem-valley 5 \\
--prefix "$SLURM_PROCID $SLURMD_NODENAME $((mib * 1024 * 1024))" \\
> "$output_dir/mem.$SLURM_PROCID.out"
""",
)
return str(path)
def parse_output(output_files):
records = []
for output_file in output_files:
record = None
for t in atf.timer():
content = atf.run_command_output(f"cat {output_file}", quiet=True)
try:
task_id, node, requested_bytes, vmhwm_kb, final_vmrss_kb = (
content.split()
)
record = {
"task_id": task_id,
"node": node,
"requested_bytes": int(requested_bytes),
"vmhwm_bytes": int(vmhwm_kb) * 1024,
"final_vmrss_bytes": int(final_vmrss_kb) * 1024,
}
break
except ValueError:
continue
else:
pytest.fail(f"Unable to read {output_file}")
records.append(record)
return records
def parse_mem(fields):
"""Get memory in bytes from sacct output"""
for field in fields.split(","):
key, _, value = field.partition("=")
if key == "mem" and value.isdigit():
return int(value)
return None
def try_get_mem_obs(step_id):
sacct_fields = {
"tot": "TRESUsageInTot",
"ave": "TRESUsageInAve",
"max": "TRESUsageInMax",
"ave_rss": "AveRSS",
"max_rss": "MaxRSS",
"max_node": "MaxRSSNode",
"max_task": "MaxRSSTask",
}
output = atf.run_command_output(
f"sacct -j {step_id} -nP --noconvert -o {','.join(sacct_fields.values())}",
fatal=True,
)
lines = [line for line in output.splitlines() if line.strip()]
if len(lines) != 1:
return None
values = lines[0].split("|")
if len(values) != len(sacct_fields):
return None
values = dict(zip(sacct_fields, values))
try:
for field in ["tot", "ave", "max"]:
values[field] = parse_mem(values[field])
for field in ["ave_rss", "max_rss"]:
values[field] = int(values[field])
for field in ["max_node", "max_task"]:
values[field] = values[field].strip()
except ValueError:
return None
if any(values[field] is None for field in ["tot", "ave", "max"]):
return None
return values
def get_mem_obs(step_id):
for t in atf.timer():
mem_obs = try_get_mem_obs(step_id)
if mem_obs is not None:
return mem_obs
pytest.fail("Unable to parse sacct output")
def assert_mem_released(records):
for record in records:
mem_released = record["vmhwm_bytes"] - record["final_vmrss_bytes"]
assert (
mem_released > record["requested_bytes"] * 0.5
), f"RSS should go down after the peak memory was released, but got: {records}"
def max_peak_record(records):
return max(records, key=lambda record: record["vmhwm_bytes"])
def assert_peak_not_final(mem_obs, records):
"""Assert accounting captured the short RSS peak, not post-release RSS."""
peak_record = max_peak_record(records)
final_tot = sum(record["final_vmrss_bytes"] for record in records)
final_ave = final_tot / len(records)
requested_tot = sum(record["requested_bytes"] for record in records)
requested_ave = requested_tot / len(records)
expected = {
"tot": ("TRESUsageInTot", final_tot, requested_tot * 0.5),
"max": (
"TRESUsageInMax",
peak_record["final_vmrss_bytes"],
peak_record["requested_bytes"] * 0.5,
),
"max_rss": (
"MaxRSS",
peak_record["final_vmrss_bytes"],
peak_record["requested_bytes"] * 0.5,
),
"ave": ("TRESUsageInAve", final_ave, requested_ave * 0.5),
"ave_rss": ("AveRSS", final_ave, requested_ave * 0.5),
}
for field, (label, final_value, min_delta) in expected.items():
accounted_delta = mem_obs[field] - final_value
assert (
accounted_delta > min_delta
), f"{label} should be a peak, but got like final RSS"
def assert_max_ids(mem_obs, records):
peak_record = max_peak_record(records)
node_map = {record["node"]: record["task_id"] for record in records}
node = mem_obs["max_node"]
task = mem_obs["max_task"]
assert node in node_map, f"Unknown MaxRSSNode {node}"
assert (
node_map[node] == peak_record["task_id"]
), f"MaxRSSNode {node} should match highest VmHWM"
assert (
task == peak_record["task_id"]
), f"MaxRSSTask {task} should match highest VmHWM"
def assert_mem_matches_records(mem_obs, records):
# Tot/Ave/Max reflect per-task peaks only in completed-job (sacct) accounting
expected_peaks = [record["vmhwm_bytes"] for record in records]
expected_tot = sum(expected_peaks)
expected_max = max(expected_peaks)
expected_ave = expected_tot / len(expected_peaks)
expected = {
"tot": ("TRESUsageInTot", expected_tot),
"max": ("TRESUsageInMax", expected_max),
"max_rss": ("MaxRSS", expected_max),
"ave": ("TRESUsageInAve", expected_ave),
"ave_rss": ("AveRSS", expected_ave),
}
for field, (label, value) in expected.items():
assert mem_obs[field] == pytest.approx(
value, rel=memory_tolerance
), f"{label} should be {value}, but got {mem_obs[field]}"
def assert_peak_accounted(srun_cmd, node_count, task_count, output_dir):
job_id = atf.submit_job_sbatch(
f"-N{node_count} --mem={job_mem_mib} -o mem.%j.out -e mem.%j.err "
f'--wrap="{srun_cmd}"',
fatal=True,
)
atf.wait_for_job_state(job_id, "DONE", fatal=True)
atf.wait_for_step_accounted(job_id, 0, fatal=True)
output_files = [output_dir / f"mem.{task_id}.out" for task_id in range(task_count)]
records = parse_output(output_files)
mem_obs = get_mem_obs(f"{job_id}.0")
assert_mem_released(records)
assert_peak_not_final(mem_obs, records)
assert_mem_matches_records(mem_obs, records)
return mem_obs, records
def test_mem_sacct(task_script):
"""Test completed accounting preserves a short single-task peak."""
output_dir = Path(".").absolute()
srun_cmd = f"srun -N1 -n1 {task_script} {output_dir}"
assert_peak_accounted(
srun_cmd,
node_count=1,
task_count=1,
output_dir=output_dir,
)
def test_mem_sacct_nodes(task_script):
"""Test completed accounting attributes the max peak task and node."""
output_dir = Path(".").absolute()
srun_cmd = f"srun -N2 -n2 --ntasks-per-node=1 {task_script} {output_dir}"
mem_obs, records = assert_peak_accounted(
srun_cmd,
node_count=2,
task_count=2,
output_dir=output_dir,
)
assert_max_ids(mem_obs, records)