| ############################################################################ |
| # Copyright (C) SchedMD LLC. |
| ############################################################################ |
| import atf |
| import pytest |
| import re |
| |
| IMEX_CHANNEL_PATH = "/dev/nvidia-caps-imex-channels" |
| NUM_NODES = 4 |
| CHANNEL_MAX = NUM_NODES - 1 |
| TTY_MAJOR_NUM = 4 |
| |
| |
| # Setup |
| @pytest.fixture(scope="module", autouse=True) |
| def setup(): |
| atf.require_nodes(NUM_NODES) |
| atf.require_config_parameter("SwitchType", "switch/nvidia_imex") |
| atf.require_config_parameter( |
| "SwitchParameters", |
| f"imex_channel_count={CHANNEL_MAX},imex_dev_major={TTY_MAJOR_NUM}", |
| ) |
| atf.require_slurm_running() |
| |
| |
| def _simple_channel_job(): |
| output = atf.run_command( |
| f"srun --quiet ls {IMEX_CHANNEL_PATH}", |
| fatal=False, |
| ) |
| assert output["exit_code"] == 0, "Expected srun to run successfully" |
| |
| lines = output["stdout"].strip().splitlines() |
| |
| pattern = re.compile(r"^channel\d+$") |
| |
| for line in lines: |
| assert line, "Empty line found" |
| assert pattern.match(line), f"Invalid line format: {line}" |
| |
| return lines |
| |
| |
| def _job_expect_channel(channel): |
| channel_list = _simple_channel_job() |
| |
| assert channel_list == [channel], f"Unexpected channel found in {IMEX_CHANNEL_PATH}" |
| |
| |
| def _job_expect_pending(): |
| job_id = atf.submit_job_sbatch( |
| f'--wrap="srun ls {IMEX_CHANNEL_PATH}; sleep 30"', fatal=True |
| ) |
| if atf.get_version("bin/scontrol") >= (25, 11): |
| # Dev #50642: Unique IMEX channel per segment |
| reason = "NvidiaImexChannels" |
| else: |
| reason = None |
| atf.wait_for_job_state(job_id, "PENDING", desired_reason=reason, fatal=True) |
| |
| return job_id |
| |
| |
| def test_single_channel(): |
| """Test single channel creation""" |
| |
| # Run twice to make sure that channels ids are getting released properly |
| for _ in range(2): |
| _job_expect_channel("channel1") |
| |
| |
| def _allocate_single_channel(): |
| job_id = atf.submit_job_sbatch( |
| f'--wrap="srun ls {IMEX_CHANNEL_PATH}; sleep 30"', fatal=True |
| ) |
| atf.wait_for_job_state(job_id, "RUNNING", fatal=True) |
| |
| return job_id |
| |
| |
| @pytest.mark.xfail( |
| atf.get_version("sbin/slurmctld") < (25, 11), |
| reason="Dev #50642: Unique IMEX channel per segment", |
| ) |
| def test_multiple_channels(): |
| """Test channel creation for multiple jobs""" |
| |
| # Run twice to make sure that channels ids are getting released properly |
| for _ in range(2): |
| running_jobid = 0 |
| pending_jobid = 0 |
| |
| atf.cancel_all_jobs() |
| for i in range(CHANNEL_MAX): |
| # Make sure that the channel id is as expected for a quick job |
| _job_expect_channel(f"channel{i+1}") |
| # Allocate a channel with a long job to keep it from the next job |
| running_jobid = _allocate_single_channel() |
| |
| # At this point there should be no more channels left, and a job should |
| # pend waiting for a channel to be released |
| pending_jobid = _job_expect_pending() |
| |
| # Cancel last running job |
| atf.cancel_jobs([running_jobid]) |
| |
| # Previously pending job should start running |
| atf.wait_for_job_state(pending_jobid, "RUNNING", fatal=True) |
| |
| |
| @pytest.mark.xfail( |
| atf.get_version("sbin/slurmd") < (25, 11), |
| reason="Dev #50642: Unique IMEX channel per segment", |
| ) |
| def test_batch_channel(): |
| """Test channel creation for batch step""" |
| |
| output_file = "output.txt" |
| |
| job_id = atf.submit_job_sbatch( |
| f'--wrap="ls {IMEX_CHANNEL_PATH}" --output={output_file}', fatal=True |
| ) |
| atf.wait_for_job_state(job_id, "DONE", fatal=True) |
| |
| # Now check the output file |
| with open(output_file, "r") as f: |
| contents = f.read().strip() |
| |
| assert contents == "channel1", f"Expected 'channel1' but got: {contents!r}" |