diff --git a/.github/actions/docker-build/action.yml b/.github/actions/docker-build/action.yml index 33458a4551..cf5c8a7d28 100644 --- a/.github/actions/docker-build/action.yml +++ b/.github/actions/docker-build/action.yml @@ -16,9 +16,6 @@ inputs: image_tag: description: 'Custom image tag (optional, defaults to framework:latest)' required: false - ngc_ci_access_token: - description: 'NGC CI Access Token' - required: false ci_token: description: 'CI Token' required: false @@ -49,6 +46,12 @@ inputs: torch_backend: description: 'Optional override for TORCH_BACKEND build-arg (e.g., cu129)' required: false + enable_kvbm: + description: 'Enable KVBM support (optional)' + required: false + dynamo_base_image: + description: 'Pre-built Dynamo base image to use instead of building from scratch' + required: false outputs: image_tag: @@ -61,20 +64,9 @@ runs: - name: Set up Docker Buildx uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 #v3.11.1 with: - driver: docker + driver: docker-container # Enable BuildKit for enhanced metadata buildkitd-flags: --debug - - name: Login to ECR - shell: bash - env: - ECR_HOSTNAME: ${{ inputs.aws_account_id }}.dkr.ecr.${{ inputs.aws_default_region }}.amazonaws.com - run: | - aws ecr get-login-password --region ${{ inputs.aws_default_region }} | docker login --username AWS --password-stdin ${ECR_HOSTNAME} - - name: Login to NGC - if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name == 'push' - shell: bash - run: | - echo "${{ inputs.ngc_ci_access_token }}" | docker login nvcr.io -u '$oauthtoken' --password-stdin - name: Cleanup if: always() shell: bash @@ -90,9 +82,12 @@ runs: AWS_ACCESS_KEY_ID: ${{ inputs.aws_access_key_id }} AWS_SECRET_ACCESS_KEY: ${{ inputs.aws_secret_access_key }} PLATFORM: ${{ inputs.platform }} + ECR_HOSTNAME: ${{ inputs.aws_account_id }}.dkr.ecr.${{ inputs.aws_default_region }}.amazonaws.com GITHUB_RUN_ID: ${{ github.run_id }} GITHUB_JOB: ${{ github.job }} + GITHUB_REF_NAME: ${{ github.ref_name }} run: | + set -x # Determine image tag if [ -n "${{ inputs.image_tag }}" ]; then IMAGE_TAG="${{ inputs.image_tag }}" @@ -112,18 +107,34 @@ runs: echo "๐Ÿ“ Build log will be saved to: ${BUILD_LOG_FILE}" # Collect optional overrides provided by the workflow + # Set base cache args and set --cache-to if this is a main commit EXTRA_ARGS="" + EXTRA_ARGS="--cache-to type=inline " + EXTRA_ARGS+="--cache-from type=registry,ref=${ECR_HOSTNAME}/ai-dynamo/dynamo:${{ inputs.framework }}-${PLATFORM##*/}-cache " + EXTRA_ARGS+="--cache-from type=registry,ref=${ECR_HOSTNAME}/ai-dynamo/dynamo:main-${{ inputs.framework }}-${PLATFORM##*/} " + if [[ "$GITHUB_REF_NAME" == "main" ]]; then + EXTRA_ARGS+="--cache-to type=registry,ref=${ECR_HOSTNAME}/ai-dynamo/dynamo:${{ inputs.framework }}-${PLATFORM##*/}-cache,mode=max " + fi + + echo "$EXTRA_ARGS" + # Collect optional overrides provided by the workflow if [ -n "${{ inputs.base_image_tag }}" ]; then - EXTRA_ARGS+=" --base-image-tag ${{ inputs.base_image_tag }}" + EXTRA_ARGS+="--base-image-tag ${{ inputs.base_image_tag }} " fi if [ -n "${{ inputs.runtime_image_tag }}" ]; then - EXTRA_ARGS+=" --build-arg RUNTIME_IMAGE_TAG=${{ inputs.runtime_image_tag }}" + EXTRA_ARGS+="--build-arg RUNTIME_IMAGE_TAG=${{ inputs.runtime_image_tag }} " fi if [ -n "${{ inputs.cuda_version }}" ]; then - EXTRA_ARGS+=" --build-arg CUDA_VERSION=${{ inputs.cuda_version }}" + EXTRA_ARGS+="--build-arg CUDA_VERSION=${{ inputs.cuda_version }} " fi if [ -n "${{ inputs.torch_backend }}" ]; then - EXTRA_ARGS+=" --build-arg TORCH_BACKEND=${{ inputs.torch_backend }}" + EXTRA_ARGS+="--build-arg TORCH_BACKEND=${{ inputs.torch_backend }} " + fi + if [ -n "${{ inputs.dynamo_base_image }}" ]; then + EXTRA_ARGS+=" --dynamo-base-image ${{ inputs.dynamo_base_image }}" + fi + if [ -n "${{ inputs.enable_kvbm }}" ]; then + EXTRA_ARGS+=" --build-arg ENABLE_KVBM=${{ inputs.enable_kvbm }}" fi # Execute build and capture output (show on console AND save to file) @@ -144,6 +155,26 @@ runs: # Exit with the build's exit code exit ${BUILD_EXIT_CODE} + - name: Run Sanity Check on Runtime Image + if: inputs.target == 'runtime' + shell: bash + run: | + IMAGE_TAG="${{ steps.build.outputs.image_tag }}" + echo "Running sanity check on image: $IMAGE_TAG" + + # Run the sanity check script inside the container + # The script is located in /workspace/deploy/sanity_check.py in runtime containers + set +e + docker run --rm "$IMAGE_TAG" python /workspace/deploy/sanity_check.py --runtime-check --no-gpu-check + SANITY_CHECK_EXIT_CODE=$? + set -e + if [ ${SANITY_CHECK_EXIT_CODE} -ne 0 ]; then + echo "ERROR: Sanity check failed - ai-dynamo packages not properly installed" + exit ${SANITY_CHECK_EXIT_CODE} + else + echo "โœ… Sanity check passed" + fi + - name: Capture Build Metrics id: metrics shell: bash @@ -223,8 +254,7 @@ runs: chmod +x .github/scripts/parse_buildkit_output.py # Check for build logs and build stage arguments dynamically - BASE_BUILD_LOG="build-logs/base-image-build.log" - FRAMEWORK_BUILD_LOG="build-logs/framework-${FRAMEWORK_LOWER}-build.log" + BUILD_LOG="build-logs/single-stage-build.log" # Path to container metadata created in previous step CONTAINER_METADATA="build-metrics/metrics-${{ inputs.framework }}-${PLATFORM_ARCH}-${WORKFLOW_ID}-${JOB_ID}.json" @@ -237,18 +267,11 @@ runs: # Build stage arguments dynamically based on which logs exist STAGE_ARGS=() - if [ -f "$BASE_BUILD_LOG" ]; then - echo " โœ“ Found base image log: ${BASE_BUILD_LOG}" - STAGE_ARGS+=("base:${BASE_BUILD_LOG}") - else - echo " โ„น๏ธ No base image log found" - fi - - if [ -f "$FRAMEWORK_BUILD_LOG" ]; then - echo " โœ“ Found framework log: ${FRAMEWORK_BUILD_LOG}" - STAGE_ARGS+=("runtime:${FRAMEWORK_BUILD_LOG}") + if [ -f "$BUILD_LOG" ]; then + echo " โœ“ Found base image log: ${BUILD_LOG}" + STAGE_ARGS+=("runtime:${BUILD_LOG}") else - echo " โ„น๏ธ No framework log found" + echo " โ„น๏ธ No image log found" fi # Check for any additional stage logs (e.g., build-logs/stage3-*.log) @@ -280,13 +303,6 @@ runs: if [ ${PARSER_EXIT_CODE} -eq 0 ] && [ -f "$COMPREHENSIVE_JSON" ]; then echo "โœ… Comprehensive build metrics generated successfully" echo "๐Ÿ“„ Output file: ${COMPREHENSIVE_JSON}" - echo "" - echo "==========================================" - echo "๐Ÿ“‹ FULL JSON OUTPUT (for debugging)" - echo "==========================================" - cat "$COMPREHENSIVE_JSON" - echo "" - echo "==========================================" else echo "โš ๏ธ Metrics generation had issues but continuing..." fi @@ -296,7 +312,7 @@ runs: uses: actions/upload-artifact@v4 if: always() with: - name: build-metrics-${{ inputs.framework }}-${{ env.PLATFORM_ARCH }}-${{ github.run_id }}-${{ job.check_run_id }} + name: build-metrics-${{ inputs.framework }}-${{ inputs.target }}-${{ env.PLATFORM_ARCH }}-${{ github.run_id }}-${{ job.check_run_id }} path: build-metrics/build-${{ inputs.framework }}-${{ env.PLATFORM_ARCH }}-${{ github.run_id }}-${{ job.check_run_id }}.json retention-days: 7 diff --git a/.github/actions/docker-login/action.yml b/.github/actions/docker-login/action.yml new file mode 100644 index 0000000000..1e24aff400 --- /dev/null +++ b/.github/actions/docker-login/action.yml @@ -0,0 +1,46 @@ +name: 'Docker Login' +description: 'Login to multiple container registries (ECR, NGC, ACR)' + +inputs: + ngc_ci_access_token: + description: 'NGC CI Access Token' + required: false + aws_default_region: + description: 'AWS Default Region' + required: false + aws_account_id: + description: 'AWS Account ID' + required: false + azure_acr_hostname: + description: 'Azure ACR hostname' + required: false + azure_acr_user: + description: 'Azure ACR user' + required: false + azure_acr_password: + description: 'Azure ACR password' + required: false + +runs: + using: "composite" + steps: + - name: ECR Login + shell: bash + if: ${{ inputs.aws_default_region != '' && inputs.aws_account_id != '' }} + env: + ECR_HOSTNAME: ${{ inputs.aws_account_id }}.dkr.ecr.${{ inputs.aws_default_region }}.amazonaws.com + run: | + set -euo pipefail + aws ecr get-login-password --region ${{ inputs.aws_default_region }} | docker login --username AWS --password-stdin "${ECR_HOSTNAME}" + - name: NGC Login + if: ${{ inputs.ngc_ci_access_token != '' }} + shell: bash + run: | + set -euo pipefail + echo "${{ inputs.ngc_ci_access_token }}" | docker login nvcr.io -u '$oauthtoken' --password-stdin + - name: ACR Login + shell: bash + if: ${{ inputs.azure_acr_hostname != '' && inputs.azure_acr_user != '' && inputs.azure_acr_password != '' }} + run: | + set -euo pipefail + echo "${{ inputs.azure_acr_password }}" | docker login "${{ inputs.azure_acr_hostname }}" --username "${{ inputs.azure_acr_user }}" --password-stdin diff --git a/.github/actions/docker-tag-push/action.yml b/.github/actions/docker-tag-push/action.yml index 51428e9985..4b1a96b2b1 100644 --- a/.github/actions/docker-tag-push/action.yml +++ b/.github/actions/docker-tag-push/action.yml @@ -1,12 +1,18 @@ +name: 'Docker Tag and Push' description: 'Tag and Push Docker Images' inputs: local_image: description: 'Local Image Name:Tag' required: true - push_tag: - description: 'Target Name:Tag' + push_tags: + description: 'Target Name:Tag (newline-separated list for multiple tags)' required: true + # There isn't a clean way to have an additional tag that is conditional + # Adding this to handle this use-case (we want multiple tags for main builds) + conditional_tag: + description: 'Optional tag for conditionals' + required: false aws_push: description: 'Push to AWS Boolean' required: false @@ -21,54 +27,56 @@ inputs: aws_default_region: description: 'AWS Default Region' required: false - aws_access_key_id: - description: 'AWS Access Key ID' - required: false - aws_secret_access_key: - description: 'AWS Secret Access Key' - required: false azure_acr_hostname: description: 'Azure ACR hostname' required: false - azure_acr_user: - description: 'Azure ACR user' - required: false - azure_acr_password: - description: 'Azure ACR password' - required: false outputs: - image_tag: - description: 'Image Tag' - value: ${{ inputs.push_tag }} + image_tags: + description: 'Image Tags' + value: ${{ inputs.push_tags }} runs: using: "composite" steps: - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - - name: ACR Login - shell: bash - if: ${{ inputs.azure_push == 'true' }} - run: | - echo "${{ inputs.azure_acr_password }}" | docker login ${{ inputs.azure_acr_hostname }} --username ${{ inputs.azure_acr_user }} --password-stdin - name: ECR Tag and Push shell: bash if: ${{ inputs.aws_push == 'true' }} env: LOCAL_IMAGE: ${{ inputs.local_image }} - PUSH_TAG: ${{ inputs.push_tag }} + PUSH_TAGS: ${{ inputs.push_tags }} + CONDITIONAL_TAG: ${{ inputs.conditional_tag }} ECR_HOSTNAME: ${{ inputs.aws_account_id }}.dkr.ecr.${{ inputs.aws_default_region }}.amazonaws.com run: | - docker tag ${LOCAL_IMAGE} ${ECR_HOSTNAME}/${PUSH_TAG} - docker push ${ECR_HOSTNAME}/${PUSH_TAG} + set -euo pipefail + if [[ ${CONDITIONAL_TAG} != '' ]]; then + docker tag ${LOCAL_IMAGE} ${ECR_HOSTNAME}/${CONDITIONAL_TAG} + docker push ${ECR_HOSTNAME}/${CONDITIONAL_TAG} + fi + while IFS= read -r TAG; do + if [ -z "$TAG" ]; then + continue + fi + echo "Tagging and pushing: ${ECR_HOSTNAME}/${TAG}" + docker tag "${LOCAL_IMAGE}" "${ECR_HOSTNAME}/${TAG}" + docker push "${ECR_HOSTNAME}/${TAG}" + done <<< "$PUSH_TAGS" - name: ACR Tag and Push shell: bash if: ${{ inputs.azure_push == 'true' }} env: LOCAL_IMAGE: ${{ inputs.local_image }} - PUSH_TAG: ${{ inputs.push_tag }} + PUSH_TAGS: ${{ inputs.push_tags }} AZURE_ACR_HOSTNAME: ${{ inputs.azure_acr_hostname }} run: | - docker tag ${LOCAL_IMAGE} ${AZURE_ACR_HOSTNAME}/${PUSH_TAG} - docker push ${AZURE_ACR_HOSTNAME}/${PUSH_TAG} + set -euo pipefail + while IFS= read -r TAG; do + if [ -z "$TAG" ]; then + continue + fi + echo "Tagging and pushing: ${AZURE_ACR_HOSTNAME}/${TAG}" + docker tag "${LOCAL_IMAGE}" "${AZURE_ACR_HOSTNAME}/${TAG}" + docker push "${AZURE_ACR_HOSTNAME}/${TAG}" + done <<< "$PUSH_TAGS" diff --git a/.github/actions/pytest/action.yml b/.github/actions/pytest/action.yml index 0f44baf3a7..5cc89b4bc5 100644 --- a/.github/actions/pytest/action.yml +++ b/.github/actions/pytest/action.yml @@ -24,6 +24,10 @@ inputs: description: 'Platform architecture (amd64, arm64)' required: false default: 'amd64' + dry_run: + description: 'Run pytest in dry-run mode (collect tests only, do not execute)' + required: false + default: 'false' runs: @@ -54,31 +58,50 @@ runs: # Run pytest with detailed output and JUnit XML set +e # Don't exit on test failures - # Detect GPU availability and conditionally add GPU flags - GPU_FLAGS="" - if command -v nvidia-smi &> /dev/null && nvidia-smi &> /dev/null; then - echo "GPU detected, enabling GPU runtime" - GPU_FLAGS="--runtime=nvidia --gpus all" + # Determine docker runtime flags and pytest command based on dry_run mode + if [[ "${{ inputs.dry_run }}" == "true" ]]; then + echo "๐Ÿ” Running pytest in dry-run mode (collect-only, no GPU required)" + GPU_FLAGS="" + PYTEST_CMD="pytest -v --collect-only -m \"${{ inputs.pytest_marks }}\"" else - echo "No GPU detected, running in CPU-only mode" + echo "๐Ÿš€ Running pytest in normal mode" + PYTEST_CMD="pytest -v --tb=short --basetemp=/tmp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=10 -m \"${{ inputs.pytest_marks }}\"" + + # Detect GPU availability and conditionally add GPU flags + GPU_FLAGS="" + if command -v nvidia-smi &> /dev/null && nvidia-smi &> /dev/null; then + echo "โœ“ GPU detected, enabling GPU runtime" + GPU_FLAGS="--runtime=nvidia --gpus all" + else + echo "โš ๏ธ No GPU detected, running in CPU-only mode" + fi fi + # Get absolute path for test-results directory and ensure it has proper permissions + TEST_RESULTS_DIR="$(pwd)/test-results" + chmod 777 "${TEST_RESULTS_DIR}" + echo "๐Ÿ“ Test results will be saved to: ${TEST_RESULTS_DIR}" + docker run ${GPU_FLAGS} --rm -w /workspace \ --cpus=${NUM_CPUS} \ --network host \ --name ${{ env.CONTAINER_ID }}_pytest \ + -v "${TEST_RESULTS_DIR}:/workspace/test-results" \ ${{ inputs.image_tag }} \ - bash -c "mkdir -p /workspace/test-results && pytest -v --tb=short --basetemp=/tmp -o cache_dir=/tmp/.pytest_cache --junitxml=/workspace/test-results/${{ env.PYTEST_XML_FILE }} --durations=10 -m \"${{ inputs.pytest_marks }}\"" + bash -c "${PYTEST_CMD}" TEST_EXIT_CODE=$? echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> $GITHUB_ENV echo "๐Ÿงช Tests completed with exit code: ${TEST_EXIT_CODE}" - # Copy test results from container to host - docker cp ${{ env.CONTAINER_ID }}_pytest:/workspace/test-results . || echo "Failed to copy test results" - - # Clean up container - docker rm -f ${{ env.CONTAINER_ID }}_pytest || echo "Failed to clean up container" + # Verify test results were written (only in normal mode) + if [[ "${{ inputs.dry_run }}" != "true" ]]; then + if [[ -f "${TEST_RESULTS_DIR}/${{ env.PYTEST_XML_FILE }}" ]]; then + echo "โœ… Test results file found: ${TEST_RESULTS_DIR}/${{ env.PYTEST_XML_FILE }}" + else + echo "โš ๏ธ Test results file not found: ${TEST_RESULTS_DIR}/${{ env.PYTEST_XML_FILE }}" + fi + fi # Always continue to results processing exit 0 @@ -103,23 +126,9 @@ runs: ERROR_TESTS=$(grep -o 'errors="[0-9]*"' "$JUNIT_FILE" | grep -o '[0-9]*' | head -1 || echo "0") echo "๐Ÿ“Š ${TOTAL_TESTS} tests completed (${FAILED_TESTS} failed, ${ERROR_TESTS} errors)" - # Create uniquely named metadata file with step context information - # Use framework-testtype-arch to make it unique per test run - METADATA_FILE="test-results/test_metadata_${{ inputs.framework }}_${STR_TEST_TYPE}_${{ inputs.platform_arch }}.json" - JUNIT_NAME="pytest_test_report_${{ inputs.framework }}_${STR_TEST_TYPE}_${{ inputs.platform_arch }}.xml" - # Rename XML file to unique name + JUNIT_NAME="pytest_test_report_${{ inputs.framework }}_${STR_TEST_TYPE}_${{ inputs.platform_arch }}_${{ github.run_id }}_${{ job.check_run_id }}.xml" mv "$JUNIT_FILE" "test-results/$JUNIT_NAME" - - echo '{' > "$METADATA_FILE" - echo ' "job_name": "${{ github.job }}",' >> "$METADATA_FILE" - echo ' "framework": "${{ inputs.framework }}",' >> "$METADATA_FILE" - echo ' "test_type": "${{ inputs.test_type }}",' >> "$METADATA_FILE" - echo ' "platform_arch": "${{ inputs.platform_arch }}",' >> "$METADATA_FILE" - echo ' "junit_xml_file": "'"$JUNIT_NAME"'",' >> "$METADATA_FILE" - echo ' "step_name": "Run ${{ inputs.test_type }} tests"' >> "$METADATA_FILE" - echo '}' >> "$METADATA_FILE" - echo "๐Ÿ“ Created test metadata file: $METADATA_FILE" echo "๐Ÿ“ Renamed XML file to: $JUNIT_NAME" else echo "โš ๏ธ JUnit XML file not found - test results may not be available for upload" @@ -135,8 +144,6 @@ runs: uses: actions/upload-artifact@v4 if: always() # Always upload test results, even if tests failed with: - name: test-results-${{ inputs.framework }}-${{ env.STR_TEST_TYPE }}-${{ env.PLATFORM_ARCH }} - path: | - test-results/pytest_test_report_${{ inputs.framework }}_${{ env.STR_TEST_TYPE }}_${{ inputs.platform_arch }}.xml - test-results/test_metadata_${{ inputs.framework }}_${{ env.STR_TEST_TYPE }}_${{ inputs.platform_arch }}.json - retention-days: 7 \ No newline at end of file + name: test-results-${{ inputs.framework }}-${{ env.STR_TEST_TYPE }}-${{ env.PLATFORM_ARCH }}-${{ github.run_id }}-${{ job.check_run_id }} + path: test-results/pytest_test_report_${{ inputs.framework }}_${{ env.STR_TEST_TYPE }}_${{ inputs.platform_arch }}_${{ github.run_id }}_${{ job.check_run_id }}.xml + retention-days: 7 diff --git a/.github/workflows/container-validation-backends.yml b/.github/workflows/container-validation-backends.yml index c8af077259..eeec8c9500 100644 --- a/.github/workflows/container-validation-backends.yml +++ b/.github/workflows/container-validation-backends.yml @@ -69,14 +69,15 @@ jobs: uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 + - name: Docker Login + uses: ./.github/actions/docker-login with: - driver: docker - - name: Login to ECR - shell: bash - env: - ECR_HOSTNAME: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com - run: | - aws ecr get-login-password --region ${{ secrets.AWS_DEFAULT_REGION }} | docker login --username AWS --password-stdin ${ECR_HOSTNAME} + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} + azure_acr_user: ${{ secrets.AZURE_ACR_USER }} + azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} - name: Linter shell: bash env: @@ -91,7 +92,6 @@ jobs: run: | cd deploy/cloud/operator docker build --target tester --progress=plain --build-arg DOCKER_PROXY=${ECR_HOSTNAME}/dockerhub/ . - - name: Set up Go uses: actions/setup-go@44694675825211faa026b3c33043df3e48a5fa00 # v6.0.0 with: @@ -120,14 +120,10 @@ jobs: uses: ./.github/actions/docker-tag-push with: local_image: dynamo-operator:latest - push_tag: ai-dynamo/dynamo:${{ github.sha }}-operator-${{ matrix.platform.arch }} + push_tags: ai-dynamo/dynamo:${{ github.sha }}-operator-${{ matrix.platform.arch }} aws_push: 'false' azure_push: 'true' - aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} - aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} - azure_acr_user: ${{ secrets.AZURE_ACR_USER }} - azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} vllm: needs: changed-files @@ -147,6 +143,15 @@ jobs: echo ${K8S_NODE_NAME} - name: Checkout code uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0 + - name: Docker Login + uses: ./.github/actions/docker-login + with: + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} + azure_acr_user: ${{ secrets.AZURE_ACR_USER }} + azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} - name: Build Container id: build-image uses: ./.github/actions/docker-build @@ -158,27 +163,30 @@ jobs: runtime_image_tag: ${{ matrix.platform.arch == 'arm64' && '12.9.0-runtime-ubuntu24.04' || '' }} cuda_version: ${{ matrix.platform.arch == 'arm64' && '129' || '' }} torch_backend: ${{ matrix.platform.arch == 'arm64' && 'cu129' || '' }} - ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} ci_token: ${{ secrets.CI_TOKEN }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} sccache_s3_bucket: ${{ secrets.SCCACHE_S3_BUCKET }} aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_access_key_id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + - name: Login to Container Registries + uses: ./.github/actions/docker-login + with: + azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} + azure_acr_user: ${{ secrets.AZURE_ACR_USER }} + azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} - name: Docker Tag and Push uses: ./.github/actions/docker-tag-push with: local_image: ${{ steps.build-image.outputs.image_tag }} - push_tag: ai-dynamo/dynamo:${{ github.sha }}-vllm-${{ matrix.platform.arch }} - # OPS-1145: Switch aws_push to true - aws_push: 'false' + push_tags: ai-dynamo/dynamo:${{ github.sha }}-vllm-${{ matrix.platform.arch }} + conditional_tag: ${{ github.ref_name == 'main' && format('ai-dynamo/dynamo:main-vllm-{0}', matrix.platform.arch) || '' }} + aws_push: 'true' azure_push: 'true' aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} - azure_acr_user: ${{ secrets.AZURE_ACR_USER }} - azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} - - name: Run tests if: ${{ matrix.platform.arch != 'arm64' }} uses: ./.github/actions/pytest @@ -207,7 +215,15 @@ jobs: echo ${K8S_NODE_NAME} - name: Checkout repository uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0 - + - name: Docker Login + uses: ./.github/actions/docker-login + with: + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} + azure_acr_user: ${{ secrets.AZURE_ACR_USER }} + azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} - name: Build Container id: build-image uses: ./.github/actions/docker-build @@ -215,28 +231,23 @@ jobs: framework: sglang target: runtime platform: 'linux/${{ matrix.platform.arch }}' - ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} ci_token: ${{ secrets.CI_TOKEN }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} sccache_s3_bucket: ${{ secrets.SCCACHE_S3_BUCKET }} aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_access_key_id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - - name: Docker Tag and Push uses: ./.github/actions/docker-tag-push with: local_image: ${{ steps.build-image.outputs.image_tag }} - push_tag: ai-dynamo/dynamo:${{ github.sha }}-sglang-${{ matrix.platform.arch }} - # OPS-1145: Switch aws_push to true - aws_push: 'false' + push_tags: ai-dynamo/dynamo:${{ github.sha }}-sglang-${{ matrix.platform.arch }} + conditional_tag: ${{ github.ref_name == 'main' && format('ai-dynamo/dynamo:main-sglang-{0}', matrix.platform.arch) || '' }} + aws_push: 'true' azure_push: 'true' aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} - azure_acr_user: ${{ secrets.AZURE_ACR_USER }} - azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} - - name: Run tests if: ${{ matrix.platform.arch != 'arm64' }} uses: ./.github/actions/pytest @@ -265,7 +276,15 @@ jobs: echo ${K8S_NODE_NAME} - name: Checkout code uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4.3.0 - + - name: Docker Login + uses: ./.github/actions/docker-login + with: + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} + azure_acr_user: ${{ secrets.AZURE_ACR_USER }} + azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} - name: Build Container id: build-image uses: ./.github/actions/docker-build @@ -273,28 +292,23 @@ jobs: framework: trtllm target: runtime platform: 'linux/${{ matrix.platform.arch }}' - ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} ci_token: ${{ secrets.CI_TOKEN }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} sccache_s3_bucket: ${{ secrets.SCCACHE_S3_BUCKET }} aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_access_key_id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - - name: Docker Tag and Push uses: ./.github/actions/docker-tag-push with: local_image: ${{ steps.build-image.outputs.image_tag }} - push_tag: ai-dynamo/dynamo:${{ github.sha }}-trtllm-${{ matrix.platform.arch }} - # OPS-1145: Switch aws_push to true - aws_push: 'false' + push_tags: ai-dynamo/dynamo:${{ github.sha }}-trtllm-${{ matrix.platform.arch }} + conditional_tag: ${{ github.ref_name == 'main' && format('ai-dynamo/dynamo:main-trtllm-{0}', matrix.platform.arch) || '' }} + aws_push: 'true' azure_push: 'true' aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} - azure_acr_user: ${{ secrets.AZURE_ACR_USER }} - azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} - - name: Run tests if: ${{ matrix.platform.arch != 'arm64' }} uses: ./.github/actions/pytest @@ -396,6 +410,7 @@ jobs: export KUBECONFIG=$(pwd)/.kubeconfig kubectl config set-context --current --namespace=$NAMESPACE - name: Run Fault Tolerance Tests + id: run-ft-tests run: | set -x export KUBECONFIG=$(pwd)/.kubeconfig @@ -417,14 +432,49 @@ jobs: pip install -r container/deps/requirements.test.txt pip install kubernetes==32.0.1 kubernetes_asyncio kr8s pyyaml requests tabulate pydantic - # Run the pytest command (tests orchestrate K8s, don't need dynamo package) + # Create test-results directory + mkdir -p test-results + + # Run the pytest command with JUnit XML output + set +e # Don't exit on test failures pytest tests/fault_tolerance/deploy/test_deployment.py \ -m 'k8s and fault_tolerance' \ -k '${{ matrix.framework.test_scenario }}' \ -s -v \ --namespace ${NAMESPACE} \ --image ${IMAGE} \ - --client-type legacy + --client-type legacy \ + --junitxml=test-results/pytest_ft_report.xml \ + --tb=short + + TEST_EXIT_CODE=$? + echo "TEST_EXIT_CODE=${TEST_EXIT_CODE}" >> $GITHUB_ENV + echo "๐Ÿงช Fault tolerance tests completed with exit code: ${TEST_EXIT_CODE}" + + exit ${TEST_EXIT_CODE} + continue-on-error: true + + - name: Process Fault Tolerance Test Results + if: always() + run: | + set -x + + # Rename JUnit XML with unique naming if it exists + if [ -f "test-results/pytest_ft_report.xml" ]; then + mv "test-results/pytest_ft_report.xml" "test-results/pytest_ft_report_${{ matrix.framework.name }}_amd64_${{ github.run_id }}_${{ job.check_run_id }}.xml" + echo "โœ… JUnit XML report renamed with unique identifier" + else + echo "โš ๏ธ JUnit XML report not found" + fi + + - name: Upload Fault Tolerance Test Results + uses: actions/upload-artifact@v4 + if: always() + with: + name: test-results-${{ matrix.framework.name }}-fault_tolerance-amd64-${{ github.run_id }}-${{ job.check_run_id }} + path: test-results/pytest_ft_report_${{ matrix.framework.name }}_amd64_${{ github.run_id }}_${{ job.check_run_id }}.xml + retention-days: 7 + - name: Cleanup if: always() timeout-minutes: 5 @@ -448,56 +498,6 @@ jobs: kubectl delete namespace $NAMESPACE || true echo "Namespace $NAMESPACE completed." - # Upload metrics for this workflow and all its jobs - upload-workflow-metrics: - name: Upload Workflow Metrics - runs-on: gitlab - if: always() # Always run, even if other jobs fail - needs: [backend-status-check] # Wait for the status check which waits for all build jobs - - steps: - - name: Check out repository - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.x' - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install requests - - - name: Download build metrics - uses: actions/download-artifact@v4 - with: - pattern: build-metrics-* - path: build-metrics/ - merge-multiple: true - continue-on-error: true # Don't fail if artifacts don't exist - - - name: Download test results - uses: actions/download-artifact@v4 - with: - pattern: test-results-* - path: test-results/ - merge-multiple: true - continue-on-error: true # Don't fail if artifacts don't exist - - - name: Upload Complete Workflow Metrics - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - WORKFLOW_INDEX: ${{ secrets.WORKFLOW_INDEX }} - JOB_INDEX: ${{ secrets.JOB_INDEX }} - STEPS_INDEX: ${{ secrets.STEPS_INDEX }} - # Container and test index configuration - CONTAINER_INDEX: ${{ secrets.CONTAINER_INDEX }} - TEST_INDEX: ${{ secrets.TEST_INDEX }} - run: | - # Upload complete workflow metrics including container metrics - python3 .github/workflows/upload_complete_workflow_metrics.py - deploy-operator: runs-on: cpu-amd-m5-2xlarge # TODO: Uncomment this when we have a way to test the deploy-operator job in CI. @@ -617,6 +617,7 @@ jobs: kubectl config set-context --current --namespace=$NAMESPACE --kubeconfig "${KUBECONFIG}" kubectl config get-contexts - name: Run Tests + id: run-tests env: NAMESPACE: ${{ needs.deploy-operator.outputs.NAMESPACE }} run: | @@ -624,6 +625,9 @@ jobs: export KUBECONFIG=$(pwd)/.kubeconfig kubectl config set-context --current --namespace=$NAMESPACE + # Redirect all output to a log file while still showing it + exec > >(tee -a test-output.log) 2>&1 + cd examples/backends/$FRAMEWORK export FRAMEWORK_RUNTIME_IMAGE="${{ secrets.AZURE_ACR_HOSTNAME }}/ai-dynamo/dynamo:${{ github.sha }}-${FRAMEWORK}-amd64" export KUBE_NS=$NAMESPACE @@ -716,6 +720,32 @@ jobs: echo "Test passed: Response matches expected format and content" fi exit $TEST_RESULT + continue-on-error: true + + - name: Process Deployment Test Results + if: always() + run: | + set -x + + # Create test-results directory + mkdir -p test-results + + # Copy and rename the test output log with unique naming + if [ -f "test-output.log" ]; then + cp test-output.log "test-results/deploy_test_output_${{ env.FRAMEWORK }}_${{ matrix.profile }}_amd64_${{ github.run_id }}_${{ job.check_run_id }}.log" + echo "โœ… Test output log copied to test-results/" + else + echo "โš ๏ธ test-output.log not found" + fi + + - name: Upload Deployment Test Results + uses: actions/upload-artifact@v4 + if: always() + with: + name: test-results-${{ env.FRAMEWORK }}-deploy-${{ matrix.profile }}-amd64-${{ github.run_id }}-${{ job.check_run_id }} + path: test-results/deploy_test_output_${{ env.FRAMEWORK }}_${{ matrix.profile }}_amd64_${{ github.run_id }}_${{ job.check_run_id }}.log + retention-days: 7 + - name: Cleanup if: always() timeout-minutes: 5 diff --git a/.github/workflows/container-validation-dynamo.yml b/.github/workflows/container-validation-dynamo.yml index c068eecabb..8a0e182fc6 100644 --- a/.github/workflows/container-validation-dynamo.yml +++ b/.github/workflows/container-validation-dynamo.yml @@ -33,8 +33,9 @@ jobs: uses: docker/setup-buildx-action@v3 - name: Login to NGC if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name == 'push' - run: | - echo "${{ secrets.NGC_CI_ACCESS_TOKEN }}" | docker login nvcr.io -u '$oauthtoken' --password-stdin + uses: ./.github/actions/docker-login + with: + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} - name: Define Image Tag id: define_image_tag run: | @@ -43,20 +44,20 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.CI_TOKEN }} run: | - ./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target dev --framework none + ./container/build.sh --tag ${{ steps.define_image_tag.outputs.image_tag }} --target dev --framework none --enable-kvbm - name: Start services with docker-compose working-directory: ./deploy run: | docker compose up -d nats-server etcd-server - - name: Run Rust checks (block-manager + integration tests) + - name: Run Rust checks (block-manager + media-nixl + integration tests) run: | docker run --rm -w /workspace/lib/llm \ --name ${{ env.CONTAINER_ID }}_rust_checks \ ${{ steps.define_image_tag.outputs.image_tag }} \ bash -ec 'rustup component add rustfmt clippy && \ cargo fmt -- --check && \ - cargo clippy --features block-manager --no-deps --all-targets -- -D warnings && \ - cargo test --locked --all-targets --features=block-manager && \ + cargo clippy --features block-manager,media-nixl --no-deps --all-targets -- -D warnings && \ + cargo test --locked --all-targets --features=block-manager,media-nixl && \ cargo test --locked --features integration -- --nocapture' - name: Cleanup services if: always() diff --git a/.github/workflows/generate-docs.yml b/.github/workflows/generate-docs.yml index e281129925..24f1a6cdea 100644 --- a/.github/workflows/generate-docs.yml +++ b/.github/workflows/generate-docs.yml @@ -13,18 +13,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -name: Generate Documentation +# Dynamo docs build and publish workflow +# Build: +# - Builds documentation using Docker container +# - Creates artifact for downstream use +# - Runs on: main, release/*, tags, PRs (docs changes only), manual dispatch +# +# Publish: +# - Main branch: publish to S3 under 'dev' (development docs) +# - Tagged commits: publish to S3 under 'archive/X.Y.Z' AND update 'latest' to match the release +# - Manual dispatch: publish specified version to archive (does NOT update 'latest') +# - PRs: no S3 publish (only internal preview deployment if targeting release branch) +# - Version manifest: automatically updated in S3 when publishing new versions (versions1.json) +# - Akamai: flushes cache for the target path after publish (when DOCS_AKAMAI_ENABLED=true) +# +# Required Configuration: +# - Repository variable: DOCS_PUBLISH_S3_TARGET_PATH (prefix under S3 bucket, e.g., "dynamo") +# - Repository variable: DOCS_BASE_URL (base URL for docs site, e.g., "https://docs.nvidia.com/dynamo") +# - Secrets: AWS credentials (DOCS_AWS_ACCESS_KEY_ID, DOCS_AWS_SECRET_ACCESS_KEY, DOCS_AWS_S3_BUCKET, DOCS_AWS_REGION) +# - Secrets: DOCS_TOKEN (GitHub PAT for PR preview deployment to external repo) +# - Secrets (optional): DOCS_AWS_IAM_STS_ROLE (for OIDC authentication instead of IAM keys) +# - Secrets (optional): DOCS_AKAMAI_* EdgeGrid credentials for cache flush +# - Variable (optional): DOCS_AKAMAI_ENABLED (set to 'true' to enable Akamai cache flush) +# +# Commit message flags: +# - '/skip-dev': skip publishing 'dev' on main branch +# - '/not-latest': publish version to archive but don't update 'latest' +name: Generate and Publish Documentation on: push: branches: - main - release/* + tags: + - '*' pull_request: paths: - 'docs/**' - 'container/Dockerfile.docs' - '.github/workflows/generate-docs.yml' + workflow_dispatch: + inputs: + version: + description: 'Optional: Version to publish (e.g., 1.2.3). If not provided, publishes as dev.' + required: false + type: string + ref: + description: 'Optional: Git ref to checkout (tag, branch, or SHA). Use to build docs from older tags.' + required: false + type: string jobs: build-docs: @@ -33,13 +71,40 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + with: + ref: ${{ inputs.ref || github.ref }} + + - name: Determine docs version + id: version + shell: bash + run: | + VERSION="dev" + # Option 1: Tag push (e.g., v0.3.0 -> 0.3.0) + if [[ "${{ github.ref_type }}" == "tag" ]]; then + TAG="${{ github.ref_name }}" + if [[ "${TAG}" =~ ^v([0-9]+(\.[0-9]+){1,2}([._-](post|rc|dev)[0-9]+)?)$ ]]; then + VERSION="${BASH_REMATCH[1]}" + echo "::notice::Detected version from tag: ${VERSION}" + fi + # Option 2: Manual dispatch with version input + elif [[ -n "${{ inputs.version || '' }}" ]]; then + VERSION="${{ inputs.version }}" + echo "::notice::Using version from manual input: ${VERSION}" + fi + + echo "version=${VERSION}" >> "$GITHUB_OUTPUT" + echo "Building docs for version: ${VERSION}" - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Generate documentation + env: + DOCS_VERSION: ${{ steps.version.outputs.version }} run: | - docker build -t docs-builder -f container/Dockerfile.docs . + docker build -t docs-builder \ + --build-arg DYNAMO_DOCS_VERSION="${DOCS_VERSION}" \ + -f container/Dockerfile.docs . - name: Copy documentation out of container run: | @@ -159,3 +224,388 @@ jobs: body: comment }); } + + publish-s3: + name: Publish docs to S3 and flush Akamai + needs: [build-docs] + runs-on: ubuntu-latest + if: ${{ github.event_name != 'pull_request' }} + + permissions: + contents: read + id-token: write + actions: read + + env: + S3_BUCKET: ${{ secrets.DOCS_AWS_S3_BUCKET }} + DOCS_DIR: dynamo-docs + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v4 + with: + # Use OIDC (role assumption) if available, otherwise use IAM keys + role-to-assume: ${{ secrets.DOCS_AWS_IAM_STS_ROLE }} + aws-access-key-id: ${{ secrets.DOCS_AWS_ACCESS_KEY_ID }} + aws-secret-access-key: ${{ secrets.DOCS_AWS_SECRET_ACCESS_KEY }} + aws-region: ${{ secrets.DOCS_AWS_REGION }} + + - name: Verify AWS identity + run: | + aws sts get-caller-identity >/dev/null || { + echo "::error::Failed to authenticate with AWS. Check credentials configuration." + exit 1 + } + + - name: Download documentation artifacts + uses: actions/download-artifact@v4 + with: + pattern: dynamo-docs-* + path: ${{ env.DOCS_DIR }} + + - name: Validate documentation artifacts + run: | + # The artifact is downloaded into a subdirectory, move contents up one level + ARTIFACT_DIR=$(find "${{ env.DOCS_DIR }}" -mindepth 1 -maxdepth 1 -type d | head -n 1) + if [[ -z "${ARTIFACT_DIR}" ]]; then + echo "::error::No artifact directory found" + exit 1 + fi + + echo "::notice::Moving contents from ${ARTIFACT_DIR} to ${{ env.DOCS_DIR }}" + mv "${ARTIFACT_DIR}"/* "${{ env.DOCS_DIR }}/" + rmdir "${ARTIFACT_DIR}" + + # Validate extraction + if [[ ! -d "${{ env.DOCS_DIR }}" ]] || [[ -z "$(ls -A ${{ env.DOCS_DIR }})" ]]; then + echo "::error::Documentation directory is empty after extraction" + exit 1 + fi + + echo "::notice::Documentation size: $(du -sh ${{ env.DOCS_DIR }} | cut -f1)" + + - name: Determine version and validate inputs + id: vars + env: + ARTIFACTS_PATH: dynamo-docs + TARGET_PATH: ${{ vars.DOCS_PUBLISH_S3_TARGET_PATH }} + COMMIT_MSG: ${{ github.event.head_commit.message || '' }} + shell: bash + run: | + set -euo pipefail + + if [[ -z "${TARGET_PATH}" ]]; then + echo "::error::target-path was not provided. Set repository variable DOCS_PUBLISH_S3_TARGET_PATH." + exit 1 + fi + + if [[ ! -d "${ARTIFACTS_PATH}" ]]; then + echo "::error::Failed to find documentation artifacts at ${ARTIFACTS_PATH}" + exit 1 + fi + + # Determine version from various sources + VERSION="" + PUBLISH_TO_LATEST="false" + + # Option 1: Direct tag push + if [[ "${{ github.ref_type }}" == "tag" ]]; then + TAG="${{ github.ref_name }}" + if [[ "${TAG}" =~ ^v([0-9]+(\.[0-9]+){1,2}([._-](post|rc|dev)[0-9]+)?)$ ]]; then + VERSION="${BASH_REMATCH[1]}" + echo "Detected version from tag: ${VERSION}" + PUBLISH_TO_LATEST="true" + fi + + # Check for /not-latest flag in commit message + if [[ "${COMMIT_MSG}" =~ /not-latest ]]; then + PUBLISH_TO_LATEST="false" + echo "Detected /not-latest flag in commit message" + fi + + # Option 2: Manual dispatch with version input + elif [[ -n "${{ inputs.version || '' }}" ]]; then + VERSION="${{ inputs.version }}" + echo "Using version from manual input: ${VERSION}" + + # Don't publish to latest on manual dispatch + PUBLISH_TO_LATEST="false" + echo "Manual dispatch detected - will not publish to latest" + fi + + echo "version=${VERSION}" >> "$GITHUB_OUTPUT" + echo "artifacts_path=${ARTIFACTS_PATH}" >> "$GITHUB_OUTPUT" + echo "publish_to_latest=${PUBLISH_TO_LATEST}" >> "$GITHUB_OUTPUT" + + if [[ -n "${VERSION}" ]]; then + echo "::notice::Publishing version: ${VERSION}" + if [[ "${PUBLISH_TO_LATEST}" == "true" ]]; then + echo "::notice::Will also publish to 'latest'" + else + echo "::notice::Will NOT publish to 'latest'" + fi + else + echo "::notice::Publishing as dev (no version detected)" + fi + + - name: Normalize S3 path + id: paths + env: + S3_TARGET_ROOT: ${{ env.S3_BUCKET }} + TARGET_PATH: ${{ vars.DOCS_PUBLISH_S3_TARGET_PATH }} + shell: bash + run: | + set -euo pipefail + S3_ROOT="${S3_TARGET_ROOT%/}" + S3_PATH="${TARGET_PATH#/}" + S3_PATH="${S3_PATH%/}" + echo "S3_TARGET_PATH...${S3_PATH}" + echo "s3_root=${S3_ROOT}" >> "$GITHUB_OUTPUT" + echo "s3_path=${S3_PATH}" >> "$GITHUB_OUTPUT" + + - name: Publish version + if: ${{ steps.vars.outputs.version != '' }} + working-directory: ${{ env.DOCS_DIR }} + id: publish_version + env: + S3_ROOT: ${{ steps.paths.outputs.s3_root }} + S3_PATH: ${{ steps.paths.outputs.s3_path }} + VERSION: ${{ steps.vars.outputs.version }} + shell: bash + run: | + set -euo pipefail + echo "Publishing version ${VERSION} to ${S3_ROOT}/${S3_PATH}/archive/${VERSION}" + aws s3 sync . "${S3_ROOT}/${S3_PATH}/archive/${VERSION}" --exclude .buildinfo --exclude .doctrees --delete + echo "published=true" >> "$GITHUB_OUTPUT" + + - name: Update versions manifest in S3 + if: ${{ steps.publish_version.outputs.published == 'true' }} + env: + DOCS_BASE_URL: ${{ vars.DOCS_BASE_URL }} + S3_ROOT: ${{ steps.paths.outputs.s3_root }} + S3_PATH: ${{ steps.paths.outputs.s3_path }} + VERSION: ${{ steps.vars.outputs.version }} + shell: bash + run: | + set -euo pipefail + + MANIFEST_URL="${S3_ROOT}/${S3_PATH}/versions1.json" + LOCAL_MANIFEST="/tmp/versions1.json" + + # Download existing manifest from S3 + aws s3 cp "${MANIFEST_URL}" "${LOCAL_MANIFEST}" + + # Check if version already exists in manifest + if jq -e ".[] | select(.version == \"${VERSION}\")" "${LOCAL_MANIFEST}" > /dev/null 2>&1; then + echo "Version ${VERSION} already exists in manifest, skipping update" + else + echo "Adding version ${VERSION} to manifest" + + # Create new version entry and insert after "dev" and "latest" (index 2) + jq --arg version "${VERSION}" \ + --arg url "${DOCS_BASE_URL}/archive/${VERSION}/" \ + '.[0:2] + [{version: $version, url: $url}] + .[2:]' \ + "${LOCAL_MANIFEST}" > "${LOCAL_MANIFEST}.tmp" + mv "${LOCAL_MANIFEST}.tmp" "${LOCAL_MANIFEST}" + + # Upload updated manifest to S3 + aws s3 cp "${LOCAL_MANIFEST}" "${MANIFEST_URL}" + echo "โœ… Added ${VERSION} to versions1.json" + fi + + - name: Publish latest + if: ${{ steps.publish_version.outputs.published == 'true' && steps.vars.outputs.publish_to_latest == 'true' }} + working-directory: ${{ env.DOCS_DIR }} + id: publish_latest + env: + S3_ROOT: ${{ steps.paths.outputs.s3_root }} + S3_PATH: ${{ steps.paths.outputs.s3_path }} + shell: bash + run: | + set -euo pipefail + echo "Publishing latest to ${S3_ROOT}/${S3_PATH}/latest" + aws s3 sync . "${S3_ROOT}/${S3_PATH}/latest" --exclude .buildinfo --exclude .doctrees --delete + echo "published_latest=true" >> "$GITHUB_OUTPUT" + + - name: Publish dev (main branch) + # Publish main branch to 'dev' directory for development docs + # Skip if commit message contains '/skip-dev' anywhere + if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message || '', '/skip-dev') }} + working-directory: ${{ env.DOCS_DIR }} + id: publish_dev + env: + S3_ROOT: ${{ steps.paths.outputs.s3_root }} + S3_PATH: ${{ steps.paths.outputs.s3_path }} + shell: bash + run: | + set -euo pipefail + echo "Publishing development docs to ${S3_ROOT}/${S3_PATH}/dev" + aws s3 sync . "${S3_ROOT}/${S3_PATH}/dev" --exclude .buildinfo --exclude .doctrees --delete + echo "published=true" >> "$GITHUB_OUTPUT" + + - name: Update versions manifest in all archive directories + # Update versions*.json in ALL archive directories so old docs show current version list + # Only run when publishing a version (not for dev builds) + if: ${{ steps.vars.outputs.version != '' }} + working-directory: ${{ env.DOCS_DIR }} + env: + S3_ROOT: ${{ steps.paths.outputs.s3_root }} + S3_PATH: ${{ steps.paths.outputs.s3_path }} + shell: bash + run: | + set -euo pipefail + + # Get list of all archive directories + echo "Updating version manifests in all archive directories..." + ARCHIVE_DIRS=$(aws s3 ls "${S3_ROOT}/${S3_PATH}/archive/" | grep "PRE" | awk '{print $2}' | tr -d '/') + + for file in versions.json versions1.json; do + if [[ -f "${file}" ]]; then + for dir in ${ARCHIVE_DIRS}; do + echo "Updating ${file} in archive/${dir}/" + aws s3 cp "${file}" "${S3_ROOT}/${S3_PATH}/archive/${dir}/${file}" || { + echo "::warning::Failed to update ${file} in archive/${dir}" + } + done + fi + done + + echo "โœ… Version manifests updated in all archive directories" + + - name: Collect publish outputs + id: publish + env: + S3_PATH: ${{ steps.paths.outputs.s3_path }} + VERSION: ${{ steps.vars.outputs.version }} + PUBLISHED_VERSION: ${{ steps.publish_version.outputs.published || 'false' }} + PUBLISHED_LATEST: ${{ steps.publish_latest.outputs.published_latest || 'false' }} + PUBLISHED_DEV: ${{ steps.publish_dev.outputs.published || 'false' }} + shell: bash + run: | + set -euo pipefail + echo "s3_target_path=${S3_PATH}" >> "$GITHUB_OUTPUT" + echo "request_name=Publish docs from ${GITHUB_REPOSITORY}@${GITHUB_SHA:0:8}" >> "$GITHUB_OUTPUT" + echo "published_latest=${PUBLISHED_LATEST}" >> "$GITHUB_OUTPUT" + + # Determine what to flush based on what was published + # - Version publish: flush entire path (versions.json updated in all archive dirs) + # - Dev publish only: flush just the dev directory + if [[ "${PUBLISHED_VERSION}" == "true" ]]; then + echo "perform_flush=true" >> "$GITHUB_OUTPUT" + echo "flush_path=${S3_PATH}" >> "$GITHUB_OUTPUT" + echo "::notice::Will flush entire ${S3_PATH} (version publish updates all archives)" + elif [[ "${PUBLISHED_DEV}" == "true" ]]; then + echo "perform_flush=true" >> "$GITHUB_OUTPUT" + echo "flush_path=${S3_PATH}/dev" >> "$GITHUB_OUTPUT" + echo "::notice::Will flush ${S3_PATH}/dev only (dev publish)" + else + echo "perform_flush=false" >> "$GITHUB_OUTPUT" + echo "flush_path=" >> "$GITHUB_OUTPUT" + fi + + - name: Flush Akamai cache + # Only run if cache flush is needed AND Akamai is enabled + if: ${{ steps.publish.outputs.perform_flush == 'true' && vars.DOCS_AKAMAI_ENABLED == 'true' }} + env: + FLUSH_PATH: ${{ steps.publish.outputs.flush_path }} + REQUEST_NAME: ${{ steps.publish.outputs.request_name }} + # Use repository variable or secret for notification emails + # Format: JSON array of email addresses, e.g., '["email1@example.com", "email2@example.com"]' + EMAILS_JSON: ${{ secrets.DOCS_AKAMAI_NOTIFICATION_EMAILS }} + AKAMAI_CLIENT_SECRET: ${{ secrets.DOCS_AKAMAI_CLIENT_SECRET }} + AKAMAI_HOST: ${{ secrets.DOCS_AKAMAI_HOST }} + AKAMAI_ACCESS_TOKEN: ${{ secrets.DOCS_AKAMAI_ACCESS_TOKEN }} + AKAMAI_CLIENT_TOKEN: ${{ secrets.DOCS_AKAMAI_CLIENT_TOKEN }} + shell: bash + run: | + set -euo pipefail + + # Install required tools for Akamai + sudo apt-get update -qq + sudo apt-get install -y -qq jq xsltproc + pip install -q httpie httpie-edgegrid + + echo "Flushing Akamai cache for path: ${FLUSH_PATH}" + + # Generate Akamai ECCU request XML using the XSLT template + XSLT_TEMPLATE="${GITHUB_WORKSPACE}/.github/workflows/templates/akamai-eccu-flush.xslt" + + if [[ ! -f "${XSLT_TEMPLATE}" ]]; then + echo "::error::XSLT template file not found at ${XSLT_TEMPLATE}" + exit 1 + fi + + # Process XSLT to generate ECCU request XML + xsltproc --stringparam target-path "${FLUSH_PATH}" "${XSLT_TEMPLATE}" "${XSLT_TEMPLATE}" | \ + sed 's/xmlns:match="x" //' > /tmp/flush.xml + + # Prepare Akamai EdgeGrid credentials + echo "[default]" > ~/.edgerc + echo "client_secret = ${AKAMAI_CLIENT_SECRET}" >> ~/.edgerc + echo "host = ${AKAMAI_HOST}" >> ~/.edgerc + echo "access_token = ${AKAMAI_ACCESS_TOKEN}" >> ~/.edgerc + echo "client_token = ${AKAMAI_CLIENT_TOKEN}" >> ~/.edgerc + + # Validate and prepare email list JSON + if [[ -n "${EMAILS_JSON}" ]]; then + echo "${EMAILS_JSON}" | jq -c . > /tmp/email-addresses.json || { + echo "::error::Invalid JSON format for AKAMAI_NOTIFICATION_EMAILS" + exit 1 + } + else + echo '[]' > /tmp/email-addresses.json + fi + + # Submit ECCU request to Akamai + http --ignore-stdin --auth-type edgegrid -a default: POST :/eccu-api/v1/requests \ + metadata=@"/tmp/flush.xml" \ + propertyName=docs.nvidia.com \ + propertyNameExactMatch=true \ + propertyType=HOST_HEADER \ + requestName="${REQUEST_NAME}" \ + statusUpdateEmails:=@/tmp/email-addresses.json || { + echo "::warning::Failed to flush Akamai cache, but continuing workflow" + # Don't fail the workflow if cache flush fails + } + + - name: Summary + if: always() + env: + VERSION: ${{ steps.vars.outputs.version }} + S3_PATH: ${{ steps.paths.outputs.s3_path }} + PUBLISHED_VERSION: ${{ steps.publish_version.outputs.published || 'false' }} + PUBLISHED_LATEST: ${{ steps.publish.outputs.published_latest || 'false' }} + PUBLISHED_DEV: ${{ steps.publish_dev.outputs.published || 'false' }} + CACHE_FLUSHED: ${{ steps.publish.outputs.perform_flush }} + FLUSH_PATH: ${{ steps.publish.outputs.flush_path }} + run: | + echo "## ๐Ÿ“š Documentation Publishing Summary" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "### Source" >> $GITHUB_STEP_SUMMARY + echo "- **Workflow Run:** [#${{ github.run_id }}](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }})" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "### Published To" >> $GITHUB_STEP_SUMMARY + if [[ "${PUBLISHED_VERSION}" == "true" ]]; then + echo "- โœ… **Version:** \`${VERSION}\` โ†’ \`s3://.../${S3_PATH}/archive/${VERSION}\`" >> $GITHUB_STEP_SUMMARY + if [[ "${PUBLISHED_LATEST}" == "true" ]]; then + echo "- โœ… **Latest:** \`${VERSION}\` โ†’ \`s3://.../${S3_PATH}/latest\` (updated to match release)" >> $GITHUB_STEP_SUMMARY + else + echo "- โญ๏ธ **Latest:** not updated (manual dispatch or /not-latest flag)" >> $GITHUB_STEP_SUMMARY + fi + fi + if [[ "${PUBLISHED_DEV}" == "true" ]]; then + echo "- โœ… **Dev:** \`s3://.../${S3_PATH}/dev\` (main branch)" >> $GITHUB_STEP_SUMMARY + fi + if [[ "${PUBLISHED_VERSION}" != "true" ]] && [[ "${PUBLISHED_DEV}" != "true" ]]; then + echo "- โš ๏ธ No documentation was published" >> $GITHUB_STEP_SUMMARY + fi + echo "" >> $GITHUB_STEP_SUMMARY + echo "### Cache" >> $GITHUB_STEP_SUMMARY + if [[ "${CACHE_FLUSHED}" == "true" ]]; then + echo "- โœ… Akamai cache flush requested for \`${FLUSH_PATH}\`" >> $GITHUB_STEP_SUMMARY + else + echo "- โญ๏ธ Cache flush skipped (nothing published or Akamai disabled)" >> $GITHUB_STEP_SUMMARY + fi diff --git a/.github/workflows/nightly-ci.yml b/.github/workflows/nightly-ci.yml index 84eba7b624..bee2426d1d 100644 --- a/.github/workflows/nightly-ci.yml +++ b/.github/workflows/nightly-ci.yml @@ -1,34 +1,96 @@ -name: Nightly CI +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +name: Nightly CI pipeline on: schedule: - - cron: '0 8 * * *' # Every day at 12:00 AM PST (08:00 UTC) - workflow_dispatch: + - cron: '0 8 * * *' # Every day at 12:00 AM PST (08:00 UTC) + +permissions: + contents: read + +defaults: + run: + shell: bash --noprofile --norc -eo pipefail {0} + +env: + REGISTRY_IMAGE: ai-dynamo/dynamo + NIGHTLY_IMAGE_PREFIX: nightly + +############################## BUILD JOBS ############################## jobs: - vllm: + build-amd64: + name: Build ${{ matrix.framework }} (amd64) + runs-on: cpu-amd-m5-4xlarge + timeout-minutes: 120 strategy: fail-fast: false matrix: - platform: - - { arch: amd64, runner: gpu-l40-amd64 } - - { arch: arm64, runner: cpu-arm-r8g-4xlarge } - name: vllm (${{ matrix.platform.arch }}) - runs-on: ${{ matrix.platform.runner }} + framework: [vllm, trtllm, sglang] + env: + ECR_HOSTNAME: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com steps: - - name: Checkout code - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 #v4.3.1 - - name: Build vLLM Docker Image - id: build-vllm + - uses: actions/checkout@v4 + - name: Login to Container Registries + uses: ./.github/actions/docker-login + with: + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} + azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} + azure_acr_user: ${{ secrets.AZURE_ACR_USER }} + azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} + - name: Pull existing images for cache + shell: bash + continue-on-error: true + run: | + echo "Attempting to pull existing images for layer caching..." + docker pull "${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:main-${{ matrix.framework }}-framework-amd64" || echo "Framework image not found in cache" + docker pull "${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-amd64" || echo "Runtime image not found in cache" + echo "Cache pull completed" + - name: Build Framework Image + id: build_framework uses: ./.github/actions/docker-build with: - framework: vllm + framework: ${{ matrix.framework }} + target: framework + platform: linux/amd64 + base_image_tag: '' + runtime_image_tag: '' + cuda_version: '' + torch_backend: '' + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} + ci_token: ${{ secrets.CI_TOKEN }} + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + sccache_s3_bucket: ${{ secrets.SCCACHE_S3_BUCKET }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + aws_access_key_id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + image_tag: framework-${{ matrix.framework }}-amd64:${{ github.run_id }} + - name: Tag and Push Framework Images + uses: ./.github/actions/docker-tag-push + with: + local_image: framework-${{ matrix.framework }}-amd64:${{ github.run_id }} + push_tags: | + ${{ env.REGISTRY_IMAGE }}:main-${{ matrix.framework }}-framework-amd64 + ${{ env.REGISTRY_IMAGE }}:main-${{ matrix.framework }}-framework-amd64-run-${{ github.run_id }} + aws_push: 'true' + azure_push: 'false' + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + - name: Build Runtime Image + id: build_runtime + uses: ./.github/actions/docker-build + with: + framework: ${{ matrix.framework }} target: runtime - platform: linux/${{ matrix.platform.arch }} - base_image_tag: ${{ matrix.platform.arch == 'arm64' && '25.06-cuda12.9-devel-ubuntu24.04' || '' }} - runtime_image_tag: ${{ matrix.platform.arch == 'arm64' && '12.9.0-runtime-ubuntu24.04' || '' }} - cuda_version: ${{ matrix.platform.arch == 'arm64' && '129' || '' }} - torch_backend: ${{ matrix.platform.arch == 'arm64' && 'cu129' || '' }} + platform: linux/amd64 + base_image_tag: '' + runtime_image_tag: '' + cuda_version: '' + torch_backend: '' ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} ci_token: ${{ secrets.CI_TOKEN }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} @@ -36,70 +98,77 @@ jobs: aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_access_key_id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - image_tag: nightly-vllm-${{ matrix.platform.arch }} - - name: Tag and Push vLLM Nightly Image + image_tag: runtime-${{ matrix.framework }}-amd64:${{ github.run_id }} + - name: Tag and Push Runtime Images uses: ./.github/actions/docker-tag-push with: - local_image: ${{ steps.build-vllm.outputs.image_tag }} - # Tag the image nightly - push_tag: ai-dynamo/dynamo:nightly-vllm-${{ matrix.platform.arch }} - aws_push: 'false' + local_image: runtime-${{ matrix.framework }}-amd64:${{ github.run_id }} + push_tags: | + ${{ env.REGISTRY_IMAGE }}:${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-amd64 + ${{ env.REGISTRY_IMAGE }}:${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-amd64-run-${{ github.run_id }} + aws_push: 'true' azure_push: 'true' aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} azure_acr_user: ${{ secrets.AZURE_ACR_USER }} azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} - - name: Run unit tests - if: ${{ matrix.platform.arch != 'arm64' }} - uses: ./.github/actions/pytest - with: - image_tag: nightly-vllm-${{ matrix.platform.arch }} - pytest_marks: "vllm and unit" - framework: "vllm" - test_type: "unit" - platform_arch: ${{ matrix.platform.arch }} - - name: Run e2e tests - if: ${{ matrix.platform.arch != 'arm64' }} - uses: ./.github/actions/pytest - with: - image_tag: nightly-vllm-${{ matrix.platform.arch }} - pytest_marks: "nightly and vllm and gpu_1" - framework: "vllm" - test_type: "e2e" - platform_arch: ${{ matrix.platform.arch }} - #################### - # Framework Builds # - #################### - vllm-framework: + build-arm64: + name: Build ${{ matrix.framework }} (arm64) + runs-on: cpu-arm-r8g-4xlarge + timeout-minutes: 120 strategy: fail-fast: false matrix: - platform: - - { arch: amd64, runner: cpu-amd-m5-4xlarge } - - { arch: arm64, runner: cpu-arm-r8g-4xlarge } - name: vllm-framework (${{ matrix.platform.arch }}) - runs-on: ${{ matrix.platform.runner }} + include: + - framework: vllm + base_image_tag: '25.06-cuda12.9-devel-ubuntu24.04' + runtime_image_tag: '12.9.0-runtime-ubuntu24.04' + cuda_version: '129' + torch_backend: 'cu129' + - framework: trtllm + base_image_tag: '25.06-py3' + runtime_image_tag: '' + cuda_version: '129' + torch_backend: 'cu129' + - framework: sglang + base_image_tag: '' + runtime_image_tag: '' + cuda_version: '' + torch_backend: '' env: - FRAMEWORK: vllm - steps: &framework-build-steps - - name: Checkout code - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 #v4.3.1 + ECR_HOSTNAME: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com + steps: + - uses: actions/checkout@v4 + - name: Login to Container Registries + uses: ./.github/actions/docker-login with: - ref: main - - name: Build Image - id: build-image + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} + azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} + azure_acr_user: ${{ secrets.AZURE_ACR_USER }} + azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} + - name: Pull existing images for cache + shell: bash + continue-on-error: true + run: | + echo "Attempting to pull existing images for layer caching..." + docker pull "${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:main-${{ matrix.framework }}-framework-arm64" || echo "Framework image not found in cache" + docker pull "${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-arm64" || echo "Runtime image not found in cache" + echo "Cache pull completed" + - name: Build Framework Image + id: build_framework uses: ./.github/actions/docker-build with: - framework: ${{ env.FRAMEWORK }} + framework: ${{ matrix.framework }} target: framework - platform: linux/${{ matrix.platform.arch }} - # Ternary operations that are specific to vllm/arm64, empty str for all other combinations - base_image_tag: ${{ (matrix.platform.arch == 'arm64' && env.FRAMEWORK == 'vllm') && '25.06-cuda12.9-devel-ubuntu24.04' || '' }} - runtime_image_tag: ${{ (matrix.platform.arch == 'arm64' && env.FRAMEWORK == 'vllm') && '12.9.0-runtime-ubuntu24.04' || '' }} - cuda_version: ${{ (matrix.platform.arch == 'arm64' && env.FRAMEWORK == 'vllm') && '129' || '' }} - torch_backend: ${{ (matrix.platform.arch == 'arm64' && env.FRAMEWORK == 'vllm') && 'cu129' || '' }} + platform: linux/arm64 + base_image_tag: ${{ matrix.base_image_tag }} + runtime_image_tag: ${{ matrix.runtime_image_tag }} + cuda_version: ${{ matrix.cuda_version }} + torch_backend: ${{ matrix.torch_backend }} ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} ci_token: ${{ secrets.CI_TOKEN }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} @@ -107,39 +176,630 @@ jobs: aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_access_key_id: ${{ secrets.AWS_ACCESS_KEY_ID }} aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - - name: Docker Tag and Push + image_tag: framework-${{ matrix.framework }}-arm64:${{ github.run_id }} + - name: Tag and Push Framework Images uses: ./.github/actions/docker-tag-push with: - local_image: ${{ steps.build-image.outputs.image_tag }} - push_tag: ai-dynamo/dynamo:main-${{ env.FRAMEWORK }}-framework-${{ matrix.platform.arch }} + local_image: framework-${{ matrix.framework }}-arm64:${{ github.run_id }} + push_tags: | + ${{ env.REGISTRY_IMAGE }}:main-${{ matrix.framework }}-framework-arm64 + ${{ env.REGISTRY_IMAGE }}:main-${{ matrix.framework }}-framework-arm64-run-${{ github.run_id }} aws_push: 'true' azure_push: 'false' aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + - name: Build Runtime Image + id: build_runtime + uses: ./.github/actions/docker-build + with: + framework: ${{ matrix.framework }} + target: runtime + platform: linux/arm64 + base_image_tag: ${{ matrix.base_image_tag }} + runtime_image_tag: ${{ matrix.runtime_image_tag }} + cuda_version: ${{ matrix.cuda_version }} + torch_backend: ${{ matrix.torch_backend }} + ngc_ci_access_token: ${{ secrets.NGC_CI_ACCESS_TOKEN }} + ci_token: ${{ secrets.CI_TOKEN }} + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + sccache_s3_bucket: ${{ secrets.SCCACHE_S3_BUCKET }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + aws_access_key_id: ${{ secrets.AWS_ACCESS_KEY_ID }} + aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + image_tag: runtime-${{ matrix.framework }}-arm64:${{ github.run_id }} + - name: Tag and Push Runtime Images + uses: ./.github/actions/docker-tag-push + with: + local_image: runtime-${{ matrix.framework }}-arm64:${{ github.run_id }} + push_tags: | + ${{ env.REGISTRY_IMAGE }}:${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-arm64 + ${{ env.REGISTRY_IMAGE }}:${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-arm64-run-${{ github.run_id }} + aws_push: 'true' + azure_push: 'true' + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} azure_acr_hostname: ${{ secrets.AZURE_ACR_HOSTNAME }} azure_acr_user: ${{ secrets.AZURE_ACR_USER }} azure_acr_password: ${{ secrets.AZURE_ACR_PASSWORD }} - sglang-framework: + +############################## TEST JOBS ############################## + + unit-tests: + name: ${{ matrix.framework }}-${{ matrix.arch.arch }}-unit + needs: [build-amd64, build-arm64] + if: always() + runs-on: ${{ matrix.arch.runner }} + timeout-minutes: 45 strategy: fail-fast: false matrix: - platform: - - { arch: amd64, runner: cpu-amd-m5-4xlarge } - - { arch: arm64, runner: cpu-arm-r8g-4xlarge } - name: sglang-framework (${{ matrix.platform.arch }}) - runs-on: ${{ matrix.platform.runner }} - env: - FRAMEWORK: sglang - steps: *framework-build-steps - trtllm-framework: + framework: [vllm, trtllm, sglang] + arch: + - arch: amd64 + runner: gpu-l40-amd64 + - arch: arm64 + runner: cpu-arm-r8g-4xlarge + steps: + - uses: actions/checkout@v4 + - name: Check if build succeeded + id: check_build + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + set +x + echo "Checking build status for ${{ matrix.framework }} (${{ matrix.arch.arch }})" + # Determine which build job to check + if [ "${{ matrix.arch.arch }}" = "amd64" ]; then + BUILD_JOB_NAME="Build ${{ matrix.framework }} (amd64)" + else + BUILD_JOB_NAME="Build ${{ matrix.framework }} (arm64)" + fi + # Query GitHub API for job status using curl (token from env to avoid log exposure) + JOBS=$(curl -s -S -L --fail-with-body \ + -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + -H "Accept: application/vnd.github.v3+json" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/jobs?per_page=100" 2>&1) + if [ $? -ne 0 ]; then + echo "Error: Failed to query GitHub API" + exit 1 + fi + # Find the specific build job and check its conclusion + BUILD_STATUS=$(echo "$JOBS" | jq -r --arg job_name "$BUILD_JOB_NAME" '.jobs[] | select(.name == $job_name) | .conclusion') + echo "Build status for '$BUILD_JOB_NAME': $BUILD_STATUS" + if [ "$BUILD_STATUS" != "success" ]; then + echo "Build failed or did not complete successfully. Failing tests." + exit 1 + fi + echo "Build succeeded. Proceeding with tests." + - name: Login to Container Registries + uses: ./.github/actions/docker-login + with: + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + - name: Pull nightly image + shell: bash + env: + ECR_HOSTNAME: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com + IMAGE_TAG: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + run: | + docker pull ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} + docker tag ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} ${IMAGE_TAG} + - name: Run Unit Tests + uses: ./.github/actions/pytest + with: + image_tag: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + pytest_marks: "unit and (nightly or post_merge or pre_merge)" + framework: ${{ matrix.framework }} + test_type: unit + platform_arch: ${{ matrix.arch.arch }} + cpu_limit: '8' + dry_run: ${{ matrix.arch.arch == 'arm64' && 'true' || 'false' }} + + integration-tests: + name: ${{ matrix.framework }}-${{ matrix.arch.arch }}-integ + needs: [build-amd64, build-arm64] + if: always() + runs-on: ${{ matrix.arch.runner }} + timeout-minutes: ${{ matrix.arch.timeout }} strategy: fail-fast: false matrix: - platform: - - { arch: amd64, runner: cpu-amd-m5-4xlarge } - - { arch: arm64, runner: cpu-arm-r8g-4xlarge } - name: trtllm-framework (${{ matrix.platform.arch }}) - runs-on: ${{ matrix.platform.runner }} + framework: [vllm, trtllm, sglang] + arch: + - arch: amd64 + runner: gpu-l40-amd64 + timeout: 90 + - arch: arm64 + runner: cpu-arm-r8g-4xlarge + timeout: 90 + steps: + - uses: actions/checkout@v4 + - name: Check if build succeeded + id: check_build + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + set +x + echo "Checking build status for ${{ matrix.framework }} (${{ matrix.arch.arch }})" + BUILD_JOB_NAME="Build ${{ matrix.framework }} (${{ matrix.arch.arch }})" + JOBS=$(curl -s -S -L --fail-with-body \ + -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + -H "Accept: application/vnd.github.v3+json" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/jobs?per_page=100" 2>&1) + if [ $? -ne 0 ]; then + echo "Error: Failed to query GitHub API" + exit 1 + fi + BUILD_STATUS=$(echo "$JOBS" | jq -r --arg job_name "$BUILD_JOB_NAME" '.jobs[] | select(.name == $job_name) | .conclusion') + echo "Build status for '$BUILD_JOB_NAME': $BUILD_STATUS" + if [ "$BUILD_STATUS" != "success" ]; then + echo "Build failed or did not complete successfully. Marking tests as failed." + exit 1 + fi + echo "Build succeeded. Proceeding with tests." + - name: Login to Container Registries + uses: ./.github/actions/docker-login + with: + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + - name: Pull nightly image + shell: bash + env: + ECR_HOSTNAME: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com + IMAGE_TAG: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + run: | + docker pull ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} + docker tag ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} ${IMAGE_TAG} + - name: Run Integration Tests + uses: ./.github/actions/pytest + with: + image_tag: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + pytest_marks: "integration and (nightly or post_merge or pre_merge)" + framework: ${{ matrix.framework }} + test_type: integration + platform_arch: ${{ matrix.arch.arch }} + dry_run: ${{ matrix.arch.arch == 'arm64' && 'true' || 'false' }} + + e2e-single-gpu-tests: + name: ${{ matrix.framework }}-${{ matrix.arch.arch }}-1gpu-e2e + needs: [build-amd64, build-arm64] + if: always() + runs-on: ${{ matrix.arch.runner }} + timeout-minutes: ${{ matrix.arch.timeout }} + strategy: + fail-fast: false + matrix: + framework: [vllm, trtllm, sglang] + arch: + - arch: amd64 + runner: gpu-l40-amd64 + timeout: 120 + - arch: arm64 + runner: cpu-arm-r8g-4xlarge + timeout: 120 + steps: + - uses: actions/checkout@v4 + - name: Check if build succeeded + id: check_build + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + set +x + echo "Checking build status for ${{ matrix.framework }} (${{ matrix.arch.arch }})" + BUILD_JOB_NAME="Build ${{ matrix.framework }} (${{ matrix.arch.arch }})" + JOBS=$(curl -s -S -L --fail-with-body \ + -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + -H "Accept: application/vnd.github.v3+json" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/jobs?per_page=100" 2>&1) + if [ $? -ne 0 ]; then + echo "Error: Failed to query GitHub API" + echo "skip=true" >> $GITHUB_OUTPUT + exit 0 + fi + BUILD_STATUS=$(echo "$JOBS" | jq -r --arg job_name "$BUILD_JOB_NAME" '.jobs[] | select(.name == $job_name) | .conclusion') + echo "Build status for '$BUILD_JOB_NAME': $BUILD_STATUS" + if [ "$BUILD_STATUS" != "success" ]; then + echo "Build failed or did not complete successfully. Failing tests." + exit 1 + fi + echo "Build succeeded. Proceeding with tests." + - name: Login to Container Registries + uses: ./.github/actions/docker-login + with: + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + - name: Pull nightly image + shell: bash + env: + ECR_HOSTNAME: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com + IMAGE_TAG: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + run: | + docker pull ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} + docker tag ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} ${IMAGE_TAG} + - name: Run E2E Tests (gpu_1) + uses: ./.github/actions/pytest + with: + image_tag: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + pytest_marks: "${{ matrix.framework }} and e2e and gpu_1" + framework: ${{ matrix.framework }} + test_type: e2e-single-gpu + platform_arch: ${{ matrix.arch.arch }} + dry_run: ${{ matrix.arch.arch == 'arm64' && 'true' || 'false' }} + + e2e-multi-gpu-tests: + name: ${{ matrix.framework }}-${{ matrix.arch.arch }}-2gpu-e2e + needs: [build-amd64, build-arm64] + if: always() + runs-on: ${{ matrix.arch.runner }} + timeout-minutes: ${{ matrix.arch.timeout }} + strategy: + fail-fast: false + matrix: + framework: [vllm, trtllm, sglang] + arch: + - arch: amd64 + runner: gpu-l40-amd64 + timeout: 150 + - arch: arm64 + runner: cpu-arm-r8g-4xlarge + timeout: 150 + steps: + - uses: actions/checkout@v4 + - name: Check if build succeeded + id: check_build + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + set +x + echo "Checking build status for ${{ matrix.framework }} (${{ matrix.arch.arch }})" + BUILD_JOB_NAME="Build ${{ matrix.framework }} (${{ matrix.arch.arch }})" + JOBS=$(curl -s -S -L --fail-with-body \ + -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + -H "Accept: application/vnd.github.v3+json" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/jobs?per_page=100" 2>&1) + if [ $? -ne 0 ]; then + echo "Error: Failed to query GitHub API" + echo "skip=true" >> $GITHUB_OUTPUT + exit 0 + fi + BUILD_STATUS=$(echo "$JOBS" | jq -r --arg job_name "$BUILD_JOB_NAME" '.jobs[] | select(.name == $job_name) | .conclusion') + echo "Build status for '$BUILD_JOB_NAME': $BUILD_STATUS" + if [ "$BUILD_STATUS" != "success" ]; then + echo "Build failed or did not complete successfully. Marking tests as failed." + exit 1 + fi + echo "Build succeeded. Proceeding with tests." + - name: Login to Container Registries + uses: ./.github/actions/docker-login + with: + aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + - name: Pull nightly image + shell: bash + env: + ECR_HOSTNAME: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com + IMAGE_TAG: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + run: | + docker pull ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} + docker tag ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} ${IMAGE_TAG} + - name: Run E2E Tests (gpu_2) + uses: ./.github/actions/pytest + with: + image_tag: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + pytest_marks: "(nightly or post_merge or pre_merge) and e2e and gpu_2" + framework: ${{ matrix.framework }} + test_type: e2e-multi-gpu + platform_arch: ${{ matrix.arch.arch }} + dry_run: 'true' + + # component-tests: + # name: ${{ matrix.framework }}-${{ matrix.arch.arch }}-${{ matrix.component }} + # needs: [build-amd64, build-arm64] + # if: always() + # runs-on: ${{ matrix.arch.runner }} + # timeout-minutes: ${{ matrix.arch.timeout }} + # strategy: + # fail-fast: false + # matrix: + # framework: [vllm, trtllm, sglang] + # arch: + # - arch: amd64 + # runner: gpu-l40-amd64 + # timeout: 90 + # component: router + # marks: "nightly and router" + # - arch: amd64 + # runner: gpu-l40-amd64 + # timeout: 90 + # component: planner + # marks: "nightly and planner" + # - arch: amd64 + # runner: gpu-l40-amd64 + # timeout: 150 + # component: kvbm + # marks: "nightly and (kvbm or kvbm_v2)" + # - arch: arm64 + # runner: cpu-arm-r8g-4xlarge + # timeout: 60 + # component: router + # marks: "nightly and router" + # - arch: arm64 + # runner: cpu-arm-r8g-4xlarge + # timeout: 60 + # component: planner + # marks: "nightly and planner" + # - arch: arm64 + # runner: cpu-arm-r8g-4xlarge + # timeout: 150 + # component: kvbm + # marks: "nightly and (kvbm or kvbm_v2)" + # - arch: amd64 + # runner: gpu-l40-amd64 + # timeout: 90 + # component: router + # marks: "nightly and router" + # - arch: amd64 + # runner: gpu-l40-amd64 + # timeout: 90 + # component: planner + # marks: "nightly and planner" + # - arch: amd64 + # runner: gpu-l40-amd64 + # timeout: 150 + # component: kvbm + # marks: "nightly and (kvbm or kvbm_v2)" + # - arch: arm64 + # runner: cpu-arm-r8g-4xlarge + # timeout: 60 + # component: router + # marks: "nightly and router" + # - arch: arm64 + # runner: cpu-arm-r8g-4xlarge + # timeout: 60 + # component: planner + # marks: "nightly and planner" + # - arch: arm64 + # runner: cpu-arm-r8g-4xlarge + # timeout: 150 + # component: kvbm + # marks: "nightly and (kvbm or kvbm_v2)" + # - arch: amd64 + # runner: gpu-l40-amd64 + # timeout: 90 + # component: router + # marks: "nightly and router" + # - arch: amd64 + # runner: gpu-l40-amd64 + # timeout: 90 + # component: planner + # marks: "nightly and planner" + # - arch: amd64 + # runner: gpu-l40-amd64 + # timeout: 150 + # component: kvbm + # marks: "nightly and (kvbm or kvbm_v2)" + # - arch: arm64 + # runner: cpu-arm-r8g-4xlarge + # timeout: 60 + # component: router + # marks: "nightly and router" + # - arch: arm64 + # runner: cpu-arm-r8g-4xlarge + # timeout: 60 + # component: planner + # marks: "nightly and planner" + # - arch: arm64 + # runner: cpu-arm-r8g-4xlarge + # timeout: 150 + # component: kvbm + # marks: "nightly and (kvbm or kvbm_v2)" + + # steps: + # - uses: actions/checkout@v4 + # - name: Check if build succeeded + # id: check_build + # env: + # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # run: | + # set +x + # echo "Checking build status for ${{ matrix.framework }} (${{ matrix.arch.arch }})" + + # if [ "${{ matrix.arch.arch }}" = "amd64" ]; then + # BUILD_JOB_NAME="Build ${{ matrix.framework }} (amd64)" + # else + # BUILD_JOB_NAME="Build ${{ matrix.framework }} (arm64)" + # fi + + # JOBS=$(curl -s -S -L --fail-with-body \ + # -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + # -H "Accept: application/vnd.github.v3+json" \ + # "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/jobs?per_page=100" 2>&1) + + # if [ $? -ne 0 ]; then + # echo "Error: Failed to query GitHub API" + # echo "skip=true" >> $GITHUB_OUTPUT + # exit 0 + # fi + + # BUILD_STATUS=$(echo "$JOBS" | jq -r --arg job_name "$BUILD_JOB_NAME" '.jobs[] | select(.name == $job_name) | .conclusion') + + # echo "Build status for '$BUILD_JOB_NAME': $BUILD_STATUS" + + # if [ "$BUILD_STATUS" != "success" ]; then + # echo "Build failed or did not complete successfully. Marking tests as failed." + # exit 1 + # fi + + # echo "Build succeeded. Proceeding with tests." + # - name: Login to Container Registries + # uses: ./.github/actions/docker-login + # with: + # aws_default_region: ${{ secrets.AWS_DEFAULT_REGION }} + # aws_account_id: ${{ secrets.AWS_ACCOUNT_ID }} + # - name: Pull nightly image + # shell: bash + # env: + # ECR_HOSTNAME: ${{ secrets.AWS_ACCOUNT_ID }}.dkr.ecr.${{ secrets.AWS_DEFAULT_REGION }}.amazonaws.com + # IMAGE_TAG: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + # run: | + # docker pull ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} + # docker tag ${ECR_HOSTNAME}/${{ env.REGISTRY_IMAGE }}:${IMAGE_TAG} ${IMAGE_TAG} + # - name: Run Component Tests (${{ matrix.component }}) + # uses: ./.github/actions/pytest + # with: + # image_tag: ${{ env.NIGHTLY_IMAGE_PREFIX }}-${{ matrix.framework }}-${{ matrix.arch.arch }} + # pytest_marks: "${{ matrix.marks }}" + # framework: ${{ matrix.framework }} + # test_type: component-${{ matrix.component }} + # platform_arch: ${{ matrix.arch.arch }} + + ############################## RESULTS SUMMARY ############################## + results-summary: + name: Results Summary + runs-on: ubuntu-latest + if: always() + needs: [build-amd64, build-arm64, unit-tests, integration-tests, e2e-single-gpu-tests, e2e-multi-gpu-tests] # component-tests + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Gather job metadata + id: gather + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + set +x -e + echo "# Nightly CI Results Summary" > results.md + echo "" >> results.md + echo "| Stage | Status | Runner | Duration (min) | Artifacts |" >> results.md + echo "|-------|--------|--------|----------------|-----------|" >> results.md + + curl -s -S -L --fail-with-body \ + -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + -H "Accept: application/vnd.github.v3+json" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/jobs?per_page=100" \ + 2>/dev/null | jq -c '.jobs[] | {id, name, runner_name, conclusion, started_at, completed_at}' > jobs.jsonl + + while read job_entry; do + job_id=$(echo "$job_entry" | jq -r '.id') + name=$(echo "$job_entry" | jq -r '.name') + runner=$(echo "$job_entry" | jq -r '.runner_name') + status=$(echo "$job_entry" | jq -r '.conclusion') + started=$(echo "$job_entry" | jq -r '.started_at') + completed=$(echo "$job_entry" | jq -r '.completed_at') + minutes="N/A" + if [[ "$started" != "null" && "$completed" != "null" ]]; then + start_epoch=$(date -d "$started" +%s) + end_epoch=$(date -d "$completed" +%s) + minutes=$(( (end_epoch - start_epoch)/60 )) + fi + artifact_link="https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}#job-$job_id" + printf "| %s | %s | %s | %s | [Log & Artifacts](%s) |\n" "$name" "$status" "$runner" "$minutes" "$artifact_link" >> results.md + done < jobs.jsonl + + echo "" >> results.md + echo "---" >> results.md + - name: Display workflow summary + run: cat results.md + - name: Upload results summary as job summary + run: cat results.md >> $GITHUB_STEP_SUMMARY + - name: Upload results as artifact for Slack + uses: actions/upload-artifact@v4 + if: always() + with: + name: nightly-results-summary + path: results.md + retention-days: 7 + + ############################## SLACK NOTIFICATION ############################## + notify-slack: + name: Notify Slack + runs-on: cpu-amd-m5-4xlarge + if: always() && github.event_name == 'schedule' && !github.event.repository.fork + needs: results-summary + permissions: + contents: read env: - FRAMEWORK: trtllm - steps: *framework-build-steps + HAS_SLACK_WEBHOOK: ${{ secrets.SLACK_NOTIFY_NIGHTLY_WEBHOOK_URL != '' }} + steps: + - name: Send Slack notification + if: env.HAS_SLACK_WEBHOOK == 'true' + continue-on-error: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_NOTIFY_NIGHTLY_WEBHOOK_URL }} + RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + run: | + set -euo pipefail + + JOBS_JSON=$(mktemp) + trap 'rm -f "$JOBS_JSON"' EXIT + + if ! curl -sSL \ + -H "Authorization: Bearer ${GITHUB_TOKEN}" \ + -H "Accept: application/vnd.github+json" \ + "https://api.github.com/repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/jobs?per_page=100" \ + > "$JOBS_JSON"; then + echo "Error: Failed to fetch job data from GitHub API" + exit 1 + fi + + if [ ! -s "$JOBS_JSON" ]; then + echo "Error: No job data received" + exit 1 + fi + + TOTAL_JOBS=$(jq '[.jobs[]] | length' "$JOBS_JSON") + SUCCESS_COUNT=$(jq '[.jobs[] | select(.conclusion == "success")] | length' "$JOBS_JSON") + FAILED_COUNT=$(jq '[.jobs[] | select(.conclusion == "failure")] | length' "$JOBS_JSON") + + if [ "$FAILED_COUNT" -eq 0 ]; then + STATUS="Success โœ…" + STATUS_EMOJI=":white_check_mark:" + else + STATUS="Failed โŒ" + STATUS_EMOJI=":x:" + fi + + # Main message with summary + SUMMARY_TEXT="*Nightly CI Pipeline - ${STATUS}*"$'\n'"Summary: ${SUCCESS_COUNT}/${TOTAL_JOBS} jobs passed"$'\n'"<${RUN_URL}|View Workflow Summary>" + + if [ "$FAILED_COUNT" -eq 0 ]; then + # Success - simple message + PAYLOAD=$(jq -n \ + --arg text "$SUMMARY_TEXT" \ + '{text: $text}') + else + # Failed - message with blocks + FAILED_JOBS=$(jq -r '.jobs[] | select(.conclusion == "failure") | "โ€ข " + .name' "$JOBS_JSON") + FAILED_JOBS_TEXT="*Failed Jobs (${FAILED_COUNT}):*"$'\n'"${FAILED_JOBS}" + + PAYLOAD=$(jq -n \ + --arg summary "$SUMMARY_TEXT" \ + --arg failed "$FAILED_JOBS_TEXT" \ + '{ + text: $summary, + blocks: [ + { + type: "section", + text: { + type: "mrkdwn", + text: $summary + } + }, + { + type: "section", + text: { + type: "mrkdwn", + text: $failed + } + } + ] + }') + fi + + if curl -sSf -X POST -H "Content-Type: application/json" -d "$PAYLOAD" "$SLACK_WEBHOOK_URL"; then + echo "Slack notification sent successfully" + else + echo "Warning: Failed to send Slack notification" + exit 1 + fi diff --git a/.github/workflows/templates/README.md b/.github/workflows/templates/README.md new file mode 100644 index 0000000000..add5d96a08 --- /dev/null +++ b/.github/workflows/templates/README.md @@ -0,0 +1,21 @@ +# Workflow Templates + +This directory contains reusable templates and utilities for GitHub Actions workflows. + +## Files + +### akamai-eccu-flush.xslt + +XSLT template for generating Akamai ECCU (Edge Content Control Utility) XML requests. + +**Purpose**: Generates XML for cache invalidation requests to Akamai CDN. + +**Usage**: +```bash +xsltproc --stringparam target-path "path/to/flush" \ + akamai-eccu-flush.xslt akamai-eccu-flush.xslt > eccu-request.xml +``` + +**Used by**: `.github/workflows/publish-s3.yml` for flushing CDN cache after documentation deployment. + +The template creates a hierarchical XML structure with nested `match:recursive-dirs` elements representing the directory path to invalidate in the Akamai cache. diff --git a/.github/workflows/templates/akamai-eccu-flush.xslt b/.github/workflows/templates/akamai-eccu-flush.xslt new file mode 100644 index 0000000000..80a05c7523 --- /dev/null +++ b/.github/workflows/templates/akamai-eccu-flush.xslt @@ -0,0 +1,78 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + now + + + + + + diff --git a/CODEOWNERS b/CODEOWNERS index cbae9d0262..f0f982e24b 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -18,7 +18,7 @@ Cargo.toml @ai-dynamo/dynamo-rust-codeowners # Dynamo deploy /deploy/ @ai-dynamo/dynamo-deploy-codeowners /examples/*/deploy/ @ai-dynamo/dynamo-deploy-codeowners - +/examples/backends/*/deploy/ @ai-dynamo/dynamo-deploy-codeowners # CI/CD /.github/ @ai-dynamo/Devops /.github/workflows/*.ps1 @ai-dynamo/Devops diff --git a/Cargo.lock b/Cargo.lock index 417a3d1d5b..cba8d9f485 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1602,6 +1602,19 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "console" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b430743a6eb14e9764d4260d4c0d8123087d504eeb9c48f2b2a5e810dd369df4" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width 0.2.2", + "windows-sys 0.61.2", +] + [[package]] name = "console-api" version = "0.8.1" @@ -2370,7 +2383,7 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" dependencies = [ - "console", + "console 0.15.11", "shell-words", "tempfile", "thiserror 1.0.69", @@ -2650,6 +2663,7 @@ dependencies = [ "bytes", "candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", "chrono", + "clap 4.5.53", "criterion 0.3.6", "cudarc", "dashmap 5.5.3", @@ -2671,6 +2685,7 @@ dependencies = [ "hyper 1.8.1", "hyper-util", "image", + "indicatif 0.18.3", "insta", "itertools 0.14.0", "json-five", @@ -4051,7 +4066,7 @@ dependencies = [ "dirs", "futures", "http 1.4.0", - "indicatif", + "indicatif 0.17.11", "libc", "log", "num_cpus", @@ -4672,7 +4687,7 @@ version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" dependencies = [ - "console", + "console 0.15.11", "number_prefix", "portable-atomic", "rayon", @@ -4680,6 +4695,19 @@ dependencies = [ "web-time", ] +[[package]] +name = "indicatif" +version = "0.18.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9375e112e4b463ec1b1c6c011953545c65a30164fbab5b581df32b3abf0dcb88" +dependencies = [ + "console 0.16.1", + "portable-atomic", + "unicode-width 0.2.2", + "unit-prefix", + "web-time", +] + [[package]] name = "inlinable_string" version = "0.1.15" @@ -4734,7 +4762,7 @@ version = "1.44.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5c943d4415edd8153251b6f197de5eb1640e56d84e8d9159bea190421c73698" dependencies = [ - "console", + "console 0.15.11", "globset", "once_cell", "pest", @@ -6135,7 +6163,7 @@ dependencies = [ "http 1.4.0", "image", "indexmap 2.12.1", - "indicatif", + "indicatif 0.17.11", "interprocess", "itertools 0.14.0", "libc", @@ -11546,6 +11574,12 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "unit-prefix" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81e544489bf3d8ef66c953931f56617f423cd4b5494be343d9b9d3dda037b9a3" + [[package]] name = "universal-hash" version = "0.5.1" diff --git a/benchmarks/profiler/profile_sla.py b/benchmarks/profiler/profile_sla.py index 560c1000fe..cad24242f5 100644 --- a/benchmarks/profiler/profile_sla.py +++ b/benchmarks/profiler/profile_sla.py @@ -50,6 +50,7 @@ profile_prefill_aiconfigurator, ) from benchmarks.profiler.utils.profiler_argparse import create_profiler_parser +from benchmarks.profiler.webui.select_config import pick_config_with_webui from deploy.utils.dynamo_deployment import ( DynamoDeploymentClient, cleanup_remaining_deployments, @@ -476,45 +477,57 @@ async def run_profile(args): # Safety guards: no results โ†’ exit early with a clear message if not prefill_data.num_gpus: logger.error("No prefill results produced; skipping recommendations.") + return - # select best parallel mapping for prefill - if min(prefill_data.ttft) > args.ttft: - logger.warning( - "No engine configuration satisfies the TTFT requirement, please try a smaller model or more powerful hardware" + if args.pick_with_webui: + # select best P/D config in webUI + selected_prefill_idx, selected_decode_idx = pick_config_with_webui( + prefill_data, decode_data, args ) - selected_prefill_idx = int(np.argmin(np.array(prefill_data.ttft))) else: - valid_indices = [ - i for i, ttft in enumerate(prefill_data.ttft) if ttft <= args.ttft - ] - # Among valid TP sizes, select the one with highest throughput per GPU - valid_thpts = [prefill_data.thpt_per_gpu[i] for i in valid_indices] - max_thpt_idx = valid_indices[int(np.argmax(valid_thpts))] - selected_prefill_idx = max_thpt_idx - logger.info( - f"Suggested prefill parallel mapping: {prefill_data.parallel_mapping_labels[selected_prefill_idx]} on {prefill_data.num_gpus[selected_prefill_idx]} GPU(s) (TTFT {prefill_data.ttft[selected_prefill_idx]:.2f} ms, throughput {prefill_data.thpt_per_gpu[selected_prefill_idx]:.2f} tokens/s/GPU)" - ) + # automatically select P/D config within SLA with the highest throughput/GPU + # select best parallel mapping for prefill + if min(prefill_data.ttft) > args.ttft: + logger.warning( + "No engine configuration satisfies the TTFT requirement, please try a smaller model or more powerful hardware" + ) + selected_prefill_idx = int(np.argmin(np.array(prefill_data.ttft))) + else: + valid_indices = [ + i + for i, ttft in enumerate(prefill_data.ttft) + if ttft <= args.ttft + ] + # Among valid TP sizes, select the one with highest throughput per GPU + valid_thpts = [prefill_data.thpt_per_gpu[i] for i in valid_indices] + max_thpt_idx = valid_indices[int(np.argmax(valid_thpts))] + selected_prefill_idx = max_thpt_idx + logger.info( + f"Suggested prefill parallel mapping: {prefill_data.parallel_mapping_labels[selected_prefill_idx]} on {prefill_data.num_gpus[selected_prefill_idx]} GPU(s) (TTFT {prefill_data.ttft[selected_prefill_idx]:.2f} ms, throughput {prefill_data.thpt_per_gpu[selected_prefill_idx]:.2f} tokens/s/GPU)" + ) - # select best parallel mapping for decode - if not decode_data.num_gpus: - logger.error("No decode results produced; skipping recommendations.") - return - if min(decode_data.itl) > args.itl: - logger.warning( - "No engine configuration satisfies the ITL requirement, please try a smaller model or more powerful hardware" + # select best parallel mapping for decode + if not decode_data.num_gpus: + logger.error( + "No decode results produced; skipping recommendations." + ) + return + if min(decode_data.itl) > args.itl: + logger.warning( + "No engine configuration satisfies the ITL requirement, please try a smaller model or more powerful hardware" + ) + selected_decode_idx = int(np.argmin(np.array(decode_data.itl))) + else: + valid_indices = [ + i for i, itl in enumerate(decode_data.itl) if itl <= args.itl + ] + # Among valid TP sizes, select the one with highest throughput per GPU + valid_thpts = [decode_data.thpt_per_gpu[i] for i in valid_indices] + max_thpt_idx = valid_indices[int(np.argmax(valid_thpts))] + selected_decode_idx = max_thpt_idx + logger.info( + f"Suggested decode parallel mapping: {decode_data.parallel_mapping_labels[selected_decode_idx]} on {decode_data.num_gpus[selected_decode_idx]} GPU(s) (ITL {decode_data.itl[selected_decode_idx]:.2f} ms, throughput {decode_data.thpt_per_gpu[selected_decode_idx]:.2f} tokens/s/GPU)" ) - selected_decode_idx = int(np.argmin(np.array(decode_data.itl))) - else: - valid_indices = [ - i for i, itl in enumerate(decode_data.itl) if itl <= args.itl - ] - # Among valid TP sizes, select the one with highest throughput per GPU - valid_thpts = [decode_data.thpt_per_gpu[i] for i in valid_indices] - max_thpt_idx = valid_indices[int(np.argmax(valid_thpts))] - selected_decode_idx = max_thpt_idx - logger.info( - f"Suggested decode parallel mapping: {decode_data.parallel_mapping_labels[selected_decode_idx]} on {decode_data.num_gpus[selected_decode_idx]} GPU(s) (ITL {decode_data.itl[selected_decode_idx]:.2f} ms, throughput {decode_data.thpt_per_gpu[selected_decode_idx]:.2f} tokens/s/GPU)" - ) if args.dry_run: # use min value for prefill and decode GPU counts diff --git a/benchmarks/profiler/utils/defaults.py b/benchmarks/profiler/utils/defaults.py index c15b510d7c..f0f97c635b 100644 --- a/benchmarks/profiler/utils/defaults.py +++ b/benchmarks/profiler/utils/defaults.py @@ -30,6 +30,10 @@ AIPERF_PREFILL_BENCHMARK_OSL = 5 AIPERF_PREFILL_ATTN_DP_NUM_REQ_RATIO = 4 +# Cost calculation defaults +# TODO: allow user to configure this in GUI +GPU_COST_PER_HOUR = 3.0 # Cost per GPU per hour in dollars + class EngineType(str, Enum): PREFILL = "prefill" diff --git a/benchmarks/profiler/utils/pareto.py b/benchmarks/profiler/utils/pareto.py index 0ab1104673..9e8d52de54 100644 --- a/benchmarks/profiler/utils/pareto.py +++ b/benchmarks/profiler/utils/pareto.py @@ -4,33 +4,39 @@ def compute_pareto(x, y): """ - compute the pareto front (top-left is better) for the given x and y values - return sorted lists of the x and y values for the pareto front + Compute the pareto front (top-left is better) for the given x and y values. + + Returns: + tuple: (xs, ys, indices) where: + - xs: list of x values on the pareto front + - ys: list of y values on the pareto front + - indices: list of original indices corresponding to the pareto points """ # Validate inputs if x is None or y is None: - return [], [] + return [], [], [] if len(x) != len(y): raise ValueError("x and y must have the same length") if len(x) == 0: - return [], [] + return [], [], [] - # Build point list and sort by x asc, then y desc so we prefer smaller x and larger y. - points = list(zip(x, y)) + # Build point list with original indices and sort by x asc, then y desc + points = [(x[i], y[i], i) for i in range(len(x))] points.sort(key=lambda p: (p[0], -p[1])) - # Single pass to keep only non-dominated points (minimize x, maximize y). + # Single pass to keep only non-dominated points (minimize x, maximize y) pareto = [] max_y = float("-inf") - for px, py in points: + for px, py, idx in points: if py > max_y: - pareto.append((px, py)) + pareto.append((px, py, idx)) max_y = py # Return sorted by x ascending for convenience pareto.sort(key=lambda p: (p[0], p[1])) - xs = [px for px, _ in pareto] - ys = [py for _, py in pareto] - return xs, ys + xs = [px for px, _, _ in pareto] + ys = [py for _, py, _ in pareto] + indices = [idx for _, _, idx in pareto] + return xs, ys, indices diff --git a/benchmarks/profiler/utils/plot.py b/benchmarks/profiler/utils/plot.py index 10c7077022..68e14b1b4f 100644 --- a/benchmarks/profiler/utils/plot.py +++ b/benchmarks/profiler/utils/plot.py @@ -21,6 +21,7 @@ from matplotlib import cm from scipy.interpolate import griddata +from benchmarks.profiler.utils.defaults import GPU_COST_PER_HOUR from benchmarks.profiler.utils.pareto import compute_pareto logger = logging.getLogger(__name__) @@ -297,13 +298,11 @@ def plot_pd_joint_results(isl, osl, prefill_data, decode_data, output_dir): decode_data: DecodeProfileData instance containing profiling results output_dir: directory to save the plot """ - GPU_COST_PER_HOUR = 3.0 # $3/hour - # compute pareto front for prefill - p_ttft, p_thpt = compute_pareto(prefill_data.ttft, prefill_data.thpt_per_gpu) + p_ttft, p_thpt, _ = compute_pareto(prefill_data.ttft, prefill_data.thpt_per_gpu) # compute pareto front for decode - d_itl, d_thpt = compute_pareto(decode_data.itl, decode_data.thpt_per_gpu) + d_itl, d_thpt, _ = compute_pareto(decode_data.itl, decode_data.thpt_per_gpu) # convert to cost per thousand requests p_ttft = np.array(p_ttft) diff --git a/benchmarks/profiler/utils/profiler_argparse.py b/benchmarks/profiler/utils/profiler_argparse.py index 4a35ef8387..ee84075f53 100644 --- a/benchmarks/profiler/utils/profiler_argparse.py +++ b/benchmarks/profiler/utils/profiler_argparse.py @@ -3,6 +3,7 @@ import argparse import ast +import os from typing import Any, Dict import yaml @@ -84,6 +85,8 @@ def create_profiler_parser() -> argparse.Namespace: aic_backend: String (aiconfigurator backend of the target model, if not provided, will use args.backend, default: "") aic_backend_version: String (specify backend version when using aiconfigurator to estimate perf, default: None) dry_run: Boolean (dry run the profile job, default: False) + pick_with_webui: Boolean (pick the best parallelization mapping using webUI, default: False) + webui_port: Int (webUI port, default: $PROFILER_WEBUI_PORT or 8000) sla: isl: Int (target input sequence length, default: 3000) osl: Int (target output sequence length, default: 500) @@ -113,6 +116,8 @@ def create_profiler_parser() -> argparse.Namespace: help="Configuration as Python dict literal, YAML, or JSON string. CLI args override config values. " "Example: \"{'engine': {'backend': 'vllm', 'config': '/path'}, 'sla': {'isl': 3000}}\"", ) + + # CLI arguments with config-aware defaults (using nested .get() for cleaner code) parser.add_argument( "--model", type=str, @@ -126,7 +131,6 @@ def create_profiler_parser() -> argparse.Namespace: help="Container image to use for DGD components (frontend, planner, workers). Overrides images in config file.", ) - # CLI arguments with config-aware defaults (using nested .get() for cleaner code) parser.add_argument( "--namespace", type=str, @@ -233,6 +237,23 @@ def create_profiler_parser() -> argparse.Namespace: default=config.get("hardware", {}).get("enable_gpu_discovery", False), help="Enable automatic GPU discovery from Kubernetes cluster nodes. When enabled, overrides any manually specified hardware configuration. Requires cluster-wide node access permissions.", ) + parser.add_argument( + "--pick-with-webui", + action="store_true", + default=config.get("sweep", {}).get("pick_with_webui", False), + help="Pick the best parallelization mapping using webUI", + ) + + default_webui_port = 8000 + webui_port_env = os.environ.get("PROFILER_WEBUI_PORT") + if webui_port_env: + default_webui_port = int(webui_port_env) + parser.add_argument( + "--webui-port", + type=int, + default=config.get("sweep", {}).get("webui_port", default_webui_port), + help="WebUI port", + ) # Dynamically add all planner arguments from planner_argparse.py add_planner_arguments_to_parser(parser, prefix="planner-") diff --git a/benchmarks/profiler/webui/data_template.json b/benchmarks/profiler/webui/data_template.json new file mode 100644 index 0000000000..f18d67eda2 --- /dev/null +++ b/benchmarks/profiler/webui/data_template.json @@ -0,0 +1,98 @@ +{ + "settings": { + "allow_confirm_datapoint": true, + "hide_show_config": true + }, + "prefill": { + "chart": { + "labels": [], + "datasets": [ + { + "label": "Prefill Performance", + "data": [], + "backgroundColor": "#1f77b4", + "borderColor": "#1f77b4" + } + ], + "target_line": { + "value": 0.0, + "label": "Target TTFT: ? ms" + }, + "axes": { + "x": { + "title": "Time to First Token (ms)", + "min": 0 + }, + "y": { + "title": "Prefill Throughput per GPU (tokens/s/GPU)", + "min": 0 + } + } + }, + "table": { + "columns": [ + "GPUs", + "TTFT (ms)", + "Throughput (tokens/s/GPU)", + "Action" + ], + "data": [] + } + }, + "decode": { + "chart": { + "datasets": [], + "target_line": { + "value": 0.0, + "label": "Target ITL: ? ms" + }, + "axes": { + "x": { + "title": "Inter Token Latency (ms)", + "min": 0 + }, + "y": { + "title": "Decode Throughput per GPU (tokens/s/GPU)", + "min": 0 + } + } + }, + "table": { + "columns": [ + "GPUs", + "ITL (ms)", + "Throughput (tokens/s/GPU)", + "Action" + ], + "data": [] + } + }, + "cost": { + "chart": { + "datasets": [], + "axes": { + "x": { + "title": "Tokens per User", + "min": 0 + }, + "y": { + "title": "Cost ($)", + "min": 0 + } + }, + "title": "Cost Per 1000 ? requests" + }, + "table": { + "columns": [ + "TTFT (ms)", + "Prefill Thpt (tokens/s/GPU)", + "ITL (ms)", + "Decode Thpt (tokens/s/GPU)", + "Tokens/User", + "Cost ($)", + "Action" + ], + "data": [] + } + } +} diff --git a/benchmarks/profiler/webui/select_config.py b/benchmarks/profiler/webui/select_config.py new file mode 100644 index 0000000000..c8cd894882 --- /dev/null +++ b/benchmarks/profiler/webui/select_config.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import os +import queue +from pathlib import Path + +from benchmarks.profiler.webui.utils import ( + PlotType, + create_gradio_interface, + create_selection_handler, + populate_cost_data, + populate_decode_data, + populate_prefill_data, + wait_for_selection, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S" +) +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + + +def generate_config_data(prefill_data, decode_data, args): + """ + Generate JSON data file for WebUI from profiling results. + + Args: + prefill_data: PrefillProfileData instance + decode_data: DecodeProfileData instance + args: Arguments containing SLA targets (ttft, itl, isl, osl) and output_dir + + Returns a JSON data file for WebUI consumption, + see https://github.com/ai-dynamo/aiconfigurator/blob/main/src/aiconfigurator/webapp/components/profiling/standalone/sample_profiling_data.json for more details + """ + # Load template + template_path = Path(__file__).parent / "data_template.json" + with open(template_path, "r") as f: + data = json.load(f) + + # Construct output path + output_path = os.path.join(args.output_dir, "webui_data.json") + + # Set SLA targets + data[PlotType.PREFILL]["chart"]["target_line"]["value"] = args.ttft + data[PlotType.PREFILL]["chart"]["target_line"][ + "label" + ] = f"Target TTFT: {args.ttft} ms" + + data[PlotType.DECODE]["chart"]["target_line"]["value"] = args.itl + data[PlotType.DECODE]["chart"]["target_line"][ + "label" + ] = f"Target ITL: {args.itl} ms" + + data[PlotType.COST]["chart"][ + "title" + ] = f"Cost Per 1000 i{args.isl}o{args.osl} requests" + + # Populate data sections + populate_prefill_data(data, prefill_data) + populate_decode_data(data, decode_data) + populate_cost_data(data, prefill_data, decode_data, args) + + # Save JSON file + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as f: + json.dump(data, f, indent=4) + + logger.info(f"Generated WebUI config data at {output_path}") + return data + + +def pick_config_with_webui(prefill_data, decode_data, args): + """ + Launch WebUI for user to pick configurations. + + Args: + prefill_data: PrefillProfileData instance + decode_data: DecodeProfileData instance + args: Arguments containing SLA targets and output_dir + + Returns: + tuple[int, int]: (selected_prefill_idx, selected_decode_idx) + """ + # Generate JSON data file and load it + generate_config_data(prefill_data, decode_data, args) + + output_path = os.path.join(args.output_dir, "webui_data.json") + with open(output_path, "r") as f: + json_data_str = f.read() + data_dict = json.loads(json_data_str) + + logger.info(f"Launching WebUI on port {args.webui_port}...") + + # Queue to communicate selection from UI to main thread + selection_queue: queue.Queue[tuple[int | None, int | None]] = queue.Queue() + + # Track individual selections + prefill_selection = {"idx": None} + decode_selection = {"idx": None} + + # Create selection handler and Gradio interface + handle_selection = create_selection_handler( + data_dict, selection_queue, prefill_selection, decode_selection + ) + demo = create_gradio_interface(json_data_str, handle_selection) + + return wait_for_selection(demo, selection_queue, args.webui_port) diff --git a/benchmarks/profiler/webui/utils.py b/benchmarks/profiler/webui/utils.py new file mode 100644 index 0000000000..749d0880a5 --- /dev/null +++ b/benchmarks/profiler/webui/utils.py @@ -0,0 +1,414 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import queue +import threading +from enum import Enum + +import gradio as gr +import numpy as np +from aiconfigurator.webapp.components.profiling import ( + create_performance_results_section, + create_profiling_ui_components, + inject_profiling_assets, + load_profiling_javascript, +) + +from benchmarks.profiler.utils.defaults import GPU_COST_PER_HOUR +from benchmarks.profiler.utils.pareto import compute_pareto + +logger = logging.getLogger(__name__) + + +class PlotType(str, Enum): + """Enum for the three plot/config types in the WebUI.""" + + PREFILL = "prefill" + DECODE = "decode" + COST = "cost" + + +# Color palette for chart datasets +# TODO: handle case with more than 8 lines +CHART_COLORS = [ + "#1f77b4", # blue + "#ff7f0e", # orange + "#2ca02c", # green + "#d62728", # red + "#9467bd", # purple + "#8c564b", # brown + "#e377c2", # pink + "#7f7f7f", # gray +] + +# TODO: is this too long? +WEB_UI_SELECTION_TIMEOUT = 3600 + + +def populate_prefill_data(data, prefill_data): + """Populate prefill chart and table data.""" + if not prefill_data.num_gpus: + return + + # Get unique GPU counts for labels + unique_gpus = sorted(set(prefill_data.num_gpus)) + data[PlotType.PREFILL]["chart"]["labels"] = [f"{gpu} GPUs" for gpu in unique_gpus] + + # Populate chart data points + chart_data = [] + for i, (gpu, ttft, thpt, label) in enumerate( + zip( + prefill_data.num_gpus, + prefill_data.ttft, + prefill_data.thpt_per_gpu, + prefill_data.parallel_mapping_labels, + ) + ): + chart_data.append( + { + "x": round(ttft, 2), + "y": round(thpt, 2), + "gpu": gpu, + "tableIdx": i, + "gpuLabel": f"{gpu} GPUs [{label}]", + } + ) + data[PlotType.PREFILL]["chart"]["datasets"][0]["data"] = chart_data + + # Populate table data + table_data = [] + for i, (gpu, ttft, thpt, label) in enumerate( + zip( + prefill_data.num_gpus, + prefill_data.ttft, + prefill_data.thpt_per_gpu, + prefill_data.parallel_mapping_labels, + ) + ): + # TODO: Add actual config YAML data + config_yaml = f"prefill_config_{i}.yaml" + table_data.append([gpu, round(ttft, 2), round(thpt, 2), config_yaml]) + data[PlotType.PREFILL]["table"]["data"] = table_data + + +def populate_decode_data(data, decode_data): + """Populate decode chart and table data.""" + if not decode_data.num_gpus: + return + + # Group by GPU count for multiple datasets + gpu_groups: dict[int, list[dict[str, float | int]]] = {} + for i, (gpu, itl, thpt, label) in enumerate( + zip( + decode_data.num_gpus, + decode_data.itl, + decode_data.thpt_per_gpu, + decode_data.parallel_mapping_labels, + ) + ): + if gpu not in gpu_groups: + gpu_groups[gpu] = [] + gpu_groups[gpu].append({"x": round(itl, 2), "y": round(thpt, 2), "tableIdx": i}) + + # Create datasets for each GPU count with different colors + datasets = [] + for idx, (gpu, points) in enumerate(sorted(gpu_groups.items())): + color = CHART_COLORS[idx % len(CHART_COLORS)] + datasets.append( + { + "label": f"{gpu} GPUs", + "data": points, + "backgroundColor": color, + "borderColor": color, + } + ) + data[PlotType.DECODE]["chart"]["datasets"] = datasets + + # Populate table data + table_data = [] + for i, (gpu, itl, thpt, label) in enumerate( + zip( + decode_data.num_gpus, + decode_data.itl, + decode_data.thpt_per_gpu, + decode_data.parallel_mapping_labels, + ) + ): + config_yaml = f"decode_config_{i}.yaml" + table_data.append([gpu, round(itl, 2), round(thpt, 2), config_yaml]) + data[PlotType.DECODE]["table"]["data"] = table_data + + +def populate_cost_data(data, prefill_data, decode_data, args): + """Populate cost chart and table data with pareto-optimal configurations.""" + if not prefill_data.num_gpus or not decode_data.num_gpus: + return + + # Compute pareto front for prefill (minimize TTFT, maximize throughput) + p_ttft, p_thpt, prefill_pareto_indices = compute_pareto( + prefill_data.ttft, prefill_data.thpt_per_gpu + ) + + # Compute pareto front for decode (minimize ITL, maximize throughput) + d_itl, d_thpt, decode_pareto_indices = compute_pareto( + decode_data.itl, decode_data.thpt_per_gpu + ) + + # Convert to numpy arrays + p_ttft = np.array(p_ttft) + p_thpt = np.array(p_thpt) + d_itl = np.array(d_itl) + d_thpt = np.array(d_thpt) + + # Generate cost datasets - one line per prefill config + cost_datasets = [] + table_data = [] + cost_index_mapping = {} # Map cost table row idx -> (prefill_idx, decode_idx) + table_idx = 0 + + for p_idx, (_p_ttft, _p_thpt) in enumerate(zip(p_ttft, p_thpt)): + # Calculate prefill cost (fixed for this line) + prefill_cost = args.isl * 1000 / _p_thpt * GPU_COST_PER_HOUR / 3600 + + # For each decode config, calculate total cost + line_data = [] + for d_idx, (_d_itl, _d_thpt) in enumerate(zip(d_itl, d_thpt)): + # Calculate decode cost + decode_cost = args.osl * 1000 / _d_thpt * GPU_COST_PER_HOUR / 3600 + total_cost = prefill_cost + decode_cost + + # X-axis: tokens per user (based on ITL) + tokens_per_user = 1000 / _d_itl + + line_data.append( + { + "x": round(tokens_per_user, 2), + "y": round(total_cost, 2), + "tableIdx": table_idx, + } + ) + + # Store mapping from cost table row to original indices + orig_prefill_idx = prefill_pareto_indices[p_idx] + orig_decode_idx = decode_pareto_indices[d_idx] + cost_index_mapping[table_idx] = (orig_prefill_idx, orig_decode_idx) + + # Add to table data + table_data.append( + [ + round(_p_ttft, 2), + round(_p_thpt, 2), + round(_d_itl, 2), + round(_d_thpt, 2), + round(tokens_per_user, 2), + round(total_cost, 2), + f"cost_config_{table_idx}.yaml", # TODO: Add actual config + ] + ) + table_idx += 1 + + # Create dataset for this prefill config + color = CHART_COLORS[p_idx % len(CHART_COLORS)] + cost_datasets.append( + { + "label": f"TTFT: {_p_ttft:.2f}ms", + "data": line_data, + "backgroundColor": color, + "borderColor": color, + } + ) + + data[PlotType.COST]["chart"]["datasets"] = cost_datasets + data[PlotType.COST]["table"]["data"] = table_data + + # Store the index mapping in the JSON for reference + data[PlotType.COST]["index_mapping"] = { + str(k): list(v) for k, v in cost_index_mapping.items() + } + + +def create_selection_handler( + data_dict, selection_queue, prefill_selection, decode_selection +): + """Create a selection handler closure for the WebUI. + + Args: + data_dict: Parsed JSON data containing cost index mapping + selection_queue: Queue to communicate selections to main thread + prefill_selection: Dict tracking prefill selection state + decode_selection: Dict tracking decode selection state + + Returns: + Callable: Selection handler function for Gradio + """ + + def handle_selection(selection_json): + """Handle datapoint selection from table.""" + if not selection_json or selection_json.strip() == "": + return + + try: + selection = json.loads(selection_json) + plot_type = selection.get("plotType") + row_idx = selection.get("rowIndex") + + logger.info(f"Selection received: {plot_type}, row {row_idx}") + + # Store selection for later confirmation + if plot_type == PlotType.COST: + # Cost selection - use index mapping to get original indices + cost_index_mapping = data_dict[PlotType.COST].get("index_mapping", {}) + mapping_entry = cost_index_mapping.get(str(row_idx)) + + if mapping_entry: + prefill_idx, decode_idx = mapping_entry + if prefill_idx is not None and decode_idx is not None: + logger.info( + f"Cost selection determines: Prefill={prefill_idx}, Decode={decode_idx}" + ) + # Auto-submit for cost selection + selection_queue.put((prefill_idx, decode_idx)) + elif plot_type == PlotType.PREFILL: + prefill_selection["idx"] = row_idx + logger.info(f"Prefill selected: {row_idx}") + # Check if we have both selections + if decode_selection["idx"] is not None: + logger.info( + f"Both selections complete: Prefill={row_idx}, Decode={decode_selection['idx']}" + ) + selection_queue.put((row_idx, decode_selection["idx"])) + else: + logger.info("Waiting for decode selection...") + elif plot_type == PlotType.DECODE: + decode_selection["idx"] = row_idx + logger.info(f"Decode selected: {row_idx}") + # Check if we have both selections + if prefill_selection["idx"] is not None: + logger.info( + f"Both selections complete: Prefill={prefill_selection['idx']}, Decode={row_idx}" + ) + selection_queue.put((prefill_selection["idx"], row_idx)) + else: + logger.info("Waiting for prefill selection...") + + except Exception as e: + logger.error(f"Error handling selection: {e}") + + return handle_selection + + +def create_gradio_interface(json_data_str, handle_selection): + """Create the Gradio interface for configuration selection. + + Args: + json_data_str: JSON string containing profiling data + handle_selection: Selection handler function + + Returns: + gr.Blocks: Configured Gradio demo + """ + with gr.Blocks(title="Configuration Selection") as demo: + # Create hidden UI components (reused from AIC profiling module) + ui_components = create_profiling_ui_components() + selection_input = ui_components["selection_input"] + selection_button = ui_components["selection_button"] + json_data = ui_components["json_data"] + + # Inject CSS and modal (reused from AIC profiling module) + inject_profiling_assets() + + gr.Markdown("# ๐Ÿ“Š Profiling Results - Select Configuration") + gr.Markdown( + """ + **Two ways to select prefill and decode configs:** + 1. **Cost Analysis** (recommended): Click any row in the Cost Analysis table - automatically determines both prefill and decode + 2. **Individual**: Click one row in the Prefill table AND one row in the Decode table + The selection will be processed automatically once complete. + + > ๐Ÿ“ **Note:** The dotted red line in the prefill and decode charts are default TTFT and ITL SLAs if not specified. + + > โš ๏ธ **Warning:** The TTFT values here represent the ideal case when requests arrive uniformly, minimizing queueing. Real-world TTFT may be higher than profiling results. To mitigate the issue, planner uses ][correction factors](https://github.com/ai-dynamo/dynamo/blob/main/docs/planner/sla_planner.md#2-correction-factor-calculation) to adjust dynamically at runtime. + """ + ) + + # Performance Results Section (reused from AIC profiling module) + create_performance_results_section() + + # Handle selection button + selection_button.click( + fn=handle_selection, + inputs=[selection_input], + outputs=[], + ) + + # Trigger visualization when JSON data changes + json_data.change( + fn=None, + inputs=[json_data], + outputs=[], + js=( + "(data) => { if (data && data.trim() && window.initializeVisualizations) " + "window.initializeVisualizations(data); }" + ), + ) + + # Load JavaScript and data automatically on page load + def load_data(): + """Load profiling data.""" + return json_data_str + + demo.load( + fn=load_data, inputs=[], outputs=[json_data], js=load_profiling_javascript() + ) + + return demo + + +def wait_for_selection(demo, selection_queue, port): + """Launch the demo and wait for user selection. + + Args: + demo: Gradio demo instance + selection_queue: Queue to receive selection from UI + port: Port number for the WebUI + + Returns: + tuple[int, int]: (selected_prefill_idx, selected_decode_idx) + """ + + # Launch the interface in a separate thread + def launch_thread(): + demo.launch( + server_name="0.0.0.0", + server_port=port, + share=False, + prevent_thread_lock=True, + ) + + thread = threading.Thread(target=launch_thread, daemon=True) + thread.start() + + logger.info(f"WebUI launched. Waiting for user selection on http://0.0.0.0:{port}") + logger.info("Please select a row from the Cost Analysis table") + + # Block and wait for selection + try: + selected_prefill_idx, selected_decode_idx = selection_queue.get( + timeout=WEB_UI_SELECTION_TIMEOUT + ) + logger.info( + f"User selected: Prefill={selected_prefill_idx}, Decode={selected_decode_idx}" + ) + + # Close the demo + demo.close() + + return selected_prefill_idx, selected_decode_idx + + except queue.Empty: + logger.error("Selection timeout - no selection made within 1 hour") + demo.close() + # Return default + return 0, 0 diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml index 608ca9a19c..f8d98d5b17 100644 --- a/benchmarks/pyproject.toml +++ b/benchmarks/pyproject.toml @@ -40,13 +40,13 @@ classifiers = [ ] dependencies = [ - "aiconfigurator @ git+https://github.com/ai-dynamo/aiconfigurator.git@release/0.4.0", + "aiconfigurator[webapp] @ git+https://github.com/ai-dynamo/aiconfigurator.git@bdc142609b97c23a298115f09a9f88ae143f48d8", "networkx", "pandas", "pydantic>=2", "tabulate", "types-tabulate", - # Satisfies vLLM 0.11.0 (>=4.55.2), vLLM 0.11.2 (>=4.56.0,<5), TRT-LLM 1.2.0rc2/rc3 (==4.56.0), SGLang 0.5.4.post3 (==4.57.1) + # Satisfies vLLM 0.11.0 (>=4.55.2), vLLM 0.11.2 (>=4.56.0,<5), TRT-LLM 1.2.0rc5 (==4.56.0), SGLang 0.5.6 (==4.57.1) "transformers>=4.56.0,<=4.57.1", "pytest-mypy", ] diff --git a/components/src/dynamo/common/utils/input_params.py b/components/src/dynamo/common/utils/input_params.py new file mode 100644 index 0000000000..7201101306 --- /dev/null +++ b/components/src/dynamo/common/utils/input_params.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +class InputParamManager: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def get_input_param(self, request: dict, use_tokenizer: bool): + """ + Get the input parameter for the request. + """ + + if use_tokenizer: + print(f"Request: {request}") + if self.tokenizer is None: + raise ValueError("Tokenizer is not available") + + if "messages" in request: + return self.tokenizer.apply_chat_template( + request["messages"], tokenize=False, add_generation_prompt=True + ) + elif "prompt" in request: + return request["prompt"] + elif "text" in request: + return request["text"] + else: + raise ValueError("No input parameter found in request") + + return request.get("token_ids") diff --git a/components/src/dynamo/mocker/args.py b/components/src/dynamo/mocker/args.py index 3bdde43daf..6180a1c9b6 100644 --- a/components/src/dynamo/mocker/args.py +++ b/components/src/dynamo/mocker/args.py @@ -113,6 +113,7 @@ def create_temp_engine_args_file(args) -> Path: else None, "is_prefill": getattr(args, "is_prefill_worker", None), "is_decode": getattr(args, "is_decode_worker", None), + "enable_local_indexer": getattr(args, "enable_local_indexer", None), } # Remove None values to only include explicitly set arguments @@ -284,6 +285,12 @@ def parse_args(): default=False, help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)", ) + parser.add_argument( + "--enable-local-indexer", + action="store_true", + default=False, + help="Enable worker-local KV indexer for tracking this worker's own KV cache state (default: False)", + ) parser.add_argument( "--store-kv", type=str, diff --git a/components/src/dynamo/planner/kube.py b/components/src/dynamo/planner/kube.py index 6759946737..aa0d7e45fd 100644 --- a/components/src/dynamo/planner/kube.py +++ b/components/src/dynamo/planner/kube.py @@ -78,11 +78,48 @@ def get_graph_deployment(self, graph_deployment_name: str) -> dict: ) raise - def update_graph_replicas( - self, graph_deployment_name: str, component_name: str, replicas: int + def update_service_replicas( + self, graph_deployment_name: str, service_name: str, replicas: int + ) -> None: + """ + Update replicas for a service using Scale subresource when DGDSA exists. + Falls back to DGD patch for backward compatibility with older operators. + + Args: + graph_deployment_name: Name of the DynamoGraphDeployment + service_name: Name of the service in DGD.spec.services + replicas: Desired number of replicas + """ + # DGDSA naming convention: - + adapter_name = f"{graph_deployment_name}-{service_name.lower()}" + + try: + # Try to scale via DGDSA Scale subresource + self.custom_api.patch_namespaced_custom_object_scale( + group="nvidia.com", + version="v1alpha1", + namespace=self.current_namespace, + plural="dynamographdeploymentscalingadapters", + name=adapter_name, + body={"spec": {"replicas": replicas}}, + ) + logger.info(f"Scaled DGDSA {adapter_name} to {replicas} replicas") + + except client.ApiException as e: + if e.status == 404: + # DGDSA doesn't exist - fall back to DGD patch (old operator) + logger.info( + f"DGDSA {adapter_name} not found, falling back to DGD update" + ) + self._update_dgd_replicas(graph_deployment_name, service_name, replicas) + else: + raise + + def _update_dgd_replicas( + self, graph_deployment_name: str, service_name: str, replicas: int ) -> None: - """Update the replicas count for a component in a DynamoGraphDeployment""" - patch = {"spec": {"services": {component_name: {"replicas": replicas}}}} + """Update replicas directly in DGD (fallback for old operators)""" + patch = {"spec": {"services": {service_name: {"replicas": replicas}}}} self.custom_api.patch_namespaced_custom_object( group="nvidia.com", version="v1alpha1", @@ -91,6 +128,20 @@ def update_graph_replicas( name=graph_deployment_name, body=patch, ) + logger.info( + f"Updated DGD {graph_deployment_name} service {service_name} to {replicas} replicas" + ) + + def update_graph_replicas( + self, graph_deployment_name: str, component_name: str, replicas: int + ) -> None: + """ + Update replicas for a service. Now uses DGDSA when available. + + Deprecated: Use update_service_replicas() instead for clarity. + This method is kept for backward compatibility. + """ + self.update_service_replicas(graph_deployment_name, component_name, replicas) def is_deployment_ready(self, deployment: dict) -> bool: """Check if a graph deployment is ready""" diff --git a/components/src/dynamo/planner/utils/planner_core.py b/components/src/dynamo/planner/utils/planner_core.py index ff6a4156f7..ac40d52ebf 100644 --- a/components/src/dynamo/planner/utils/planner_core.py +++ b/components/src/dynamo/planner/utils/planner_core.py @@ -24,7 +24,7 @@ PrefillInterpolator, ) from dynamo.planner.utils.pre_swept_results_utils import PreSweptResultsHelper -from dynamo.planner.utils.prometheus import PrometheusAPIClient +from dynamo.planner.utils.prometheus import MetricSource, PrometheusAPIClient from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_mooncake from dynamo.runtime import DistributedRuntime from dynamo.runtime.logging import configure_dynamo_logging @@ -58,6 +58,67 @@ def is_valid(self) -> bool: ) +class PlannerPrometheusMetrics: + """Container for all Planner Prometheus metrics.""" + + def __init__(self, prefix: str = "planner"): + # Worker counts + self.num_p_workers = Gauge( + f"{prefix}:num_p_workers", "Number of prefill workers" + ) + self.num_d_workers = Gauge( + f"{prefix}:num_d_workers", "Number of decode workers" + ) + + # Observed metrics + self.observed_ttft = Gauge( + f"{prefix}:observed_ttft", "Observed time to first token (ms)" + ) + self.observed_itl = Gauge( + f"{prefix}:observed_itl", "Observed inter-token latency (ms)" + ) + self.observed_request_rate = Gauge( + f"{prefix}:observed_request_rate", "Observed request rate (req/s)" + ) + self.observed_request_duration = Gauge( + f"{prefix}:observed_request_duration", "Observed request duration (s)" + ) + self.observed_isl = Gauge( + f"{prefix}:observed_isl", "Observed input sequence length" + ) + self.observed_osl = Gauge( + f"{prefix}:observed_osl", "Observed output sequence length" + ) + + # Correction factors + self.p_correction_factor = Gauge( + f"{prefix}:p_correction_factor", "Prefill correction factor" + ) + self.d_correction_factor = Gauge( + f"{prefix}:d_correction_factor", "Decode correction factor" + ) + + # Predicted metrics + self.predicted_request_rate = Gauge( + f"{prefix}:predicted_request_rate", "Predicted request rate (req/s)" + ) + self.predicted_isl = Gauge( + f"{prefix}:predicted_isl", "Predicted input sequence length" + ) + self.predicted_osl = Gauge( + f"{prefix}:predicted_osl", "Predicted output sequence length" + ) + self.predicted_num_p = Gauge( + f"{prefix}:predicted_num_p", "Predicted number of prefill replicas" + ) + self.predicted_num_d = Gauge( + f"{prefix}:predicted_num_d", "Predicted number of decode replicas" + ) + + # Cumulative GPU usage + self.gpu_hours = Gauge(f"{prefix}:gpu_hours", "Cumulative GPU hours used") + + class Planner: def __init__( self, @@ -89,9 +150,20 @@ def __init__( else: raise ValueError(f"Invalid environment: {args.environment}") + # Use backend metrics for vLLM (queries vllm:* metrics directly from workers) + # Use frontend metrics for other backends (queries dynamo_frontend_* metrics) + metric_source = ( + MetricSource.VLLM + if args.backend.lower() == "vllm" + else MetricSource.FRONTEND + ) + logger.info( + f"Initializing Prometheus client with metric_source='{metric_source}' for backend '{args.backend}'" + ) self.prometheus_api_client = PrometheusAPIClient( args.metric_pulling_prometheus_endpoint, args.namespace, + metric_source=metric_source, ) self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor]( @@ -153,13 +225,10 @@ def __init__( self.prometheus_port = args.metric_reporting_prometheus_port # Initialize Prometheus metrics - # TODO: use proper naming - self.num_p_workers_gauge = Gauge( - "num_p_workers", "Number of prefill workers" - ) - self.num_d_workers_gauge = Gauge( - "num_d_workers", "Number of decode workers" - ) + self.prometheus_metrics = PlannerPrometheusMetrics() + + # Track cumulative GPU hours + self.cumulative_gpu_hours = 0.0 # Start Prometheus HTTP server if port is specified if self.prometheus_port != 0: @@ -246,8 +315,21 @@ async def observe_metrics(self): # Update Prometheus metrics if server is running if self.prometheus_port != 0: - self.num_p_workers_gauge.set(len(self.p_endpoints)) - self.num_d_workers_gauge.set(len(self.d_endpoints)) + self.prometheus_metrics.num_p_workers.set(len(self.p_endpoints)) + self.prometheus_metrics.num_d_workers.set(len(self.d_endpoints)) + + # Calculate and accumulate GPU hours for this interval + # TODO: track startup and shutdown times to get more accurate GPU hours + interval_gpu_hours = ( + ( + len(self.p_endpoints) * self.args.prefill_engine_num_gpu + + len(self.d_endpoints) * self.args.decode_engine_num_gpu + ) + * self.args.adjustment_interval + / 3600 + ) + self.cumulative_gpu_hours += interval_gpu_hours + self.prometheus_metrics.gpu_hours.set(self.cumulative_gpu_hours) # Prometheus returns seconds, convert to milliseconds self.last_metrics.ttft = ( @@ -294,6 +376,19 @@ async def observe_metrics(self): f"Observed ttft: {self.last_metrics.ttft:.2f}ms itl: {self.last_metrics.itl:.2f}ms" ) + # Update observed metrics in Prometheus + if self.prometheus_port != 0: + self.prometheus_metrics.observed_ttft.set(self.last_metrics.ttft) + self.prometheus_metrics.observed_itl.set(self.last_metrics.itl) + self.prometheus_metrics.observed_request_rate.set( + self.last_metrics.num_req / self.args.adjustment_interval + ) + self.prometheus_metrics.observed_request_duration.set( + self.last_metrics.request_duration + ) + self.prometheus_metrics.observed_isl.set(self.last_metrics.isl) + self.prometheus_metrics.observed_osl.set(self.last_metrics.osl) + self.num_req_predictor.add_data_point(self.last_metrics.num_req) self.isl_predictor.add_data_point(self.last_metrics.isl) self.osl_predictor.add_data_point(self.last_metrics.osl) @@ -446,6 +541,15 @@ async def make_adjustments(self): logger.info( f"Correction factors: TTFT: {self.p_correction_factor:.3f}, ITL: {self.d_correction_factor:.3f}" ) + + # Update correction factor metrics in Prometheus + if self.prometheus_port != 0: + self.prometheus_metrics.p_correction_factor.set( + self.p_correction_factor + ) + self.prometheus_metrics.d_correction_factor.set( + self.d_correction_factor + ) except Exception as e: logger.error(f"Failed to correct prediction factors: {e}") return @@ -453,10 +557,23 @@ async def make_adjustments(self): next_num_req, next_isl, next_osl = self.predict_load() if next_num_req is not None and next_isl is not None and next_osl is not None: + # Update predicted load metrics in Prometheus + if self.prometheus_port != 0: + self.prometheus_metrics.predicted_request_rate.set( + next_num_req / self.args.adjustment_interval + ) + self.prometheus_metrics.predicted_isl.set(next_isl) + self.prometheus_metrics.predicted_osl.set(next_osl) + try: next_num_p, next_num_d = self._compute_replica_requirements( next_num_req, next_isl, next_osl ) + + # Update predicted replica metrics in Prometheus + if self.prometheus_port != 0: + self.prometheus_metrics.predicted_num_p.set(next_num_p) + self.prometheus_metrics.predicted_num_d.set(next_num_d) except Exception as e: logger.error(f"Failed to compute number of replicas: {e}") return diff --git a/components/src/dynamo/planner/utils/prometheus.py b/components/src/dynamo/planner/utils/prometheus.py index 99a314832d..c657e93cb2 100644 --- a/components/src/dynamo/planner/utils/prometheus.py +++ b/components/src/dynamo/planner/utils/prometheus.py @@ -15,11 +15,15 @@ import logging import typing +from enum import Enum from prometheus_api_client import PrometheusConnect from pydantic import BaseModel, ValidationError from dynamo import prometheus_names +from dynamo.prometheus_names import ( + frontend_service as metric_names, # Note that we are mapping from frontend metric names to VLLM +) from dynamo.runtime.logging import configure_dynamo_logging configure_dynamo_logging() @@ -32,9 +36,11 @@ class FrontendMetric(BaseModel): endpoint: typing.Optional[str] = None instance: typing.Optional[str] = None job: typing.Optional[str] = None - model: typing.Optional[str] = None - namespace: typing.Optional[str] = None - pod: typing.Optional[str] = None + model: typing.Optional[str] = None # Frontend uses this label + model_name: typing.Optional[str] = None # Backend (vLLM) uses this label + namespace: typing.Optional[str] = None # Kubernetes namespace + pod: typing.Optional[str] = None # Pod name (used for backend filtering) + engine: typing.Optional[str] = None # vLLM engine index class FrontendMetricContainer(BaseModel): @@ -42,10 +48,78 @@ class FrontendMetricContainer(BaseModel): value: typing.Tuple[float, float] # [timestamp, value] +class MetricSource(Enum): + FRONTEND = "frontend" + VLLM = "vllm" + SGLANG = "sglang" # not supported yet + TRTLLM = "trtllm" # not supported yet + + +METRIC_SOURCE_MAP = { # sourced from prometheus_names.py + MetricSource.VLLM: { + metric_names.TIME_TO_FIRST_TOKEN_SECONDS: "vllm:time_to_first_token_seconds", # histogram + metric_names.INTER_TOKEN_LATENCY_SECONDS: "vllm:inter_token_latency_seconds", # histogram + metric_names.REQUEST_DURATION_SECONDS: "vllm:e2e_request_latency_seconds", # histogram - vLLM's e2e latency + metric_names.INPUT_SEQUENCE_TOKENS: "vllm:prompt_tokens_total", # counter - total prompt tokens + metric_names.OUTPUT_SEQUENCE_TOKENS: "vllm:generation_tokens_total", # counter - total generation tokens + metric_names.REQUESTS_TOTAL: "vllm:request_success_total", # counter + }, + MetricSource.FRONTEND: { + metric_names.TIME_TO_FIRST_TOKEN_SECONDS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.TIME_TO_FIRST_TOKEN_SECONDS}", + metric_names.INTER_TOKEN_LATENCY_SECONDS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.INTER_TOKEN_LATENCY_SECONDS}", + metric_names.REQUEST_DURATION_SECONDS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.REQUEST_DURATION_SECONDS}", + metric_names.INPUT_SEQUENCE_TOKENS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.INPUT_SEQUENCE_TOKENS}", + metric_names.OUTPUT_SEQUENCE_TOKENS: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.OUTPUT_SEQUENCE_TOKENS}", + metric_names.REQUESTS_TOTAL: f"{prometheus_names.name_prefix.FRONTEND}_{metric_names.REQUESTS_TOTAL}", + }, +} + +METRIC_SOURCE_MODEL_ATTR = { + MetricSource.VLLM: "model_name", + MetricSource.FRONTEND: "model", +} + + class PrometheusAPIClient: - def __init__(self, url: str, dynamo_namespace: str): + """ + Client for querying Dynamo metrics from Prometheus. + + Supports querying both frontend and backend metrics: + - Frontend metrics: {prometheus_names.name_prefix.FRONTEND}_* (from Dynamo HTTP frontend) + - Backend metrics: vllm:* (from vLLM engine workers) + + Usage: + # Query frontend metrics (default) + frontend_client = PrometheusAPIClient(url="http://prometheus:9090", + dynamo_namespace="my-deployment") + ttft = frontend_client.get_avg_time_to_first_token("60s", "llama-3-8b") + + # Query backend worker metrics + backend_client = PrometheusAPIClient(url="http://prometheus:9090", + dynamo_namespace="my-deployment", + metric_source=MetricSource.VLLM) + ttft = backend_client.get_avg_time_to_first_token("60s", "llama-3-8b") + """ + + def __init__( + self, + url: str, + dynamo_namespace: str, + metric_source: MetricSource = MetricSource.FRONTEND, + ): + """ + Initialize Prometheus API client. + + Args: + url: Prometheus server URL + dynamo_namespace: Dynamo namespace to filter metrics + metric_source: Either MetricSource.FRONTEND or MetricSource.VLLM. + """ + self.prom = PrometheusConnect(url=url, disable_ssl=True) self.dynamo_namespace = dynamo_namespace + self.metric_source = metric_source + self.model_attr = METRIC_SOURCE_MODEL_ATTR[self.metric_source] def _get_average_metric( self, full_metric_name: str, interval: str, operation_name: str, model_name: str @@ -55,45 +129,127 @@ def _get_average_metric( increase(metric_sum[interval])/increase(metric_count[interval]) Args: - full_metric_name: Full metric name (e.g., 'dynamo_frontend_inter_token_latency_seconds') + full_metric_name: Full metric name (e.g., metric_names.INTER_TOKEN_LATENCY_SECONDS or metric_names.TIME_TO_FIRST_TOKEN_SECONDS) interval: Time interval for the query (e.g., '60s') operation_name: Human-readable name for error logging + model_name: Model name to filter by Returns: Average metric value or 0 if no data/error """ try: - # Prepend the frontend metric prefix if not already present - if not full_metric_name.startswith(prometheus_names.name_prefix.FRONTEND): - full_metric_name = ( - f"{prometheus_names.name_prefix.FRONTEND}_{full_metric_name}" - ) - query = f"increase({full_metric_name}_sum[{interval}])/increase({full_metric_name}_count[{interval}])" - result = self.prom.custom_query(query=query) - if not result: + full_metric_name = METRIC_SOURCE_MAP[self.metric_source][full_metric_name] + + # Query sum and count separately + sum_query = f"increase({full_metric_name}_sum[{interval}])" + count_query = f"increase({full_metric_name}_count[{interval}])" + + sum_result = self.prom.custom_query(query=sum_query) + count_result = self.prom.custom_query(query=count_query) + + if not sum_result or not count_result: # No data available yet (no requests made) - return 0 silently logger.warning( f"No prometheus metric data available for {full_metric_name}, use 0 instead" ) return 0 - metrics_containers = parse_frontend_metric_containers(result) - values = [] - for container in metrics_containers: - # Frontend lowercases model names for Prometheus labels so we need to do case-insensitive comparison - if ( - container.metric.model - and container.metric.model.lower() == model_name.lower() - and container.metric.dynamo_namespace == self.dynamo_namespace - ): - values.append(container.value[1]) - - if not values: + sum_containers = parse_frontend_metric_containers(sum_result) + count_containers = parse_frontend_metric_containers(count_result) + + # Sum up values for matching containers + total_sum = 0.0 + total_count = 0.0 + + for container in sum_containers: + model_value = getattr(container.metric, self.model_attr, None) + model_match = model_value and model_value.lower() == model_name.lower() + namespace_match = ( + container.metric.dynamo_namespace == self.dynamo_namespace + ) + + # Filter by model and namespace + if model_match and namespace_match: + total_sum += container.value[1] + + for container in count_containers: + model_value = getattr(container.metric, self.model_attr, None) + model_match = model_value and model_value.lower() == model_name.lower() + namespace_match = ( + container.metric.dynamo_namespace == self.dynamo_namespace + ) + + # Filter by model and namespace + if model_match and namespace_match: + total_count += container.value[1] + + if total_count == 0: logger.warning( f"No prometheus metric data available for {full_metric_name} with model {model_name} and dynamo namespace {self.dynamo_namespace}, use 0 instead" ) return 0 - return sum(values) / len(values) + + return total_sum / total_count + + except Exception as e: + logger.error(f"Error getting {operation_name}: {e}") + return 0 + + def _get_counter_average( + self, counter_metric: str, interval: str, model_name: str, operation_name: str + ) -> float: + """ + Get average value from a counter metric by dividing total increase by request count increase. + Used for vLLM token counters (prompt_tokens_total, generation_tokens_total). + + Formula: increase(counter_total[interval]) / increase(request_success_total[interval]) + """ + try: + full_metric_name = METRIC_SOURCE_MAP[self.metric_source][counter_metric] + requests_metric = METRIC_SOURCE_MAP[self.metric_source][ + metric_names.REQUESTS_TOTAL + ] + + # Query both the counter and request count + counter_query = f"increase({full_metric_name}[{interval}])" + requests_query = f"increase({requests_metric}[{interval}])" + + counter_result = self.prom.custom_query(query=counter_query) + requests_result = self.prom.custom_query(query=requests_query) + + if not counter_result or not requests_result: + logger.warning( + f"No prometheus metric data available for {full_metric_name}, use 0 instead" + ) + return 0 + + counter_containers = parse_frontend_metric_containers(counter_result) + requests_containers = parse_frontend_metric_containers(requests_result) + + # Sum up values for matching pods + total_counter = 0.0 + total_requests = 0.0 + + for container in counter_containers: + model_value = getattr(container.metric, self.model_attr, None) + if model_value and model_value.lower() == model_name.lower(): + if container.metric.dynamo_namespace == self.dynamo_namespace: + total_counter += container.value[1] + + for container in requests_containers: + model_value = getattr(container.metric, self.model_attr, None) + if model_value and model_value.lower() == model_name.lower(): + if container.metric.dynamo_namespace == self.dynamo_namespace: + total_requests += container.value[1] + + if total_requests == 0: + logger.warning( + f"No requests for {operation_name} calculation, use 0 instead" + ) + return 0 + + average = total_counter / total_requests + return average except Exception as e: logger.error(f"Error getting {operation_name}: {e}") @@ -101,7 +257,7 @@ def _get_average_metric( def get_avg_inter_token_latency(self, interval: str, model_name: str): return self._get_average_metric( - prometheus_names.frontend_service.INTER_TOKEN_LATENCY_SECONDS, + metric_names.INTER_TOKEN_LATENCY_SECONDS, interval, "avg inter token latency", model_name, @@ -109,7 +265,7 @@ def get_avg_inter_token_latency(self, interval: str, model_name: str): def get_avg_time_to_first_token(self, interval: str, model_name: str): return self._get_average_metric( - prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS, + metric_names.TIME_TO_FIRST_TOKEN_SECONDS, interval, "avg time to first token", model_name, @@ -117,35 +273,38 @@ def get_avg_time_to_first_token(self, interval: str, model_name: str): def get_avg_request_duration(self, interval: str, model_name: str): return self._get_average_metric( - prometheus_names.frontend_service.REQUEST_DURATION_SECONDS, + metric_names.REQUEST_DURATION_SECONDS, interval, "avg request duration", model_name, ) def get_avg_request_count(self, interval: str, model_name: str): - # This function follows a different query pattern than the other metrics + """ + Get request count over the specified interval. + + For frontend: queries dynamo_frontend_requests_total + For backend: queries vllm:request_success_total + """ try: - requests_total_metric = prometheus_names.frontend_service.REQUESTS_TOTAL - # Prepend the frontend metric prefix if not already present - if not requests_total_metric.startswith( - prometheus_names.name_prefix.FRONTEND - ): - requests_total_metric = ( - f"{prometheus_names.name_prefix.FRONTEND}_{requests_total_metric}" - ) + requests_total_metric = METRIC_SOURCE_MAP[self.metric_source][ + metric_names.REQUESTS_TOTAL + ] + raw_res = self.prom.custom_query( query=f"increase({requests_total_metric}[{interval}])" ) metrics_containers = parse_frontend_metric_containers(raw_res) total_count = 0.0 for container in metrics_containers: - # Frontend lowercases model names for Prometheus labels so we need to do case-insensitive comparison - if ( - container.metric.model - and container.metric.model.lower() == model_name.lower() - and container.metric.dynamo_namespace == self.dynamo_namespace - ): + model_value = getattr(container.metric, self.model_attr, None) + model_match = model_value and model_value.lower() == model_name.lower() + namespace_match = ( + container.metric.dynamo_namespace == self.dynamo_namespace + ) + + # Filter by model and namespace + if model_match and namespace_match: total_count += container.value[1] return total_count except Exception as e: @@ -153,16 +312,32 @@ def get_avg_request_count(self, interval: str, model_name: str): return 0 def get_avg_input_sequence_tokens(self, interval: str, model_name: str): + if self.metric_source == MetricSource.VLLM: + # Backend uses prompt_tokens counter (not histogram) + return self._get_counter_average( + metric_names.INPUT_SEQUENCE_TOKENS, + interval, + model_name, + "input_sequence_tokens", + ) return self._get_average_metric( - prometheus_names.frontend_service.INPUT_SEQUENCE_TOKENS, + metric_names.INPUT_SEQUENCE_TOKENS, interval, "avg input sequence tokens", model_name, ) def get_avg_output_sequence_tokens(self, interval: str, model_name: str): + if self.metric_source == MetricSource.VLLM: + # Backend uses generation_tokens counter (not histogram) + return self._get_counter_average( + metric_names.OUTPUT_SEQUENCE_TOKENS, + interval, + model_name, + "output_sequence_tokens", + ) return self._get_average_metric( - prometheus_names.frontend_service.OUTPUT_SEQUENCE_TOKENS, + metric_names.OUTPUT_SEQUENCE_TOKENS, interval, "avg output sequence tokens", model_name, diff --git a/components/src/dynamo/sglang/args.py b/components/src/dynamo/sglang/args.py index 6cbb3bde30..50964db13c 100644 --- a/components/src/dynamo/sglang/args.py +++ b/components/src/dynamo/sglang/args.py @@ -70,7 +70,7 @@ "flags": ["--use-sglang-tokenizer"], "action": "store_true", "default": False, - "help": "Use SGLang's tokenizer. This will skip tokenization of the input and output and only v1/chat/completions will be available when using the dynamo frontend. Cannot be used with --custom-jinja-template.", + "help": "Use SGLang's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend. Cannot be used with --custom-jinja-template.", }, "multimodal-processor": { "flags": ["--multimodal-processor"], diff --git a/components/src/dynamo/sglang/health_check.py b/components/src/dynamo/sglang/health_check.py index 5d6c2b2be3..baebb2bae1 100644 --- a/components/src/dynamo/sglang/health_check.py +++ b/components/src/dynamo/sglang/health_check.py @@ -53,7 +53,9 @@ class SglangHealthCheckPayload(HealthCheckPayload): Provides SGLang defaults and inherits environment override support from base class. """ - def __init__(self, engine: Optional[sgl.Engine] = None) -> None: + def __init__( + self, engine: Optional[sgl.Engine] = None, use_text_input: bool = False + ) -> None: """Initialize SGLang health check payload with model-specific BOS token. Args: @@ -62,7 +64,6 @@ def __init__(self, engine: Optional[sgl.Engine] = None) -> None: bos_token_id = _get_bos_token_id_from_engine(engine) self.default_payload = { - "token_ids": [bos_token_id], "stop_conditions": { "max_tokens": 1, # Generate only 1 token "ignore_eos": False, @@ -75,6 +76,12 @@ def __init__(self, engine: Optional[sgl.Engine] = None) -> None: "eos_token_ids": [], "annotations": [], } + + if use_text_input: + self.default_payload["prompt"] = "Test" + else: + self.default_payload["token_ids"] = [bos_token_id] + super().__init__() @@ -84,7 +91,9 @@ class SglangPrefillHealthCheckPayload(HealthCheckPayload): The prefill handler expects a wrapped structure with 'request' and 'sampling_params'. """ - def __init__(self, engine: Optional[sgl.Engine] = None) -> None: + def __init__( + self, engine: Optional[sgl.Engine] = None, use_text_input: bool = False + ) -> None: """Initialize SGLang prefill health check payload with proper wrapped structure. Args: @@ -93,9 +102,7 @@ def __init__(self, engine: Optional[sgl.Engine] = None) -> None: bos_token_id = _get_bos_token_id_from_engine(engine) self.default_payload = { - "request": { - "token_ids": [bos_token_id], - }, + "request": {}, "sampling_params": { "max_new_tokens": 1, # Generate only 1 token "temperature": 0.0, @@ -104,4 +111,10 @@ def __init__(self, engine: Optional[sgl.Engine] = None) -> None: "ignore_eos": False, }, } + + if use_text_input: + self.default_payload["request"]["prompt"] = "Test" # type: ignore + else: + self.default_payload["request"]["token_ids"] = [bos_token_id] # type: ignore + super().__init__() diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index 65fd7efc4f..afc6eaecbb 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -103,11 +103,8 @@ async def init(runtime: DistributedRuntime, config: Config): server_args, dynamo_args = config.server_args, config.dynamo_args # Prevent SGLang from blocking on non-leader nodes - # We can switch this to 0 and leverage our own metrics - # after https://github.com/sgl-project/sglang/pull/13686 - # is merged in if server_args.node_rank >= 1: - os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "1" + os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" engine = sgl.Engine(server_args=server_args) @@ -123,6 +120,23 @@ async def init(runtime: DistributedRuntime, config: Config): await _handle_non_leader_node(engine, generate_endpoint) return + # Register engine routes for profiling + async def start_profile_handler(body: dict) -> dict: + """Handle /engine/start_profile requests""" + await engine.tokenizer_manager.start_profile(**body) + return {"status": "ok", "message": "Profiling started"} + + async def stop_profile_handler(body: dict) -> dict: + """Handle /engine/stop_profile requests""" + await engine.tokenizer_manager.stop_profile() + return {"status": "ok", "message": "Profiling stopped"} + + runtime.register_engine_route("start_profile", start_profile_handler) + runtime.register_engine_route("stop_profile", stop_profile_handler) + logging.info( + "Registered engine routes: /engine/start_profile, /engine/stop_profile" + ) + prefill_client = None prefill_router_client = None if config.serving_mode == DisaggregationMode.DECODE: @@ -154,8 +168,10 @@ async def init(runtime: DistributedRuntime, config: Config): handler = DecodeWorkerHandler( component, engine, config, publisher, prefill_client, prefill_router_client ) - - health_check_payload = SglangHealthCheckPayload(engine).to_dict() + print(f"Config: {config}") + health_check_payload = SglangHealthCheckPayload( + engine, use_text_input=dynamo_args.use_sglang_tokenizer + ).to_dict() logging.info( f"Registering model with endpoint types: {dynamo_args.dyn_endpoint_types}" @@ -205,11 +221,8 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): server_args, dynamo_args = config.server_args, config.dynamo_args # Prevent SGLang from blocking on non-leader nodes - # We can switch this to 0 and leverage our own metrics - # after https://github.com/sgl-project/sglang/pull/13686 - # is merged in if server_args.node_rank >= 1: - os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "1" + os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0" engine = sgl.Engine(server_args=server_args) @@ -225,6 +238,23 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): await _handle_non_leader_node(engine, generate_endpoint) return + # Register engine routes for profiling + async def start_profile_handler(body: dict) -> dict: + """Handle /engine/start_profile requests""" + await engine.tokenizer_manager.start_profile(**body) + return {"status": "ok", "message": "Profiling started"} + + async def stop_profile_handler(body: dict) -> dict: + """Handle /engine/stop_profile requests""" + await engine.tokenizer_manager.stop_profile() + return {"status": "ok", "message": "Profiling stopped"} + + runtime.register_engine_route("start_profile", start_profile_handler) + runtime.register_engine_route("stop_profile", stop_profile_handler) + logging.info( + "Registered engine routes: /engine/start_profile, /engine/stop_profile" + ) + # Perform dummy warmup for prefill worker to avoid initial TTFT hit # Only needed on leader node that handles requests await _warmup_prefill_engine(engine, server_args) @@ -291,7 +321,9 @@ async def init_embedding(runtime: DistributedRuntime, config: Config): ready_event = asyncio.Event() handler = EmbeddingWorkerHandler(component, engine, config, publisher) - health_check_payload = SglangHealthCheckPayload(engine).to_dict() + health_check_payload = SglangHealthCheckPayload( + engine, use_text_input=dynamo_args.use_sglang_tokenizer + ).to_dict() try: # Start endpoint immediately and register model concurrently @@ -396,16 +428,24 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con await pd_worker_client.wait_for_instances() - tasks = [ - generate_endpoint.serve_endpoint( - handler.generate, - graceful_shutdown=True, - metrics_labels=[("model", server_args.served_model_name)], - ) - ] + ready_event = asyncio.Event() try: - await asyncio.gather(*tasks) + await asyncio.gather( + generate_endpoint.serve_endpoint( + handler.generate, + graceful_shutdown=True, + metrics_labels=[("model", server_args.served_model_name)], + ), + register_llm_with_readiness_gate( + None, # encode worker doesn't have engine + generate_endpoint, + server_args, + dynamo_args, + input_type=ModelInput.Text, + readiness_gate=ready_event, + ), + ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") raise @@ -439,11 +479,24 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): await handler.async_init() + health_check_payload = SglangHealthCheckPayload(engine).to_dict() + ready_event = asyncio.Event() + try: - await generate_endpoint.serve_endpoint( - handler.generate, - metrics_labels=[("model", server_args.served_model_name)], - graceful_shutdown=True, + await asyncio.gather( + generate_endpoint.serve_endpoint( + handler.generate, + metrics_labels=[("model", server_args.served_model_name)], + graceful_shutdown=True, + health_check_payload=health_check_payload, + ), + register_llm_with_readiness_gate( + engine, + generate_endpoint, + server_args, + dynamo_args, + readiness_gate=ready_event, + ), ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") @@ -468,6 +521,7 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co await handler.async_init() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() + ready_event = asyncio.Event() try: await asyncio.gather( @@ -476,7 +530,14 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co graceful_shutdown=True, metrics_labels=[("model", server_args.served_model_name)], health_check_payload=health_check_payload, - ) + ), + register_llm_with_readiness_gate( + engine, + generate_endpoint, + server_args, + dynamo_args, + readiness_gate=ready_event, + ), ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") diff --git a/components/src/dynamo/sglang/publisher.py b/components/src/dynamo/sglang/publisher.py index 358d116643..2658b5a1af 100644 --- a/components/src/dynamo/sglang/publisher.py +++ b/components/src/dynamo/sglang/publisher.py @@ -10,7 +10,7 @@ import zmq import zmq.asyncio from prometheus_client import CollectorRegistry, multiprocess -from sglang.srt.utils import get_local_ip_auto, get_zmq_socket +from sglang.srt.utils import get_local_ip_auto, get_zmq_socket, maybe_wrap_ipv6_address from dynamo.common.utils.prometheus import register_engine_metrics_callback from dynamo.llm import ( @@ -26,6 +26,30 @@ from dynamo.sglang.args import Config +def format_zmq_endpoint(endpoint_template: str, ip_address: str) -> str: + """Format ZMQ endpoint by replacing wildcard with IP address. + + Properly handles IPv6 addresses by wrapping them in square brackets. + Uses SGLang's maybe_wrap_ipv6_address for consistent formatting. + + Args: + endpoint_template: ZMQ endpoint template with wildcard (e.g., "tcp://*:5557") + ip_address: IP address to use (can be IPv4 or IPv6) + + Returns: + Formatted ZMQ endpoint string + + Example: + >>> format_zmq_endpoint("tcp://*:5557", "192.168.1.1") + 'tcp://192.168.1.1:5557' + >>> format_zmq_endpoint("tcp://*:5557", "2a02:6b8:c46:2b4:0:74c1:75b0:0") + 'tcp://[2a02:6b8:c46:2b4:0:74c1:75b0:0]:5557' + """ + # Use SGLang's utility to wrap IPv6 addresses in brackets + formatted_ip = maybe_wrap_ipv6_address(ip_address) + return endpoint_template.replace("*", formatted_ip) + + class DynamoSglangPublisher: """ Handles SGLang kv events and metrics reception and publishing. @@ -121,7 +145,7 @@ def init_kv_event_publish(self) -> Optional[ZmqKvEventPublisher]: if self.server_args.kv_events_config: kv_events = json.loads(self.server_args.kv_events_config) ep = kv_events.get("endpoint") - zmq_ep = ep.replace("*", get_local_ip_auto()) if ep else None + zmq_ep = format_zmq_endpoint(ep, get_local_ip_auto()) if ep else None zmq_config = ZmqKvEventPublisherConfig( worker_id=self.generate_endpoint.connection_id(), diff --git a/components/src/dynamo/sglang/request_handlers/handler_base.py b/components/src/dynamo/sglang/request_handlers/handler_base.py index 4d4472e19a..ededd819d4 100644 --- a/components/src/dynamo/sglang/request_handlers/handler_base.py +++ b/components/src/dynamo/sglang/request_handlers/handler_base.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import base64 +import json import logging import random import socket @@ -10,9 +12,11 @@ from typing import Any, AsyncGenerator, Dict, Optional, Tuple import sglang as sgl +from sglang.srt.tracing import trace as sglang_trace from sglang.srt.utils import get_local_ip_auto from dynamo._core import Client, Component, Context +from dynamo.common.utils.input_params import InputParamManager from dynamo.sglang.args import Config from dynamo.sglang.publisher import DynamoSglangPublisher @@ -49,6 +53,13 @@ def __init__( self.prefill_client = prefill_client self.serving_mode = config.serving_mode self.skip_tokenizer_init = config.server_args.skip_tokenizer_init + self.enable_trace = config.server_args.enable_trace + + self.input_param_manager = InputParamManager( + self.engine.tokenizer_manager.tokenizer + if not self.skip_tokenizer_init + else None + ) @abstractmethod async def generate(self, request: Dict[str, Any], context: Context): @@ -68,23 +79,13 @@ def cleanup(self) -> None: pass def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]: - """Get the appropriate input parameter for SGLang engine. - - Args: - request: Request dict with token_ids or messages. + request_input = self.input_param_manager.get_input_param( + request, use_tokenizer=not self.skip_tokenizer_init + ) - Returns: - Dict with either input_ids or prompt for engine. - """ - if self.skip_tokenizer_init: - return {"input_ids": request["token_ids"]} - else: - # use sglang's chat templating itself but leave tokenization to the - # interal engine's TokenizerManager - prompt = self.engine.tokenizer_manager.tokenizer.apply_chat_template( - request["messages"], tokenize=False, add_generation_prompt=True - ) - return {"prompt": prompt} + return { + "prompt" if isinstance(request_input, str) else "input_ids": request_input + } @staticmethod def _generate_bootstrap_room() -> int: @@ -117,6 +118,39 @@ def _get_bootstrap_info(engine: sgl.Engine) -> Tuple[str, int]: return bootstrap_host, bootstrap_port + def _propagate_trace_context_to_sglang( + self, context: Context, bootstrap_room: int = 0 + ): + """Propagate Dynamo's trace context to SGLang for distributed tracing. SGLang expects a certain + format derived by loooking at https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/tracing/trace.py + in the to_dict() method. + + Args: + context: Dynamo Context object containing trace information. + bootstrap_room: Bootstrap room ID (0 for aggregated, actual room for disaggregated). + """ + trace_id = context.trace_id + span_id = context.span_id + if not trace_id or not span_id: + return + + # Build trace context for SGLang + trace_context = { + str(bootstrap_room): { + "root_span": {"traceparent": f"00-{trace_id}-{span_id}-01"}, + "prev_span": { + "span_id": int(span_id, 16), + "trace_id": int(trace_id, 16), + }, + } + } + + # Encode and propagate + base64_context = base64.b64encode( + json.dumps(trace_context, ensure_ascii=False).encode("utf-8") + ).decode("utf-8") + sglang_trace.trace_set_remote_propagate_context(base64_context) + async def _handle_cancellation( self, request_id_future: asyncio.Future, context: Context ): diff --git a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py index e7fd9f17ae..47572e2f54 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/decode_handler.py @@ -112,6 +112,7 @@ async def generate( RuntimeError: If no bootstrap info received from prefill worker. """ logging.debug(f"New Request ID: {context.id()}") + trace_id = context.trace_id sampling_params = self._build_sampling_params(request) input_param = self._get_input_param(request) @@ -154,6 +155,11 @@ async def generate( if not bootstrap_info: raise RuntimeError("No bootstrap info received from prefill worker") + if self.enable_trace: + self._propagate_trace_context_to_sglang( + context, bootstrap_info["bootstrap_room"] + ) + decode = await self.engine.async_generate( **input_param, sampling_params=sampling_params, @@ -161,6 +167,7 @@ async def generate( bootstrap_host=bootstrap_info["bootstrap_host"], bootstrap_port=bootstrap_info["bootstrap_port"], bootstrap_room=bootstrap_info["bootstrap_room"], + rid=trace_id, ) if self.skip_tokenizer_init: @@ -170,10 +177,14 @@ async def generate( async for out in self._process_text_stream(decode, context): yield out else: + if self.enable_trace: + self._propagate_trace_context_to_sglang(context) + agg = await self.engine.async_generate( **input_param, sampling_params=sampling_params, stream=True, + rid=trace_id, ) if self.skip_tokenizer_init: async for out in self._process_token_stream(agg, context): diff --git a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py index dc55ab9762..e019ea5c9e 100644 --- a/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py +++ b/components/src/dynamo/sglang/request_handlers/llm/prefill_handler.py @@ -64,6 +64,7 @@ async def generate( Bootstrap info dict with host, port, and room for decode worker connection. """ logging.debug(f"New Request ID: {context.id()}") + trace_id = context.trace_id bootstrap_room = self._generate_bootstrap_room() bootstrap_info = { @@ -76,6 +77,10 @@ async def generate( input_param = self._get_input_param(request["request"]) + # Propagate trace context to SGLang + if self.enable_trace: + self._propagate_trace_context_to_sglang(context, bootstrap_room) + results = await self.engine.async_generate( **input_param, sampling_params=request["sampling_params"], @@ -83,6 +88,7 @@ async def generate( bootstrap_host=self.bootstrap_host, bootstrap_port=self.bootstrap_port, bootstrap_room=bootstrap_room, + rid=trace_id, ) task = asyncio.create_task(self._consume_results(results, context)) diff --git a/components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py b/components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py index cbb2f904a7..957c936b34 100644 --- a/components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py +++ b/components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py @@ -159,7 +159,7 @@ async def generate( # Create descriptor for the multimodal data descriptor = connect.Descriptor(precomputed_embeddings) - with self._connector.create_readable(descriptor) as readable: + with await self._connector.create_readable(descriptor) as readable: request.serialized_request = readable.metadata() logger.debug(f"Request: {request.model_dump_json()}") @@ -184,6 +184,5 @@ async def async_init(self, runtime: DistributedRuntime): # Create and initialize a dynamo connector for this worker. # We'll needs this to move data between this worker and remote workers efficiently. self._connector = connect.Connector() - await self._connector.initialize() logger.info("Startup completed.") diff --git a/components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py b/components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py index 8f94afd3fd..e4a10fc71a 100644 --- a/components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py +++ b/components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py @@ -77,7 +77,6 @@ def __init__(self): async def initialize(self): """Initialize the connector for embeddings processing""" self._connector = connect.Connector() - await self._connector.initialize() async def process_embeddings(self, request: SglangMultimodalRequest): """Process embeddings from serialized request""" @@ -103,7 +102,6 @@ async def process_embeddings(self, request: SglangMultimodalRequest): "Connector is None - this should not happen after initialization" ) self._connector = connect.Connector() - await self._connector.initialize() read_op = await self._connector.begin_read( request.serialized_request, descriptor diff --git a/components/src/dynamo/sglang/tests/test_sglang_unit.py b/components/src/dynamo/sglang/tests/test_sglang_unit.py index 5835131dbe..9bd60d18fe 100644 --- a/components/src/dynamo/sglang/tests/test_sglang_unit.py +++ b/components/src/dynamo/sglang/tests/test_sglang_unit.py @@ -26,7 +26,6 @@ pytest.mark.gpu_1, pytest.mark.pre_merge, ] - # Create SGLang-specific CLI args fixture # This will use monkeypatch to write to argv mock_sglang_cli = make_cli_args_fixture("dynamo.sglang") diff --git a/components/src/dynamo/trtllm/encode_helper.py b/components/src/dynamo/trtllm/encode_helper.py index a022489ce6..c8ac97a2b7 100644 --- a/components/src/dynamo/trtllm/encode_helper.py +++ b/components/src/dynamo/trtllm/encode_helper.py @@ -241,7 +241,7 @@ async def process_embedding_request( # Create readable operation with main embeddings tensor (works for both formats) descriptor = nixl_connect.Descriptor(encodings) - with connector.create_readable(descriptor) as readable_op: + with await connector.create_readable(descriptor) as readable_op: # Get the metadata for the readable operation op_metadata = readable_op.metadata() diff --git a/components/src/dynamo/trtllm/main.py b/components/src/dynamo/trtllm/main.py index 59a35b39d3..80238cb4f0 100644 --- a/components/src/dynamo/trtllm/main.py +++ b/components/src/dynamo/trtllm/main.py @@ -22,7 +22,6 @@ import uvloop from prometheus_client import REGISTRY from tensorrt_llm.llmapi import ( - BuildConfig, CapacitySchedulerPolicy, DynamicBatchConfig, KvCacheConfig, @@ -162,13 +161,6 @@ async def init(runtime: DistributedRuntime, config: Config): else: gpus_per_node = config.gpus_per_node - build_config = BuildConfig( - max_batch_size=config.max_batch_size, - max_num_tokens=config.max_num_tokens, - max_beam_width=config.max_beam_width, - max_seq_len=config.max_seq_len, - ) - kv_cache_config = KvCacheConfig( free_gpu_memory_fraction=config.free_gpu_memory_fraction ) @@ -190,7 +182,6 @@ async def init(runtime: DistributedRuntime, config: Config): "pipeline_parallel_size": config.pipeline_parallel_size, "moe_expert_parallel_size": config.expert_parallel_size, "backend": Backend.PYTORCH, - "build_config": build_config, "kv_cache_config": kv_cache_config, "gpus_per_node": gpus_per_node, "max_num_tokens": config.max_num_tokens, @@ -332,7 +323,6 @@ async def init(runtime: DistributedRuntime, config: Config): connector = None logging.info("Initializing NIXL Connect.") connector = nixl_connect.Connector() - await connector.initialize() dump_config( config.dump_config_to, {"engine_args": engine_args, "dynamo_args": config} diff --git a/components/src/dynamo/trtllm/request_handlers/handler_base.py b/components/src/dynamo/trtllm/request_handlers/handler_base.py index 58390bcedc..9500b25135 100644 --- a/components/src/dynamo/trtllm/request_handlers/handler_base.py +++ b/components/src/dynamo/trtllm/request_handlers/handler_base.py @@ -106,6 +106,76 @@ def check_error(self, result: dict): result["finish_reason"] == "stop" or result["finish_reason"] == "error" ) + @staticmethod + def _extract_logprobs( + output, num_output_tokens_so_far: int + ) -> tuple[list[float] | None, list[list[dict]] | None]: + """ + Extract logprobs from the TRTLLM output for new tokens. + + Args: + output: TRTLLM CompletionOutput object + num_output_tokens_so_far: Number of tokens already processed + Returns: + Tuple of (log_probs, top_logprobs) in Dynamo's expected format: + - log_probs: List of log probabilities for each new token + - top_logprobs: List of top logprobs dicts for each new token + """ + if output.logprobs is None: + return None, None + + # Get logprobs for new tokens only + new_logprobs = output.logprobs[num_output_tokens_so_far:] + if not new_logprobs: + return None, None + + # From TRTLLM CompletionOutput API, logprobs: (TokenLogprobs | List[float], optional) + # Expect TokenLogprobs output when logprobs is set, check edge case where list[float] is returned instead + if isinstance(new_logprobs[0], float): + return [float(lp) for lp in new_logprobs], None + + log_probs = [] + top_logprobs = [] + + for token_idx, token_logprobs_dict in enumerate(new_logprobs): + if token_logprobs_dict is None: + continue + + # Get the actual token_id that was generated at this position + actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx] + + # Extract log probability for the selected token + if actual_token_id in token_logprobs_dict: + selected_logprob = token_logprobs_dict[actual_token_id] + log_probs.append(float(selected_logprob.logprob)) + else: + # Fallback: use the first logprob if selected token not found + first_logprob = next(iter(token_logprobs_dict.values()), None) + if first_logprob: + log_probs.append(float(first_logprob.logprob)) + + # Build top_logprobs list for this token position + # NOTE: TRTLLM LogProb API doesn't have decoded_token, will default to None + token_top_logprobs = [] + for tok_id, logprob_info in token_logprobs_dict.items(): + token_top_logprobs.append( + { + "rank": logprob_info.rank + if hasattr(logprob_info, "rank") + else 0, + "token_id": tok_id, + "token": ( + logprob_info.decoded_token + if hasattr(logprob_info, "decoded_token") + else None + ), + "logprob": float(logprob_info.logprob), + } + ) + top_logprobs.append(token_top_logprobs) + + return log_probs if log_probs else None, top_logprobs if top_logprobs else None + async def _handle_cancellation( self, generation_result: GenerationResult, context: Context ): @@ -236,6 +306,26 @@ async def generate_locally( if hasattr(sampling_params, key): setattr(sampling_params, key, value) + # Additional sampling params in output options + output_options = request.get("output_options", {}) + if output_options: + logprobs_value = output_options.get("logprobs") + + # Handle logprobs + if logprobs_value is not None: + if hasattr(sampling_params, "logprobs"): + setattr( + sampling_params, "logprobs", max(1, int(logprobs_value)) + ) # If top_logprobs = 0, still want to see chosen token logprob + + # Handle prompt_logprobs + prompt_logprobs_value = output_options.get("prompt_logprobs") + if prompt_logprobs_value: + if hasattr(sampling_params, "prompt_logprobs"): + setattr( + sampling_params, "prompt_logprobs", int(prompt_logprobs_value) + ) + max_tokens = request["stop_conditions"]["max_tokens"] if max_tokens: sampling_params.max_tokens = max_tokens @@ -302,6 +392,15 @@ async def generate_locally( out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} + # Extract logprobs from the output + log_probs, top_logprobs = self._extract_logprobs( + output, num_output_tokens_so_far + ) + if log_probs: + out["log_probs"] = log_probs + if top_logprobs: + out["top_logprobs"] = top_logprobs + if output.finish_reason: out["finish_reason"] = output.finish_reason if output.stop_reason: @@ -369,8 +468,12 @@ async def generate_locally( # 2. Per-request errors - send to client, don't shutdown except RequestError as e: - logging.warning(f"Request {request_id} error: {e}") - yield {"finish_reason": "error", "token_ids": []} + error_msg = str(e) + logging.warning(f"Request {request_id} error: {error_msg}") + yield { + "finish_reason": {"error": error_msg}, + "token_ids": [], + } # 3. ALL OTHER ERRORS - graceful shutdown except Exception as e: @@ -384,7 +487,7 @@ async def generate_locally( # Try to send error to client before shutdown try: yield { - "finish_reason": "error", + "finish_reason": {"error": error_msg}, "token_ids": [], } except Exception: diff --git a/components/src/dynamo/vllm/args.py b/components/src/dynamo/vllm/args.py index 56dec6f8e8..767a9d3b93 100644 --- a/components/src/dynamo/vllm/args.py +++ b/components/src/dynamo/vllm/args.py @@ -40,6 +40,7 @@ class Config: custom_jinja_template: Optional[str] = None store_kv: str request_plane: str + enable_local_indexer: bool = False # mirror vLLM model: str @@ -69,6 +70,9 @@ class Config: # dump config to file dump_config_to: Optional[str] = None + # Use vLLM's tokenizer for pre/post processing + use_vllm_tokenizer: bool = False + def has_connector(self, connector_name: str) -> bool: """ Check if a specific connector is enabled. @@ -201,12 +205,41 @@ def parse_args() -> Config: default=os.environ.get("DYN_REQUEST_PLANE", "nats"), help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", ) + parser.add_argument( + "--enable-local-indexer", + action="store_true", + help="Enable worker-local KV indexer for tracking this worker's own KV cache state.", + ) + parser.add_argument( + "--use-vllm-tokenizer", + action="store_true", + default=False, + help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.", + ) add_config_dump_args(parser) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() engine_args = AsyncEngineArgs.from_cli_args(args) + # Workaround for vLLM GIL contention bug with NIXL connector when using UniProcExecutor. + # With TP=1, vLLM defaults to UniProcExecutor which runs scheduler and worker in the same + # process. This causes a hot loop in _process_engine_step that doesn't release the GIL, + # blocking NIXL's add_remote_agent from completing. Using "mp" backend forces separate + # processes, avoiding the GIL contention. + # Note: Only apply for NIXL - other connectors (kvbm, lmcache) work fine with UniProcExecutor + # and forcing mp can expose race conditions in vLLM's scheduler. + # See: https://github.com/vllm-project/vllm/issues/29369 + connector_list = [c.lower() for c in args.connector] if args.connector else [] + uses_nixl = "nixl" in connector_list + tp_size = getattr(engine_args, "tensor_parallel_size", None) or 1 + if uses_nixl and tp_size == 1 and engine_args.distributed_executor_backend is None: + logger.info( + "Setting --distributed-executor-backend=mp for TP=1 to avoid " + "UniProcExecutor GIL contention with NIXL connector" + ) + engine_args.distributed_executor_backend = "mp" + if engine_args.enable_prefix_caching is None: logger.debug( "--enable-prefix-caching or --no-enable-prefix-caching not specified. Defaulting to True (vLLM v1 default behavior)" @@ -285,6 +318,8 @@ def parse_args() -> Config: config.mm_prompt_template = args.mm_prompt_template config.store_kv = args.store_kv config.request_plane = args.request_plane + config.enable_local_indexer = args.enable_local_indexer + config.use_vllm_tokenizer = args.use_vllm_tokenizer # Validate custom Jinja template file exists if provided if config.custom_jinja_template is not None: @@ -353,24 +388,6 @@ def create_kv_events_config(config: Config) -> Optional[KVEventsConfig]: logger.info("No kv_events_config required: prefix caching is disabled") return None - # There is a bug with KV events publishing when LORA is enabled. - # This is fixed in https://github.com/vllm-project/vllm/pull/27728 but not released yet. - # remove below check once new vLLM version is released with the fix. - if config.engine_args.enable_lora: - if config.engine_args.kv_events_config is None: - # No explicit kv events config provided by user, we'll disable kv cache because LoRA is enabled and its not supported yet. - return None - else: - # User provided their own kv events config and it'll not work when LoRA is enabled. - message = ( - "KV events doesn't work when LoRA is enabled due to upstream vLLM bug. " - "Please see https://github.com/vllm-project/vllm/pull/27728." - "For now, either disable lora or dont use explicit kv envents config." - "Dont set both --kv-events-config and --enable-lora in vllm command line args." - ) - logger.error(message) - raise ValueError(message) - # If user provided their own config, use that if c := getattr(config.engine_args, "kv_events_config"): # Warn user that enable_kv_cache_events probably should be True (user may have omitted it from JSON) diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 55ee6ffcf3..f8d17c7369 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -5,16 +5,18 @@ import logging import os import tempfile +import time from abc import ABC, abstractmethod from contextlib import asynccontextmanager from typing import Any, AsyncGenerator, Dict, Final -from vllm.inputs import TokensPrompt +from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.v1.engine.exceptions import EngineDeadError +from dynamo.common.utils.input_params import InputParamManager from dynamo.llm import ( ModelInput, ModelType, @@ -70,10 +72,11 @@ def build_sampling_params( model_max_len: int | None = None, ) -> SamplingParams: """ - Build SamplingParams from a PreprocessedRequest. + Build SamplingParams from a PreprocessedRequest (internal protocol format). Args: - request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions' + request: The PreprocessedRequest dict with 'sampling_options', 'stop_conditions', + and 'output_options' default_sampling_params: Default sampling parameters to initialize with Returns: @@ -82,8 +85,22 @@ def build_sampling_params( sampling_params = SamplingParams(**default_sampling_params) sampling_params.detokenize = False - # Apply sampling_options + # Handle guided_decoding - convert to StructuredOutputsParams + guided_decoding = request["sampling_options"].get("guided_decoding") + if guided_decoding is not None and isinstance(guided_decoding, dict): + sampling_params.structured_outputs = StructuredOutputsParams( + json=guided_decoding.get("json"), + regex=guided_decoding.get("regex"), + choice=guided_decoding.get("choice"), + grammar=guided_decoding.get("grammar"), + whitespace_pattern=guided_decoding.get("whitespace_pattern"), + ) + + # Apply remaining sampling_options for key, value in request["sampling_options"].items(): + # Skip guided_decoding - already handled above + if key == "guided_decoding": + continue if value is not None and hasattr(sampling_params, key): setattr(sampling_params, key, value) @@ -102,6 +119,41 @@ def build_sampling_params( existing = sampling_params.stop_token_ids or [] sampling_params.stop_token_ids = list(set(existing).union(value)) + # Apply output_options (logprobs, prompt_logprobs, etc.) + output_options = request.get("output_options", {}) + if output_options: + # Handle logprobs - vLLM expects this as an integer or None + logprobs_value = output_options.get("logprobs") + if logprobs_value is not None and logprobs_value != "": + try: + parsed_logprobs = int(logprobs_value) + if parsed_logprobs < 0: + logger.warning( + f"Invalid logprobs value: {logprobs_value} (must be non-negative), ignoring" + ) + else: + sampling_params.logprobs = parsed_logprobs + except (ValueError, TypeError): + logger.warning( + f"Invalid logprobs value: {logprobs_value} (must be integer), ignoring" + ) + + # Handle prompt_logprobs - vLLM expects this as an integer or None + prompt_logprobs_value = output_options.get("prompt_logprobs") + if prompt_logprobs_value is not None and prompt_logprobs_value != "": + try: + parsed_prompt_logprobs = int(prompt_logprobs_value) + if parsed_prompt_logprobs < 0: + logger.warning( + f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be non-negative), ignoring" + ) + else: + sampling_params.prompt_logprobs = parsed_prompt_logprobs + except (ValueError, TypeError): + logger.warning( + f"Invalid prompt_logprobs value: {prompt_logprobs_value} (must be integer), ignoring" + ) + # If max_tokens wasn't provided (None or missing), compute a dynamic default provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None) token_ids = request.get("token_ids", []) @@ -114,6 +166,61 @@ def build_sampling_params( return sampling_params +def build_sampling_params_openai( + request: Dict[str, Any], + default_sampling_params: Dict[str, Any], +) -> SamplingParams: + """ + Build SamplingParams from an OpenAI-compatible request format. + + Args: + request: The OpenAI-style request dict with parameters like temperature, max_tokens, etc. + default_sampling_params: Default sampling parameters to initialize with + + Returns: + SamplingParams configured from the request + """ + sampling_params = SamplingParams(**default_sampling_params) + sampling_params.detokenize = True + + # Map common OpenAI parameters to SamplingParams + openai_mapping = { + "temperature": "temperature", + "top_p": "top_p", + "presence_penalty": "presence_penalty", + "frequency_penalty": "frequency_penalty", + "seed": "seed", + "top_k": "top_k", + "repetition_penalty": "repetition_penalty", + "min_p": "min_p", + "length_penalty": "length_penalty", + "use_beam_search": "use_beam_search", + } + + for req_key, param_key in openai_mapping.items(): + if req_key in request and request[req_key] is not None: + if hasattr(sampling_params, param_key): + setattr(sampling_params, param_key, request[req_key]) + + # Handle max_tokens + if "max_tokens" in request and request["max_tokens"] is not None: + sampling_params.max_tokens = request["max_tokens"] + + # Handle stop sequences + if "stop" in request and request["stop"] is not None: + sampling_params.stop = request["stop"] + + # Handle ignore_eos (custom extension) + if "ignore_eos" in request and request["ignore_eos"] is not None: + sampling_params.ignore_eos = request["ignore_eos"] + + # Handle min_tokens (custom extension) + if "min_tokens" in request and request["min_tokens"] is not None: + sampling_params.min_tokens = request["min_tokens"] + + return sampling_params + + class BaseWorkerHandler(ABC): """ Request handler for the generate and clear_kv_blocks endpoints. @@ -129,6 +236,7 @@ def __init__( enable_multimodal: bool = False, generate_endpoint=None, config=None, + use_vllm_tokenizer: bool = False, ): self.runtime = runtime self.component = component @@ -146,6 +254,14 @@ def __init__( self.lora_id_for_name: dict[str, int] = {} self.lora_name_to_path: dict[str, str] = {} + self.use_vllm_tokenizer = use_vllm_tokenizer + + # Initialize InputParamManager for text-in-text-out mode + tokenizer = None + if use_vllm_tokenizer and hasattr(engine, "tokenizer"): + tokenizer = engine.tokenizer + self.input_param_manager = InputParamManager(tokenizer) + @abstractmethod async def generate(self, request, context) -> AsyncGenerator[dict, None]: raise NotImplementedError @@ -563,6 +679,66 @@ def _build_completion_usage(request_output: RequestOutput) -> Dict[str, Any]: ), } + @staticmethod + def _extract_logprobs( + output, num_output_tokens_so_far: int + ) -> tuple[list[float] | None, list[list[dict]] | None]: + """ + Extract logprobs from vLLM CompletionOutput for new tokens. + + Args: + output: vLLM CompletionOutput object + num_output_tokens_so_far: Number of tokens already processed + + Returns: + Tuple of (log_probs, top_logprobs) in Dynamo's expected format: + - log_probs: List of log probabilities for each new token + - top_logprobs: List of top logprobs dicts for each new token + """ + if output.logprobs is None: + return None, None + + # Get logprobs for new tokens only + new_logprobs = output.logprobs[num_output_tokens_so_far:] + if not new_logprobs: + return None, None + + log_probs = [] + top_logprobs = [] + + for token_idx, token_logprobs_dict in enumerate(new_logprobs): + if token_logprobs_dict is None: + continue + + # Get the actual token_id that was generated at this position + actual_token_id = output.token_ids[num_output_tokens_so_far + token_idx] + + # Extract log probability for the selected token + # vLLM guarantees the selected token is always in the logprobs dict + selected_logprob = token_logprobs_dict[actual_token_id] + log_probs.append(float(selected_logprob.logprob)) + + # Build top_logprobs list for this token position + token_top_logprobs = [] + for tok_id, logprob_info in token_logprobs_dict.items(): + token_top_logprobs.append( + { + "rank": ( + logprob_info.rank if hasattr(logprob_info, "rank") else 0 + ), + "token_id": tok_id, + "token": ( + logprob_info.decoded_token + if hasattr(logprob_info, "decoded_token") + else None + ), + "logprob": float(logprob_info.logprob), + } + ) + top_logprobs.append(token_top_logprobs) + + return log_probs if log_probs else None, top_logprobs if top_logprobs else None + async def generate_tokens( self, prompt, @@ -608,6 +784,16 @@ async def generate_tokens( output = res.outputs[0] next_total_toks = len(output.token_ids) out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} + + # Extract logprobs for new tokens if available + log_probs, top_logprobs = self._extract_logprobs( + output, num_output_tokens_so_far + ) + if log_probs is not None: + out["log_probs"] = log_probs + if top_logprobs is not None: + out["top_logprobs"] = top_logprobs + if output.finish_reason: out["finish_reason"] = output.finish_reason out[ @@ -655,6 +841,7 @@ def __init__( enable_multimodal: bool = False, generate_endpoint=None, config=None, + use_vllm_tokenizer: bool = False, ): super().__init__( runtime, @@ -665,6 +852,7 @@ def __init__( enable_multimodal, generate_endpoint, config, + use_vllm_tokenizer, ) async def generate(self, request, context): @@ -672,6 +860,17 @@ async def generate(self, request, context): request_id = context.id() logger.debug(f"Decode Request ID: {request_id}") + if self.use_vllm_tokenizer: + # Text-in-text-out mode: use InputParamManager and OpenAI-compatible format + async for chunk in self._generate_text_mode(request, context, request_id): + yield chunk + else: + # Token-in-token-out mode: internal protocol format + async for chunk in self._generate_token_mode(request, context, request_id): + yield chunk + + async def _generate_token_mode(self, request, context, request_id): + """Generate tokens using internal protocol format (token-in-token-out).""" # Extract and decode multimodal data if present multi_modal_data = await self._extract_multimodal_data(request) @@ -745,6 +944,81 @@ async def generate(self, request, context): self.runtime.shutdown() os._exit(1) + async def _generate_text_mode(self, request, context, request_id): + """Generate text using OpenAI-compatible format (text-in-text-out).""" + # Get text input using InputParamManager + input_text = self.input_param_manager.get_input_param( + request, use_tokenizer=True + ) + + # Build prompt for vLLM + prompt = TextPrompt(prompt=input_text) + + # Build sampling params from OpenAI-style request + sampling_params = build_sampling_params_openai( + request, self.default_sampling_params + ) + + dp_rank = request.get("dp_rank", None) + openai_request_id = request.get("id") or request.get("request_id", request_id) + previous_text = "" + + async with self._abort_monitor(context, request_id): + try: + gen = self.engine_client.generate( + prompt, + sampling_params, + request_id, + data_parallel_rank=dp_rank, + ) + + async for res in gen: + if not res.outputs: + yield { + "id": openai_request_id, + "created": int(time.time()), + "object": "chat.completion.chunk", + "model": "unknown", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": "error", + } + ], + } + break + + output = res.outputs[0] + # Calculate the delta text (new text since last chunk) + delta_text = output.text[len(previous_text) :] + previous_text = output.text + + choice_data = { + "index": 0, + "delta": { + "role": "assistant", + "content": delta_text, + }, + "finish_reason": output.finish_reason, + } + + chunk = { + "id": openai_request_id, + "created": int(time.time()), + "object": "chat.completion.chunk", + "model": "unknown", + "choices": [choice_data], + } + + yield chunk + + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) + class PrefillWorkerHandler(BaseWorkerHandler): def __init__( @@ -757,6 +1031,7 @@ def __init__( enable_multimodal: bool = False, generate_endpoint=None, config=None, + use_vllm_tokenizer: bool = False, ): super().__init__( runtime, @@ -767,6 +1042,7 @@ def __init__( enable_multimodal, generate_endpoint, config, + use_vllm_tokenizer, ) async def generate(self, request, context): @@ -774,6 +1050,17 @@ async def generate(self, request, context): request_id = context.id() logger.debug(f"Prefill Request ID: {request_id}") + if self.use_vllm_tokenizer: + # Text-in-text-out mode: use InputParamManager + async for chunk in self._generate_text_mode(request, context, request_id): + yield chunk + else: + # Token-in-token-out mode: internal protocol format + async for chunk in self._generate_token_mode(request, context, request_id): + yield chunk + + async def _generate_token_mode(self, request, context, request_id): + """Generate prefill using internal protocol format (token-in-token-out).""" # Extract and decode multimodal data if present multi_modal_data = await self._extract_multimodal_data(request) @@ -877,3 +1164,77 @@ async def generate(self, request, context): raise GeneratorExit( "Prefill engine was shut down during token generation" ) from None + + async def _generate_text_mode(self, request, context, request_id): + """Generate prefill using OpenAI-compatible format (text-in-text-out).""" + # Get text input using InputParamManager + input_text = self.input_param_manager.get_input_param( + request, use_tokenizer=True + ) + + # Build prompt for vLLM + prompt = TextPrompt(prompt=input_text) + + # Build sampling params from OpenAI-style request + sampling_params = build_sampling_params_openai( + request, self.default_sampling_params + ) + sampling_params.detokenize = False # Prefill doesn't need detokenization + + # Configure for prefill-only mode with remote decode + if sampling_params.extra_args is None: + sampling_params.extra_args = {} + sampling_params.extra_args["kv_transfer_params"] = { + "do_remote_decode": True, + } + sampling_params_defaults = { + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + } + # Add only missing keys + for k, v in sampling_params_defaults.items(): + sampling_params.extra_args["kv_transfer_params"].setdefault(k, v) + # Override for prefill: only generate 1 token + sampling_params.max_tokens = 1 + sampling_params.min_tokens = 1 + + dp_rank = request.get("dp_rank", None) + + async with self._abort_monitor(context, request_id, is_prefill=True): + try: + gen = self.engine_client.generate( + prompt, sampling_params, request_id, data_parallel_rank=dp_rank + ) + except EngineDeadError as e: + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) + + try: + async for res in gen: + logger.debug(f"kv transfer params: {res.kv_transfer_params}") + + token_ids = res.outputs[0].token_ids if res.outputs else [] + + output: Dict[str, Any] = { + "token_ids": list(token_ids), + "disaggregated_params": ( + {"kv_transfer_params": res.kv_transfer_params} + if res.kv_transfer_params + else None + ), + "completion_usage": BaseWorkerHandler._build_completion_usage( + request_output=res + ), + } + + yield output + except asyncio.CancelledError: + # raise the error because we cannot migrate prefill requests + raise GeneratorExit( + "Prefill engine was shut down during token generation" + ) from None diff --git a/components/src/dynamo/vllm/health_check.py b/components/src/dynamo/vllm/health_check.py index d24230930d..fdf85e241a 100644 --- a/components/src/dynamo/vllm/health_check.py +++ b/components/src/dynamo/vllm/health_check.py @@ -8,11 +8,15 @@ """ import logging +from typing import TYPE_CHECKING, Optional from dynamo.health_check import HealthCheckPayload logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from vllm.v1.engine.async_llm import AsyncLLM + def _get_bos_token_id_from_engine(engine_client) -> int: """ @@ -45,6 +49,36 @@ def _get_bos_token_id_from_engine(engine_client) -> int: return 1 +def _make_default_payload( + engine_client: Optional["AsyncLLM"], use_text_input: bool +) -> dict: + sampling_options = { + "temperature": 0.0, + } + + stop_conditions = { + "max_tokens": 1, + "stop": None, + "stop_token_ids": None, + "include_stop_str_in_output": False, + "ignore_eos": False, + } + + if use_text_input: + return { + "prompt": "Test", + **sampling_options, + **stop_conditions, + } + else: + bos_token_id = _get_bos_token_id_from_engine(engine_client) + return { + "token_ids": [bos_token_id], + "sampling_options": sampling_options, + "stop_conditions": stop_conditions, + } + + class VllmHealthCheckPayload(HealthCheckPayload): """ vLLM-specific health check payload. @@ -52,32 +86,18 @@ class VllmHealthCheckPayload(HealthCheckPayload): Provides vLLM defaults and inherits environment override support from base class. """ - def __init__(self, engine_client=None): + def __init__(self, engine_client=None, use_text_input: bool = False): """ Initialize vLLM health check payload with vLLM-specific defaults. Args: engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from. If provided, will attempt to use the model's actual BOS token. + use_text_input: If True, use text-based input (prompt field) instead of token_ids. + This should match the use_vllm_tokenizer config setting. """ - bos_token_id = _get_bos_token_id_from_engine(engine_client) - # Set vLLM default payload - minimal request that completes quickly - # The handler expects token_ids, sampling_options, and stop_conditions - self.default_payload = { - "token_ids": [bos_token_id], - "sampling_options": { - "max_tokens": 1, - "temperature": 0.0, - }, - "stop_conditions": { - "stop": None, - "stop_token_ids": None, - "include_stop_str_in_output": False, - "ignore_eos": False, - "min_tokens": 0, - }, - } + self.default_payload = _make_default_payload(engine_client, use_text_input) super().__init__() @@ -88,7 +108,7 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload): The prefill handler expects PreprocessedRequest format with sampling_options and stop_conditions. """ - def __init__(self, engine_client=None): + def __init__(self, engine_client=None, use_text_input: bool = False): """ Initialize vLLM prefill health check payload with proper PreprocessedRequest structure. @@ -96,23 +116,5 @@ def __init__(self, engine_client=None): engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from. If provided, will attempt to use the model's actual BOS token. """ - bos_token_id = _get_bos_token_id_from_engine(engine_client) - - # Prefill handler expects PreprocessedRequest format: token_ids, sampling_options, stop_conditions - # The handler will override max_tokens/min_tokens to 1 and add do_remote_decode - self.default_payload = { - "token_ids": [bos_token_id], - "sampling_options": { - "temperature": 0.0, - "top_p": 1.0, - "top_k": -1, - }, - "stop_conditions": { - "stop": None, - "stop_token_ids": None, - "include_stop_str_in_output": False, - "ignore_eos": False, - "min_tokens": 0, - }, - } + self.default_payload = _make_default_payload(engine_client, use_text_input) super().__init__() diff --git a/components/src/dynamo/vllm/main.py b/components/src/dynamo/vllm/main.py index d698add3f8..b26d663891 100644 --- a/components/src/dynamo/vllm/main.py +++ b/components/src/dynamo/vllm/main.py @@ -224,6 +224,7 @@ def setup_kv_event_publisher( worker_id=generate_endpoint.connection_id(), kv_block_size=vllm_config.cache_config.block_size, zmq_endpoint=zmq_endpoint, + enable_local_indexer=config.enable_local_indexer, ) kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config) kv_publishers.append(kv_publisher) @@ -336,6 +337,7 @@ async def register_vllm_model( runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"] runtime_config.max_num_seqs = runtime_values["max_num_seqs"] runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"] + runtime_config.enable_local_indexer = config.enable_local_indexer # Add tool/reasoning parsers for decode models if model_type != ModelType.Prefill: @@ -384,6 +386,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): enable_multimodal=config.enable_multimodal, generate_endpoint=generate_endpoint, config=config, + use_vllm_tokenizer=config.use_vllm_tokenizer, ) handler.add_temp_dir(prometheus_temp_dir) @@ -418,8 +421,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): # Register prefill model with ModelType.Prefill if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register + model_input = ( + ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens + ) await register_vllm_model( - ModelInput.Tokens, + model_input, ModelType.Prefill, generate_endpoint, config, @@ -428,7 +434,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): migration_limit=0, # Prefill doesn't support migration ) - health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict() + health_check_payload = VllmPrefillHealthCheckPayload( + engine_client, use_text_input=config.use_vllm_tokenizer + ).to_dict() try: logger.debug("Starting serve_endpoint for prefill worker") @@ -497,6 +505,7 @@ async def init(runtime: DistributedRuntime, config: Config): enable_multimodal=config.enable_multimodal, generate_endpoint=generate_endpoint, config=config, + use_vllm_tokenizer=config.use_vllm_tokenizer, ) handler.add_temp_dir(prometheus_temp_dir) @@ -536,6 +545,10 @@ async def init(runtime: DistributedRuntime, config: Config): f"Registering model with endpoint types: {config.dyn_endpoint_types}" ) + model_input = ( + ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens + ) + # Warn if custom template provided but chat endpoint not enabled if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types: logger.warning( @@ -544,7 +557,7 @@ async def init(runtime: DistributedRuntime, config: Config): ) await register_vllm_model( - ModelInput.Tokens, + model_input, model_type, generate_endpoint, config, @@ -553,7 +566,9 @@ async def init(runtime: DistributedRuntime, config: Config): migration_limit=config.migration_limit, ) - health_check_payload = VllmHealthCheckPayload(engine_client).to_dict() + health_check_payload = VllmHealthCheckPayload( + engine_client, use_text_input=config.use_vllm_tokenizer + ).to_dict() try: logger.debug("Starting serve_endpoint for decode worker") diff --git a/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py b/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py index 059ba57a9d..d72804f284 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py @@ -69,7 +69,6 @@ async def async_init(self, runtime: DistributedRuntime): # Create and initialize a dynamo connector for this worker. # We'll needs this to move data between this worker and remote workers efficiently. self._connector = connect.Connector() - await self._connector.initialize() logger.info("Encode worker startup completed.") async def generate( @@ -130,7 +129,7 @@ async def generate( request.embeddings_shape = tuple(embeddings.shape) descriptor = connect.Descriptor(embeddings_cpu) - with self._connector.create_readable(descriptor) as readable: + with await self._connector.create_readable(descriptor) as readable: request.serialized_request = readable.metadata() # Clear the image URL as hint that the image is passed as embeddings. request.multimodal_input.image_url = None diff --git a/components/src/dynamo/vllm/multimodal_handlers/processor_handler.py b/components/src/dynamo/vllm/multimodal_handlers/processor_handler.py index eb84c20190..1ee10d02cd 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/processor_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/processor_handler.py @@ -11,7 +11,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest from vllm.outputs import RequestOutput -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.tokenizers import TokenizerLike as AnyTokenizer from dynamo.runtime import Client diff --git a/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py b/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py index 727d5bdb87..0db99946ea 100644 --- a/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py +++ b/components/src/dynamo/vllm/multimodal_handlers/worker_handler.py @@ -52,7 +52,6 @@ def __init__( async def async_init(self, runtime: DistributedRuntime): """Async initialization - connector needs async setup""" self._connector = connect.Connector() - await self._connector.initialize() logger.info("Multimodal Decode Worker async initialization completed.") async def generate(self, request: vLLMMultimodalRequest, context): @@ -138,7 +137,6 @@ async def async_init(self, runtime: DistributedRuntime): """Async initialization for connector that requires async setup""" # Initialize the connector asynchronously self._connector = connect.Connector() - await self._connector.initialize() logger.info("Multimodal PD Worker async initialization completed.") async def generate(self, request: vLLMMultimodalRequest, context): diff --git a/components/src/dynamo/vllm/multimodal_utils/chat_processor.py b/components/src/dynamo/vllm/multimodal_utils/chat_processor.py index fe8d95dc81..3a693131d9 100644 --- a/components/src/dynamo/vllm/multimodal_utils/chat_processor.py +++ b/components/src/dynamo/vllm/multimodal_utils/chat_processor.py @@ -28,9 +28,22 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_engine import RequestPrompt +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.inputs.data import TokensPrompt from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.tokenizers import TokenizerLike as AnyTokenizer + + +class StubEngineClient: + """ + Stub EngineClient for preprocessing-only use of OpenAIServingChat/Completion. + Provides the minimal attributes required by OpenAIServingModels. + """ + + def __init__(self, model_config: ModelConfig): + self.model_config = model_config + self.input_processor = None + self.io_processor = None @runtime_checkable @@ -120,12 +133,19 @@ class ChatProcessor: def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig): self.tokenizer = tokenizer self.model_config = model_config + # Create stub engine client and models for preprocessing-only usage + stub_engine = StubEngineClient(model_config) + serving_models = OpenAIServingModels( + engine_client=stub_engine, + base_model_paths=[ + BaseModelPath(name=model_config.model, model_path=model_config.model) + ], + ) self.openai_serving = OpenAIServingChat( - engine_client=None, - model_config=model_config, - models=None, - request_logger=None, + engine_client=stub_engine, + models=serving_models, response_role="assistant", + request_logger=None, chat_template=None, chat_template_content_format="auto", ) @@ -186,7 +206,6 @@ async def stream_response( conversation, self.tokenizer, request_metadata, - enable_force_include_usage=False, ): if raw_response.startswith("data: [DONE]"): yield raw_response @@ -220,7 +239,6 @@ async def stream_response( conversation, self.tokenizer, request_metadata, - enable_force_include_usage=False, ): if raw_response.startswith("data: [DONE]"): break @@ -267,10 +285,17 @@ class CompletionsProcessor: def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig): self.tokenizer = tokenizer self.model_config = model_config + # Create stub engine client and models for preprocessing-only usage + stub_engine = StubEngineClient(model_config) + serving_models = OpenAIServingModels( + engine_client=stub_engine, + base_model_paths=[ + BaseModelPath(name=model_config.model, model_path=model_config.model) + ], + ) self.openai_serving = OpenAIServingCompletion( - engine_client=None, - model_config=model_config, - models=None, + engine_client=stub_engine, + models=serving_models, request_logger=None, ) diff --git a/components/src/dynamo/vllm/multimodal_utils/protocol.py b/components/src/dynamo/vllm/multimodal_utils/protocol.py index ef8d2bea91..c05f6cdeeb 100644 --- a/components/src/dynamo/vllm/multimodal_utils/protocol.py +++ b/components/src/dynamo/vllm/multimodal_utils/protocol.py @@ -26,7 +26,7 @@ from vllm.multimodal.inputs import MultiModalUUIDDict # noqa: F401 from vllm.outputs import CompletionOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import RequestMetrics +from vllm.v1.metrics.stats import RequestStateStats import dynamo.nixl_connect as connect @@ -156,7 +156,7 @@ class MyRequestOutput(BaseModel): https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85 This class is used to serialize the RequestOutput and any recursively defined types - We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses + We can do this because PromptLogprobs, RequestStateStats, and CompletionOutput are all serializable dataclasses """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -167,7 +167,7 @@ class MyRequestOutput(BaseModel): prompt_logprobs: Optional[PromptLogprobs] = None outputs: List[CompletionOutput] finished: bool - metrics: Optional[RequestMetrics] = None + metrics: Optional[RequestStateStats] = None kv_transfer_params: Optional[dict[str, Any]] = None # lora_request: Optional[LoRARequest] = None # encoder_prompt: Optional[str] = None diff --git a/container/Dockerfile b/container/Dockerfile index ae639f7314..333c572dfa 100644 --- a/container/Dockerfile +++ b/container/Dockerfile @@ -1,6 +1,18 @@ # syntax=docker/dockerfile:1.10.0 # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# NOTE FOR dynamo_base AND wheel_builder STAGES: +# +# All changes to dynamo_base and wheel_builder stages should be replicated across +# Dockerfile and Dockerfile. images.: +# - Dockerfile +# - Dockerfile.vllm +# - Dockerfile.sglang +# - Dockerfile.trtllm +# This duplication was introduced purposely to quickly enable Docker layer caching and +# deduplication. Please ensure these stages stay in sync until the duplication can be +# addressed. ################################## ########## Build Arguments ######## @@ -14,6 +26,7 @@ ARG BASE_IMAGE_TAG ARG PYTHON_VERSION ARG ENABLE_KVBM +ARG ENABLE_MEDIA_NIXL ARG CARGO_BUILD_JOBS # Define general architecture ARGs for supporting both x86 and aarch64 builds. @@ -35,9 +48,9 @@ ARG SCCACHE_BUCKET="" ARG SCCACHE_REGION="" # NIXL configuration -ARG NIXL_UCX_REF=v1.19.0 -ARG NIXL_REF=0.7.1 -ARG NIXL_GDRCOPY_REF=v2.5.1 +ARG NIXL_UCX_REF +ARG NIXL_REF +ARG NIXL_GDRCOPY_REF ################################## ########## Base Image ############ @@ -201,10 +214,11 @@ ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET}} \ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ - CC=${USE_SCCACHE:+sccache gcc} && \ - CXX=${USE_SCCACHE:+sccache g++} && \ - export CC=${CC} && \ - export CXX=${CXX} && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ cd /usr/local/src && \ git clone https://github.com/openucx/ucx.git && \ cd ucx && \ @@ -235,6 +249,11 @@ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ source ${VIRTUAL_ENV}/bin/activate && \ git clone --depth 1 --branch ${NIXL_REF} "https://github.com/ai-dynamo/nixl.git" && \ cd nixl && \ @@ -260,6 +279,11 @@ RUN echo "$NIXL_LIB_DIR" > /etc/ld.so.conf.d/nixl.conf && \ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ cd /workspace/nixl && \ uv build . --out-dir /opt/dynamo/dist/nixl --python $PYTHON_VERSION @@ -274,11 +298,20 @@ ARG ENABLE_KVBM RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export RUSTC_WRAPPER="sccache"; \ + fi && \ source ${VIRTUAL_ENV}/bin/activate && \ cd /opt/dynamo && \ uv build --wheel --out-dir /opt/dynamo/dist && \ cd /opt/dynamo/lib/bindings/python && \ - maturin build --release --out /opt/dynamo/dist && \ + if [ "$ENABLE_MEDIA_NIXL" = "true" ]; then \ + maturin build --release --features dynamo-llm/media-nixl --out /opt/dynamo/dist; \ + else \ + maturin build --release --out /opt/dynamo/dist; \ + fi && \ if [ "$ENABLE_KVBM" = "true" ]; then \ cd /opt/dynamo/lib/bindings/kvbm && \ maturin build --release --out target/wheels && \ @@ -354,7 +387,7 @@ USER dynamo ENV HOME=/home/dynamo \ DYNAMO_HOME=/opt/dynamo \ CARGO_TARGET_DIR=/opt/dynamo/target -ENV LD_LIBRARY_PATH=${NIXL_LIB_DIR}:${NIXL_PLUGIN_DIR}:/usr/local/ucx/lib:/usr/local/ucx/lib/ucx:${LD_LIBRARY_PATH} +ENV LD_LIBRARY_PATH=${NIXL_LIB_DIR}:${NIXL_PLUGIN_DIR}:/usr/local/ucx/lib:/usr/local/ucx/lib/ucx:/usr/local/cuda/compat/lib.real:${LD_LIBRARY_PATH} # Create and activate virtual environment ARG PYTHON_VERSION @@ -379,10 +412,15 @@ RUN uv pip install \ /opt/dynamo/wheelhouse/ai_dynamo*any.whl \ /opt/dynamo/wheelhouse/nixl/nixl*.whl && \ if [ "$ENABLE_KVBM" = "true" ]; then \ - uv pip install /opt/dynamo/wheelhouse/kvbm*.whl; \ - fi \ - && cd /workspace/benchmarks \ - && UV_GIT_LFS=1 uv pip install --no-cache . + KVBM_WHEEL=$(ls /opt/dynamo/wheelhouse/kvbm*.whl 2>/dev/null | head -1); \ + if [ -z "$KVBM_WHEEL" ]; then \ + echo "ERROR: ENABLE_KVBM is true but no KVBM wheel found in wheelhouse" >&2; \ + exit 1; \ + fi; \ + uv pip install "$KVBM_WHEEL"; \ + fi && \ + cd /workspace/benchmarks && \ + UV_GIT_LFS=1 uv pip install --no-cache . # Setup launch banner in common directory accessible to all users RUN --mount=type=bind,source=./container/launch_message/runtime.txt,target=/opt/dynamo/launch_message.txt \ diff --git a/container/Dockerfile.docs b/container/Dockerfile.docs index ffcf58d767..f9376fd363 100644 --- a/container/Dockerfile.docs +++ b/container/Dockerfile.docs @@ -18,6 +18,10 @@ FROM ubuntu:24.04 ARG DYNAMO_COMMIT_SHA ENV DYNAMO_COMMIT_SHA=$DYNAMO_COMMIT_SHA +# Version for documentation (e.g., "0.3.0" for releases, "dev" for main/PRs) +ARG DYNAMO_DOCS_VERSION=dev +ENV DYNAMO_DOCS_VERSION=$DYNAMO_DOCS_VERSION + COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ RUN apt-get update && \ diff --git a/container/Dockerfile.local_dev b/container/Dockerfile.local_dev index c24bec9718..5ce054659f 100644 --- a/container/Dockerfile.local_dev +++ b/container/Dockerfile.local_dev @@ -14,22 +14,16 @@ ARG DEV_BASE="" FROM ${DEV_BASE} AS local-dev -# Don't want dynamo to be editable, just change uid and gid. -ENV USERNAME=dynamo -ARG USER_UID -ARG USER_GID -ARG WORKSPACE_DIR=/workspace - -ARG DYNAMO_COMMIT_SHA -ENV DYNAMO_COMMIT_SHA=$DYNAMO_COMMIT_SHA - -ARG ARCH +# Switch to root for package installation (dev stage ends as dynamo user) +USER root +# Reset SHELL to non-login bash (dev stage uses login shell) +SHELL ["/bin/bash", "-c"] # Update package lists and install developer utilities. Some of these may exist in the base image, # but to ensure consistency across all dev images, we explicitly list all required dev tools here. RUN apt-get update && apt-get install -y \ # Development utilities - curl wget git vim nano \ + curl wget git vim nano less \ # System utilities htop nvtop tmux screen \ # Network utilities @@ -45,21 +39,9 @@ RUN apt-get update && apt-get install -y \ # File utilities tree fd-find ripgrep \ # Shell utilities - zsh fish bash-completion - -# https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user -# Configure user with sudo access for Dev Container workflows -RUN apt-get install -y sudo gnupg2 gnupg1 \ - && echo "$USERNAME ALL=(root) NOPASSWD:ALL" > /etc/sudoers.d/$USERNAME \ - && chmod 0440 /etc/sudoers.d/$USERNAME \ - && mkdir -p /home/$USERNAME \ - # Handle GID conflicts: if target GID exists and it's not our group, remove it - && (getent group $USER_GID | grep -v "^$USERNAME:" && groupdel $(getent group $USER_GID | cut -d: -f1) || true) \ - # Create group if it doesn't exist, otherwise modify existing group - && (getent group $USERNAME > /dev/null 2>&1 && groupmod -g $USER_GID $USERNAME || groupadd -g $USER_GID $USERNAME) \ - && usermod -u $USER_UID -g $USER_GID $USERNAME \ - && chown -R $USERNAME:$USERNAME /home/$USERNAME \ - && chsh -s /bin/bash $USERNAME + zsh fish bash-completion \ + # User management + sudo gnupg2 gnupg1 # Install awk separately with fault tolerance # awk is a virtual package with multiple implementations (gawk, mawk, original-awk). @@ -71,12 +53,34 @@ RUN (apt-get install -y gawk || \ echo "Warning: Could not install any awk implementation") && \ (which awk && echo "awk successfully installed: $(which awk)" || echo "awk not available") + +# Don't want dynamo to be editable, just change uid and gid. +ENV USERNAME=dynamo +ARG USER_UID +ARG USER_GID +ARG WORKSPACE_DIR=/workspace +ARG ARCH=amd64 + # Add NVIDIA devtools repository and install development tools RUN wget -qO - https://developer.download.nvidia.com/devtools/repos/ubuntu2404/${ARCH}/nvidia.pub | gpg --dearmor -o /etc/apt/keyrings/nvidia-devtools.gpg && \ echo "deb [signed-by=/etc/apt/keyrings/nvidia-devtools.gpg] https://developer.download.nvidia.com/devtools/repos/ubuntu2404/${ARCH} /" | tee /etc/apt/sources.list.d/nvidia-devtools.list && \ apt-get update && \ apt-get install -y nsight-systems-2025.5.1 +# https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user +# Configure user with sudo access for Dev Container workflows +RUN echo "$USERNAME ALL=(root) NOPASSWD:ALL" > /etc/sudoers.d/$USERNAME \ + && chmod 0440 /etc/sudoers.d/$USERNAME \ + && mkdir -p /home/$USERNAME \ + # Handle GID conflicts: if target GID exists and it's not our group, remove it + && (getent group $USER_GID | grep -v "^$USERNAME:" && groupdel $(getent group $USER_GID | cut -d: -f1) || true) \ + # Create group if it doesn't exist, otherwise modify existing group + && (getent group $USERNAME > /dev/null 2>&1 && groupmod -g $USER_GID $USERNAME || groupadd -g $USER_GID $USERNAME) \ + && usermod -u $USER_UID -g $USER_GID -G 0 $USERNAME \ + && chown $USERNAME:$USER_GID /home/$USERNAME \ + && chsh -s /bin/bash $USERNAME + + # Clean up package lists at the end RUN rm -rf /var/lib/apt/lists/* @@ -87,45 +91,26 @@ ENV WORKSPACE_DIR=${WORKSPACE_DIR} # Path configuration notes: # - DYNAMO_HOME: Main project directory (workspace mount point) # - CARGO_TARGET_DIR: Build artifacts in workspace/target for persistence -# - CARGO_HOME: Must be in $HOME/.cargo (not workspace) because: -# * Workspace gets mounted to different paths where cargo binaries may not exist -# * Contains critical cargo binaries and registry that need consistent paths -# - RUSTUP_HOME: Must be in $HOME/.rustup (not workspace) because: -# * Contains rust toolchain binaries that must be at expected system paths -# * Workspace mount point would break rustup's toolchain resolution # - PATH: Includes cargo binaries for rust tool access ENV HOME=/home/$USERNAME ENV DYNAMO_HOME=${WORKSPACE_DIR} ENV CARGO_TARGET_DIR=${WORKSPACE_DIR}/target -ENV CARGO_HOME=${HOME}/.cargo -ENV RUSTUP_HOME=${HOME}/.rustup +# NOTE: CARGO_HOME and RUSTUP_HOME are already inherited from dev stage (Dockerfile.sglang|trtllm|vllm) ENV PATH=${CARGO_HOME}/bin:$PATH -# Copy Rust toolchain from system directories to user home directories with proper ownership -RUN rsync -a --chown=$USER_UID:$USER_GID /usr/local/rustup/ $RUSTUP_HOME/ - -RUN rsync -a --chown=$USER_UID:$USER_GID /usr/local/cargo/ $CARGO_HOME/ - -# Copy virtual environment with proper ownership using rsync instead of chown. -# Why rsync instead of chown -R: -# chown -R is extremely slow in Docker containers, especially on large directory trees -# like Python virtual environments with thousands of files. This is a well-documented -# Docker performance issue. rsync --chown is 3-4x faster as it sets ownership during copy. -RUN rsync -a --chown=$USER_UID:$USER_GID ${VIRTUAL_ENV}/ /tmp/venv-temp/ && \ - rm -rf ${VIRTUAL_ENV} && \ - mv /tmp/venv-temp ${VIRTUAL_ENV} - -# At this point, we are executing as the ubuntu user +# Switch to dynamo user (dev stage has umask 002, so files should already be group-writable) USER $USERNAME WORKDIR $HOME # https://code.visualstudio.com/remote/advancedcontainers/persist-bash-history RUN SNIPPET="export PROMPT_COMMAND='history -a' && export HISTFILE=$HOME/.commandhistory/.bash_history" \ && mkdir -p $HOME/.commandhistory \ + && chmod g+w $HOME/.commandhistory \ && touch $HOME/.commandhistory/.bash_history \ && echo "$SNIPPET" >> "$HOME/.bashrc" -RUN mkdir -p /home/$USERNAME/.cache/ +RUN mkdir -p /home/$USERNAME/.cache/ \ + && chmod g+w /home/$USERNAME/.cache/ ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"] CMD [] diff --git a/container/Dockerfile.sglang b/container/Dockerfile.sglang index bff39a2dfe..7296472f7d 100644 --- a/container/Dockerfile.sglang +++ b/container/Dockerfile.sglang @@ -1,6 +1,33 @@ # syntax=docker/dockerfile:1.10.0 # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# NOTE FOR dynamo_base AND wheel_builder STAGES: +# +# All changes to dynamo_base and wheel_builder stages should be replicated across +# Dockerfile and Dockerfile. images.: +# - Dockerfile +# - Dockerfile.vllm +# - Dockerfile.sglang +# - Dockerfile.trtllm +# This duplication was introduced purposely to quickly enable Docker layer caching and +# deduplication. Please ensure these stages stay in sync until the duplication can be +# addressed. +# +# Throughout this file, we make certain paths group-writable because this allows +# both the dynamo user (UID 1000) and Dev Container users (UID != 1000) to work +# properly without needing slow chown -R operations (which can add 2-10 extra +# minutes). +# +# DEVELOPMENT PATHS THAT MUST BE GROUP-WRITABLE (for non-virtualenv containers): +# /workspace - Users create/modify project files +# /home/dynamo - Users create config/cache files +# /home/dynamo/.local - SGLang uses $HOME/.local/lib/python3.10/site-packages for pip install +# +# HOW TO ACHIEVE GROUP-WRITABLE PERMISSIONS: +# 1. SHELL + /etc/profile.d - Login shell sources umask 002 globally for all RUN commands (775/664) +# 2. COPY --chmod=775 - Sets permissions on copied children (not destination) +# 3. chmod g+w (no -R) - Fixes destination dirs only (milliseconds vs minutes) # This section contains build arguments that are common and shared with # the plain Dockerfile, so they should NOT have a default. The source of truth is from build.sh. @@ -11,19 +38,296 @@ ARG BASE_IMAGE_TAG ARG FRAMEWORK_IMAGE ARG FRAMEWORK_IMAGE_TAG ARG PYTHON_VERSION +ARG ENABLE_KVBM +ARG ENABLE_MEDIA_NIXL +ARG CARGO_BUILD_JOBS ARG CUDA_VERSION ARG ARCH=amd64 ARG ARCH_ALT=x86_64 -ARG CARGO_BUILD_JOBS # sccache configuration - inherit from base build ARG USE_SCCACHE ARG SCCACHE_BUCKET="" ARG SCCACHE_REGION="" -ARG DYNAMO_BASE_IMAGE="dynamo:latest-none" -FROM ${DYNAMO_BASE_IMAGE} AS dynamo_base +# NIXL configuration +ARG NIXL_UCX_REF +ARG NIXL_REF +ARG NIXL_GDRCOPY_REF + +################################## +########## Base Image ############ +################################## + +FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dynamo_base + +ARG ARCH +ARG ARCH_ALT + +USER root +WORKDIR /opt/dynamo + +# Install uv package manager +COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ + +# Install NATS server +ENV NATS_VERSION="v2.10.28" +RUN --mount=type=cache,target=/var/cache/apt \ + wget --tries=3 --waitretry=5 https://github.com/nats-io/nats-server/releases/download/${NATS_VERSION}/nats-server-${NATS_VERSION}-${ARCH}.deb && \ + dpkg -i nats-server-${NATS_VERSION}-${ARCH}.deb && rm nats-server-${NATS_VERSION}-${ARCH}.deb + +# Install etcd +ENV ETCD_VERSION="v3.5.21" +RUN wget --tries=3 --waitretry=5 https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-${ARCH}.tar.gz -O /tmp/etcd.tar.gz && \ + mkdir -p /usr/local/bin/etcd && \ + tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1 && \ + rm /tmp/etcd.tar.gz +ENV PATH=/usr/local/bin/etcd/:$PATH + +# Rust Setup +# Rust environment setup +ENV RUSTUP_HOME=/usr/local/rustup \ + CARGO_HOME=/usr/local/cargo \ + PATH=/usr/local/cargo/bin:$PATH \ + RUST_VERSION=1.90.0 + +# Define Rust target based on ARCH_ALT ARG +ARG RUSTARCH=${ARCH_ALT}-unknown-linux-gnu + +# Install Rust +RUN wget --tries=3 --waitretry=5 "https://static.rust-lang.org/rustup/archive/1.28.1/${RUSTARCH}/rustup-init" && \ + chmod +x rustup-init && \ + ./rustup-init -y --no-modify-path --profile minimal --default-toolchain $RUST_VERSION --default-host ${RUSTARCH} && \ + rm rustup-init && \ + chmod -R a+w $RUSTUP_HOME $CARGO_HOME + + +################################## +##### Wheel Build Image ########## +################################## + +# Redeclare ARCH_ALT ARG so it's available for interpolation in the FROM instruction +ARG ARCH_ALT + +FROM quay.io/pypa/manylinux_2_28_${ARCH_ALT} AS wheel_builder + +# Redeclare ARGs for this stage +ARG ARCH +ARG ARCH_ALT +ARG CARGO_BUILD_JOBS + +WORKDIR /workspace + +# Copy CUDA from base stage +COPY --from=dynamo_base /usr/local/cuda /usr/local/cuda +COPY --from=dynamo_base /etc/ld.so.conf.d/hpcx.conf /etc/ld.so.conf.d/hpcx.conf + +# Set environment variables first so they can be used in COPY commands +ENV CARGO_BUILD_JOBS=${CARGO_BUILD_JOBS:-16} \ + RUSTUP_HOME=/usr/local/rustup \ + CARGO_HOME=/usr/local/cargo \ + CARGO_TARGET_DIR=/opt/dynamo/target \ + PATH=/usr/local/cargo/bin:$PATH + +# Copy artifacts from base stage +COPY --from=dynamo_base $RUSTUP_HOME $RUSTUP_HOME +COPY --from=dynamo_base $CARGO_HOME $CARGO_HOME +# Install system dependencies +RUN yum groupinstall -y 'Development Tools' && \ + dnf install -y almalinux-release-synergy && \ + dnf config-manager --set-enabled powertools && \ + dnf install -y \ + # Build tools + cmake \ + ninja-build \ + clang-devel \ + gcc-c++ \ + flex \ + wget \ + # Kernel module build dependencies + dkms \ + # Protobuf support + protobuf-compiler \ + # RDMA/InfiniBand support (required for UCX build with --with-verbs) + libibverbs \ + libibverbs-devel \ + rdma-core \ + rdma-core-devel \ + libibumad \ + libibumad-devel \ + librdmacm-devel \ + numactl-devel + +# Ensure a modern protoc is available (required for --experimental_allow_proto3_optional) +RUN set -eux; \ + PROTOC_VERSION=25.3; \ + case "${ARCH_ALT}" in \ + x86_64) PROTOC_ZIP="protoc-${PROTOC_VERSION}-linux-x86_64.zip" ;; \ + aarch64) PROTOC_ZIP="protoc-${PROTOC_VERSION}-linux-aarch_64.zip" ;; \ + *) echo "Unsupported architecture: ${ARCH_ALT}" >&2; exit 1 ;; \ + esac; \ + wget --tries=3 --waitretry=5 -O /tmp/protoc.zip "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/${PROTOC_ZIP}"; \ + rm -f /usr/local/bin/protoc /usr/bin/protoc; \ + unzip -o /tmp/protoc.zip -d /usr/local bin/protoc include/*; \ + chmod +x /usr/local/bin/protoc; \ + ln -s /usr/local/bin/protoc /usr/bin/protoc; \ + protoc --version + +# Point build tools explicitly at the modern protoc +ENV PROTOC=/usr/local/bin/protoc + +ENV CUDA_PATH=/usr/local/cuda \ + PATH=/usr/local/cuda/bin:$PATH \ + LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/lib:/usr/local/lib64:$LD_LIBRARY_PATH \ + NVIDIA_DRIVER_CAPABILITIES=video,compute,utility + +# Create virtual environment for building wheels +ARG PYTHON_VERSION +ENV VIRTUAL_ENV=/workspace/.venv +RUN uv venv ${VIRTUAL_ENV} --python $PYTHON_VERSION && \ + uv pip install --upgrade meson pybind11 patchelf maturin[patchelf] + +ARG NIXL_UCX_REF +ARG NIXL_REF +ARG NIXL_GDRCOPY_REF + +# Build and install gdrcopy +RUN git clone --depth 1 --branch ${NIXL_GDRCOPY_REF} https://github.com/NVIDIA/gdrcopy.git && \ + cd gdrcopy/packages && \ + CUDA=/usr/local/cuda ./build-rpm-packages.sh && \ + rpm -Uvh gdrcopy-kmod-*.el8.noarch.rpm && \ + rpm -Uvh gdrcopy-*.el8.${ARCH_ALT}.rpm && \ + rpm -Uvh gdrcopy-devel-*.el8.noarch.rpm + +# Install SCCACHE if requested +ARG USE_SCCACHE +ARG SCCACHE_BUCKET +ARG SCCACHE_REGION +COPY container/use-sccache.sh /tmp/use-sccache.sh +RUN if [ "$USE_SCCACHE" = "true" ]; then \ + /tmp/use-sccache.sh install; \ + fi + +# Set SCCACHE environment variables +ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET}} \ + SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION}} \ + RUSTC_WRAPPER=${USE_SCCACHE:+sccache} + +# Build and install UCX +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ + cd /usr/local/src && \ + git clone https://github.com/openucx/ucx.git && \ + cd ucx && \ + git checkout $NIXL_UCX_REF && \ + ./autogen.sh && \ + ./contrib/configure-release \ + --prefix=/usr/local/ucx \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=/usr/local/cuda \ + --with-verbs \ + --with-dm \ + --with-gdrcopy=/usr/local \ + --with-efa \ + --enable-mt && \ + make -j && \ + make -j install-strip && \ + /tmp/use-sccache.sh show-stats "UCX" && \ + echo "/usr/local/ucx/lib" > /etc/ld.so.conf.d/ucx.conf && \ + echo "/usr/local/ucx/lib/ucx" >> /etc/ld.so.conf.d/ucx.conf && \ + ldconfig + +# build and install nixl +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ + source ${VIRTUAL_ENV}/bin/activate && \ + git clone --depth 1 --branch ${NIXL_REF} "https://github.com/ai-dynamo/nixl.git" && \ + cd nixl && \ + mkdir build && \ + meson setup build/ --prefix=/opt/nvidia/nvda_nixl --buildtype=release \ + -Dcudapath_lib="/usr/local/cuda/lib64" \ + -Dcudapath_inc="/usr/local/cuda/include" \ + -Ducx_path="/usr/local/ucx" && \ + cd build && \ + ninja && \ + ninja install && \ + /tmp/use-sccache.sh show-stats "NIXL" + +ENV NIXL_LIB_DIR=/opt/nvidia/nvda_nixl/lib64 \ + NIXL_PLUGIN_DIR=/opt/nvidia/nvda_nixl/lib64/plugins \ + NIXL_PREFIX=/opt/nvidia/nvda_nixl +ENV LD_LIBRARY_PATH=${NIXL_LIB_DIR}:${NIXL_PLUGIN_DIR}:/usr/local/ucx/lib:/usr/local/ucx/lib/ucx:${LD_LIBRARY_PATH} + +RUN echo "$NIXL_LIB_DIR" > /etc/ld.so.conf.d/nixl.conf && \ + echo "$NIXL_PLUGIN_DIR" >> /etc/ld.so.conf.d/nixl.conf && \ + ldconfig + +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ + cd /workspace/nixl && \ + uv build . --out-dir /opt/dynamo/dist/nixl --python $PYTHON_VERSION + +# Copy source code (order matters for layer caching) +COPY pyproject.toml README.md LICENSE Cargo.toml Cargo.lock rust-toolchain.toml hatch_build.py /opt/dynamo/ +COPY launch/ /opt/dynamo/launch/ +COPY lib/ /opt/dynamo/lib/ +COPY components/ /opt/dynamo/components/ + +# Build dynamo wheels +ARG ENABLE_KVBM +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export RUSTC_WRAPPER="sccache"; \ + fi && \ + source ${VIRTUAL_ENV}/bin/activate && \ + cd /opt/dynamo && \ + uv build --wheel --out-dir /opt/dynamo/dist && \ + cd /opt/dynamo/lib/bindings/python && \ + if [ "$ENABLE_MEDIA_NIXL" = "true" ]; then \ + maturin build --release --features dynamo-llm/media-nixl --out /opt/dynamo/dist; \ + else \ + maturin build --release --out /opt/dynamo/dist; \ + fi && \ + if [ "$ENABLE_KVBM" = "true" ]; then \ + cd /opt/dynamo/lib/bindings/kvbm && \ + maturin build --release --out target/wheels && \ + auditwheel repair \ + --exclude libnixl.so \ + --exclude libnixl_build.so \ + --exclude libnixl_common.so \ + --plat manylinux_2_28_${ARCH_ALT} \ + --wheel-dir /opt/dynamo/dist \ + target/wheels/*.whl; \ + fi && \ + /tmp/use-sccache.sh show-stats "Dynamo" ######################################################## ########## Framework Development Image ################ @@ -51,16 +355,13 @@ ARG DEEPEP_COMMIT=9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee ARG DEEPEP_GB_COMMIT=1b14ad661c7640137fcfe93cccb2694ede1220b0 ARG CMAKE_BUILD_PARALLEL_LEVEL=2 ARG SGL_KERNEL_VERSION=0.3.16.post5 -ARG SGLANG_COMMIT=0.5.4.post3 +ARG SGLANG_COMMIT=0.5.6 ARG GDRCOPY_COMMIT=v2.4.4 ARG NVSHMEM_VERSION=3.3.9 ARG GRACE_BLACKWELL=false ARG ARCH ARG ARCH_ALT ARG PYTHON_VERSION -ARG USE_SCCACHE -ARG SCCACHE_BUCKET -ARG SCCACHE_REGION ARG CARGO_BUILD_JOBS ARG CUDA_VERSION @@ -163,21 +464,6 @@ RUN apt-get update \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean -# Install sccache if requested -COPY container/use-sccache.sh /tmp/use-sccache.sh -RUN if [ "$USE_SCCACHE" = "true" ]; then \ - /tmp/use-sccache.sh install; \ -fi - -# Set environment variables - they'll be empty strings if USE_SCCACHE=false -ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET}} \ - SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION}} \ - SCCACHE_S3_KEY_PREFIX=${USE_SCCACHE:+${ARCH}} \ - RUSTC_WRAPPER=${USE_SCCACHE:+sccache} \ - CMAKE_C_COMPILER_LAUNCHER=${USE_SCCACHE:+sccache} \ - CMAKE_CXX_COMPILER_LAUNCHER=${USE_SCCACHE:+sccache} \ - CMAKE_CUDA_COMPILER_LAUNCHER=${USE_SCCACHE:+sccache} - WORKDIR /sgl-workspace # GDRCopy installation @@ -190,18 +476,25 @@ RUN git clone --depth 1 --branch ${GDRCOPY_COMMIT} https://github.com/NVIDIA/gdr # Fix DeepEP IBGDA symlink RUN ln -sf /usr/lib/$(uname -m)-linux-gnu/libmlx5.so.1 /usr/lib/$(uname -m)-linux-gnu/libmlx5.so -# Create dynamo user EARLY - before copying files, with group 0 for OpenShift compatibility +# Create dynamo user with group 0 for OpenShift compatibility RUN userdel -r ubuntu > /dev/null 2>&1 || true \ && useradd -m -s /bin/bash -g 0 dynamo \ && [ `id -u dynamo` -eq 1000 ] \ && mkdir -p /workspace /home/dynamo/.cache /opt/dynamo \ - && chown -R dynamo: /sgl-workspace /workspace /home/dynamo /opt/dynamo \ - && chmod -R g+w /sgl-workspace /workspace /home/dynamo/.cache /opt/dynamo + # Non-recursive chown - only the directories themselves, not contents + && chown dynamo:0 /sgl-workspace /workspace /home/dynamo /home/dynamo/.cache /opt/dynamo \ + # No chmod needed: umask 002 handles new files, COPY --chmod handles copied content + # Set umask globally for all subsequent RUN commands (must be done as root before USER dynamo) + # NOTE: Setting ENV UMASK=002 does NOT work - umask is a shell builtin, not an environment variable + && mkdir -p /etc/profile.d && echo 'umask 002' > /etc/profile.d/00-umask.sh USER dynamo ENV HOME=/home/dynamo +# This picks up the umask 002 from the /etc/profile.d/00-umask.sh file for subsequent RUN commands +SHELL ["/bin/bash", "-l", "-o", "pipefail", "-c"] -# Install SGLang (requires CUDA 12.8.1 or 12.9.1) +# Install SGLang (requires CUDA 12.8.1 or 12.9.1). Note that when system-wide packages is not writable, +# so it gets installed to ~/.local/lib/python/site-packages. RUN python3 -m pip install --no-cache-dir --ignore-installed pip==25.3 setuptools==80.9.0 wheel==0.45.1 html5lib==1.1 six==1.17.0 \ && git clone --depth 1 --branch v${SGLANG_COMMIT} https://github.com/sgl-project/sglang.git \ && cd sglang \ @@ -235,10 +528,7 @@ RUN --mount=type=cache,target=/var/cache/curl,uid=1000,gid=0 \ && sed -i 's/#define NUM_CPU_TIMEOUT_SECS 100/#define NUM_CPU_TIMEOUT_SECS 1000/' csrc/kernels/configs.cuh # Build and install NVSHMEM library only (without python library) -RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ - --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ - export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \ - cd /sgl-workspace/nvshmem && \ +RUN cd /sgl-workspace/nvshmem && \ if [ "$GRACE_BLACKWELL" = true ]; then CUDA_ARCH="90;100;120"; else CUDA_ARCH="90"; fi && \ NVSHMEM_SHMEM_SUPPORT=0 \ NVSHMEM_UCX_SUPPORT=0 \ @@ -249,15 +539,11 @@ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ NVSHMEM_USE_GDRCOPY=1 \ cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=${NVSHMEM_DIR} -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH} -DNVSHMEM_BUILD_PYTHON_LIB=OFF && \ - cmake --build build --target install -j${CMAKE_BUILD_PARALLEL_LEVEL} && \ - /tmp/use-sccache.sh show-stats "NVSHMEM" + cmake --build build --target install -j${CMAKE_BUILD_PARALLEL_LEVEL} # Build nvshmem4py wheels separately (Python 3.10, CUDA 12) to avoid building the python library twice for multiple python versions # Need to reconfigure with PYTHON_LIB=ON to add the nvshmem4py subdirectory -RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ - --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ - export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \ - cd /sgl-workspace/nvshmem && \ +RUN cd /sgl-workspace/nvshmem && \ if [ "$GRACE_BLACKWELL" = true ]; then CUDA_ARCH="90;100;120"; else CUDA_ARCH="90"; fi && \ NVSHMEM_SHMEM_SUPPORT=0 \ NVSHMEM_UCX_SUPPORT=0 \ @@ -268,19 +554,17 @@ RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ NVSHMEM_USE_GDRCOPY=1 \ cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=${NVSHMEM_DIR} -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH} -DNVSHMEM_BUILD_PYTHON_LIB=ON && \ - cmake --build build --target build_nvshmem4py_wheel_cu12_${PYTHON_VERSION} -j${CMAKE_BUILD_PARALLEL_LEVEL} && \ - /tmp/use-sccache.sh show-stats "NVSHMEM4PY" + cmake --build build --target build_nvshmem4py_wheel_cu12_${PYTHON_VERSION} -j${CMAKE_BUILD_PARALLEL_LEVEL} # Install DeepEP -RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ - --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ - export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \ - cd /sgl-workspace/DeepEP && \ +RUN cd /sgl-workspace/DeepEP && \ NVSHMEM_DIR=${NVSHMEM_DIR} TORCH_CUDA_ARCH_LIST="9.0;10.0" pip install --no-build-isolation . # Copy rust installation from dynamo_base to avoid duplication efforts -COPY --from=dynamo_base /usr/local/rustup /usr/local/rustup -COPY --from=dynamo_base /usr/local/cargo /usr/local/cargo +# Pattern: COPY --chmod=775 ; RUN chmod g+w because COPY --chmod only affects /*, not +COPY --from=dynamo_base --chown=dynamo:0 --chmod=775 /usr/local/rustup /usr/local/rustup +COPY --from=dynamo_base --chown=dynamo:0 --chmod=775 /usr/local/cargo /usr/local/cargo +RUN chmod g+w /usr/local/rustup /usr/local/cargo ENV RUSTUP_HOME=/usr/local/rustup \ CARGO_HOME=/usr/local/cargo \ @@ -333,16 +617,20 @@ ${NIXL_PLUGIN_DIR}:\ /usr/local/nvidia/lib64:\ ${LD_LIBRARY_PATH} -# Copy NATS and ETCD from dynamo_base, and UCX/NIXL +# Copy NATS and ETCD from dynamo_base, and UCX/NIXL from wheel_builder COPY --from=dynamo_base /usr/bin/nats-server /usr/bin/nats-server COPY --from=dynamo_base /usr/local/bin/etcd/ /usr/local/bin/etcd/ -COPY --from=dynamo_base /usr/local/ucx /usr/local/ucx -COPY --from=dynamo_base $NIXL_PREFIX $NIXL_PREFIX +COPY --from=wheel_builder /usr/local/ucx /usr/local/ucx +COPY --chown=dynamo: --from=wheel_builder $NIXL_PREFIX $NIXL_PREFIX +COPY --chown=dynamo: --from=wheel_builder /opt/nvidia/nvda_nixl/lib64/. ${NIXL_LIB_DIR}/ +COPY --chown=dynamo: --from=wheel_builder /opt/dynamo/dist/nixl/ /opt/dynamo/wheelhouse/nixl/ +COPY --chown=dynamo: --from=wheel_builder /workspace/nixl/build/src/bindings/python/nixl-meta/nixl-*.whl /opt/dynamo/wheelhouse/nixl/ ENV PATH=/usr/local/bin/etcd/:/usr/local/cuda/nvvm/bin:${HOME}/.local/bin:$PATH # Install Dynamo wheels from dynamo_base wheelhouse -COPY --chown=dynamo: benchmarks/ /opt/dynamo/benchmarks/ -COPY --chown=dynamo: --from=dynamo_base /opt/dynamo/wheelhouse/ /opt/dynamo/wheelhouse/ +# Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not +COPY --chmod=775 --chown=dynamo:0 benchmarks/ /opt/dynamo/benchmarks/ +COPY --chmod=775 --chown=dynamo:0 --from=wheel_builder /opt/dynamo/dist/*.whl /opt/dynamo/wheelhouse/ RUN python3 -m pip install \ /opt/dynamo/wheelhouse/ai_dynamo_runtime*.whl \ /opt/dynamo/wheelhouse/ai_dynamo*any.whl \ @@ -361,7 +649,16 @@ RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requi --requirement /tmp/requirements.test.txt ## Copy attribution files and launch banner with correct ownership -COPY --chown=dynamo: ATTRIBUTION* LICENSE /workspace/ +COPY --chmod=664 --chown=dynamo:0 ATTRIBUTION* LICENSE /workspace/ + +# Copy tests, benchmarks, deploy and components for CI with correct ownership +# Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not +COPY --chmod=775 --chown=dynamo:0 tests /workspace/tests +COPY --chmod=775 --chown=dynamo:0 examples /workspace/examples +COPY --chmod=775 --chown=dynamo:0 benchmarks /workspace/benchmarks +COPY --chmod=775 --chown=dynamo:0 deploy /workspace/deploy +COPY --chmod=775 --chown=dynamo:0 components/ /workspace/components/ +COPY --chmod=775 --chown=dynamo:0 recipes/ /workspace/recipes/ # Setup launch banner in common directory accessible to all users RUN --mount=type=bind,source=./container/launch_message/runtime.txt,target=/opt/dynamo/launch_message.txt \ @@ -369,20 +666,16 @@ RUN --mount=type=bind,source=./container/launch_message/runtime.txt,target=/opt/ # Setup environment for all users USER root -RUN chmod 755 /opt/dynamo/.launch_screen && \ +# Fix directory permissions: COPY --chmod only affects contents, not the directory itself +RUN chmod g+w /workspace /workspace/* /opt/dynamo /opt/dynamo/* && \ + chown dynamo:0 /workspace /opt/dynamo/ && \ + chmod 755 /opt/dynamo/.launch_screen && \ echo 'source /opt/dynamo/venv/bin/activate' >> /etc/bash.bashrc && \ echo 'cat /opt/dynamo/.launch_screen' >> /etc/bash.bashrc USER dynamo # Copy tests, benchmarks, deploy and components for CI with correct ownership -COPY --chown=dynamo: tests /workspace/tests -COPY --chown=dynamo: examples /workspace/examples -COPY --chown=dynamo: benchmarks /workspace/benchmarks -COPY --chown=dynamo: deploy /workspace/deploy -COPY --chown=dynamo: components/ /workspace/components/ -COPY --chown=dynamo: recipes/ /workspace/recipes/ - ARG DYNAMO_COMMIT_SHA ENV DYNAMO_COMMIT_SHA=$DYNAMO_COMMIT_SHA @@ -419,6 +712,8 @@ ENV VIRTUAL_ENV=/opt/dynamo/venv \ PATH="/opt/dynamo/venv/bin:${PATH}" USER root +# venv permissions are handled by umask 002 set earlier + # Install development tools and utilities RUN apt-get update -y && \ apt-get install -y --no-install-recommends \ @@ -477,7 +772,7 @@ RUN curl --retry 3 --retry-delay 2 -LSso /usr/local/bin/clang-format https://git && rm -rf clangd_18.1.3 clangd.zip # Editable install of dynamo -COPY pyproject.toml README.md hatch_build.py /workspace/ +COPY --chmod=664 pyproject.toml README.md hatch_build.py /workspace/ RUN python3 -m pip install --no-deps -e . # Install Python development packages diff --git a/container/Dockerfile.trtllm b/container/Dockerfile.trtllm index ea55a535eb..6e0d3ae0cf 100644 --- a/container/Dockerfile.trtllm +++ b/container/Dockerfile.trtllm @@ -1,5 +1,32 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# NOTE FOR dynamo_base AND wheel_builder STAGES: +# +# All changes to dynamo_base and wheel_builder stages should be replicated across +# Dockerfile and Dockerfile. images.: +# - Dockerfile +# - Dockerfile.vllm +# - Dockerfile.sglang +# - Dockerfile.trtllm +# This duplication was introduced purposely to quickly enable Docker layer caching and +# deduplication. Please ensure these stages stay in sync until the duplication can be +# addressed. +# +# Throughout this file, we make certain paths group-writable because this allows +# both the dynamo user (UID 1000) and Dev Container users (UID != 1000) to work +# properly without needing slow chown -R operations (which can add 2-10 extra +# minutes). +# +# DEVELOPMENT PATHS THAT MUST BE GROUP-WRITABLE (for virtualenv containers): +# /workspace - Users create/modify project files +# /home/dynamo - Users create config/cache files +# /opt/dynamo/venv - TensorRT-LLM uses venv, so entire venv must be writable for pip install +# +# HOW TO ACHIEVE GROUP-WRITABLE PERMISSIONS: +# 1. SHELL + /etc/profile.d - Login shell sources umask 002 globally for all RUN commands (775/664) +# 2. COPY --chmod=775 - Sets permissions on copied children (not destination) +# 3. chmod g+w (no -R) - Fixes destination dirs only (milliseconds vs minutes) # This section contains build arguments that are common and shared with # the plain Dockerfile, so they should NOT have a default. The source of truth is from build.sh. @@ -8,6 +35,8 @@ ARG BASE_IMAGE_TAG ARG PYTHON_VERSION ARG ENABLE_KVBM +ARG ENABLE_MEDIA_NIXL +ARG CARGO_BUILD_JOBS ARG PYTORCH_BASE_IMAGE="nvcr.io/nvidia/pytorch" ARG PYTORCH_BASE_IMAGE_TAG="25.10-py3" @@ -20,6 +49,16 @@ ARG TENSORRTLLM_PIP_WHEEL="tensorrt-llm" ARG TENSORRTLLM_INDEX_URL="https://pypi.nvidia.com/" ARG GITHUB_TRTLLM_COMMIT +# SCCACHE configuration +ARG USE_SCCACHE +ARG SCCACHE_BUCKET="" +ARG SCCACHE_REGION="" + +# NIXL configuration +ARG NIXL_UCX_REF +ARG NIXL_REF +ARG NIXL_GDRCOPY_REF + # Define general architecture ARGs for supporting both x86 and aarch64 builds. # ARCH: Used for package suffixes (e.g., amd64, arm64) # ARCH_ALT: Used for Rust targets, manylinux suffix (e.g., x86_64, aarch64) @@ -35,12 +74,282 @@ ARG GITHUB_TRTLLM_COMMIT ARG ARCH=amd64 ARG ARCH_ALT=x86_64 -ARG DYNAMO_BASE_IMAGE="dynamo:latest-none" -FROM ${DYNAMO_BASE_IMAGE} AS dynamo_base - # Copy artifacts from NGC PyTorch image FROM ${PYTORCH_BASE_IMAGE}:${PYTORCH_BASE_IMAGE_TAG} AS pytorch_base +################################## +########## Base Image ############ +################################## + +FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dynamo_base + +ARG ARCH +ARG ARCH_ALT + +USER root +WORKDIR /opt/dynamo + +# Install uv package manager +COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ + +# Install NATS server +ENV NATS_VERSION="v2.10.28" +RUN --mount=type=cache,target=/var/cache/apt \ + wget --tries=3 --waitretry=5 https://github.com/nats-io/nats-server/releases/download/${NATS_VERSION}/nats-server-${NATS_VERSION}-${ARCH}.deb && \ + dpkg -i nats-server-${NATS_VERSION}-${ARCH}.deb && rm nats-server-${NATS_VERSION}-${ARCH}.deb + +# Install etcd +ENV ETCD_VERSION="v3.5.21" +RUN wget --tries=3 --waitretry=5 https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-${ARCH}.tar.gz -O /tmp/etcd.tar.gz && \ + mkdir -p /usr/local/bin/etcd && \ + tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1 && \ + rm /tmp/etcd.tar.gz +ENV PATH=/usr/local/bin/etcd/:$PATH + +# Rust Setup +# Rust environment setup +ENV RUSTUP_HOME=/usr/local/rustup \ + CARGO_HOME=/usr/local/cargo \ + PATH=/usr/local/cargo/bin:$PATH \ + RUST_VERSION=1.90.0 + +# Define Rust target based on ARCH_ALT ARG +ARG RUSTARCH=${ARCH_ALT}-unknown-linux-gnu + +# Install Rust +RUN wget --tries=3 --waitretry=5 "https://static.rust-lang.org/rustup/archive/1.28.1/${RUSTARCH}/rustup-init" && \ + chmod +x rustup-init && \ + ./rustup-init -y --no-modify-path --profile minimal --default-toolchain $RUST_VERSION --default-host ${RUSTARCH} && \ + rm rustup-init && \ + chmod -R a+w $RUSTUP_HOME $CARGO_HOME + + +################################## +##### Wheel Build Image ########## +################################## + +# Redeclare ARCH_ALT ARG so it's available for interpolation in the FROM instruction +ARG ARCH_ALT + +FROM quay.io/pypa/manylinux_2_28_${ARCH_ALT} AS wheel_builder + +# Redeclare ARGs for this stage +ARG ARCH +ARG ARCH_ALT +ARG CARGO_BUILD_JOBS + +WORKDIR /workspace + +# Copy CUDA from base stage +COPY --from=dynamo_base /usr/local/cuda /usr/local/cuda +COPY --from=dynamo_base /etc/ld.so.conf.d/hpcx.conf /etc/ld.so.conf.d/hpcx.conf + +# Set environment variables first so they can be used in COPY commands +ENV CARGO_BUILD_JOBS=${CARGO_BUILD_JOBS:-16} \ + RUSTUP_HOME=/usr/local/rustup \ + CARGO_HOME=/usr/local/cargo \ + CARGO_TARGET_DIR=/opt/dynamo/target \ + PATH=/usr/local/cargo/bin:$PATH + +# Copy artifacts from base stage +COPY --from=dynamo_base $RUSTUP_HOME $RUSTUP_HOME +COPY --from=dynamo_base $CARGO_HOME $CARGO_HOME +# Install system dependencies +RUN yum groupinstall -y 'Development Tools' && \ + dnf install -y almalinux-release-synergy && \ + dnf config-manager --set-enabled powertools && \ + dnf install -y \ + # Build tools + cmake \ + ninja-build \ + clang-devel \ + gcc-c++ \ + flex \ + wget \ + # Kernel module build dependencies + dkms \ + # Protobuf support + protobuf-compiler \ + # RDMA/InfiniBand support (required for UCX build with --with-verbs) + libibverbs \ + libibverbs-devel \ + rdma-core \ + rdma-core-devel \ + libibumad \ + libibumad-devel \ + librdmacm-devel \ + numactl-devel + +# Ensure a modern protoc is available (required for --experimental_allow_proto3_optional) +RUN set -eux; \ + PROTOC_VERSION=25.3; \ + case "${ARCH_ALT}" in \ + x86_64) PROTOC_ZIP="protoc-${PROTOC_VERSION}-linux-x86_64.zip" ;; \ + aarch64) PROTOC_ZIP="protoc-${PROTOC_VERSION}-linux-aarch_64.zip" ;; \ + *) echo "Unsupported architecture: ${ARCH_ALT}" >&2; exit 1 ;; \ + esac; \ + wget --tries=3 --waitretry=5 -O /tmp/protoc.zip "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/${PROTOC_ZIP}"; \ + rm -f /usr/local/bin/protoc /usr/bin/protoc; \ + unzip -o /tmp/protoc.zip -d /usr/local bin/protoc include/*; \ + chmod +x /usr/local/bin/protoc; \ + ln -s /usr/local/bin/protoc /usr/bin/protoc; \ + protoc --version + +# Point build tools explicitly at the modern protoc +ENV PROTOC=/usr/local/bin/protoc + +ENV CUDA_PATH=/usr/local/cuda \ + PATH=/usr/local/cuda/bin:$PATH \ + LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/lib:/usr/local/lib64:$LD_LIBRARY_PATH \ + NVIDIA_DRIVER_CAPABILITIES=video,compute,utility + +# Create virtual environment for building wheels +ARG PYTHON_VERSION +ENV VIRTUAL_ENV=/workspace/.venv +RUN uv venv ${VIRTUAL_ENV} --python $PYTHON_VERSION && \ + uv pip install --upgrade meson pybind11 patchelf maturin[patchelf] + +ARG NIXL_UCX_REF +ARG NIXL_REF +ARG NIXL_GDRCOPY_REF + +# Build and install gdrcopy +RUN git clone --depth 1 --branch ${NIXL_GDRCOPY_REF} https://github.com/NVIDIA/gdrcopy.git && \ + cd gdrcopy/packages && \ + CUDA=/usr/local/cuda ./build-rpm-packages.sh && \ + rpm -Uvh gdrcopy-kmod-*.el8.noarch.rpm && \ + rpm -Uvh gdrcopy-*.el8.${ARCH_ALT}.rpm && \ + rpm -Uvh gdrcopy-devel-*.el8.noarch.rpm + +# Install SCCACHE if requested +ARG USE_SCCACHE +ARG SCCACHE_BUCKET +ARG SCCACHE_REGION +COPY container/use-sccache.sh /tmp/use-sccache.sh +RUN if [ "$USE_SCCACHE" = "true" ]; then \ + /tmp/use-sccache.sh install; \ + fi + +# Set SCCACHE environment variables +ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET}} \ + SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION}} + +# Build and install UCX +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ + cd /usr/local/src && \ + git clone https://github.com/openucx/ucx.git && \ + cd ucx && \ + git checkout $NIXL_UCX_REF && \ + ./autogen.sh && \ + ./contrib/configure-release \ + --prefix=/usr/local/ucx \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=/usr/local/cuda \ + --with-verbs \ + --with-dm \ + --with-gdrcopy=/usr/local \ + --with-efa \ + --enable-mt && \ + make -j && \ + make -j install-strip && \ + /tmp/use-sccache.sh show-stats "UCX" && \ + echo "/usr/local/ucx/lib" > /etc/ld.so.conf.d/ucx.conf && \ + echo "/usr/local/ucx/lib/ucx" >> /etc/ld.so.conf.d/ucx.conf && \ + ldconfig + +# build and install nixl +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ + source ${VIRTUAL_ENV}/bin/activate && \ + git clone --depth 1 --branch ${NIXL_REF} "https://github.com/ai-dynamo/nixl.git" && \ + cd nixl && \ + mkdir build && \ + meson setup build/ --prefix=/opt/nvidia/nvda_nixl --buildtype=release \ + -Dcudapath_lib="/usr/local/cuda/lib64" \ + -Dcudapath_inc="/usr/local/cuda/include" \ + -Ducx_path="/usr/local/ucx" && \ + cd build && \ + ninja && \ + ninja install && \ + /tmp/use-sccache.sh show-stats "NIXL" + +ENV NIXL_LIB_DIR=/opt/nvidia/nvda_nixl/lib64 \ + NIXL_PLUGIN_DIR=/opt/nvidia/nvda_nixl/lib64/plugins \ + NIXL_PREFIX=/opt/nvidia/nvda_nixl +ENV LD_LIBRARY_PATH=${NIXL_LIB_DIR}:${NIXL_PLUGIN_DIR}:/usr/local/ucx/lib:/usr/local/ucx/lib/ucx:${LD_LIBRARY_PATH} + +RUN echo "$NIXL_LIB_DIR" > /etc/ld.so.conf.d/nixl.conf && \ + echo "$NIXL_PLUGIN_DIR" >> /etc/ld.so.conf.d/nixl.conf && \ + ldconfig + +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ + cd /workspace/nixl && \ + uv build . --out-dir /opt/dynamo/dist/nixl --python $PYTHON_VERSION + +# Copy source code (order matters for layer caching) +COPY pyproject.toml README.md LICENSE Cargo.toml Cargo.lock rust-toolchain.toml hatch_build.py /opt/dynamo/ +COPY launch/ /opt/dynamo/launch/ +COPY lib/ /opt/dynamo/lib/ +COPY components/ /opt/dynamo/components/ + +# Build dynamo wheels +ARG ENABLE_KVBM +ARG USE_SCCACHE +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export RUSTC_WRAPPER="sccache"; \ + fi && \ + source ${VIRTUAL_ENV}/bin/activate && \ + cd /opt/dynamo && \ + uv build --wheel --out-dir /opt/dynamo/dist && \ + cd /opt/dynamo/lib/bindings/python && \ + if [ "$ENABLE_MEDIA_NIXL" = "true" ]; then \ + maturin build --release --features dynamo-llm/media-nixl --out /opt/dynamo/dist; \ + else \ + maturin build --release --out /opt/dynamo/dist; \ + fi && \ + if [ "$ENABLE_KVBM" = "true" ]; then \ + cd /opt/dynamo/lib/bindings/kvbm && \ + maturin build --release --out target/wheels && \ + auditwheel repair \ + --exclude libnixl.so \ + --exclude libnixl_build.so \ + --exclude libnixl_common.so \ + --plat manylinux_2_28_${ARCH_ALT} \ + --wheel-dir /opt/dynamo/dist \ + target/wheels/*.whl; \ + fi && \ + /tmp/use-sccache.sh show-stats "Dynamo" + ################################################## ########## Framework Builder Stage ############## ################################################## @@ -148,7 +457,7 @@ RUN if [ "$HAS_TRTLLM_CONTEXT" = "1" ]; then \ sed -i 's/pip3 install/uv pip install/g' /tmp/install_tensorrt.sh && \ bash /tmp/install_tensorrt.sh && \ # Install TensorRT-LLM wheel from the provided index URL, allow dependencies from PyPI - # TRTLLM 1.2.0rc2 has issues installing from pypi with uv, installing from direct wheel link works best + # TRTLLM 1.2.0rc5 has issues installing from pypi with uv, installing from direct wheel link works best # explicitly installing triton 3.5.0 as trtllm only lists triton as dependency on x64_64 for some reason if echo "${TENSORRTLLM_PIP_WHEEL}" | grep -q '^tensorrt-llm=='; then \ TRTLLM_VERSION=$(echo "${TENSORRTLLM_PIP_WHEEL}" | sed -E 's/tensorrt-llm==([0-9a-zA-Z.+-]+).*/\1/'); \ @@ -238,8 +547,12 @@ RUN userdel -r ubuntu > /dev/null 2>&1 || true \ && useradd -m -s /bin/bash -g 0 dynamo \ && [ `id -u dynamo` -eq 1000 ] \ && mkdir -p /home/dynamo/.cache /opt/dynamo \ - && chown -R dynamo: /workspace /home/dynamo /opt/dynamo \ - && chmod -R g+w /workspace /home/dynamo/.cache /opt/dynamo + # Non-recursive chown - only the directories themselves, not contents + && chown dynamo:0 /home/dynamo /home/dynamo/.cache /opt/dynamo /workspace \ + # No chmod needed: umask 002 handles new files, COPY --chmod handles copied content + # Set umask globally for all subsequent RUN commands (must be done as root before USER dynamo) + # NOTE: Setting ENV UMASK=002 does NOT work - umask is a shell builtin, not an environment variable + && mkdir -p /etc/profile.d && echo 'umask 002' > /etc/profile.d/00-umask.sh # Install Python, build-essential and python3-dev as apt dependencies ARG PYTHON_VERSION @@ -291,6 +604,9 @@ RUN if [ ${ARCH_ALT} = "x86_64" ]; then \ # Switch to dynamo user USER dynamo ENV HOME=/home/dynamo +# This picks up the umask 002 from the /etc/profile.d/00-umask.sh file for subsequent RUN commands +SHELL ["/bin/bash", "-l", "-o", "pipefail", "-c"] + ENV DYNAMO_HOME=/workspace ENV NIXL_PREFIX=/opt/nvidia/nvda_nixl ENV NIXL_LIB_DIR=$NIXL_PREFIX/lib/${ARCH_ALT}-linux-gnu @@ -301,13 +617,17 @@ COPY --from=framework /usr/local/tensorrt /usr/local/tensorrt COPY --from=framework /usr/lib/${ARCH_ALT}-linux-gnu/libgomp.so* /usr/lib/${ARCH_ALT}-linux-gnu/ # Copy pre-built venv with PyTorch and TensorRT-LLM from framework stage -COPY --chown=dynamo: --from=framework ${VIRTUAL_ENV} ${VIRTUAL_ENV} +# Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not +COPY --chmod=775 --chown=dynamo:0 --from=framework ${VIRTUAL_ENV} ${VIRTUAL_ENV} # Copy UCX from framework image as plugin for NIXL # Copy NIXL source from framework image -# Copy dynamo wheels for gitlab artifacts -COPY --chown=dynamo: --from=dynamo_base /usr/local/ucx /usr/local/ucx -COPY --chown=dynamo: --from=dynamo_base $NIXL_PREFIX $NIXL_PREFIX +# Copy dynamo wheels for gitlab artifacts (read-only, no group-write needed) +COPY --chown=dynamo: --from=wheel_builder /usr/local/ucx /usr/local/ucx +COPY --chown=dynamo: --from=wheel_builder $NIXL_PREFIX $NIXL_PREFIX +COPY --chown=dynamo: --from=wheel_builder /opt/nvidia/nvda_nixl/lib64/. ${NIXL_LIB_DIR}/ +COPY --chown=dynamo: --from=wheel_builder /opt/dynamo/dist/nixl/ /opt/dynamo/wheelhouse/nixl/ +COPY --chown=dynamo: --from=wheel_builder /workspace/nixl/build/src/bindings/python/nixl-meta/nixl-*.whl /opt/dynamo/wheelhouse/nixl/ ENV TENSORRT_LIB_DIR=/usr/local/tensorrt/targets/${ARCH_ALT}-linux-gnu/lib ENV PATH="/usr/local/ucx/bin:${VIRTUAL_ENV}/bin:/opt/hpcx/ompi/bin:/usr/local/bin/etcd/:/usr/local/cuda/bin:/usr/local/cuda/nvvm/bin:$PATH" @@ -324,22 +644,30 @@ $TENSORRT_LIB_DIR:\ $LD_LIBRARY_PATH ENV OPAL_PREFIX=/opt/hpcx/ompi -COPY --chown=dynamo: ATTRIBUTION* LICENSE /workspace/ -COPY --chown=dynamo: benchmarks/ /workspace/benchmarks/ +COPY --chmod=664 --chown=dynamo:0 ATTRIBUTION* LICENSE /workspace/ +COPY --chmod=775 --chown=dynamo:0 benchmarks/ /workspace/benchmarks/ # Install dynamo, NIXL, and dynamo-specific dependencies +# Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not ARG ENABLE_KVBM -COPY --chown=dynamo: --from=dynamo_base /opt/dynamo/wheelhouse/ /opt/dynamo/wheelhouse/ +COPY --chmod=775 --chown=dynamo:0 --from=wheel_builder /opt/dynamo/dist/*.whl /opt/dynamo/wheelhouse/ RUN uv pip install \ --no-cache \ /opt/dynamo/wheelhouse/ai_dynamo_runtime*.whl \ /opt/dynamo/wheelhouse/ai_dynamo*any.whl \ - /opt/dynamo/wheelhouse/nixl/nixl*.whl \ - && if [ "${ENABLE_KVBM}" = "true" ]; then \ - uv pip install --no-cache /opt/dynamo/wheelhouse/kvbm*.whl; \ - fi \ - && cd /workspace/benchmarks \ - && UV_GIT_LFS=1 uv pip install --no-cache . + /opt/dynamo/wheelhouse/nixl/nixl*.whl && \ + if [ "${ENABLE_KVBM}" = "true" ]; then \ + KVBM_WHEEL=$(ls /opt/dynamo/wheelhouse/kvbm*.whl 2>/dev/null | head -1); \ + if [ -z "$KVBM_WHEEL" ]; then \ + echo "ERROR: ENABLE_KVBM is true but no KVBM wheel found in wheelhouse" >&2; \ + exit 1; \ + fi; \ + uv pip install --no-cache "$KVBM_WHEEL"; \ + fi && \ + cd /workspace/benchmarks && \ + UV_GIT_LFS=1 uv pip install --no-cache . && \ + # pip/uv bypasses umask when creating .egg-info files, but chmod -R is fast here (small directory) + chmod -R g+w /workspace/benchmarks # Install common and test dependencies RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requirements.txt \ @@ -352,12 +680,13 @@ RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requi --requirement /tmp/requirements.test.txt \ cupy-cuda13x -# Copy tests, benchmarks, deploy and components for CI -COPY --chown=dynamo: tests /workspace/tests -COPY --chown=dynamo: examples /workspace/examples -COPY --chown=dynamo: deploy /workspace/deploy -COPY --chown=dynamo: components/ /workspace/components/ -COPY --chown=dynamo: recipes/ /workspace/recipes/ +# Copy tests, deploy and components for CI with correct ownership +# Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not +COPY --chmod=775 --chown=dynamo:0 tests /workspace/tests +COPY --chmod=775 --chown=dynamo:0 examples /workspace/examples +COPY --chmod=775 --chown=dynamo:0 deploy /workspace/deploy +COPY --chmod=775 --chown=dynamo:0 components/ /workspace/components/ +COPY --chmod=775 --chown=dynamo:0 recipes/ /workspace/recipes/ # Setup launch banner in common directory accessible to all users RUN --mount=type=bind,source=./container/launch_message/runtime.txt,target=/opt/dynamo/launch_message.txt \ @@ -365,7 +694,10 @@ RUN --mount=type=bind,source=./container/launch_message/runtime.txt,target=/opt/ # Setup environment for all users USER root -RUN chmod 755 /opt/dynamo/.launch_screen && \ +# Fix directory permissions: COPY --chmod only affects contents, not the directory itself +RUN chmod g+w ${VIRTUAL_ENV} /workspace /workspace/* /opt/dynamo /opt/dynamo/* && \ + chown dynamo:0 ${VIRTUAL_ENV} /workspace /opt/dynamo/ && \ + chmod 755 /opt/dynamo/.launch_screen && \ echo 'source /opt/dynamo/venv/bin/activate' >> /etc/bash.bashrc && \ echo 'cat /opt/dynamo/.launch_screen' >> /etc/bash.bashrc @@ -426,6 +758,10 @@ RUN apt-get update -y && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* +# Set umask for group-writable files in dev stage (runs as root) +RUN mkdir -p /etc/profile.d && echo 'umask 002' > /etc/profile.d/00-umask.sh +SHELL ["/bin/bash", "-l", "-o", "pipefail", "-c"] + # Set workspace directory variable ENV WORKSPACE_DIR=${WORKSPACE_DIR} \ DYNAMO_HOME=${WORKSPACE_DIR} \ @@ -435,8 +771,11 @@ ENV WORKSPACE_DIR=${WORKSPACE_DIR} \ VIRTUAL_ENV=/opt/dynamo/venv \ PATH=/usr/local/cargo/bin:$PATH -COPY --from=dynamo_base /usr/local/rustup /usr/local/rustup -COPY --from=dynamo_base /usr/local/cargo /usr/local/cargo +# Copy rust installation from dynamo_base to avoid duplication efforts +# Pattern: COPY --chmod=775 ; chmod g+w because COPY --chmod only affects /*, not +COPY --from=dynamo_base --chmod=775 /usr/local/rustup /usr/local/rustup +COPY --from=dynamo_base --chmod=775 /usr/local/cargo /usr/local/cargo +RUN chmod g+w /usr/local/rustup /usr/local/cargo # Install maturin, for maturin develop RUN uv pip install --no-cache maturin[patchelf] diff --git a/container/Dockerfile.vllm b/container/Dockerfile.vllm index fda3069610..1898466bfa 100644 --- a/container/Dockerfile.vllm +++ b/container/Dockerfile.vllm @@ -1,32 +1,48 @@ # syntax=docker/dockerfile:1.10.0 # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# NOTE FOR dynamo_base AND wheel_builder STAGES: +# +# All changes to dynamo_base and wheel_builder stages should be replicated across +# Dockerfile and Dockerfile. images.: +# - Dockerfile +# - Dockerfile.vllm +# - Dockerfile.sglang +# - Dockerfile.trtllm +# This duplication was introduced purposely to quickly enable Docker layer caching and +# deduplication. Please ensure these stages stay in sync until the duplication can be +# addressed. +# +# Throughout this file, we make certain paths group-writable because this allows +# both the dynamo user (UID 1000) and Dev Container users (UID != 1000) to work +# properly without needing slow chown -R operations (which can add 2-10 extra +# minutes). +# +# DEVELOPMENT PATHS THAT MUST BE GROUP-WRITABLE (for virtualenv containers): +# /workspace - Users create/modify project files +# /home/dynamo - Users create config/cache files +# /opt/dynamo/venv - vLLM uses venv, so entire venv must be writable for pip install +# +# HOW TO ACHIEVE GROUP-WRITABLE PERMISSIONS: +# 1. SHELL + /etc/profile.d - Login shell sources umask 002 globally for all RUN commands (775/664) +# 2. COPY --chmod=775 - Sets permissions on copied children (not destination) +# 3. chmod g+w (no -R) - Fixes destination dirs only (milliseconds vs minutes) + +################################## +########## Build Arguments ######## +################################## + +# This section contains build arguments that are common and shared across various +# Dockerfile., so they should NOT have a default. The source of truth is from build.sh. -# This section contains build arguments that are common and shared with -# the plain Dockerfile, so they should NOT have a default. The source of truth is from build.sh. ARG BASE_IMAGE ARG BASE_IMAGE_TAG ARG PYTHON_VERSION ARG ENABLE_KVBM - -ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda" -ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04" -ARG CUDA_VERSION="12.8" - -# Make sure to update the dependency version in pyproject.toml when updating this -ARG VLLM_REF="v0.11.0" -# FlashInfer only respected when building vLLM from source, ie when VLLM_REF does not start with 'v' or for arm64 builds -ARG FLASHINF_REF="v0.3.1" -ARG TORCH_BACKEND="cu128" - -# If left blank, then we will fallback to vLLM defaults -ARG DEEPGEMM_REF="" - -# sccache configuration - inherit from base build -ARG USE_SCCACHE -ARG SCCACHE_BUCKET="" -ARG SCCACHE_REGION="" +ARG ENABLE_MEDIA_NIXL +ARG CARGO_BUILD_JOBS # Define general architecture ARGs for supporting both x86 and aarch64 builds. # ARCH: Used for package suffixes (e.g., amd64, arm64) @@ -37,17 +53,304 @@ ARG SCCACHE_REGION="" # # For arm64/aarch64, build with: # --build-arg ARCH=arm64 --build-arg ARCH_ALT=aarch64 -# -# NOTE: There isn't an easy way to define one of these values based on the other value -# without adding if statements everywhere, so just define both as ARGs for now. +#TODO OPS-592: Leverage uname -m to determine ARCH instead of passing it as an arg ARG ARCH=amd64 ARG ARCH_ALT=x86_64 -ARG DYNAMO_BASE_IMAGE="dynamo:latest-none" -FROM ${DYNAMO_BASE_IMAGE} AS dynamo_base +# SCCACHE configuration +ARG USE_SCCACHE +ARG SCCACHE_BUCKET="" +ARG SCCACHE_REGION="" + +# NIXL configuration +ARG NIXL_UCX_REF +ARG NIXL_REF +ARG NIXL_GDRCOPY_REF + +ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda" +ARG RUNTIME_IMAGE_TAG="12.9.0-runtime-ubuntu24.04" +ARG CUDA_VERSION="12.9" + +# Make sure to update the dependency version in pyproject.toml when updating this +ARG VLLM_REF="v0.12.0" +# FlashInfer Ref used to install flashinfer-cubin and flashinfer-jit-cache +ARG FLASHINF_REF="v0.5.3" + +# If left blank, then we will fallback to vLLM defaults +ARG DEEPGEMM_REF="" +ARG LMCACHE_REF="0.3.10" + +################################## +########## Base Image ############ +################################## + +FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS dynamo_base + +ARG ARCH +ARG ARCH_ALT + +USER root +WORKDIR /opt/dynamo + +# Install uv package manager +COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ + +# Install NATS server +ENV NATS_VERSION="v2.10.28" +RUN --mount=type=cache,target=/var/cache/apt \ + wget --tries=3 --waitretry=5 https://github.com/nats-io/nats-server/releases/download/${NATS_VERSION}/nats-server-${NATS_VERSION}-${ARCH}.deb && \ + dpkg -i nats-server-${NATS_VERSION}-${ARCH}.deb && rm nats-server-${NATS_VERSION}-${ARCH}.deb + +# Install etcd +ENV ETCD_VERSION="v3.5.21" +RUN wget --tries=3 --waitretry=5 https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-${ARCH}.tar.gz -O /tmp/etcd.tar.gz && \ + mkdir -p /usr/local/bin/etcd && \ + tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1 && \ + rm /tmp/etcd.tar.gz +ENV PATH=/usr/local/bin/etcd/:$PATH + +# Rust Setup +# Rust environment setup +ENV RUSTUP_HOME=/usr/local/rustup \ + CARGO_HOME=/usr/local/cargo \ + PATH=/usr/local/cargo/bin:$PATH \ + RUST_VERSION=1.90.0 + +# Define Rust target based on ARCH_ALT ARG +ARG RUSTARCH=${ARCH_ALT}-unknown-linux-gnu + +# Install Rust +RUN wget --tries=3 --waitretry=5 "https://static.rust-lang.org/rustup/archive/1.28.1/${RUSTARCH}/rustup-init" && \ + chmod +x rustup-init && \ + ./rustup-init -y --no-modify-path --profile minimal --default-toolchain $RUST_VERSION --default-host ${RUSTARCH} && \ + rm rustup-init && \ + chmod -R a+w $RUSTUP_HOME $CARGO_HOME + + +################################## +##### Wheel Build Image ########## +################################## + +# Redeclare ARCH_ALT ARG so it's available for interpolation in the FROM instruction +ARG ARCH_ALT + +FROM quay.io/pypa/manylinux_2_28_${ARCH_ALT} AS wheel_builder + +# Redeclare ARGs for this stage +ARG ARCH +ARG ARCH_ALT +ARG CARGO_BUILD_JOBS + +WORKDIR /workspace + +# Copy CUDA from base stage +COPY --from=dynamo_base /usr/local/cuda /usr/local/cuda +COPY --from=dynamo_base /etc/ld.so.conf.d/hpcx.conf /etc/ld.so.conf.d/hpcx.conf + +# Set environment variables first so they can be used in COPY commands +ENV CARGO_BUILD_JOBS=${CARGO_BUILD_JOBS:-16} \ + RUSTUP_HOME=/usr/local/rustup \ + CARGO_HOME=/usr/local/cargo \ + CARGO_TARGET_DIR=/opt/dynamo/target \ + PATH=/usr/local/cargo/bin:$PATH + +# Copy artifacts from base stage +COPY --from=dynamo_base $RUSTUP_HOME $RUSTUP_HOME +COPY --from=dynamo_base $CARGO_HOME $CARGO_HOME +# Install system dependencies +RUN yum groupinstall -y 'Development Tools' && \ + dnf install -y almalinux-release-synergy && \ + dnf config-manager --set-enabled powertools && \ + dnf install -y \ + # Build tools + cmake \ + ninja-build \ + clang-devel \ + gcc-c++ \ + flex \ + wget \ + # Kernel module build dependencies + dkms \ + # Protobuf support + protobuf-compiler \ + # RDMA/InfiniBand support (required for UCX build with --with-verbs) + libibverbs \ + libibverbs-devel \ + rdma-core \ + rdma-core-devel \ + libibumad \ + libibumad-devel \ + librdmacm-devel \ + numactl-devel + +# Ensure a modern protoc is available (required for --experimental_allow_proto3_optional) +RUN set -eux; \ + PROTOC_VERSION=25.3; \ + case "${ARCH_ALT}" in \ + x86_64) PROTOC_ZIP="protoc-${PROTOC_VERSION}-linux-x86_64.zip" ;; \ + aarch64) PROTOC_ZIP="protoc-${PROTOC_VERSION}-linux-aarch_64.zip" ;; \ + *) echo "Unsupported architecture: ${ARCH_ALT}" >&2; exit 1 ;; \ + esac; \ + wget --tries=3 --waitretry=5 -O /tmp/protoc.zip "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/${PROTOC_ZIP}"; \ + rm -f /usr/local/bin/protoc /usr/bin/protoc; \ + unzip -o /tmp/protoc.zip -d /usr/local bin/protoc include/*; \ + chmod +x /usr/local/bin/protoc; \ + ln -s /usr/local/bin/protoc /usr/bin/protoc; \ + protoc --version + +# Point build tools explicitly at the modern protoc +ENV PROTOC=/usr/local/bin/protoc + +ENV CUDA_PATH=/usr/local/cuda \ + PATH=/usr/local/cuda/bin:$PATH \ + LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/lib:/usr/local/lib64:$LD_LIBRARY_PATH \ + NVIDIA_DRIVER_CAPABILITIES=video,compute,utility + +# Create virtual environment for building wheels +ARG PYTHON_VERSION +ENV VIRTUAL_ENV=/workspace/.venv +RUN uv venv ${VIRTUAL_ENV} --python $PYTHON_VERSION && \ + uv pip install --upgrade meson pybind11 patchelf maturin[patchelf] + +ARG NIXL_UCX_REF +ARG NIXL_REF +ARG NIXL_GDRCOPY_REF + +# Build and install gdrcopy +RUN git clone --depth 1 --branch ${NIXL_GDRCOPY_REF} https://github.com/NVIDIA/gdrcopy.git && \ + cd gdrcopy/packages && \ + CUDA=/usr/local/cuda ./build-rpm-packages.sh && \ + rpm -Uvh gdrcopy-kmod-*.el8.noarch.rpm && \ + rpm -Uvh gdrcopy-*.el8.${ARCH_ALT}.rpm && \ + rpm -Uvh gdrcopy-devel-*.el8.noarch.rpm + +# Install SCCACHE if requested +ARG USE_SCCACHE +ARG SCCACHE_BUCKET +ARG SCCACHE_REGION +COPY container/use-sccache.sh /tmp/use-sccache.sh +RUN if [ "$USE_SCCACHE" = "true" ]; then \ + /tmp/use-sccache.sh install; \ + fi + +# Set SCCACHE environment variables +ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET}} \ + SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION}} -# Copy cuda tools and libs from base image -FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS base +# Build and install UCX +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ + cd /usr/local/src && \ + git clone https://github.com/openucx/ucx.git && \ + cd ucx && \ + git checkout $NIXL_UCX_REF && \ + ./autogen.sh && \ + ./contrib/configure-release \ + --prefix=/usr/local/ucx \ + --enable-shared \ + --disable-static \ + --disable-doxygen-doc \ + --enable-optimizations \ + --enable-cma \ + --enable-devel-headers \ + --with-cuda=/usr/local/cuda \ + --with-verbs \ + --with-dm \ + --with-gdrcopy=/usr/local \ + --with-efa \ + --enable-mt && \ + make -j && \ + make -j install-strip && \ + /tmp/use-sccache.sh show-stats "UCX" && \ + echo "/usr/local/ucx/lib" > /etc/ld.so.conf.d/ucx.conf && \ + echo "/usr/local/ucx/lib/ucx" >> /etc/ld.so.conf.d/ucx.conf && \ + ldconfig + +# build and install nixl +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ + source ${VIRTUAL_ENV}/bin/activate && \ + git clone --depth 1 --branch ${NIXL_REF} "https://github.com/ai-dynamo/nixl.git" && \ + cd nixl && \ + mkdir build && \ + meson setup build/ --prefix=/opt/nvidia/nvda_nixl --buildtype=release \ + -Dcudapath_lib="/usr/local/cuda/lib64" \ + -Dcudapath_inc="/usr/local/cuda/include" \ + -Ducx_path="/usr/local/ucx" && \ + cd build && \ + ninja && \ + ninja install && \ + /tmp/use-sccache.sh show-stats "NIXL" + +ENV NIXL_LIB_DIR=/opt/nvidia/nvda_nixl/lib64 \ + NIXL_PLUGIN_DIR=/opt/nvidia/nvda_nixl/lib64/plugins \ + NIXL_PREFIX=/opt/nvidia/nvda_nixl +ENV LD_LIBRARY_PATH=${NIXL_LIB_DIR}:${NIXL_PLUGIN_DIR}:/usr/local/ucx/lib:/usr/local/ucx/lib/ucx:${LD_LIBRARY_PATH} + +RUN echo "$NIXL_LIB_DIR" > /etc/ld.so.conf.d/nixl.conf && \ + echo "$NIXL_PLUGIN_DIR" >> /etc/ld.so.conf.d/nixl.conf && \ + ldconfig + +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX="${SCCACHE_S3_KEY_PREFIX:-${ARCH}}" && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CUDA_COMPILER_LAUNCHER="sccache"; \ + fi && \ + cd /workspace/nixl && \ + uv build . --out-dir /opt/dynamo/dist/nixl --python $PYTHON_VERSION + +# Copy source code (order matters for layer caching) +COPY pyproject.toml README.md LICENSE Cargo.toml Cargo.lock rust-toolchain.toml hatch_build.py /opt/dynamo/ +COPY launch/ /opt/dynamo/launch/ +COPY lib/ /opt/dynamo/lib/ +COPY components/ /opt/dynamo/components/ + +# Build dynamo wheels +ARG ENABLE_KVBM +RUN --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ + --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ + export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \ + if [ "$USE_SCCACHE" = "true" ]; then \ + export CMAKE_C_COMPILER_LAUNCHER="sccache" && \ + export CMAKE_CXX_COMPILER_LAUNCHER="sccache" && \ + export RUSTC_WRAPPER="sccache"; \ + fi && \ + source ${VIRTUAL_ENV}/bin/activate && \ + cd /opt/dynamo && \ + uv build --wheel --out-dir /opt/dynamo/dist && \ + cd /opt/dynamo/lib/bindings/python && \ + if [ "$ENABLE_MEDIA_NIXL" == "true" ]; then \ + maturin build --release --features dynamo-llm/media-nixl --out /opt/dynamo/dist; \ + else \ + maturin build --release --out /opt/dynamo/dist; \ + fi && \ + if [ "$ENABLE_KVBM" == "true" ]; then \ + cd /opt/dynamo/lib/bindings/kvbm && \ + maturin build --release --out target/wheels && \ + auditwheel repair \ + --exclude libnixl.so \ + --exclude libnixl_build.so \ + --exclude libnixl_common.so \ + --plat manylinux_2_28_${ARCH_ALT} \ + --wheel-dir /opt/dynamo/dist \ + target/wheels/*.whl; \ + fi && \ + /tmp/use-sccache.sh show-stats "Dynamo" ######################################################## ########## Framework Development Image ################ @@ -110,42 +413,19 @@ ARG VLLM_REF ARG VLLM_GIT_URL ARG DEEPGEMM_REF ARG FLASHINF_REF -ARG TORCH_BACKEND +ARG LMCACHE_REF ARG CUDA_VERSION ARG MAX_JOBS=16 ENV MAX_JOBS=$MAX_JOBS ENV CUDA_HOME=/usr/local/cuda -# Install sccache if requested -COPY container/use-sccache.sh /tmp/use-sccache.sh -# Install sccache if requested -ARG USE_SCCACHE -ARG ARCH_ALT -ARG SCCACHE_BUCKET -ARG SCCACHE_REGION - -ENV ARCH_ALT=${ARCH_ALT} -RUN if [ "$USE_SCCACHE" = "true" ]; then \ - /tmp/use-sccache.sh install; \ - fi - -# Set environment variables - they'll be empty strings if USE_SCCACHE=false -ENV SCCACHE_BUCKET=${USE_SCCACHE:+${SCCACHE_BUCKET}} \ - SCCACHE_REGION=${USE_SCCACHE:+${SCCACHE_REGION}} \ - CMAKE_C_COMPILER_LAUNCHER=${USE_SCCACHE:+sccache} \ - CMAKE_CXX_COMPILER_LAUNCHER=${USE_SCCACHE:+sccache} \ - CMAKE_CUDA_COMPILER_LAUNCHER=${USE_SCCACHE:+sccache} # Install VLLM and related dependencies RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \ --mount=type=cache,target=/root/.cache/uv \ - --mount=type=secret,id=aws-key-id,env=AWS_ACCESS_KEY_ID \ - --mount=type=secret,id=aws-secret-id,env=AWS_SECRET_ACCESS_KEY \ - export SCCACHE_S3_KEY_PREFIX=${SCCACHE_S3_KEY_PREFIX:-${ARCH}} && \ - cp /tmp/deps/vllm/install_vllm.sh /tmp/install_vllm.sh && \ - chmod +x /tmp/install_vllm.sh && \ - /tmp/install_vllm.sh --editable --vllm-ref $VLLM_REF --max-jobs $MAX_JOBS --arch $ARCH --installation-dir /opt ${DEEPGEMM_REF:+--deepgemm-ref "$DEEPGEMM_REF"} ${FLASHINF_REF:+--flashinf-ref "$FLASHINF_REF"} --torch-backend $TORCH_BACKEND --cuda-version $CUDA_VERSION && \ - /tmp/use-sccache.sh show-stats "vLLM"; + cp /tmp/deps/vllm/install_vllm.sh /tmp/install_vllm.sh && \ + chmod +x /tmp/install_vllm.sh && \ + /tmp/install_vllm.sh --vllm-ref $VLLM_REF --max-jobs $MAX_JOBS --arch $ARCH --installation-dir /opt ${DEEPGEMM_REF:+--deepgemm-ref "$DEEPGEMM_REF"} ${FLASHINF_REF:+--flashinf-ref "$FLASHINF_REF"} ${LMCACHE_REF:+--lmcache-ref "$LMCACHE_REF"} --cuda-version $CUDA_VERSION ENV LD_LIBRARY_PATH=\ /opt/vllm/tools/ep_kernels/ep_kernels_workspace/nvshmem_install/lib:\ @@ -181,13 +461,13 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" ENV CUDA_DEVICE_ORDER=PCI_BUS_ID # Copy CUDA development tools (nvcc, headers, dependencies, etc.) from base devel image -COPY --from=base /usr/local/cuda/bin/nvcc /usr/local/cuda/bin/nvcc -COPY --from=base /usr/local/cuda/bin/cudafe++ /usr/local/cuda/bin/cudafe++ -COPY --from=base /usr/local/cuda/bin/ptxas /usr/local/cuda/bin/ptxas -COPY --from=base /usr/local/cuda/bin/fatbinary /usr/local/cuda/bin/fatbinary -COPY --from=base /usr/local/cuda/include/ /usr/local/cuda/include/ -COPY --from=base /usr/local/cuda/nvvm /usr/local/cuda/nvvm -COPY --from=base /usr/local/cuda/lib64/libcudart.so* /usr/local/cuda/lib64/ +COPY --from=dynamo_base /usr/local/cuda/bin/nvcc /usr/local/cuda/bin/nvcc +COPY --from=dynamo_base /usr/local/cuda/bin/cudafe++ /usr/local/cuda/bin/cudafe++ +COPY --from=dynamo_base /usr/local/cuda/bin/ptxas /usr/local/cuda/bin/ptxas +COPY --from=dynamo_base /usr/local/cuda/bin/fatbinary /usr/local/cuda/bin/fatbinary +COPY --from=dynamo_base /usr/local/cuda/include/ /usr/local/cuda/include/ +COPY --from=dynamo_base /usr/local/cuda/nvvm /usr/local/cuda/nvvm +COPY --from=dynamo_base /usr/local/cuda/lib64/libcudart.so* /usr/local/cuda/lib64/ RUN ln -s /usr/local/cuda/lib64/libcublas.so.12 /usr/local/cuda/lib64/libcublas.so RUN ln -s /usr/local/cuda/lib64/libcublasLt.so.12 /usr/local/cuda/lib64/libcublasLt.so @@ -210,8 +490,12 @@ RUN userdel -r ubuntu > /dev/null 2>&1 || true \ && useradd -m -s /bin/bash -g 0 dynamo \ && [ `id -u dynamo` -eq 1000 ] \ && mkdir -p /home/dynamo/.cache /opt/dynamo \ - && chown -R dynamo: /workspace /home/dynamo /opt/dynamo \ - && chmod -R g+w /workspace /home/dynamo/.cache /opt/dynamo + # Non-recursive chown - only the directories themselves, not contents + && chown dynamo:0 /home/dynamo /home/dynamo/.cache /opt/dynamo /workspace \ + # No chmod needed: umask 002 handles new files, COPY --chmod handles copied content + # Set umask globally for all subsequent RUN commands (must be done as root before USER dynamo) + # NOTE: Setting ENV UMASK=002 does NOT work - umask is a shell builtin, not an environment variable + && mkdir -p /etc/profile.d && echo 'umask 002' > /etc/profile.d/00-umask.sh ARG ARCH_ALT ARG PYTHON_VERSION @@ -236,25 +520,32 @@ RUN apt-get update && \ # prometheus dependencies ca-certificates \ # DeepGemm uses 'cuobjdump' which does not come with CUDA image - cuda-command-line-tools-12-8 && \ + cuda-command-line-tools-12-9 && \ rm -rf /var/lib/apt/lists/* USER dynamo ENV HOME=/home/dynamo +# This picks up the umask 002 from the /etc/profile.d/00-umask.sh file for subsequent RUN commands +SHELL ["/bin/bash", "-l", "-o", "pipefail", "-c"] + ENV NIXL_PREFIX=/opt/nvidia/nvda_nixl ENV NIXL_LIB_DIR=$NIXL_PREFIX/lib/${ARCH_ALT}-linux-gnu ENV NIXL_PLUGIN_DIR=$NIXL_LIB_DIR/plugins ### VIRTUAL ENVIRONMENT SETUP ### # Copy entire virtual environment from framework container with correct ownership -COPY --chown=dynamo: --from=framework ${VIRTUAL_ENV} ${VIRTUAL_ENV} - -# Copy vllm with correct ownership -COPY --chown=dynamo: --from=framework /opt/vllm /opt/vllm - -# Copy UCX and NIXL to system directories -COPY --chown=dynamo: --from=dynamo_base /usr/local/ucx /usr/local/ucx -COPY --chown=dynamo: --from=dynamo_base $NIXL_PREFIX $NIXL_PREFIX +# Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not +COPY --chmod=775 --chown=dynamo:0 --from=framework ${VIRTUAL_ENV} ${VIRTUAL_ENV} + +# Copy vllm with correct ownership (read-only, no group-write needed) +COPY --chown=dynamo:0 --from=framework /opt/vllm /opt/vllm + +# Copy UCX and NIXL to system directories (read-only, no group-write needed) +COPY --from=wheel_builder /usr/local/ucx /usr/local/ucx +COPY --chown=dynamo: --from=wheel_builder $NIXL_PREFIX $NIXL_PREFIX +COPY --chown=dynamo: --from=wheel_builder /opt/nvidia/nvda_nixl/lib64/. ${NIXL_LIB_DIR}/ +COPY --chown=dynamo: --from=wheel_builder /opt/dynamo/dist/nixl/ /opt/dynamo/wheelhouse/nixl/ +COPY --chown=dynamo: --from=wheel_builder /workspace/nixl/build/src/bindings/python/nixl-meta/nixl-*.whl /opt/dynamo/wheelhouse/nixl/ ENV PATH=/usr/local/ucx/bin:$PATH ENV LD_LIBRARY_PATH=\ @@ -265,22 +556,31 @@ $NIXL_PLUGIN_DIR:\ /usr/local/ucx/lib/ucx:\ $LD_LIBRARY_PATH -# Copy local files -COPY --chown=dynamo: ATTRIBUTION* LICENSE /workspace/ -COPY --chown=dynamo: benchmarks/ /workspace/benchmarks/ +# Copy attribution files +COPY --chmod=664 --chown=dynamo:0 ATTRIBUTION* LICENSE /workspace/ +# Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not +COPY --chmod=775 --chown=dynamo:0 benchmarks/ /workspace/benchmarks/ # Install dynamo, NIXL, and dynamo-specific dependencies +# Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not ARG ENABLE_KVBM -COPY --chown=dynamo: --from=dynamo_base /opt/dynamo/wheelhouse/ /opt/dynamo/wheelhouse/ +COPY --chmod=775 --chown=dynamo:0 --from=wheel_builder /opt/dynamo/dist/*.whl /opt/dynamo/wheelhouse/ RUN uv pip install \ /opt/dynamo/wheelhouse/ai_dynamo_runtime*.whl \ /opt/dynamo/wheelhouse/ai_dynamo*any.whl \ - /opt/dynamo/wheelhouse/nixl/nixl*.whl \ - && if [ "${ENABLE_KVBM}" = "true" ]; then \ - uv pip install /opt/dynamo/wheelhouse/kvbm*.whl; \ - fi \ - && cd /workspace/benchmarks \ - && UV_GIT_LFS=1 uv pip install --no-cache . + /opt/dynamo/wheelhouse/nixl/nixl*.whl && \ + if [ "${ENABLE_KVBM}" = "true" ]; then \ + KVBM_WHEEL=$(ls /opt/dynamo/wheelhouse/kvbm*.whl 2>/dev/null | head -1); \ + if [ -z "$KVBM_WHEEL" ]; then \ + echo "ERROR: ENABLE_KVBM is true but no KVBM wheel found in wheelhouse" >&2; \ + exit 1; \ + fi; \ + uv pip install "$KVBM_WHEEL"; \ + fi && \ + cd /workspace/benchmarks && \ + UV_GIT_LFS=1 uv pip install --no-cache . && \ + # pip/uv bypasses umask when creating .egg-info files, but chmod -R is fast here (small directory) + chmod -R g+w /workspace/benchmarks # Install common and test dependencies RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requirements.txt \ @@ -290,13 +590,14 @@ RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requi --requirement /tmp/requirements.txt \ --requirement /tmp/requirements.test.txt -# Copy tests, benchmarks, deploy and components for CI -COPY --chown=dynamo: tests /workspace/tests -COPY --chown=dynamo: examples /workspace/examples -COPY --chown=dynamo: deploy /workspace/deploy -COPY --chown=dynamo: recipes/ /workspace/recipes/ -COPY --chown=dynamo: components/ /workspace/components/ -COPY --chown=dynamo: lib/ /workspace/lib/ +# Copy tests, deploy and components for CI with correct ownership +# Pattern: COPY --chmod=775 ; chmod g+w done later as root because COPY --chmod only affects /*, not +COPY --chmod=775 --chown=dynamo:0 tests /workspace/tests +COPY --chmod=775 --chown=dynamo:0 examples /workspace/examples +COPY --chmod=775 --chown=dynamo:0 deploy /workspace/deploy +COPY --chmod=775 --chown=dynamo:0 recipes/ /workspace/recipes/ +COPY --chmod=775 --chown=dynamo:0 components/ /workspace/components/ +COPY --chmod=775 --chown=dynamo:0 lib/ /workspace/lib/ # Setup launch banner in common directory accessible to all users RUN --mount=type=bind,source=./container/launch_message/runtime.txt,target=/opt/dynamo/launch_message.txt \ @@ -304,7 +605,9 @@ RUN --mount=type=bind,source=./container/launch_message/runtime.txt,target=/opt/ # Setup environment for all users USER root -RUN chmod 755 /opt/dynamo/.launch_screen && \ +# Fix directory permissions: COPY --chmod only affects contents, not the directory itself +RUN chmod g+w /workspace /workspace/* /opt/dynamo /opt/dynamo/* ${VIRTUAL_ENV} && \ + chmod 755 /opt/dynamo/.launch_screen && \ echo 'source /opt/dynamo/venv/bin/activate' >> /etc/bash.bashrc && \ echo 'cat /opt/dynamo/.launch_screen' >> /etc/bash.bashrc @@ -363,6 +666,10 @@ RUN apt-get update -y && \ protobuf-compiler && \ rm -rf /var/lib/apt/lists/* +# Set umask for group-writable files in dev stage (runs as root) +RUN mkdir -p /etc/profile.d && echo 'umask 002' > /etc/profile.d/00-umask.sh +SHELL ["/bin/bash", "-l", "-o", "pipefail", "-c"] + # Set workspace directory variable ENV WORKSPACE_DIR=${WORKSPACE_DIR} \ DYNAMO_HOME=${WORKSPACE_DIR} \ @@ -372,11 +679,15 @@ ENV WORKSPACE_DIR=${WORKSPACE_DIR} \ VIRTUAL_ENV=/opt/dynamo/venv \ PATH=/usr/local/cargo/bin:$PATH -COPY --from=dynamo_base /usr/local/rustup /usr/local/rustup -COPY --from=dynamo_base /usr/local/cargo /usr/local/cargo +# Copy rust installation from dynamo_base to avoid duplication efforts +# Pattern: COPY --chmod=775 ; chmod g+w because COPY --chmod only affects /*, not +COPY --from=dynamo_base --chmod=775 /usr/local/rustup /usr/local/rustup +COPY --from=dynamo_base --chmod=775 /usr/local/cargo /usr/local/cargo +RUN chmod g+w /usr/local/rustup /usr/local/cargo # Install maturin, for maturin develop # Editable install of dynamo +COPY pyproject.toml README.md hatch_build.py /workspace/ RUN uv pip install maturin[patchelf] && \ uv pip install --no-deps -e . diff --git a/container/README.md b/container/README.md index 00220d6021..458193c3ce 100644 --- a/container/README.md +++ b/container/README.md @@ -18,6 +18,48 @@ The NVIDIA Dynamo project uses containerized development and deployment to maint - `Dockerfile.frontend` - For Kubernetes Gateway API Inference Extension integration with EPP - `Dockerfile.epp` - For building the Endpoint Picker (EPP) image +### Stage Summary for Frameworks + +
+Show Stage Summary Table +Dockerfile.${FRAMEWORK} General Structure + +Below is a summary of the general file structure for the framework Dockerfile stages. Some exceptions exist. + +| Stage/Filepath | Target | +| --- | --- | +| **STAGE dynamo_base** | **FROM ${BASE_IMAGE}** | +| /bin/uv, /bin/uvx | COPY from ghcr.io/astral-sh/uv:latest (โ†’ framework, runtime) | +| /usr/bin/nats-server | Downloaded from GitHub (โ†’ runtime) | +| /usr/local/bin/etcd/ | Downloaded from GitHub (โ†’ runtime) | +| /usr/local/rustup/ | Installed via rustup-init (โ†’ wheel_builder, dev) | +| /usr/local/cargo/ | Installed via rustup-init (โ†’ wheel_builder, dev) | +| /usr/local/cuda/ | Inherited from BASE_IMAGE (โ†’ wheel_builder, runtime) | +| **STAGE: wheel_builder** | **FROM quay.io/pypa/manylinux_2_28_${ARCH_ALT}** | +| /usr/local/ucx/ | Built from source (โ†’ runtime) +| /opt/nvidia/nvda_nixl/ | Built from source (โ†’ runtime) +| /opt/nvidia/nvda_nixl/lib64/ | Built from source (โ†’ runtime) +| /opt/dynamo/target/ | Cargo build output (โ†’ runtime) +| /opt/dynamo/dist/*.whl | Built wheels (โ†’ runtime) +| /opt/dynamo/dist/nixl/ | Built nixl wheels (โ†’ runtime) +| **STAGE: framework** | **FROM ${BASE_IMAGE}** | +| /opt/dynamo/venv/ | Created with uv venv (โ†’ runtime) +| /${FRAMEWORK_INSTALL} | Built framework (โ†’ runtime) +| **STAGE: runtime** | **FROM ${RUNTIME_IMAGE}** | +| /usr/local/cuda/{bin,include,nvvm}/ | COPY from dynamo_base | +| /usr/bin/nats-server | COPY from dynamo_runtime | +| /usr/local/bin/etcd/ | COPY from dynamo_runtime | +| /usr/local/ucx/ | COPY from dynamo_runtime | +| /opt/nvidia/nvda_nixl/ | COPY from wheel_builder | +| /opt/dynamo/wheelhouse/ | COPY from wheel_builder | +| /opt/dynamo/venv/ | COPY from framework | +| /opt/vllm/ | COPY from framework | +| /workspace/{tests,examples,deploy}/ |COPY from build context | +| **STAGE: dev** | **FROM runtime** | +| /usr/local/rustup/ | COPY from dynamo_runtime | +| /usr/local/cargo/ | COPY from dynamo_runtime | +
+ ### Why Containerization? Each inference framework (vLLM, TensorRT-LLM, SGLang) has specific CUDA versions, Python dependencies, and system libraries. Containers provide consistent environments, framework isolation, and proper GPU configurations across development and production. @@ -37,102 +79,34 @@ The `build.sh` and `run.sh` scripts are convenience wrappers that simplify commo ## Development Targets Feature Matrix -These targets are specified with `build.sh --target ` and correspond to Docker multi-stage build targets defined in the Dockerfiles (e.g., `FROM somebase AS `). Some commonly used targets include: - -- `runtime` - For running pre-built containers without development tools (minimal size, runs as non-root `dynamo` user with UID 1000 and GID 0) -- `dev` - For development (inferencing/benchmarking/etc, runs as root user for maximum flexibility) -- `local-dev` - For development with local user permissions matching host UID/GID. This is useful when mounting host partitions (with local user permissions) to Docker partitions. The `dynamo` user UID/GID is remapped to match the host user. - -Additional targets are available in the Dockerfiles for specific build stages and use cases. - -| Feature | **dev + `run.sh`** | **local-dev + `run.sh`** | **local-dev + Dev Container** | -|---------|-------------------|--------------------------|-------------------------------| -| **Default User** | root | dynamo (matched to host UID/GID) | dynamo (matched to host UID/GID) | -| **User Setup** | None (root) | Matches UID/GID of `build.sh` user | Matches UID/GID of `build.sh` user | -| **Permissions** | root | dynamo with sudo | dynamo with sudo | -| **Home Directory** | `/root` | `/home/dynamo` | `/home/dynamo` | -| **Working Directory** | `/workspace` | `/workspace` | `/workspace` | -| **Rust Toolchain** | System install (`/usr/local/rustup`, `/usr/local/cargo`) | User install (`~/.rustup`, `~/.cargo`) | User install (`~/.rustup`, `~/.cargo`) | -| **Python Env** | dynamo user owned | dynamo owned venv | dynamo owned venv | -| **File Permissions** | root-level | user-level (dynamo), safe | user-level (dynamo), safe | -| **Compatibility** | Legacy workflows, maximum flexibility | workspace writable on NFS, non-root security | workspace writable on NFS, non-root security | +**Note**: In Dynamo, "targets" and "Docker stages" are synonymous. Each target corresponds to a stage in the multi-stage Docker build. Similarly, "frameworks" and "engines" are synonymous (vLLM, TensorRT-LLM, SGLang). -## Environment Variables Across Build Stages - -Understanding how environment variables change across different build stages is crucial for development and debugging. The Dynamo build system uses a multi-stage Docker build process where environment variables are set, inherited, and overridden at different stages. - -### Build Stage Flow - -``` -Dockerfile โ†’ base โ†’ dev (dynamo-base image) - โ†“ -Dockerfile.vllm โ†’ framework โ†’ runtime โ†’ dev (vllm dev image) - โ†“ -Dockerfile.local_dev โ†’ local-dev (from vllm dev image) -``` - -### Environment Variables by Stage - -| Variable | **base** | **baseโ†’dev** | **vllmโ†’framework** | **vllmโ†’runtime** | **vllmโ†’dev** | **local-dev** | -|----------------------|---------------------|----------------------|--------------------|--------------------|--------------|--------------------| -| **DYNAMO_HOME** | โŒ Not set | `/opt/dynamo` | โŒ Not set | `/opt/dynamo` | `/workspace` โœ… **OVERRIDE** | `/workspace` (inherited) | -| **WORKSPACE_DIR** | โŒ Not set | โŒ Not set | โŒ Not set | โŒ Not set | `/workspace` | `/workspace` (inherited) | -| **CARGO_TARGET_DIR** | โŒ Not set | `/opt/dynamo/target` | โŒ Not set | โŒ Not set | `/workspace/target` โœ… **OVERRIDE** | `/workspace/target` (inherited) | -| **VIRTUAL_ENV** | `/opt/dynamo/venv` | (inherited) | `/opt/dynamo/venv` | `/opt/dynamo/venv` | `/opt/dynamo/venv` โœ… **REDEFINE** | `/opt/dynamo/venv` (inherited) | -| **RUSTUP_HOME** | `/usr/local/rustup` | (inherited) | โŒ Not set | โŒ Not set | `/usr/local/rustup` | `/home/dynamo/.rustup` โœ… **OVERRIDE** | -| **CARGO_HOME** | `/usr/local/cargo` | (inherited) | โŒ Not set | โŒ Not set | `/usr/local/cargo` | `/home/dynamo/.cargo` โœ… **OVERRIDE** | -| **USERNAME** | โŒ Not set | `dynamo` | โŒ Not set | `dynamo` | โŒ Not set | `dynamo` | -| **HOME** | (system default) | `/home/dynamo` | (system default) | `/home/dynamo` | (system default) | `/home/dynamo` | -| **PATH** | (includes cargo) | (inherited) | (system default) | (includes venv, etcd, ucx) | `/usr/local/cargo/bin:$PATH` | `/home/dynamo/.cargo/bin:$PATH` โœ… **OVERRIDE** | - -### Key Insights - -**1. DYNAMO_HOME Dual Purpose:** -- `baseโ†’dev` and `vllmโ†’runtime`: `/opt/dynamo` - For **installed/packaged** Dynamo (CI, production) -- `vllmโ†’dev` and `local-dev`: `/workspace` - For **development** with source code mounted from host - -**2. Rust Toolchain Location:** -- `dev` target: System-wide at `/usr/local/rustup` and `/usr/local/cargo` (suitable for root) -- `local-dev` target: User-specific at `/home/dynamo/.rustup` and `/home/dynamo/.cargo` (proper UID/GID ownership) - -**3. Build Artifacts Location:** -- `baseโ†’dev`: `/opt/dynamo/target` - Build artifacts with installed package -- `vllmโ†’dev` onward: `/workspace/target` - Build artifacts in mounted workspace for persistence - -**4. Variables That Stay Constant:** -- `VIRTUAL_ENV`: Always `/opt/dynamo/venv` (ownership changes in local-dev via rsync) -- `WORKSPACE_DIR`: Always `/workspace` once set in vllmโ†’dev -- `DYNAMO_HOME`: Always `/workspace` once overridden in vllmโ†’dev (for development) - -**5. local-dev Specific Changes:** -From `Dockerfile.local_dev`, the Rust toolchain is moved to user home because: -- Workspace mount points may change, breaking toolchain paths -- User needs ownership of cargo binaries and registry for package installation -- Toolchain requires consistent system paths that don't depend on workspace location - -The Python venv ownership is also updated via rsync in local-dev to match the user's UID/GID, ensuring package installation permissions work correctly. - -**6. Non-Root User Architecture:** -Dynamo containers implement a multi-stage user strategy: -- **runtime stage**: Runs as non-root `dynamo` user (UID 1000, GID 0) for production workloads -- **dev stage**: Runs as root for maximum development flexibility (builds on runtime but switches to root) -- **local-dev stage**: Runs as `dynamo` user with UID/GID matched to host user for safe file system operations -- **Security**: Runtime and local-dev use non-root execution to reduce attack surface -- **File Ownership**: All application files, virtual environments, and build artifacts are owned by `dynamo:root` (1000:0) in runtime stage -- **Environment Setup**: Launch banner moved to `/opt/dynamo/.launch_screen` (shared across all users) and venv activation configured in `/etc/bash.bashrc` for system-wide availability. This replaces the previous per-user `~/.launch_screen` and `~/.bashrc` approach. +| Feature | **runtime + `run.sh`** | **local-dev (`run.sh` or Dev Container)** | **dev + `run.sh`** (legacy) | +|---------|----------------------|-------------------------------------------|--------------------------| +| **Usage** | Benchmarking inference and deployments, non-root | Development, compilation, testing locally | Legacy workflows, root user, use with caution | +| **User** | dynamo (UID 1000) | dynamo (UID=host user) with sudo | root (UID 0, use with caution) | +| **Home Directory** | `/home/dynamo` | `/home/dynamo` | `/root` | +| **Working Directory** | `/workspace` (in-container or mounted) | `/workspace` (must be mounted w/ `--mount-workspace`) | `/workspace` (must be mounted w/ `--mount-workspace`) | +| **Rust Toolchain** | None (uses pre-built wheels) | System install (`/usr/local/rustup`, `/usr/local/cargo`) | System install (`/usr/local/rustup`, `/usr/local/cargo`) | +| **Cargo Target** | None | `/workspace/target` | `/workspace/target` | +| **Python Env** | venv (`/opt/dynamo/venv`) for vllm/trtllm, system site-packages for sglang | venv (`/opt/dynamo/venv`) for vllm/trtllm, system site-packages for sglang | venv (`/opt/dynamo/venv`) for vllm/trtllm, system site-packages for sglang | ## Usage Guidelines -- **Use runtime target**: for production deployments. Runs as non-root `dynamo` user (UID 1000, GID 0) for security -- **Use dev + `run.sh`**: for command-line testing and inferencing. Runs as root for maximum flexibility -- **Use local-dev + `run.sh`**: for command-line development and Docker mounted partitions. Runs as `dynamo` user with UID/GID matched to your local user. Add `-it` flag for interactive sessions -- **Use local-dev + Dev Container**: VS Code/Cursor Dev Container Plugin, using `dynamo` user with UID/GID matched to your local user +- **Use runtime target**: for benchmarking inference and deployments. Runs as non-root `dynamo` user (UID 1000, GID 0) for security +- **Use local-dev + `run.sh`**: for command-line development and Docker mounted partitions. Runs as `dynamo` user with UID matched to your local user, GID 0. Add `-it` flag for interactive sessions +- **Use local-dev + Dev Container**: VS Code/Cursor Dev Container Plugin, using `dynamo` user with UID matched to your local user, GID 0 +- **Use dev + `run.sh`**: Root user, use with caution. Runs as root for backward compatibility with early workflows ## Example Commands -### 1. dev + `run.sh` (runs as root): +### 1. runtime target (runs as non-root dynamo user): ```bash -run.sh ... +# Build runtime image +./build.sh --framework vllm --target runtime + +# Run runtime container +./run.sh --image dynamo:latest-vllm-runtime -it ``` ### 2. local-dev + `run.sh` (runs as dynamo user with matched host UID/GID): @@ -141,16 +115,7 @@ run.sh --mount-workspace -it --image dynamo:latest-vllm-local-dev ... ``` ### 3. local-dev + Dev Container Extension: -Use VS Code/Cursor Dev Container Extension with devcontainer.json configuration. The `dynamo` user UID/GID is automatically matched to your local user. - -### 4. runtime target (runs as non-root dynamo user): -```bash -# Build runtime image -./build.sh --framework vllm --target runtime - -# Run runtime container -./run.sh --image dynamo:latest-vllm-runtime -``` +Use VS Code/Cursor Dev Container Extension with devcontainer.json configuration. The `dynamo` user UID is automatically matched to your local user. ## Build and Run Scripts Overview @@ -196,23 +161,6 @@ The `build.sh` script is responsible for building Docker images for different AI ./build.sh --build-arg CUSTOM_ARG=value ``` -### build.sh --dev-image - Local Development Image Builder - -The `build.sh --dev-image` option takes a dev image and then builds a local-dev image, which contains proper local user permissions. It also includes extra developer utilities (debugging tools, text editors, system monitors, etc.). - -**Common Usage Examples:** - -```bash -# Build local-dev image from dev image dynamo:latest-vllm -./build.sh --dev-image dynamo:latest-vllm --framework vllm - -# Build with custom tag from dev image dynamo:latest-vllm -./build.sh --dev-image dynamo:latest-vllm --framework vllm --tag my-local:dev - -# Dry run to see what would be built -./build.sh --dev-image dynamo:latest-vllm --framework vllm --dry-run -``` - ### Building the Frontend Image The frontend image is a specialized container that includes the Dynamo components (NATS, etcd, dynamo, NIXL, etc) along with the Endpoint Picker (EPP) for Kubernetes Gateway API Inference Extension integration. This image is primarily used for inference gateway deployments. @@ -230,6 +178,7 @@ Follow the instructions in [`deploy/inference-gateway/README.md`](../deploy/infe The base image contains the core Dynamo runtime components, NATS server, etcd, and Python dependencies: ```bash # Build the base dev image (framework=none for frontend-only deployment) +# Note: --framework none defaults ENABLE_MEDIA_NIXL=false ./build.sh --framework none --target dev ``` diff --git a/container/build.sh b/container/build.sh index aac74644fd..4680add783 100755 --- a/container/build.sh +++ b/container/build.sh @@ -89,7 +89,7 @@ DEFAULT_TENSORRTLLM_PIP_WHEEL_DIR="/tmp/trtllm_wheel/" # TensorRT-LLM commit to use for building the trtllm wheel if not provided. # Important Note: This commit is not used in our CI pipeline. See the CI # variables to learn how to run a pipeline with a specific commit. -DEFAULT_EXPERIMENTAL_TRTLLM_COMMIT="31116825b39f4e6a6a1e127001f5204b73d1dc32" # 1.2.0rc2 +DEFAULT_EXPERIMENTAL_TRTLLM_COMMIT="e4c707845ff58fcc0b1d87afb4dd0e64885c780a" # 1.2.0rc5 TRTLLM_COMMIT="" TRTLLM_USE_NIXL_KVCACHE_EXPERIMENTAL="0" TRTLLM_GIT_URL="" @@ -98,7 +98,7 @@ TRTLLM_GIT_URL="" DEFAULT_TENSORRTLLM_INDEX_URL="https://pypi.nvidia.com/" # TODO: Remove the version specification from here and use the ai-dynamo[trtllm] package. # Need to update the Dockerfile.trtllm to use the ai-dynamo[trtllm] package. -DEFAULT_TENSORRTLLM_PIP_WHEEL="tensorrt-llm==1.2.0rc3" +DEFAULT_TENSORRTLLM_PIP_WHEEL="tensorrt-llm==1.2.0rc5" TENSORRTLLM_PIP_WHEEL="" VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" @@ -106,7 +106,7 @@ VLLM_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" # Please check https://github.com/ai-dynamo/dynamo/pull/1065 # for details and reproducer to manually test if the image # can be updated to later versions. -VLLM_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" +VLLM_BASE_IMAGE_TAG="25.04-cuda12.9-devel-ubuntu24.04" NONE_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" NONE_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" @@ -122,9 +122,14 @@ SGLANG_FRAMEWORK_IMAGE_TAG="${SGLANG_CUDA_VERSION}-cudnn-devel-ubuntu24.04" NIXL_REF=0.7.1 NIXL_UCX_REF=v1.19.0 NIXL_UCX_EFA_REF=9d2b88a1f67faf9876f267658bd077b379b8bb76 +NIXL_GDRCOPY_REF=v2.5.1 NO_CACHE="" +# KVBM (KV Cache Block Manager) - default disabled, enabled automatically for VLLM/TRTLLM +# or can be explicitly enabled via --enable-kvbm flag +ENABLE_KVBM=false + # sccache configuration for S3 USE_SCCACHE="" SCCACHE_BUCKET="" @@ -194,7 +199,6 @@ get_options() { fi ;; --base-image) - # Note: --base-image cannot be used with --dev-image if [ "$2" ]; then BASE_IMAGE=$2 shift @@ -218,14 +222,6 @@ get_options() { missing_requirement "$1" fi ;; - --dev-image) - if [ "$2" ]; then - DEV_IMAGE_INPUT=$2 - shift - else - missing_requirement "$1" - fi - ;; --uid) if [ "$2" ]; then CUSTOM_UID=$2 @@ -272,7 +268,7 @@ get_options() { ;; --cache-from) if [ "$2" ]; then - CACHE_FROM="--cache-from $2" + CACHE_FROM+="--cache-from $2 " shift else missing_requirement "$1" @@ -280,7 +276,7 @@ get_options() { ;; --cache-to) if [ "$2" ]; then - CACHE_TO="--cache-to $2" + CACHE_TO+="--cache-to $2 " shift else missing_requirement "$1" @@ -297,6 +293,9 @@ get_options() { --enable-kvbm) ENABLE_KVBM=true ;; + --enable-media-nixl) + ENABLE_MEDIA_NIXL=true + ;; --make-efa) NIXL_UCX_REF=$NIXL_UCX_EFA_REF ;; @@ -345,20 +344,10 @@ get_options() { shift done - # Validate argument combinations - if [[ -n "${DEV_IMAGE_INPUT:-}" && -n "${BASE_IMAGE:-}" ]]; then - error "ERROR: --dev-image cannot be used with --base-image. Use --dev-image to build from existing images or --base-image to build new images." - fi - - # Validate that --target and --dev-image cannot be used together - if [[ -n "${DEV_IMAGE_INPUT:-}" && -n "${TARGET:-}" ]]; then - error "ERROR: --target cannot be used with --dev-image. Use --target to build from scratch or --dev-image to build from existing images." - fi - - # Validate that --uid and --gid are only used with local-dev related options + # Validate that --uid and --gid are only used with local-dev target if [[ -n "${CUSTOM_UID:-}" || -n "${CUSTOM_GID:-}" ]]; then - if [[ -z "${DEV_IMAGE_INPUT:-}" && "${TARGET:-}" != "local-dev" ]]; then - error "ERROR: --uid and --gid can only be used with --dev-image or --target local-dev" + if [[ "${TARGET:-}" != "local-dev" ]]; then + error "ERROR: --uid and --gid can only be used with --target local-dev" fi fi @@ -460,15 +449,15 @@ show_help() { echo " [--cache-from cache location to start from]" echo " [--cache-to location where to cache the build output]" echo " [--tag tag for image]" - echo " [--dev-image dev image to build local-dev from]" - echo " [--uid user ID for local-dev images (only with --dev-image or --target local-dev)]" - echo " [--gid group ID for local-dev images (only with --dev-image or --target local-dev)]" + echo " [--uid user ID for local-dev images (only with --target local-dev)]" + echo " [--gid group ID for local-dev images (only with --target local-dev)]" echo " [--no-cache disable docker build cache]" echo " [--dry-run print docker commands without running]" echo " [--build-context name=path to add build context]" echo " [--release-build perform a release build]" echo " [--make-efa Enables EFA support for NIXL]" echo " [--enable-kvbm Enables KVBM support in Python 3.12]" + echo " [--enable-media-nixl Enable media processing with NIXL support (default: true for frameworks, false for none)]" echo " [--use-sccache enable sccache for Rust/C/C++ compilation caching]" echo " [--sccache-bucket S3 bucket name for sccache (required with --use-sccache)]" echo " [--sccache-region S3 region for sccache (required with --use-sccache)]" @@ -543,17 +532,13 @@ fi # Add NIXL_REF as a build argument BUILD_ARGS+=" --build-arg NIXL_REF=${NIXL_REF} " -# Function to build local-dev image with header +# Function to build local-dev image build_local_dev_with_header() { local dev_base_image="$1" local tags="$2" local success_msg="$3" local header_title="$4" - echo "======================================" - echo "$header_title" - echo "======================================" - # Get user info right before using it USER_UID=${CUSTOM_UID:-$(id -u)} USER_GID=${CUSTOM_GID:-$(id -g)} @@ -566,7 +551,8 @@ build_local_dev_with_header() { exit 1 fi - echo "Building new local-dev image from: $dev_base_image" + echo "" + echo "Now building new local-dev image from: $dev_base_image" echo "User 'dynamo' will have UID: $USER_UID, GID: $USER_GID" # Show the docker command being executed if not in dry-run mode @@ -593,8 +579,8 @@ build_local_dev_with_header() { # Show usage instructions echo "" echo "To run the local-dev image as the local user ($USER_UID/$USER_GID):" - # Extract the last tag from the tags string - last_tag=$(echo "$tags" | grep -o -- '--tag [^ ]*' | tail -1 | cut -d' ' -f2) + # Extract the first tag from the tags string (the full version tag, not the latest tag) + last_tag=$(echo "$tags" | grep -o -- '--tag [^ ]*' | head -1 | cut -d' ' -f2) # Calculate relative path to run.sh from current working directory # Get the directory where build.sh is located build_dir="$(dirname "${BASH_SOURCE[0]}")" @@ -798,24 +784,41 @@ fi # ENABLE_KVBM: Used in base Dockerfile for block-manager feature. # Declared but not currently used in Dockerfile.{vllm,trtllm}. +# Force KVBM to be enabled for VLLM and TRTLLM frameworks if [[ $FRAMEWORK == "VLLM" ]] || [[ $FRAMEWORK == "TRTLLM" ]]; then echo "Forcing enable_kvbm to true in ${FRAMEWORK} image build" ENABLE_KVBM=true -else - ENABLE_KVBM=false fi +# For other frameworks, ENABLE_KVBM defaults to false unless --enable-kvbm flag was provided -if [ ! -z ${ENABLE_KVBM} ]; then - echo "Enabling the KVBM in the dynamo image" +if [[ ${ENABLE_KVBM} == "true" ]]; then + echo "Enabling KVBM in the dynamo image" BUILD_ARGS+=" --build-arg ENABLE_KVBM=${ENABLE_KVBM} " fi -# NIXL_UCX_REF: Used in base Dockerfile only. -# Passed to framework Dockerfile.{vllm,sglang,...} where it's NOT used. +# ENABLE_MEDIA_NIXL: Enable media processing with NIXL support +# Used in base Dockerfile for maturin build feature flag. +# Can be explicitly overridden with --enable-media-nixl flag +if [ -z "${ENABLE_MEDIA_NIXL}" ]; then + if [[ $FRAMEWORK == "VLLM" ]] || [[ $FRAMEWORK == "TRTLLM" ]] || [[ $FRAMEWORK == "SGLANG" ]]; then + ENABLE_MEDIA_NIXL=true + else + ENABLE_MEDIA_NIXL=false + fi +fi +BUILD_ARGS+=" --build-arg ENABLE_MEDIA_NIXL=${ENABLE_MEDIA_NIXL} " + +# NIXL_UCX_REF: Used in dynamo base stages. if [ -n "${NIXL_UCX_REF}" ]; then BUILD_ARGS+=" --build-arg NIXL_UCX_REF=${NIXL_UCX_REF} " fi +# NIXL_GDRCOPY_REF: Used in dynamo base stages. +if [ -n "${NIXL_GDRCOPY_REF}" ]; then + BUILD_ARGS+=" --build-arg NIXL_GDRCOPY_REF=${NIXL_GDRCOPY_REF} " + +fi + # MAX_JOBS is only used by Dockerfile.vllm if [ -n "${MAX_JOBS}" ]; then BUILD_ARGS+=" --build-arg MAX_JOBS=${MAX_JOBS} " @@ -853,117 +856,27 @@ fi show_image_options -if [ -z "$RUN_PREFIX" ]; then - set -x -fi - - -# Skip Build 1 and Build 2 if DEV_IMAGE_INPUT is set (we'll handle it at the bottom) -if [[ -z "${DEV_IMAGE_INPUT:-}" ]]; then - # Follow 2-step build process for all frameworks - if [[ $FRAMEWORK != "NONE" ]]; then - # Define base image tag with framework suffix to prevent clobbering - # Different frameworks require different base configurations: - # - VLLM: Python 3.12, ENABLE_KVBM=true, BASE_IMAGE=cuda-dl-base - # - SGLANG: Python 3.10, BASE_IMAGE=cuda-dl-base - # - TRTLLM: Python 3.12, ENABLE_KVBM=true, BASE_IMAGE=pytorch - # Without unique tags, building different frameworks would overwrite each other's names - DYNAMO_BASE_IMAGE="dynamo-base:${VERSION}-${FRAMEWORK,,}" - # Start base image build - echo "======================================" - echo "Starting Build 1: Base Image" - echo "======================================" - - # Create build log directory for BuildKit reports - BUILD_LOG_DIR="${BUILD_CONTEXT}/build-logs" - mkdir -p "${BUILD_LOG_DIR}" - BASE_BUILD_LOG="${BUILD_LOG_DIR}/base-image-build.log" - - # Use BuildKit for enhanced metadata - if [ -z "$RUN_PREFIX" ]; then - if docker buildx version &>/dev/null; then - docker buildx build --progress=plain --load -f "${SOURCE_DIR}/Dockerfile" --target runtime $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO --tag $DYNAMO_BASE_IMAGE $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE 2>&1 | tee "${BASE_BUILD_LOG}" - BUILD_EXIT_CODE=${PIPESTATUS[0]} - else - DOCKER_BUILDKIT=1 docker build --progress=plain -f "${SOURCE_DIR}/Dockerfile" --target runtime $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO --tag $DYNAMO_BASE_IMAGE $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE 2>&1 | tee "${BASE_BUILD_LOG}" - BUILD_EXIT_CODE=${PIPESTATUS[0]} - fi - - if [ ${BUILD_EXIT_CODE} -ne 0 ]; then - exit ${BUILD_EXIT_CODE} - fi - else - $RUN_PREFIX docker build -f "${SOURCE_DIR}/Dockerfile" --target runtime $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO --tag $DYNAMO_BASE_IMAGE $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE - fi - - # Start framework build - echo "======================================" - echo "Starting Build 2: Framework Image" - echo "======================================" - - FRAMEWORK_BUILD_LOG="${BUILD_LOG_DIR}/framework-${FRAMEWORK,,}-build.log" - - BUILD_ARGS+=" --build-arg DYNAMO_BASE_IMAGE=${DYNAMO_BASE_IMAGE}" - - # Use BuildKit for enhanced metadata - if [ -z "$RUN_PREFIX" ]; then - if docker buildx version &>/dev/null; then - docker buildx build --progress=plain --load -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE 2>&1 | tee "${FRAMEWORK_BUILD_LOG}" - BUILD_EXIT_CODE=${PIPESTATUS[0]} - else - DOCKER_BUILDKIT=1 docker build --progress=plain -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE 2>&1 | tee "${FRAMEWORK_BUILD_LOG}" - BUILD_EXIT_CODE=${PIPESTATUS[0]} - fi - - if [ ${BUILD_EXIT_CODE} -ne 0 ]; then - exit ${BUILD_EXIT_CODE} - fi - else - $RUN_PREFIX docker build -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE - fi - else - # Create build log directory for BuildKit reports - BUILD_LOG_DIR="${BUILD_CONTEXT}/build-logs" - mkdir -p "${BUILD_LOG_DIR}" - SINGLE_BUILD_LOG="${BUILD_LOG_DIR}/single-stage-build.log" - - # Use BuildKit for enhanced metadata - if [ -z "$RUN_PREFIX" ]; then - if docker buildx version &>/dev/null; then - docker buildx build --progress=plain --load -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE 2>&1 | tee "${SINGLE_BUILD_LOG}" - BUILD_EXIT_CODE=${PIPESTATUS[0]} - else - DOCKER_BUILDKIT=1 docker build --progress=plain -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE 2>&1 | tee "${SINGLE_BUILD_LOG}" - BUILD_EXIT_CODE=${PIPESTATUS[0]} - fi +# Always build the main image first +# Create build log directory for BuildKit reports +BUILD_LOG_DIR="${BUILD_CONTEXT}/build-logs" +mkdir -p "${BUILD_LOG_DIR}" +SINGLE_BUILD_LOG="${BUILD_LOG_DIR}/single-stage-build.log" - if [ ${BUILD_EXIT_CODE} -ne 0 ]; then - exit ${BUILD_EXIT_CODE} - fi - else - $RUN_PREFIX docker build -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE - fi - fi +# Use BuildKit for enhanced metadata +if docker buildx version &>/dev/null; then + $RUN_PREFIX docker buildx build --progress=plain --load -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE 2>&1 | tee "${SINGLE_BUILD_LOG}" + BUILD_EXIT_CODE=${PIPESTATUS[0]} +else + $RUN_PREFIX DOCKER_BUILDKIT=1 docker build --progress=plain -f $DOCKERFILE $TARGET_STR $PLATFORM $BUILD_ARGS $CACHE_FROM $CACHE_TO $TAG $LATEST_TAG $BUILD_CONTEXT_ARG $BUILD_CONTEXT $NO_CACHE 2>&1 | tee "${SINGLE_BUILD_LOG}" + BUILD_EXIT_CODE=${PIPESTATUS[0]} fi -# Handle --dev-image option (build local-dev from existing dev image) -if [[ -n "${DEV_IMAGE_INPUT:-}" ]]; then - # Validate that the dev image is not already a local-dev image - if [[ "$DEV_IMAGE_INPUT" == *"-local-dev" ]]; then - echo "ERROR: Cannot use local-dev image as dev image input: '$DEV_IMAGE_INPUT'" - exit 1 - fi - - # Build tag arguments - always add -local-dev suffix for --dev-image - # Generate local-dev tag from input image - if [[ "$DEV_IMAGE_INPUT" == *:* ]]; then - LOCAL_DEV_TAG="--tag ${DEV_IMAGE_INPUT}-local-dev" - else - LOCAL_DEV_TAG="--tag ${DEV_IMAGE_INPUT}:latest-local-dev" - fi +if [ ${BUILD_EXIT_CODE} -ne 0 ]; then + exit ${BUILD_EXIT_CODE} +fi - build_local_dev_with_header "$DEV_IMAGE_INPUT" "$LOCAL_DEV_TAG" "Successfully built local-dev image: ${LOCAL_DEV_TAG#--tag }" "Building Local-Dev Image" -elif [[ "${LOCAL_DEV_BUILD:-}" == "true" ]]; then +# Handle local-dev target +if [[ "${LOCAL_DEV_BUILD:-}" == "true" ]]; then # Use the first tag name (TAG) if available, otherwise use latest if [[ -n "$TAG" ]]; then DEV_IMAGE=$(echo "$TAG" | sed 's/--tag //' | sed 's/-local-dev$//') @@ -985,8 +898,10 @@ elif [[ "${LOCAL_DEV_BUILD:-}" == "true" ]]; then LOCAL_DEV_TAGS+=" --tag ${LATEST_TAG_NAME}-local-dev" fi - build_local_dev_with_header "$DEV_IMAGE" "$LOCAL_DEV_TAGS" "Successfully built local-dev images" "Starting Build 3: Local-Dev Image" + # Extract first tag for success message + FIRST_TAG=$(echo "$LOCAL_DEV_TAGS" | grep -o -- '--tag [^ ]*' | head -1 | cut -d' ' -f2) + build_local_dev_with_header "$DEV_IMAGE" "$LOCAL_DEV_TAGS" "Successfully built $FIRST_TAG" "Building Local-Dev Image" fi -{ set +x; } 2>/dev/null \ No newline at end of file +{ set +x; } 2>/dev/null diff --git a/container/deps/requirements.txt b/container/deps/requirements.txt index 08d3cd6c9f..9b1b2a0447 100644 --- a/container/deps/requirements.txt +++ b/container/deps/requirements.txt @@ -11,9 +11,9 @@ # maximum versions available on different platforms (x86_64 vs aarch64, different CUDA versions) # For Multimodal EPD (required for device_map="auto" in vision model loading) -accelerate==1.12.0 -aiconfigurator @ git+https://github.com/ai-dynamo/aiconfigurator.git@5554d2eb8206738c66048bf2d72183e9bcd85759 -aiofiles==24.1.0 +accelerate +aiconfigurator[webapp] @ git+https://github.com/ai-dynamo/aiconfigurator.git@7f7ad5e248f3eaa4a0b74a069095828a4f356e60 +aiofiles aiperf @ git+https://github.com/ai-dynamo/aiperf.git@4d3fa29403c8f75da22a14f1f7b3aeb27db9288f av==15.0.0 fastapi==0.120.1 @@ -30,6 +30,7 @@ msgspec==0.19.0 mypy==1.18.2 nvidia-ml-py<=13.580.65 # NVIDIA/CUDA related, may vary by driver version opentelemetry-api<=1.38.0 # May need to stay in sync with other components +opentelemetry-exporter-otlp<=1.38.0 # May need to stay in sync with other components opentelemetry-sdk<=1.38.0 # May need to stay in sync with other components pip<=25.0.1 # System pip, varies by platform pmdarima==2.1.1 @@ -38,7 +39,7 @@ prometheus-api-client==0.6.0 prometheus_client==0.23.1 prophet==1.2.1 protobuf==5.29.5 -pydantic>=2.11.4,<2.12 # Required by aiconfigurator==0.4.0 +pydantic>=2.11.4,<2.13 # vllm==0.12.0 depends on pydantic>=2.12.0 pyright==1.1.407 PyYAML==6.0.3 scikit-learn==1.7.2 @@ -51,8 +52,8 @@ tensorboard==2.19.0 tensorboardX==2.6.2.2 # Transformers version constraint for container builds # - vLLM 0.11.0: >=4.55.2, vLLM 0.11.2: >=4.56.0,<5 -# - TensorRT-LLM 1.2.0rc2/rc3: ==4.56.0 -# - SGLang 0.5.4.post3: ==4.57.1 +# - TensorRT-LLM 1.2.0rc5: ==4.56.0 +# - SGLang 0.5.6: ==4.57.1 # Using >=4.56.0 and <=4.57.1 to satisfy all frameworks transformers>=4.56.0,<=4.57.1 types-aiofiles==25.1.0.20251011 diff --git a/container/deps/vllm/install_vllm.sh b/container/deps/vllm/install_vllm.sh index 0ebbb58823..8365deecf6 100755 --- a/container/deps/vllm/install_vllm.sh +++ b/container/deps/vllm/install_vllm.sh @@ -2,18 +2,16 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# This script is used to install vLLM and its dependencies -# If installing vLLM from a release tag, we will use pip to manage the install -# Otherwise, we will use git to checkout the vLLM source code and build it from source. -# The dependencies are installed in the following order: -# 1. vLLM -# 2. LMCache +# This script installs vLLM and its dependencies from PyPI (release versions only). +# Installation order: +# 1. LMCache (installed first so vLLM's dependencies take precedence) +# 2. vLLM # 3. DeepGEMM # 4. EP kernels set -euo pipefail -VLLM_REF="v0.11.0" +VLLM_REF="v0.12.0" # Basic Configurations ARCH=$(uname -m) @@ -21,34 +19,19 @@ MAX_JOBS=16 INSTALLATION_DIR=/tmp # VLLM and Dependency Configurations -TORCH_BACKEND="cu128" TORCH_CUDA_ARCH_LIST="9.0;10.0" # For EP Kernels DEEPGEMM_REF="" -CUDA_VERSION="12.8" # For DEEPGEMM - -# These flags are applicable when installing vLLM from source code -EDITABLE=true -VLLM_GIT_URL="https://github.com/vllm-project/vllm.git" -FLASHINF_REF="v0.3.1" +CUDA_VERSION="12.9" +FLASHINF_REF="v0.5.3" +# LMCache version - 0.3.9+ required for vLLM 0.11.2 compatibility +LMCACHE_REF="0.3.10" while [[ $# -gt 0 ]]; do case $1 in - --editable) - EDITABLE=true - shift - ;; - --no-editable) - EDITABLE=false - shift - ;; --vllm-ref) VLLM_REF="$2" shift 2 ;; - --vllm-git-url) - VLLM_GIT_URL="$2" - shift 2 - ;; --max-jobs) MAX_JOBS="$2" shift 2 @@ -69,8 +52,8 @@ while [[ $# -gt 0 ]]; do FLASHINF_REF="$2" shift 2 ;; - --torch-backend) - TORCH_BACKEND="$2" + --lmcache-ref) + LMCACHE_REF="$2" shift 2 ;; --torch-cuda-arch-list) @@ -82,19 +65,17 @@ while [[ $# -gt 0 ]]; do shift 2 ;; -h|--help) - echo "Usage: $0 [--editable|--no-editable] [--vllm-ref REF] [--max-jobs NUM] [--arch ARCH] [--deepgemm-ref REF] [--flashinf-ref REF] [--torch-backend BACKEND] [--torch-cuda-arch-list LIST] [--cuda-version VERSION]" + echo "Usage: $0 [--vllm-ref REF] [--max-jobs NUM] [--arch ARCH] [--deepgemm-ref REF] [--flashinf-ref REF] [--lmcache-ref REF] [--torch-cuda-arch-list LIST] [--cuda-version VERSION]" echo "Options:" - echo " --editable Install vllm in editable mode (default)" - echo " --no-editable Install vllm in non-editable mode" - echo " --vllm-ref REF Git reference to checkout (default: ${VLLM_REF})" - echo " --max-jobs NUM Maximum number of parallel jobs (default: ${MAX_JOBS})" - echo " --arch ARCH Architecture (amd64|arm64, default: auto-detect)" - echo " --installation-dir DIR Directory to install vllm (default: ${INSTALLATION_DIR})" - echo " --deepgemm-ref REF Git reference for DeepGEMM (default: ${DEEPGEMM_REF})" - echo " --flashinf-ref REF Git reference for Flash Infer (default: ${FLASHINF_REF})" - echo " --torch-backend BACKEND Torch backend to use (default: ${TORCH_BACKEND})" - echo " --torch-cuda-arch-list LIST CUDA architectures to compile for (default: ${TORCH_CUDA_ARCH_LIST})" - echo " --cuda-version VERSION CUDA version to use (default: ${CUDA_VERSION})" + echo " --vllm-ref REF vLLM release version (default: ${VLLM_REF})" + echo " --max-jobs NUM Maximum parallel jobs (default: ${MAX_JOBS})" + echo " --arch ARCH Architecture amd64|arm64 (default: auto-detect)" + echo " --installation-dir DIR Install directory (default: ${INSTALLATION_DIR})" + echo " --deepgemm-ref REF DeepGEMM git ref (default: ${DEEPGEMM_REF})" + echo " --flashinf-ref REF FlashInfer version (default: ${FLASHINF_REF})" + echo " --lmcache-ref REF LMCache version (default: ${LMCACHE_REF})" + echo " --torch-cuda-arch-list LIST CUDA architectures (default: ${TORCH_CUDA_ARCH_LIST})" + echo " --cuda-version VERSION CUDA version (default: ${CUDA_VERSION})" exit 0 ;; *) @@ -114,119 +95,43 @@ fi export MAX_JOBS=$MAX_JOBS export CUDA_HOME=/usr/local/cuda +# Derive torch backend from CUDA version (e.g., "12.9" -> "cu129") +TORCH_BACKEND="cu$(echo $CUDA_VERSION | tr -d '.')" + echo "=== Installing prerequisites ===" uv pip install pip cuda-python echo "\n=== Configuration Summary ===" -echo " VLLM_REF=$VLLM_REF | EDITABLE=$EDITABLE | ARCH=$ARCH" -echo " MAX_JOBS=$MAX_JOBS | TORCH_BACKEND=$TORCH_BACKEND | CUDA_VERSION=$CUDA_VERSION" -echo " TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST" -echo " DEEPGEMM_REF=$DEEPGEMM_REF | FLASHINF_REF=$FLASHINF_REF" -echo " INSTALLATION_DIR=$INSTALLATION_DIR | VLLM_GIT_URL=$VLLM_GIT_URL" +echo " VLLM_REF=$VLLM_REF | ARCH=$ARCH | CUDA_VERSION=$CUDA_VERSION | TORCH_BACKEND=$TORCH_BACKEND" +echo " FLASHINF_REF=$FLASHINF_REF | LMCACHE_REF=$LMCACHE_REF | DEEPGEMM_REF=$DEEPGEMM_REF" +echo " TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST | INSTALLATION_DIR=$INSTALLATION_DIR" + +echo "\n=== Installing LMCache ===" +if [ "$ARCH" = "amd64" ]; then + # LMCache installation currently fails on arm64 due to CUDA dependency issues + # Install LMCache BEFORE vLLM so vLLM's dependencies take precedence + uv pip install lmcache==${LMCACHE_REF} --torch-backend=${TORCH_BACKEND} + echo "โœ“ LMCache ${LMCACHE_REF} installed" +else + echo "โš  Skipping LMCache on ARM64 (compatibility issues)" +fi echo "\n=== Cloning vLLM repository ===" -# We need to clone to install dependencies +# Clone needed for DeepGEMM and EP kernels install scripts cd $INSTALLATION_DIR -git clone $VLLM_GIT_URL vllm +git clone https://github.com/vllm-project/vllm.git vllm cd vllm git checkout $VLLM_REF -# TODO leave this here in case we need to do cherry-picks in future -# GIT_COMMITTER_NAME="Container Build" GIT_COMMITTER_EMAIL="container@buildkitsandbox.local" git cherry-pick 740f064 - echo "\n=== Installing vLLM & FlashInfer ===" +echo "Installing vLLM $VLLM_REF from PyPI..." -if [[ $VLLM_REF =~ ^v ]] && { [ "$ARCH" = "amd64" ] || { [ "$ARCH" = "arm64" ] && [ "$TORCH_BACKEND" = "cu129" ]; }; }; then - # VLLM_REF starts with 'v' and either amd64, or arm64 with cu129 backend - use PyPI install - echo "Installing vLLM $VLLM_REF from PyPI... (ARCH=$ARCH, TORCH_BACKEND=$TORCH_BACKEND)" - - uv pip install vllm[flashinfer]==$VLLM_REF --torch-backend=$TORCH_BACKEND - -else - # VLLM_REF does not start with 'v' or amd64 - use git checkout path - if [ "$ARCH" = "arm64" ]; then - - # torch 2.8.0 doesn't have a aarch wheel for cu128, vLLM uses torch 2.8.0 nightly wheel builds to compile its aarch wheel against - # nightly can be unstable so we will not use it here - # for now we will use torch 2.7.1+cu128 but this requires a recompilation from source - - echo "Building vLLM from source for ARM64 architecture..." - - # Try to install specific PyTorch version first - echo "Attempting to install pinned PyTorch nightly versions..." - if ! uv pip install torch==2.7.1+cu128 torchaudio==2.7.1 torchvision==0.22.1 --index-url https://download.pytorch.org/whl/cu128; then - echo "Pinned versions failed" - exit 1 - fi - - # Create constraints file to pin all PyTorch-related versions - echo "Creating constraints file to preserve PyTorch ecosystem versions..." - TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") - TORCHAUDIO_VERSION=$(python -c "import torchaudio; print(torchaudio.__version__)") - TORCHVISION_VERSION=$(python -c "import torchvision; print(torchvision.__version__)") - - rm -rf /tmp/torch_constraints.txt - echo "torch==$TORCH_VERSION" > /tmp/torch_constraints.txt - echo "torchaudio==$TORCHAUDIO_VERSION" >> /tmp/torch_constraints.txt - echo "torchvision==$TORCHVISION_VERSION" >> /tmp/torch_constraints.txt - - echo "Pinned versions:" - echo " - torch==$TORCH_VERSION" - echo " - torchaudio==$TORCHAUDIO_VERSION" - echo " - torchvision==$TORCHVISION_VERSION" - - python use_existing_torch.py - uv pip install -c /tmp/torch_constraints.txt -r requirements/build.txt - - if [ "$EDITABLE" = "true" ]; then - MAX_JOBS=${MAX_JOBS} uv pip install --no-build-isolation -c /tmp/torch_constraints.txt -e . -v - else - MAX_JOBS=${MAX_JOBS} uv pip install --no-build-isolation -c /tmp/torch_constraints.txt . -v - fi - - echo "\n=== Installing FlashInfer from source ===" - cd $INSTALLATION_DIR - git clone https://github.com/flashinfer-ai/flashinfer.git --recursive - cd flashinfer - git checkout $FLASHINF_REF - - # Install with constraints to prevent PyTorch upgrade - uv pip install -v --no-build-isolation -c /tmp/torch_constraints.txt . - - else - echo "Building vLLM from source for AMD64 architecture..." - - # When updating above VLLM_REF make sure precompiled wheel file URL is correct. Run this command: - # aws s3 ls s3://vllm-wheels/${VLLM_REF}/ --region us-west-2 --no-sign-request - export VLLM_PRECOMPILED_WHEEL_LOCATION="https://vllm-wheels.s3.us-west-2.amazonaws.com/${VLLM_REF}/vllm-0.10.2-cp38-abi3-manylinux1_x86_64.whl" - - if [ "$EDITABLE" = "true" ]; then - uv pip install -e . --torch-backend=$TORCH_BACKEND - else - uv pip install . --torch-backend=$TORCH_BACKEND - fi - - echo "\n=== Installing FlashInfer from PyPI ===" - uv pip install flashinfer-python==$FLASHINF_REF - - fi -fi +uv pip install vllm[flashinfer]==$VLLM_REF --torch-backend=${TORCH_BACKEND} +uv pip install flashinfer-cubin==$FLASHINF_REF +uv pip install flashinfer-jit-cache==$FLASHINF_REF --extra-index-url https://flashinfer.ai/whl/${TORCH_BACKEND} echo "โœ“ vLLM installation completed" -echo "\n=== Installing LMCache ===" -if [ "$ARCH" = "amd64" ]; then - # LMCache installation currently fails on arm64 due to CUDA dependency issues: - # OSError: CUDA_HOME environment variable is not set. Please set it to your CUDA install root. - # TODO: Re-enable for arm64 after verifying lmcache compatibility and resolving the build issue. - - # Alec: Likely lmcache was compiled witha different version of torch and need to install it from source for arm64 - uv pip install lmcache==0.3.7 - echo "โœ“ LMCache installed" -else - echo "โš  Skipping LMCache on ARM64 (compatibility issues)" -fi - echo "\n=== Installing DeepGEMM ===" cd $INSTALLATION_DIR/vllm/tools @@ -239,6 +144,7 @@ echo "โœ“ DeepGEMM installation completed" echo "\n=== Installing EP Kernels (PPLX and DeepEP) ===" cd ep_kernels/ +# TODO we will be able to specify which pplx and deepep commit we want in future TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" bash install_python_libraries.sh echo "\nโœ… All installations completed successfully!" diff --git a/deploy/cloud/helm/crds/templates/nvidia.com_dynamocomponentdeployments.yaml b/deploy/cloud/helm/crds/templates/nvidia.com_dynamocomponentdeployments.yaml index 558a5b973d..c90e3bdfe7 100644 --- a/deploy/cloud/helm/crds/templates/nvidia.com_dynamocomponentdeployments.yaml +++ b/deploy/cloud/helm/crds/templates/nvidia.com_dynamocomponentdeployments.yaml @@ -77,12 +77,13 @@ spec: (such as Pod, Service, and Ingress when applicable). type: object autoscaling: - description: Autoscaling config for this component (replica range, target utilization, etc.). + description: |- + Deprecated: This field is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter + with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md + for migration guidance. This field will be removed in a future API version. properties: behavior: - description: |- - HorizontalPodAutoscalerBehavior configures the scaling behavior of the target - in both Up and Down directions (scaleUp and scaleDown fields respectively). + description: 'Deprecated: This field is ignored.' properties: scaleDown: description: |- @@ -231,10 +232,13 @@ spec: type: object type: object enabled: + description: 'Deprecated: This field is ignored.' type: boolean maxReplicas: + description: 'Deprecated: This field is ignored.' type: integer metrics: + description: 'Deprecated: This field is ignored.' items: description: |- MetricSpec specifies how to scale based on a single metric @@ -665,6 +669,7 @@ spec: type: object type: array minReplicas: + description: 'Deprecated: This field is ignored.' type: integer type: object backendFramework: @@ -10184,8 +10189,12 @@ spec: type: integer type: object replicas: - description: Replicas is the desired number of Pods for this component when autoscaling is not used. + description: |- + Replicas is the desired number of Pods for this component. + When scalingAdapter is enabled (default), this field is managed by the + DynamoGraphDeploymentScalingAdapter and should not be modified directly. format: int32 + minimum: 0 type: integer resources: description: |- @@ -10264,6 +10273,20 @@ spec: type: string type: object type: object + scalingAdapter: + description: |- + ScalingAdapter configures whether this service uses the DynamoGraphDeploymentScalingAdapter. + When enabled (default), replicas are managed via DGDSA and external autoscalers can scale + the service using the Scale subresource. When disabled, replicas can be modified directly. + properties: + disable: + default: false + description: |- + Disable indicates whether the ScalingAdapter should be disabled for this service. + When false (default), a DGDSA is created and owns the replicas field. + When true, no DGDSA is created and replicas can be modified directly in the DGD. + type: boolean + type: object serviceName: description: The name of the component type: string diff --git a/deploy/cloud/helm/crds/templates/nvidia.com_dynamographdeployments.yaml b/deploy/cloud/helm/crds/templates/nvidia.com_dynamographdeployments.yaml index ba2b19fef9..4db1e902b8 100644 --- a/deploy/cloud/helm/crds/templates/nvidia.com_dynamographdeployments.yaml +++ b/deploy/cloud/helm/crds/templates/nvidia.com_dynamographdeployments.yaml @@ -219,12 +219,13 @@ spec: (such as Pod, Service, and Ingress when applicable). type: object autoscaling: - description: Autoscaling config for this component (replica range, target utilization, etc.). + description: |- + Deprecated: This field is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter + with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md + for migration guidance. This field will be removed in a future API version. properties: behavior: - description: |- - HorizontalPodAutoscalerBehavior configures the scaling behavior of the target - in both Up and Down directions (scaleUp and scaleDown fields respectively). + description: 'Deprecated: This field is ignored.' properties: scaleDown: description: |- @@ -373,10 +374,13 @@ spec: type: object type: object enabled: + description: 'Deprecated: This field is ignored.' type: boolean maxReplicas: + description: 'Deprecated: This field is ignored.' type: integer metrics: + description: 'Deprecated: This field is ignored.' items: description: |- MetricSpec specifies how to scale based on a single metric @@ -807,6 +811,7 @@ spec: type: object type: array minReplicas: + description: 'Deprecated: This field is ignored.' type: integer type: object componentType: @@ -10319,8 +10324,12 @@ spec: type: integer type: object replicas: - description: Replicas is the desired number of Pods for this component when autoscaling is not used. + description: |- + Replicas is the desired number of Pods for this component. + When scalingAdapter is enabled (default), this field is managed by the + DynamoGraphDeploymentScalingAdapter and should not be modified directly. format: int32 + minimum: 0 type: integer resources: description: |- @@ -10399,6 +10408,20 @@ spec: type: string type: object type: object + scalingAdapter: + description: |- + ScalingAdapter configures whether this service uses the DynamoGraphDeploymentScalingAdapter. + When enabled (default), replicas are managed via DGDSA and external autoscalers can scale + the service using the Scale subresource. When disabled, replicas can be modified directly. + properties: + disable: + default: false + description: |- + Disable indicates whether the ScalingAdapter should be disabled for this service. + When false (default), a DGDSA is created and owns the replicas field. + When true, no DGDSA is created and replicas can be modified directly in the DGD. + type: boolean + type: object serviceName: description: The name of the component type: string diff --git a/deploy/cloud/helm/crds/templates/nvidia.com_dynamographdeploymentscalingadapters.yaml b/deploy/cloud/helm/crds/templates/nvidia.com_dynamographdeploymentscalingadapters.yaml new file mode 100644 index 0000000000..f822bb91db --- /dev/null +++ b/deploy/cloud/helm/crds/templates/nvidia.com_dynamographdeploymentscalingadapters.yaml @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.16.4 + helm.sh/resource-policy: keep + name: dynamographdeploymentscalingadapters.nvidia.com +spec: + group: nvidia.com + names: + kind: DynamoGraphDeploymentScalingAdapter + listKind: DynamoGraphDeploymentScalingAdapterList + plural: dynamographdeploymentscalingadapters + shortNames: + - dgdsa + singular: dynamographdeploymentscalingadapter + scope: Namespaced + versions: + - additionalPrinterColumns: + - description: DynamoGraphDeployment name + jsonPath: .spec.dgdRef.name + name: DGD + type: string + - description: Service name + jsonPath: .spec.dgdRef.serviceName + name: SERVICE + type: string + - description: Current replicas + jsonPath: .status.replicas + name: REPLICAS + type: integer + - jsonPath: .metadata.creationTimestamp + name: AGE + type: date + name: v1alpha1 + schema: + openAPIV3Schema: + description: |- + DynamoGraphDeploymentScalingAdapter provides a scaling interface for individual services + within a DynamoGraphDeployment. It implements the Kubernetes scale + subresource, enabling integration with HPA, KEDA, and custom autoscalers. + + The adapter acts as an intermediary between autoscalers and the DGD, + ensuring that only the adapter controller modifies the DGD's service replicas. + This prevents conflicts when multiple autoscaling mechanisms are in play. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: DynamoGraphDeploymentScalingAdapterSpec defines the desired state of DynamoGraphDeploymentScalingAdapter + properties: + dgdRef: + description: DGDRef references the DynamoGraphDeployment and the specific service to scale. + properties: + name: + description: Name of the DynamoGraphDeployment + minLength: 1 + type: string + serviceName: + description: ServiceName is the key name of the service within the DGD's spec.services map to scale + minLength: 1 + type: string + required: + - name + - serviceName + type: object + replicas: + description: |- + Replicas is the desired number of replicas for the target service. + This field is modified by external autoscalers (HPA/KEDA/Planner) or manually by users. + format: int32 + minimum: 0 + type: integer + required: + - dgdRef + - replicas + type: object + status: + description: DynamoGraphDeploymentScalingAdapterStatus defines the observed state of DynamoGraphDeploymentScalingAdapter + properties: + lastScaleTime: + description: LastScaleTime is the last time the adapter scaled the target service. + format: date-time + type: string + replicas: + description: |- + Replicas is the current number of replicas for the target service. + This is synced from the DGD's service replicas and is required for the scale subresource. + format: int32 + type: integer + selector: + description: |- + Selector is a label selector string for the pods managed by this adapter. + Required for HPA compatibility via the scale subresource. + type: string + type: object + type: object + served: true + storage: true + subresources: + scale: + labelSelectorPath: .status.selector + specReplicasPath: .spec.replicas + statusReplicasPath: .status.replicas + status: {} diff --git a/deploy/cloud/helm/platform/components/operator/templates/manager-rbac.yaml b/deploy/cloud/helm/platform/components/operator/templates/manager-rbac.yaml index 8ab42c0988..7ae1eb6c5d 100644 --- a/deploy/cloud/helm/platform/components/operator/templates/manager-rbac.yaml +++ b/deploy/cloud/helm/platform/components/operator/templates/manager-rbac.yaml @@ -369,6 +369,7 @@ rules: - dynamocomponentdeployments - dynamographdeploymentrequests - dynamographdeployments + - dynamographdeploymentscalingadapters - dynamomodels verbs: - create @@ -393,6 +394,7 @@ rules: - dynamocomponentdeployments/status - dynamographdeploymentrequests/status - dynamographdeployments/status + - dynamographdeploymentscalingadapters/status - dynamomodels/status verbs: - get diff --git a/deploy/cloud/helm/platform/components/operator/templates/planner.yaml b/deploy/cloud/helm/platform/components/operator/templates/planner.yaml index 11f60b5a48..a893a5afdf 100644 --- a/deploy/cloud/helm/platform/components/operator/templates/planner.yaml +++ b/deploy/cloud/helm/platform/components/operator/templates/planner.yaml @@ -39,6 +39,9 @@ rules: - apiGroups: ["nvidia.com"] resources: ["dynamocomponentdeployments", "dynamographdeployments"] verbs: ["get", "list", "create", "update", "patch"] +- apiGroups: ["nvidia.com"] + resources: ["dynamographdeploymentscalingadapters/scale"] + verbs: ["patch"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding @@ -68,4 +71,7 @@ rules: - apiGroups: ["nvidia.com"] resources: ["dynamocomponentdeployments", "dynamographdeployments"] verbs: ["get", "list", "create", "update", "patch"] -{{- end }} \ No newline at end of file +- apiGroups: ["nvidia.com"] + resources: ["dynamographdeploymentscalingadapters/scale"] + verbs: ["patch"] +{{- end }} diff --git a/deploy/cloud/helm/platform/components/operator/templates/prometheus.yaml b/deploy/cloud/helm/platform/components/operator/templates/prometheus.yaml index 87a05aa575..d1576b41c1 100644 --- a/deploy/cloud/helm/platform/components/operator/templates/prometheus.yaml +++ b/deploy/cloud/helm/platform/components/operator/templates/prometheus.yaml @@ -57,6 +57,11 @@ spec: - interval: 5s path: /metrics port: system + relabelings: + - action: replace + sourceLabels: + - __meta_kubernetes_pod_label_nvidia_com_dynamo_namespace + targetLabel: dynamo_namespace selector: matchLabels: nvidia.com/dynamo-component-type: worker diff --git a/deploy/cloud/helm/platform/values.yaml b/deploy/cloud/helm/platform/values.yaml index f3e78c4750..7702d76797 100644 --- a/deploy/cloud/helm/platform/values.yaml +++ b/deploy/cloud/helm/platform/values.yaml @@ -260,6 +260,12 @@ etcd: # Whether to enable liveness probes (disabled to reduce startup complexity) enabled: false + # Pod Disruption Budget configuration + # Should be enabled for HA deployments with 3+ replicas + pdb: + # Whether to create a PodDisruptionBudget (disabled for single-node deployments) + create: false + # Node tolerations for etcd pods (allows scheduling on specific nodes) tolerations: [] diff --git a/deploy/cloud/operator/api/v1alpha1/common.go b/deploy/cloud/operator/api/v1alpha1/common.go index 5673fd5cfd..b68dd818c0 100644 --- a/deploy/cloud/operator/api/v1alpha1/common.go +++ b/deploy/cloud/operator/api/v1alpha1/common.go @@ -53,12 +53,20 @@ type VolumeMount struct { UseAsCompilationCache bool `json:"useAsCompilationCache,omitempty"` } +// Deprecated: This field is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter +// with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md +// for migration guidance. This field will be removed in a future API version. type Autoscaling struct { - Enabled bool `json:"enabled,omitempty"` - MinReplicas int `json:"minReplicas,omitempty"` - MaxReplicas int `json:"maxReplicas,omitempty"` - Behavior *autoscalingv2.HorizontalPodAutoscalerBehavior `json:"behavior,omitempty"` - Metrics []autoscalingv2.MetricSpec `json:"metrics,omitempty"` + // Deprecated: This field is ignored. + Enabled bool `json:"enabled,omitempty"` + // Deprecated: This field is ignored. + MinReplicas int `json:"minReplicas,omitempty"` + // Deprecated: This field is ignored. + MaxReplicas int `json:"maxReplicas,omitempty"` + // Deprecated: This field is ignored. + Behavior *autoscalingv2.HorizontalPodAutoscalerBehavior `json:"behavior,omitempty"` + // Deprecated: This field is ignored. + Metrics []autoscalingv2.MetricSpec `json:"metrics,omitempty"` } type SharedMemorySpec struct { @@ -115,3 +123,15 @@ type ExtraPodSpec struct { *corev1.PodSpec `json:",inline"` MainContainer *corev1.Container `json:"mainContainer,omitempty"` } + +// ScalingAdapter configures whether a service uses the DynamoGraphDeploymentScalingAdapter +// for replica management. When enabled (default), the DGDSA owns the replicas field and +// external autoscalers (HPA, KEDA, Planner) can control scaling via the Scale subresource. +type ScalingAdapter struct { + // Disable indicates whether the ScalingAdapter should be disabled for this service. + // When false (default), a DGDSA is created and owns the replicas field. + // When true, no DGDSA is created and replicas can be modified directly in the DGD. + // +optional + // +kubebuilder:default=false + Disable bool `json:"disable,omitempty"` +} diff --git a/deploy/cloud/operator/api/v1alpha1/dynamocomponentdeployment_types.go b/deploy/cloud/operator/api/v1alpha1/dynamocomponentdeployment_types.go index 8f484057ab..8a2abb78f2 100644 --- a/deploy/cloud/operator/api/v1alpha1/dynamocomponentdeployment_types.go +++ b/deploy/cloud/operator/api/v1alpha1/dynamocomponentdeployment_types.go @@ -74,7 +74,9 @@ type DynamoComponentDeploymentSharedSpec struct { // Resources requested and limits for this component, including CPU, memory, // GPUs/devices, and any runtime-specific resources. Resources *Resources `json:"resources,omitempty"` - // Autoscaling config for this component (replica range, target utilization, etc.). + // Deprecated: This field is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter + // with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md + // for migration guidance. This field will be removed in a future API version. Autoscaling *Autoscaling `json:"autoscaling,omitempty"` // Envs defines additional environment variables to inject into the component containers. Envs []corev1.EnvVar `json:"envs,omitempty"` @@ -108,10 +110,18 @@ type DynamoComponentDeploymentSharedSpec struct { LivenessProbe *corev1.Probe `json:"livenessProbe,omitempty"` // ReadinessProbe to signal when the container is ready to receive traffic. ReadinessProbe *corev1.Probe `json:"readinessProbe,omitempty"` - // Replicas is the desired number of Pods for this component when autoscaling is not used. + // Replicas is the desired number of Pods for this component. + // When scalingAdapter is enabled (default), this field is managed by the + // DynamoGraphDeploymentScalingAdapter and should not be modified directly. + // +kubebuilder:validation:Minimum=0 Replicas *int32 `json:"replicas,omitempty"` // Multinode is the configuration for multinode components. Multinode *MultinodeSpec `json:"multinode,omitempty"` + // ScalingAdapter configures whether this service uses the DynamoGraphDeploymentScalingAdapter. + // When enabled (default), replicas are managed via DGDSA and external autoscalers can scale + // the service using the Scale subresource. When disabled, replicas can be modified directly. + // +optional + ScalingAdapter *ScalingAdapter `json:"scalingAdapter,omitempty"` } type MultinodeSpec struct { diff --git a/deploy/cloud/operator/api/v1alpha1/dynamographdeploymentscalingadapter_types.go b/deploy/cloud/operator/api/v1alpha1/dynamographdeploymentscalingadapter_types.go new file mode 100644 index 0000000000..d4da1a0ccf --- /dev/null +++ b/deploy/cloud/operator/api/v1alpha1/dynamographdeploymentscalingadapter_types.go @@ -0,0 +1,102 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package v1alpha1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// DynamoGraphDeploymentScalingAdapterSpec defines the desired state of DynamoGraphDeploymentScalingAdapter +type DynamoGraphDeploymentScalingAdapterSpec struct { + // Replicas is the desired number of replicas for the target service. + // This field is modified by external autoscalers (HPA/KEDA/Planner) or manually by users. + // +kubebuilder:validation:Required + // +kubebuilder:validation:Minimum=0 + Replicas int32 `json:"replicas"` + + // DGDRef references the DynamoGraphDeployment and the specific service to scale. + // +kubebuilder:validation:Required + DGDRef DynamoGraphDeploymentServiceRef `json:"dgdRef"` +} + +// DynamoGraphDeploymentServiceRef identifies a specific service within a DynamoGraphDeployment +type DynamoGraphDeploymentServiceRef struct { + // Name of the DynamoGraphDeployment + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + Name string `json:"name"` + + // ServiceName is the key name of the service within the DGD's spec.services map to scale + // +kubebuilder:validation:Required + // +kubebuilder:validation:MinLength=1 + ServiceName string `json:"serviceName"` +} + +// DynamoGraphDeploymentScalingAdapterStatus defines the observed state of DynamoGraphDeploymentScalingAdapter +type DynamoGraphDeploymentScalingAdapterStatus struct { + // Replicas is the current number of replicas for the target service. + // This is synced from the DGD's service replicas and is required for the scale subresource. + // +optional + Replicas int32 `json:"replicas,omitempty"` + + // Selector is a label selector string for the pods managed by this adapter. + // Required for HPA compatibility via the scale subresource. + // +optional + Selector string `json:"selector,omitempty"` + + // LastScaleTime is the last time the adapter scaled the target service. + // +optional + LastScaleTime *metav1.Time `json:"lastScaleTime,omitempty"` +} + +// +kubebuilder:object:root=true +// +kubebuilder:subresource:status +// +kubebuilder:subresource:scale:specpath=.spec.replicas,statuspath=.status.replicas,selectorpath=.status.selector +// +kubebuilder:printcolumn:name="DGD",type="string",JSONPath=".spec.dgdRef.name",description="DynamoGraphDeployment name" +// +kubebuilder:printcolumn:name="SERVICE",type="string",JSONPath=".spec.dgdRef.serviceName",description="Service name" +// +kubebuilder:printcolumn:name="REPLICAS",type="integer",JSONPath=".status.replicas",description="Current replicas" +// +kubebuilder:printcolumn:name="AGE",type="date",JSONPath=".metadata.creationTimestamp" +// +kubebuilder:resource:shortName={dgdsa} + +// DynamoGraphDeploymentScalingAdapter provides a scaling interface for individual services +// within a DynamoGraphDeployment. It implements the Kubernetes scale +// subresource, enabling integration with HPA, KEDA, and custom autoscalers. +// +// The adapter acts as an intermediary between autoscalers and the DGD, +// ensuring that only the adapter controller modifies the DGD's service replicas. +// This prevents conflicts when multiple autoscaling mechanisms are in play. +type DynamoGraphDeploymentScalingAdapter struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec DynamoGraphDeploymentScalingAdapterSpec `json:"spec,omitempty"` + Status DynamoGraphDeploymentScalingAdapterStatus `json:"status,omitempty"` +} + +// +kubebuilder:object:root=true + +// DynamoGraphDeploymentScalingAdapterList contains a list of DynamoGraphDeploymentScalingAdapter +type DynamoGraphDeploymentScalingAdapterList struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []DynamoGraphDeploymentScalingAdapter `json:"items"` +} + +func init() { + SchemeBuilder.Register(&DynamoGraphDeploymentScalingAdapter{}, &DynamoGraphDeploymentScalingAdapterList{}) +} diff --git a/deploy/cloud/operator/api/v1alpha1/zz_generated.deepcopy.go b/deploy/cloud/operator/api/v1alpha1/zz_generated.deepcopy.go index 56d33cd498..d3ecbb44ec 100644 --- a/deploy/cloud/operator/api/v1alpha1/zz_generated.deepcopy.go +++ b/deploy/cloud/operator/api/v1alpha1/zz_generated.deepcopy.go @@ -371,6 +371,11 @@ func (in *DynamoComponentDeploymentSharedSpec) DeepCopyInto(out *DynamoComponent *out = new(MultinodeSpec) **out = **in } + if in.ScalingAdapter != nil { + in, out := &in.ScalingAdapter, &out.ScalingAdapter + *out = new(ScalingAdapter) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DynamoComponentDeploymentSharedSpec. @@ -599,6 +604,115 @@ func (in *DynamoGraphDeploymentRequestStatus) DeepCopy() *DynamoGraphDeploymentR return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DynamoGraphDeploymentScalingAdapter) DeepCopyInto(out *DynamoGraphDeploymentScalingAdapter) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + out.Spec = in.Spec + in.Status.DeepCopyInto(&out.Status) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DynamoGraphDeploymentScalingAdapter. +func (in *DynamoGraphDeploymentScalingAdapter) DeepCopy() *DynamoGraphDeploymentScalingAdapter { + if in == nil { + return nil + } + out := new(DynamoGraphDeploymentScalingAdapter) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *DynamoGraphDeploymentScalingAdapter) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DynamoGraphDeploymentScalingAdapterList) DeepCopyInto(out *DynamoGraphDeploymentScalingAdapterList) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]DynamoGraphDeploymentScalingAdapter, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DynamoGraphDeploymentScalingAdapterList. +func (in *DynamoGraphDeploymentScalingAdapterList) DeepCopy() *DynamoGraphDeploymentScalingAdapterList { + if in == nil { + return nil + } + out := new(DynamoGraphDeploymentScalingAdapterList) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *DynamoGraphDeploymentScalingAdapterList) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DynamoGraphDeploymentScalingAdapterSpec) DeepCopyInto(out *DynamoGraphDeploymentScalingAdapterSpec) { + *out = *in + out.DGDRef = in.DGDRef +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DynamoGraphDeploymentScalingAdapterSpec. +func (in *DynamoGraphDeploymentScalingAdapterSpec) DeepCopy() *DynamoGraphDeploymentScalingAdapterSpec { + if in == nil { + return nil + } + out := new(DynamoGraphDeploymentScalingAdapterSpec) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DynamoGraphDeploymentScalingAdapterStatus) DeepCopyInto(out *DynamoGraphDeploymentScalingAdapterStatus) { + *out = *in + if in.LastScaleTime != nil { + in, out := &in.LastScaleTime, &out.LastScaleTime + *out = (*in).DeepCopy() + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DynamoGraphDeploymentScalingAdapterStatus. +func (in *DynamoGraphDeploymentScalingAdapterStatus) DeepCopy() *DynamoGraphDeploymentScalingAdapterStatus { + if in == nil { + return nil + } + out := new(DynamoGraphDeploymentScalingAdapterStatus) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *DynamoGraphDeploymentServiceRef) DeepCopyInto(out *DynamoGraphDeploymentServiceRef) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DynamoGraphDeploymentServiceRef. +func (in *DynamoGraphDeploymentServiceRef) DeepCopy() *DynamoGraphDeploymentServiceRef { + if in == nil { + return nil + } + out := new(DynamoGraphDeploymentServiceRef) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *DynamoGraphDeploymentSpec) DeepCopyInto(out *DynamoGraphDeploymentSpec) { *out = *in @@ -1085,6 +1199,21 @@ func (in *Resources) DeepCopy() *Resources { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ScalingAdapter) DeepCopyInto(out *ScalingAdapter) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ScalingAdapter. +func (in *ScalingAdapter) DeepCopy() *ScalingAdapter { + if in == nil { + return nil + } + out := new(ScalingAdapter) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *SharedMemorySpec) DeepCopyInto(out *SharedMemorySpec) { *out = *in diff --git a/deploy/cloud/operator/cmd/main.go b/deploy/cloud/operator/cmd/main.go index 4d79cfe3f0..dc1a33b262 100644 --- a/deploy/cloud/operator/cmd/main.go +++ b/deploy/cloud/operator/cmd/main.go @@ -578,6 +578,16 @@ func main() { os.Exit(1) } + if err = (&controller.DynamoGraphDeploymentScalingAdapterReconciler{ + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Recorder: mgr.GetEventRecorderFor("dgdscalingadapter"), + Config: ctrlConfig, + }).SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to create controller", "controller", "DGDScalingAdapter") + os.Exit(1) + } + if err = (&controller.DynamoGraphDeploymentRequestReconciler{ Client: mgr.GetClient(), Recorder: mgr.GetEventRecorderFor("dynamographdeploymentrequest"), diff --git a/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamocomponentdeployments.yaml b/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamocomponentdeployments.yaml index 558a5b973d..c90e3bdfe7 100644 --- a/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamocomponentdeployments.yaml +++ b/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamocomponentdeployments.yaml @@ -77,12 +77,13 @@ spec: (such as Pod, Service, and Ingress when applicable). type: object autoscaling: - description: Autoscaling config for this component (replica range, target utilization, etc.). + description: |- + Deprecated: This field is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter + with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md + for migration guidance. This field will be removed in a future API version. properties: behavior: - description: |- - HorizontalPodAutoscalerBehavior configures the scaling behavior of the target - in both Up and Down directions (scaleUp and scaleDown fields respectively). + description: 'Deprecated: This field is ignored.' properties: scaleDown: description: |- @@ -231,10 +232,13 @@ spec: type: object type: object enabled: + description: 'Deprecated: This field is ignored.' type: boolean maxReplicas: + description: 'Deprecated: This field is ignored.' type: integer metrics: + description: 'Deprecated: This field is ignored.' items: description: |- MetricSpec specifies how to scale based on a single metric @@ -665,6 +669,7 @@ spec: type: object type: array minReplicas: + description: 'Deprecated: This field is ignored.' type: integer type: object backendFramework: @@ -10184,8 +10189,12 @@ spec: type: integer type: object replicas: - description: Replicas is the desired number of Pods for this component when autoscaling is not used. + description: |- + Replicas is the desired number of Pods for this component. + When scalingAdapter is enabled (default), this field is managed by the + DynamoGraphDeploymentScalingAdapter and should not be modified directly. format: int32 + minimum: 0 type: integer resources: description: |- @@ -10264,6 +10273,20 @@ spec: type: string type: object type: object + scalingAdapter: + description: |- + ScalingAdapter configures whether this service uses the DynamoGraphDeploymentScalingAdapter. + When enabled (default), replicas are managed via DGDSA and external autoscalers can scale + the service using the Scale subresource. When disabled, replicas can be modified directly. + properties: + disable: + default: false + description: |- + Disable indicates whether the ScalingAdapter should be disabled for this service. + When false (default), a DGDSA is created and owns the replicas field. + When true, no DGDSA is created and replicas can be modified directly in the DGD. + type: boolean + type: object serviceName: description: The name of the component type: string diff --git a/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamographdeployments.yaml b/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamographdeployments.yaml index ba2b19fef9..4db1e902b8 100644 --- a/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamographdeployments.yaml +++ b/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamographdeployments.yaml @@ -219,12 +219,13 @@ spec: (such as Pod, Service, and Ingress when applicable). type: object autoscaling: - description: Autoscaling config for this component (replica range, target utilization, etc.). + description: |- + Deprecated: This field is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter + with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md + for migration guidance. This field will be removed in a future API version. properties: behavior: - description: |- - HorizontalPodAutoscalerBehavior configures the scaling behavior of the target - in both Up and Down directions (scaleUp and scaleDown fields respectively). + description: 'Deprecated: This field is ignored.' properties: scaleDown: description: |- @@ -373,10 +374,13 @@ spec: type: object type: object enabled: + description: 'Deprecated: This field is ignored.' type: boolean maxReplicas: + description: 'Deprecated: This field is ignored.' type: integer metrics: + description: 'Deprecated: This field is ignored.' items: description: |- MetricSpec specifies how to scale based on a single metric @@ -807,6 +811,7 @@ spec: type: object type: array minReplicas: + description: 'Deprecated: This field is ignored.' type: integer type: object componentType: @@ -10319,8 +10324,12 @@ spec: type: integer type: object replicas: - description: Replicas is the desired number of Pods for this component when autoscaling is not used. + description: |- + Replicas is the desired number of Pods for this component. + When scalingAdapter is enabled (default), this field is managed by the + DynamoGraphDeploymentScalingAdapter and should not be modified directly. format: int32 + minimum: 0 type: integer resources: description: |- @@ -10399,6 +10408,20 @@ spec: type: string type: object type: object + scalingAdapter: + description: |- + ScalingAdapter configures whether this service uses the DynamoGraphDeploymentScalingAdapter. + When enabled (default), replicas are managed via DGDSA and external autoscalers can scale + the service using the Scale subresource. When disabled, replicas can be modified directly. + properties: + disable: + default: false + description: |- + Disable indicates whether the ScalingAdapter should be disabled for this service. + When false (default), a DGDSA is created and owns the replicas field. + When true, no DGDSA is created and replicas can be modified directly in the DGD. + type: boolean + type: object serviceName: description: The name of the component type: string diff --git a/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamographdeploymentscalingadapters.yaml b/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamographdeploymentscalingadapters.yaml new file mode 100644 index 0000000000..f822bb91db --- /dev/null +++ b/deploy/cloud/operator/config/crd/bases/nvidia.com_dynamographdeploymentscalingadapters.yaml @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.16.4 + helm.sh/resource-policy: keep + name: dynamographdeploymentscalingadapters.nvidia.com +spec: + group: nvidia.com + names: + kind: DynamoGraphDeploymentScalingAdapter + listKind: DynamoGraphDeploymentScalingAdapterList + plural: dynamographdeploymentscalingadapters + shortNames: + - dgdsa + singular: dynamographdeploymentscalingadapter + scope: Namespaced + versions: + - additionalPrinterColumns: + - description: DynamoGraphDeployment name + jsonPath: .spec.dgdRef.name + name: DGD + type: string + - description: Service name + jsonPath: .spec.dgdRef.serviceName + name: SERVICE + type: string + - description: Current replicas + jsonPath: .status.replicas + name: REPLICAS + type: integer + - jsonPath: .metadata.creationTimestamp + name: AGE + type: date + name: v1alpha1 + schema: + openAPIV3Schema: + description: |- + DynamoGraphDeploymentScalingAdapter provides a scaling interface for individual services + within a DynamoGraphDeployment. It implements the Kubernetes scale + subresource, enabling integration with HPA, KEDA, and custom autoscalers. + + The adapter acts as an intermediary between autoscalers and the DGD, + ensuring that only the adapter controller modifies the DGD's service replicas. + This prevents conflicts when multiple autoscaling mechanisms are in play. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: DynamoGraphDeploymentScalingAdapterSpec defines the desired state of DynamoGraphDeploymentScalingAdapter + properties: + dgdRef: + description: DGDRef references the DynamoGraphDeployment and the specific service to scale. + properties: + name: + description: Name of the DynamoGraphDeployment + minLength: 1 + type: string + serviceName: + description: ServiceName is the key name of the service within the DGD's spec.services map to scale + minLength: 1 + type: string + required: + - name + - serviceName + type: object + replicas: + description: |- + Replicas is the desired number of replicas for the target service. + This field is modified by external autoscalers (HPA/KEDA/Planner) or manually by users. + format: int32 + minimum: 0 + type: integer + required: + - dgdRef + - replicas + type: object + status: + description: DynamoGraphDeploymentScalingAdapterStatus defines the observed state of DynamoGraphDeploymentScalingAdapter + properties: + lastScaleTime: + description: LastScaleTime is the last time the adapter scaled the target service. + format: date-time + type: string + replicas: + description: |- + Replicas is the current number of replicas for the target service. + This is synced from the DGD's service replicas and is required for the scale subresource. + format: int32 + type: integer + selector: + description: |- + Selector is a label selector string for the pods managed by this adapter. + Required for HPA compatibility via the scale subresource. + type: string + type: object + type: object + served: true + storage: true + subresources: + scale: + labelSelectorPath: .status.selector + specReplicasPath: .spec.replicas + statusReplicasPath: .status.replicas + status: {} diff --git a/deploy/cloud/operator/config/rbac/role.yaml b/deploy/cloud/operator/config/rbac/role.yaml index b473aa1ad7..2a3a00c6f8 100644 --- a/deploy/cloud/operator/config/rbac/role.yaml +++ b/deploy/cloud/operator/config/rbac/role.yaml @@ -182,6 +182,7 @@ rules: - dynamocomponentdeployments - dynamographdeploymentrequests - dynamographdeployments + - dynamographdeploymentscalingadapters - dynamomodels verbs: - create @@ -206,6 +207,7 @@ rules: - dynamocomponentdeployments/status - dynamographdeploymentrequests/status - dynamographdeployments/status + - dynamographdeploymentscalingadapters/status - dynamomodels/status verbs: - get diff --git a/deploy/cloud/operator/internal/consts/consts.go b/deploy/cloud/operator/internal/consts/consts.go index 882f9f18d9..6dd3bc0712 100644 --- a/deploy/cloud/operator/internal/consts/consts.go +++ b/deploy/cloud/operator/internal/consts/consts.go @@ -7,8 +7,6 @@ import ( ) const ( - HPACPUDefaultAverageUtilization = 80 - DefaultUserId = "default" DefaultOrgId = "default" diff --git a/deploy/cloud/operator/internal/controller/common.go b/deploy/cloud/operator/internal/controller/common.go index 70a70fdead..e41cbe1deb 100644 --- a/deploy/cloud/operator/internal/controller/common.go +++ b/deploy/cloud/operator/internal/controller/common.go @@ -53,3 +53,43 @@ type dockerSecretRetriever interface { // returns a list of secret names associated with the docker registry GetSecrets(namespace, registry string) ([]string, error) } + +// getServiceKeys returns the keys of the services map for logging purposes +func getServiceKeys(services map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec) []string { + keys := make([]string, 0, len(services)) + for k := range services { + keys = append(keys, k) + } + return keys +} + +// servicesEqual compares two services maps to detect changes in replica counts +func servicesEqual(old, new map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec) bool { + if len(old) != len(new) { + return false + } + + for key, oldSvc := range old { + newSvc, exists := new[key] + if !exists { + return false + } + + // Compare replicas + oldReplicas := int32(1) + if oldSvc.Replicas != nil { + oldReplicas = *oldSvc.Replicas + } + + newReplicas := int32(1) + if newSvc.Replicas != nil { + newReplicas = *newSvc.Replicas + } + + if oldReplicas != newReplicas { + return false + } + } + + return true +} diff --git a/deploy/cloud/operator/internal/controller/dynamocomponentdeployment_controller.go b/deploy/cloud/operator/internal/controller/dynamocomponentdeployment_controller.go index 307bf7ac05..88d92e2f42 100644 --- a/deploy/cloud/operator/internal/controller/dynamocomponentdeployment_controller.go +++ b/deploy/cloud/operator/internal/controller/dynamocomponentdeployment_controller.go @@ -338,21 +338,6 @@ func (r *DynamoComponentDeploymentReconciler) Reconcile(ctx context.Context, req } deployment = obj - - // create or update api-server hpa - modified_, _, err = commonController.SyncResource(ctx, r, dynamoComponentDeployment, func(ctx context.Context) (*autoscalingv2.HorizontalPodAutoscaler, bool, error) { - return r.generateHPA(generateResourceOption{ - dynamoComponentDeployment: dynamoComponentDeployment, - }) - }) - if err != nil { - return ctrl.Result{}, err - } - - if modified_ { - modified = true - } - } // create or update api-server service @@ -1114,63 +1099,6 @@ type generateResourceOption struct { instanceID *int } -func (r *DynamoComponentDeploymentReconciler) generateHPA(opt generateResourceOption) (*autoscalingv2.HorizontalPodAutoscaler, bool, error) { - labels := r.getKubeLabels(opt.dynamoComponentDeployment) - - annotations := r.getKubeAnnotations(opt.dynamoComponentDeployment) - - kubeName := r.getKubeName(opt.dynamoComponentDeployment, false) - - kubeNs := opt.dynamoComponentDeployment.Namespace - - hpaConf := opt.dynamoComponentDeployment.Spec.Autoscaling - - kubeHpa := &autoscalingv2.HorizontalPodAutoscaler{ - ObjectMeta: metav1.ObjectMeta{ - Name: kubeName, - Namespace: kubeNs, - Labels: labels, - Annotations: annotations, - }, - } - - if hpaConf == nil || !hpaConf.Enabled { - // if hpa is not enabled, we need to delete the hpa - return kubeHpa, true, nil - } - - minReplica := int32(hpaConf.MinReplicas) - - kubeHpa.Spec = autoscalingv2.HorizontalPodAutoscalerSpec{ - MinReplicas: &minReplica, - MaxReplicas: int32(hpaConf.MaxReplicas), - ScaleTargetRef: autoscalingv2.CrossVersionObjectReference{ - APIVersion: "apps/v1", - Kind: "Deployment", - Name: kubeName, - }, - Metrics: hpaConf.Metrics, - } - - if len(kubeHpa.Spec.Metrics) == 0 { - averageUtilization := int32(commonconsts.HPACPUDefaultAverageUtilization) - kubeHpa.Spec.Metrics = []autoscalingv2.MetricSpec{ - { - Type: autoscalingv2.ResourceMetricSourceType, - Resource: &autoscalingv2.ResourceMetricSource{ - Name: corev1.ResourceCPU, - Target: autoscalingv2.MetricTarget{ - Type: autoscalingv2.UtilizationMetricType, - AverageUtilization: &averageUtilization, - }, - }, - }, - } - } - - return kubeHpa, false, nil -} - //nolint:gocyclo,nakedret func (r *DynamoComponentDeploymentReconciler) generatePodTemplateSpec(ctx context.Context, opt generateResourceOption, role dynamo.Role) (podTemplateSpec *corev1.PodTemplateSpec, err error) { podLabels := r.getKubeLabels(opt.dynamoComponentDeployment) diff --git a/deploy/cloud/operator/internal/controller/dynamocomponentdeployment_controller_test.go b/deploy/cloud/operator/internal/controller/dynamocomponentdeployment_controller_test.go index 807c0abcc0..f3ea278946 100644 --- a/deploy/cloud/operator/internal/controller/dynamocomponentdeployment_controller_test.go +++ b/deploy/cloud/operator/internal/controller/dynamocomponentdeployment_controller_test.go @@ -827,6 +827,7 @@ func TestDynamoComponentDeploymentReconciler_generateLeaderWorkerSet(t *testing. Args: []string{"ray start --head --port=6379 && some dynamo command --tensor-parallel-size 4 --pipeline-parallel-size 1"}, Env: []corev1.EnvVar{ {Name: commonconsts.DynamoComponentEnvVar, Value: commonconsts.ComponentTypeWorker}, + {Name: "DYN_HEALTH_CHECK_ENABLED", Value: "true"}, {Name: commonconsts.DynamoNamespaceEnvVar, Value: "default"}, {Name: "DYN_PARENT_DGD_K8S_NAME", Value: "test-lws-deploy"}, {Name: "DYN_PARENT_DGD_K8S_NAMESPACE", Value: "default"}, @@ -955,6 +956,7 @@ func TestDynamoComponentDeploymentReconciler_generateLeaderWorkerSet(t *testing. Args: []string{"ray start --address=$LWS_LEADER_ADDRESS:6379 --block"}, Env: []corev1.EnvVar{ {Name: commonconsts.DynamoComponentEnvVar, Value: commonconsts.ComponentTypeWorker}, + {Name: "DYN_HEALTH_CHECK_ENABLED", Value: "true"}, {Name: commonconsts.DynamoNamespaceEnvVar, Value: "default"}, {Name: "DYN_PARENT_DGD_K8S_NAME", Value: "test-lws-deploy"}, {Name: "DYN_PARENT_DGD_K8S_NAMESPACE", Value: "default"}, diff --git a/deploy/cloud/operator/internal/controller/dynamographdeployment_controller.go b/deploy/cloud/operator/internal/controller/dynamographdeployment_controller.go index 22dcdb5490..823818ac1e 100644 --- a/deploy/cloud/operator/internal/controller/dynamographdeployment_controller.go +++ b/deploy/cloud/operator/internal/controller/dynamographdeployment_controller.go @@ -86,6 +86,7 @@ type DynamoGraphDeploymentReconciler struct { // +kubebuilder:rbac:groups=nvidia.com,resources=dynamographdeployments,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=nvidia.com,resources=dynamographdeployments/status,verbs=get;update;patch // +kubebuilder:rbac:groups=nvidia.com,resources=dynamographdeployments/finalizers,verbs=update +// +kubebuilder:rbac:groups=nvidia.com,resources=dynamographdeploymentscalingadapters,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=grove.io,resources=podcliquesets,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=grove.io,resources=podcliques/scale,verbs=get;update;patch // +kubebuilder:rbac:groups=grove.io,resources=podcliquescalinggroups/scale,verbs=get;update;patch @@ -225,6 +226,13 @@ func (r *DynamoGraphDeploymentReconciler) reconcileResources(ctx context.Context return "", "", "", fmt.Errorf("failed to reconcile top-level PVCs: %w", err) } + // Reconcile DynamoGraphDeploymentScalingAdapters for each service + err = r.reconcileScalingAdapters(ctx, dynamoDeployment) + if err != nil { + logger.Error(err, "Failed to reconcile scaling adapters") + return "", "", "", fmt.Errorf("failed to reconcile scaling adapters: %w", err) + } + // Reconcile the SA, Role and RoleBinding if k8s discovery is enabled err = r.reconcileK8sDiscoveryResources(ctx, dynamoDeployment) if err != nil { @@ -607,6 +615,89 @@ func (r *DynamoGraphDeploymentReconciler) reconcilePVCs(ctx context.Context, dyn return nil } +// reconcileScalingAdapters ensures a DynamoGraphDeploymentScalingAdapter exists for each service in the DGD +// that has scaling adapter enabled (default). Services with scalingAdapter.disable=true will not have a DGDSA. +// This enables pluggable autoscaling via HPA, KEDA, or Planner. +func (r *DynamoGraphDeploymentReconciler) reconcileScalingAdapters(ctx context.Context, dynamoDeployment *nvidiacomv1alpha1.DynamoGraphDeployment) error { + logger := log.FromContext(ctx) + + // Process each service - SyncResource handles create, update, and delete via toDelete flag + for serviceName, component := range dynamoDeployment.Spec.Services { + // Check if scaling adapter is disabled for this service + scalingAdapterDisabled := component.ScalingAdapter != nil && component.ScalingAdapter.Disable + + // Get current replicas (default to 1 if not set) + currentReplicas := int32(1) + if component.Replicas != nil { + currentReplicas = *component.Replicas + } + + // Use SyncResource to handle creation/updates/deletion + // When toDelete=true, SyncResource will delete the existing resource if it exists + _, _, err := commonController.SyncResource(ctx, r, dynamoDeployment, func(ctx context.Context) (*nvidiacomv1alpha1.DynamoGraphDeploymentScalingAdapter, bool, error) { + adapterName := generateAdapterName(dynamoDeployment.Name, serviceName) + adapter := &nvidiacomv1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: adapterName, + Namespace: dynamoDeployment.Namespace, + Labels: map[string]string{ + consts.KubeLabelDynamoGraphDeploymentName: dynamoDeployment.Name, + consts.KubeLabelDynamoComponent: serviceName, + }, + }, + Spec: nvidiacomv1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: currentReplicas, + DGDRef: nvidiacomv1alpha1.DynamoGraphDeploymentServiceRef{ + Name: dynamoDeployment.Name, + ServiceName: serviceName, + }, + }, + } + // Return toDelete=true if scaling adapter is disabled + return adapter, scalingAdapterDisabled, nil + }) + + if err != nil { + logger.Error(err, "Failed to sync DynamoGraphDeploymentScalingAdapter", "service", serviceName) + return err + } + } + + // Clean up adapters for services that were removed from DGD entirely + adapterList := &nvidiacomv1alpha1.DynamoGraphDeploymentScalingAdapterList{} + if err := r.List(ctx, adapterList, + client.InNamespace(dynamoDeployment.Namespace), + client.MatchingLabels{consts.KubeLabelDynamoGraphDeploymentName: dynamoDeployment.Name}, + ); err != nil { + logger.Error(err, "Failed to list DynamoGraphDeploymentScalingAdapters") + return err + } + + for i := range adapterList.Items { + adapter := &adapterList.Items[i] + serviceName := adapter.Spec.DGDRef.ServiceName + + // Delete adapter if service no longer exists in DGD + if _, exists := dynamoDeployment.Spec.Services[serviceName]; !exists { + logger.Info("Deleting orphaned DynamoGraphDeploymentScalingAdapter", "adapter", adapter.Name, "service", serviceName) + if err := r.Delete(ctx, adapter); err != nil && !errors.IsNotFound(err) { + logger.Error(err, "Failed to delete orphaned adapter", "adapter", adapter.Name) + return err + } + r.Recorder.Eventf(dynamoDeployment, corev1.EventTypeNormal, "AdapterDeleted", + "Deleted orphaned scaling adapter %s for removed service %s", adapter.Name, serviceName) + } + } + + return nil +} + +// generateAdapterName creates a consistent name for a DynamoGraphDeploymentScalingAdapter +// Service names are lowercased to comply with Kubernetes DNS subdomain naming requirements +func generateAdapterName(dgdName, serviceName string) string { + return fmt.Sprintf("%s-%s", dgdName, strings.ToLower(serviceName)) +} + func (r *DynamoGraphDeploymentReconciler) FinalizeResource(ctx context.Context, dynamoDeployment *nvidiacomv1alpha1.DynamoGraphDeployment) error { // for now doing nothing return nil @@ -626,6 +717,13 @@ func (r *DynamoGraphDeploymentReconciler) SetupWithManager(mgr ctrl.Manager) err UpdateFunc: func(de event.UpdateEvent) bool { return true }, GenericFunc: func(ge event.GenericEvent) bool { return true }, })). + Owns(&nvidiacomv1alpha1.DynamoGraphDeploymentScalingAdapter{}, builder.WithPredicates(predicate.Funcs{ + // ignore creation cause we don't want to be called again after we create the adapter + CreateFunc: func(ce event.CreateEvent) bool { return false }, + DeleteFunc: func(de event.DeleteEvent) bool { return true }, + UpdateFunc: func(de event.UpdateEvent) bool { return false }, // Adapter updates are handled by adapter controller + GenericFunc: func(ge event.GenericEvent) bool { return false }, + })). Owns(&corev1.PersistentVolumeClaim{}, builder.WithPredicates(predicate.Funcs{ // ignore creation cause we don't want to be called again after we create the PVC CreateFunc: func(ce event.CreateEvent) bool { return false }, diff --git a/deploy/cloud/operator/internal/controller/dynamographdeployment_controller_test.go b/deploy/cloud/operator/internal/controller/dynamographdeployment_controller_test.go new file mode 100644 index 0000000000..a217fd403c --- /dev/null +++ b/deploy/cloud/operator/internal/controller/dynamographdeployment_controller_test.go @@ -0,0 +1,321 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package controller + +import ( + "context" + "testing" + + "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" + "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/tools/record" + "k8s.io/utils/ptr" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func TestDynamoGraphDeploymentReconciler_reconcileScalingAdapters(t *testing.T) { + // Register custom types with the scheme + if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil { + t.Fatalf("Failed to add v1alpha1 to scheme: %v", err) + } + + tests := []struct { + name string + dgd *v1alpha1.DynamoGraphDeployment + existingAdapters []v1alpha1.DynamoGraphDeploymentScalingAdapter + expectedAdapterCount int + expectedAdapters map[string]int32 // map of adapter name to expected replicas + expectDeleted []string // adapter names that should be deleted + }{ + { + name: "creates adapters for all services", + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "Frontend": { + Replicas: ptr.To(int32(2)), + }, + "decode": { + Replicas: ptr.To(int32(3)), + }, + }, + }, + }, + expectedAdapterCount: 2, + expectedAdapters: map[string]int32{ + "test-dgd-frontend": 2, + "test-dgd-decode": 3, + }, + }, + { + name: "uses default replicas when not specified", + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "worker": {}, + }, + }, + }, + expectedAdapterCount: 1, + expectedAdapters: map[string]int32{ + "test-dgd-worker": 1, // default replicas + }, + }, + { + name: "skips adapter creation when disabled", + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "Frontend": { + Replicas: ptr.To(int32(2)), + }, + "decode": { + Replicas: ptr.To(int32(3)), + ScalingAdapter: &v1alpha1.ScalingAdapter{ + Disable: true, + }, + }, + }, + }, + }, + expectedAdapterCount: 1, + expectedAdapters: map[string]int32{ + "test-dgd-frontend": 2, + }, + }, + { + name: "deletes adapter when service is removed", + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + UID: "test-uid", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "Frontend": { + Replicas: ptr.To(int32(2)), + }, + }, + }, + }, + existingAdapters: []v1alpha1.DynamoGraphDeploymentScalingAdapter{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-frontend", + Namespace: "default", + Labels: map[string]string{ + consts.KubeLabelDynamoGraphDeploymentName: "test-dgd", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "nvidia.com/v1alpha1", + Kind: "DynamoGraphDeployment", + Name: "test-dgd", + UID: "test-uid", + }, + }, + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: 2, + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "Frontend", + }, + }, + }, + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-removed", + Namespace: "default", + Labels: map[string]string{ + consts.KubeLabelDynamoGraphDeploymentName: "test-dgd", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "nvidia.com/v1alpha1", + Kind: "DynamoGraphDeployment", + Name: "test-dgd", + UID: "test-uid", + }, + }, + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: 1, + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "removed", + }, + }, + }, + }, + expectedAdapterCount: 1, + expectedAdapters: map[string]int32{ + "test-dgd-frontend": 2, + }, + expectDeleted: []string{"test-dgd-removed"}, + }, + { + name: "deletes adapter when scalingAdapter.disable is set to true", + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + UID: "test-uid", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "Frontend": { + Replicas: ptr.To(int32(2)), + ScalingAdapter: &v1alpha1.ScalingAdapter{ + Disable: true, + }, + }, + }, + }, + }, + existingAdapters: []v1alpha1.DynamoGraphDeploymentScalingAdapter{ + { + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-frontend", + Namespace: "default", + Labels: map[string]string{ + consts.KubeLabelDynamoGraphDeploymentName: "test-dgd", + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "nvidia.com/v1alpha1", + Kind: "DynamoGraphDeployment", + Name: "test-dgd", + UID: "test-uid", + }, + }, + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: 2, + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "Frontend", + }, + }, + }, + }, + expectedAdapterCount: 0, + expectedAdapters: map[string]int32{}, + expectDeleted: []string{"test-dgd-frontend"}, + }, + { + name: "adapter name uses lowercase service name", + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-dgd", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "MyService": { + Replicas: ptr.To(int32(1)), + }, + }, + }, + }, + expectedAdapterCount: 1, + expectedAdapters: map[string]int32{ + "my-dgd-myservice": 1, // lowercase + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Build initial objects + var initObjs []client.Object + initObjs = append(initObjs, tt.dgd) + for i := range tt.existingAdapters { + initObjs = append(initObjs, &tt.existingAdapters[i]) + } + + // Create fake client + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects(initObjs...). + Build() + + // Create reconciler + r := &DynamoGraphDeploymentReconciler{ + Client: fakeClient, + Recorder: record.NewFakeRecorder(10), + } + + // Run reconcileScalingAdapters + ctx := context.Background() + err := r.reconcileScalingAdapters(ctx, tt.dgd) + if err != nil { + t.Fatalf("reconcileScalingAdapters() error = %v", err) + } + + // Verify adapters + adapterList := &v1alpha1.DynamoGraphDeploymentScalingAdapterList{} + if err := fakeClient.List(ctx, adapterList, client.InNamespace("default")); err != nil { + t.Fatalf("Failed to list adapters: %v", err) + } + + if len(adapterList.Items) != tt.expectedAdapterCount { + t.Errorf("Expected %d adapters, got %d", tt.expectedAdapterCount, len(adapterList.Items)) + } + + // Check expected adapters exist with correct replicas + for name, expectedReplicas := range tt.expectedAdapters { + adapter := &v1alpha1.DynamoGraphDeploymentScalingAdapter{} + err := fakeClient.Get(ctx, types.NamespacedName{Name: name, Namespace: "default"}, adapter) + if err != nil { + t.Errorf("Expected adapter %s to exist, but got error: %v", name, err) + continue + } + if adapter.Spec.Replicas != expectedReplicas { + t.Errorf("Adapter %s has replicas=%d, expected %d", name, adapter.Spec.Replicas, expectedReplicas) + } + } + + // Check that deleted adapters don't exist + for _, name := range tt.expectDeleted { + adapter := &v1alpha1.DynamoGraphDeploymentScalingAdapter{} + err := fakeClient.Get(ctx, types.NamespacedName{Name: name, Namespace: "default"}, adapter) + if err == nil { + t.Errorf("Expected adapter %s to be deleted, but it still exists", name) + } + } + }) + } +} diff --git a/deploy/cloud/operator/internal/controller/dynamographdeploymentscalingadapter_controller.go b/deploy/cloud/operator/internal/controller/dynamographdeploymentscalingadapter_controller.go new file mode 100644 index 0000000000..edaa4323ae --- /dev/null +++ b/deploy/cloud/operator/internal/controller/dynamographdeploymentscalingadapter_controller.go @@ -0,0 +1,213 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package controller + +import ( + "context" + "fmt" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/builder" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/controller-runtime/pkg/predicate" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + + nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" + "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts" + commonController "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/controller_common" +) + +// DynamoGraphDeploymentScalingAdapterReconciler reconciles a DynamoGraphDeploymentScalingAdapter object +type DynamoGraphDeploymentScalingAdapterReconciler struct { + client.Client + Scheme *runtime.Scheme + Recorder record.EventRecorder + Config commonController.Config +} + +// +kubebuilder:rbac:groups=nvidia.com,resources=dynamographdeploymentscalingadapters,verbs=get;list;watch;create;update;patch;delete +// +kubebuilder:rbac:groups=nvidia.com,resources=dynamographdeploymentscalingadapters/status,verbs=get;update;patch +// +kubebuilder:rbac:groups=nvidia.com,resources=dynamographdeployments,verbs=get;list;watch;update;patch + +// Reconcile implements the reconciliation loop for DynamoGraphDeploymentScalingAdapter +func (r *DynamoGraphDeploymentScalingAdapterReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + logger := log.FromContext(ctx) + + // 1. Fetch the DynamoGraphDeploymentScalingAdapter + adapter := &nvidiacomv1alpha1.DynamoGraphDeploymentScalingAdapter{} + if err := r.Get(ctx, req.NamespacedName, adapter); err != nil { + return ctrl.Result{}, client.IgnoreNotFound(err) + } + + // Skip reconciliation if being deleted + if !adapter.GetDeletionTimestamp().IsZero() { + logger.V(1).Info("Adapter is being deleted, skipping reconciliation") + return ctrl.Result{}, nil + } + + // 2. Fetch the referenced DGD + dgd := &nvidiacomv1alpha1.DynamoGraphDeployment{} + dgdKey := types.NamespacedName{ + Name: adapter.Spec.DGDRef.Name, + Namespace: adapter.Namespace, + } + if err := r.Get(ctx, dgdKey, dgd); err != nil { + if errors.IsNotFound(err) { + logger.Error(err, "Referenced DGD not found", "dgd", dgdKey) + // DGD doesn't exist, can't proceed + return ctrl.Result{}, err + } + return ctrl.Result{}, err + } + + // 3. Find the target service in DGD's spec.services map + component, exists := dgd.Spec.Services[adapter.Spec.DGDRef.ServiceName] + if !exists || component == nil { + logger.Error(nil, "Service not found in DGD", + "service", adapter.Spec.DGDRef.ServiceName, + "dgd", dgd.Name, + "availableServices", getServiceKeys(dgd.Spec.Services)) + return ctrl.Result{}, fmt.Errorf("service %s not found in DGD", adapter.Spec.DGDRef.ServiceName) + } + + // Get current replicas from DGD (default to 1 if not set) + currentReplicas := int32(1) + if component.Replicas != nil { + currentReplicas = *component.Replicas + } + + // 4. Update DGD if replicas changed (DGDSA is the source of truth) + if currentReplicas != adapter.Spec.Replicas { + // Update the service's replicas in DGD + component.Replicas = &adapter.Spec.Replicas + dgd.Spec.Services[adapter.Spec.DGDRef.ServiceName] = component + + if err := r.Update(ctx, dgd); err != nil { + logger.Error(err, "Failed to update DGD") + r.Recorder.Eventf(adapter, corev1.EventTypeWarning, "UpdateFailed", + "Failed to update DGD %s: %v", dgd.Name, err) + return ctrl.Result{}, err + } + + logger.Info("Scaled service", + "dgd", dgd.Name, + "service", adapter.Spec.DGDRef.ServiceName, + "from", currentReplicas, + "to", adapter.Spec.Replicas) + + r.Recorder.Eventf(adapter, corev1.EventTypeNormal, "Scaled", + "Scaled service %s from %d to %d replicas", adapter.Spec.DGDRef.ServiceName, currentReplicas, adapter.Spec.Replicas) + + // Record scaling event + now := metav1.Now() + adapter.Status.LastScaleTime = &now + } + + // 5. Update adapter status + adapter.Status.Replicas = adapter.Spec.Replicas + adapter.Status.Selector = r.buildPodSelector(dgd, adapter.Spec.DGDRef.ServiceName) + + if err := r.Status().Update(ctx, adapter); err != nil { + logger.Error(err, "Failed to update adapter status") + return ctrl.Result{}, err + } + + return ctrl.Result{}, nil +} + +// buildPodSelector constructs a label selector for the pods managed by this service +func (r *DynamoGraphDeploymentScalingAdapterReconciler) buildPodSelector(dgd *nvidiacomv1alpha1.DynamoGraphDeployment, serviceName string) string { + // Pods are labeled with: + // - nvidia.com/dynamo-graph-deployment-name = dgd.Name + // - nvidia.com/dynamo-component = serviceName (the key from spec.services map) + return fmt.Sprintf("%s=%s,%s=%s", + consts.KubeLabelDynamoGraphDeploymentName, dgd.Name, + consts.KubeLabelDynamoComponent, serviceName) +} + +// SetupWithManager sets up the controller with the Manager +func (r *DynamoGraphDeploymentScalingAdapterReconciler) SetupWithManager(mgr ctrl.Manager) error { + return ctrl.NewControllerManagedBy(mgr). + For(&nvidiacomv1alpha1.DynamoGraphDeploymentScalingAdapter{}, builder.WithPredicates( + predicate.GenerationChangedPredicate{}, + )). + Named("dgdscalingadapter"). + // Watch DGDs to sync status when DGD service replicas change + Watches( + &nvidiacomv1alpha1.DynamoGraphDeployment{}, + handler.EnqueueRequestsFromMapFunc(r.findAdaptersForDGD), + builder.WithPredicates(predicate.Funcs{ + CreateFunc: func(ce event.CreateEvent) bool { return false }, + DeleteFunc: func(de event.DeleteEvent) bool { return true }, + UpdateFunc: func(ue event.UpdateEvent) bool { + // Only trigger on spec changes (not status) + oldDGD, okOld := ue.ObjectOld.(*nvidiacomv1alpha1.DynamoGraphDeployment) + newDGD, okNew := ue.ObjectNew.(*nvidiacomv1alpha1.DynamoGraphDeployment) + if !okOld || !okNew { + return false + } + // Trigger if services map changed + return !servicesEqual(oldDGD.Spec.Services, newDGD.Spec.Services) + }, + GenericFunc: func(ge event.GenericEvent) bool { return false }, + }), + ). + WithEventFilter(commonController.EphemeralDeploymentEventFilter(r.Config)). + Complete(r) +} + +// findAdaptersForDGD maps DGD changes to adapter reconcile requests +// Uses label selector to efficiently query only adapters for this specific DGD +func (r *DynamoGraphDeploymentScalingAdapterReconciler) findAdaptersForDGD(ctx context.Context, obj client.Object) []reconcile.Request { + dgd, ok := obj.(*nvidiacomv1alpha1.DynamoGraphDeployment) + if !ok { + return nil + } + + // Use label selector to filter at API level (more efficient than in-memory filtering) + adapterList := &nvidiacomv1alpha1.DynamoGraphDeploymentScalingAdapterList{} + if err := r.List(ctx, adapterList, + client.InNamespace(dgd.Namespace), + client.MatchingLabels{consts.KubeLabelDynamoGraphDeploymentName: dgd.Name}, + ); err != nil { + log.FromContext(ctx).Error(err, "Failed to list adapters for DGD", "dgd", dgd.Name) + return nil + } + + // All returned adapters are guaranteed to belong to this DGD + requests := make([]reconcile.Request, 0, len(adapterList.Items)) + for i := range adapterList.Items { + requests = append(requests, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Name: adapterList.Items[i].Name, + Namespace: adapterList.Items[i].Namespace, + }, + }) + } + + return requests +} diff --git a/deploy/cloud/operator/internal/controller/dynamographdeploymentscalingadapter_controller_test.go b/deploy/cloud/operator/internal/controller/dynamographdeploymentscalingadapter_controller_test.go new file mode 100644 index 0000000000..33c6b9f5e8 --- /dev/null +++ b/deploy/cloud/operator/internal/controller/dynamographdeploymentscalingadapter_controller_test.go @@ -0,0 +1,512 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package controller + +import ( + "context" + "testing" + + "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" + "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/consts" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/tools/record" + "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func TestDynamoGraphDeploymentScalingAdapterReconciler_Reconcile(t *testing.T) { + // Register custom types with the scheme + if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil { + t.Fatalf("Failed to add v1alpha1 to scheme: %v", err) + } + + tests := []struct { + name string + adapter *v1alpha1.DynamoGraphDeploymentScalingAdapter + dgd *v1alpha1.DynamoGraphDeployment + expectedDGDReplicas int32 + expectedStatusReplicas int32 + expectError bool + expectRequeue bool + }{ + { + name: "updates DGD replicas when DGDSA spec differs", + adapter: &v1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-frontend", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: 5, + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "Frontend", + }, + }, + }, + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "Frontend": { + Replicas: ptr.To(int32(2)), + }, + }, + }, + }, + expectedDGDReplicas: 5, + expectedStatusReplicas: 5, + expectError: false, + }, + { + name: "no update when replicas already match", + adapter: &v1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-frontend", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: 3, + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "Frontend", + }, + }, + }, + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "Frontend": { + Replicas: ptr.To(int32(3)), + }, + }, + }, + }, + expectedDGDReplicas: 3, + expectedStatusReplicas: 3, + expectError: false, + }, + { + name: "uses default replicas (1) when DGD service has no replicas set", + adapter: &v1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-worker", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: 4, + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "worker", + }, + }, + }, + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "worker": {}, // no replicas set + }, + }, + }, + expectedDGDReplicas: 4, + expectedStatusReplicas: 4, + expectError: false, + }, + { + name: "error when service not found in DGD", + adapter: &v1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-missing", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: 2, + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "nonexistent", + }, + }, + }, + dgd: &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "Frontend": { + Replicas: ptr.To(int32(1)), + }, + }, + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Build initial objects + var initObjs []client.Object + initObjs = append(initObjs, tt.adapter, tt.dgd) + + // Create fake client with status subresource support + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects(initObjs...). + WithStatusSubresource(&v1alpha1.DynamoGraphDeploymentScalingAdapter{}). + Build() + + // Create reconciler + r := &DynamoGraphDeploymentScalingAdapterReconciler{ + Client: fakeClient, + Scheme: scheme.Scheme, + Recorder: record.NewFakeRecorder(10), + } + + // Run Reconcile + ctx := context.Background() + req := ctrl.Request{ + NamespacedName: types.NamespacedName{ + Name: tt.adapter.Name, + Namespace: tt.adapter.Namespace, + }, + } + + result, err := r.Reconcile(ctx, req) + + // Check error expectation + if tt.expectError && err == nil { + t.Errorf("Expected error, but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Skip further checks if error was expected + if tt.expectError { + return + } + + // Check requeue + if tt.expectRequeue && result.RequeueAfter == 0 { + t.Errorf("Expected requeue, but got none") + } + + // Verify DGD replicas were updated + updatedDGD := &v1alpha1.DynamoGraphDeployment{} + if err := fakeClient.Get(ctx, types.NamespacedName{Name: tt.dgd.Name, Namespace: tt.dgd.Namespace}, updatedDGD); err != nil { + t.Fatalf("Failed to get updated DGD: %v", err) + } + + service, exists := updatedDGD.Spec.Services[tt.adapter.Spec.DGDRef.ServiceName] + if !exists { + t.Fatalf("Service %s not found in updated DGD", tt.adapter.Spec.DGDRef.ServiceName) + } + + actualReplicas := int32(1) + if service.Replicas != nil { + actualReplicas = *service.Replicas + } + + if actualReplicas != tt.expectedDGDReplicas { + t.Errorf("DGD service replicas = %d, expected %d", actualReplicas, tt.expectedDGDReplicas) + } + + // Verify adapter status was updated + updatedAdapter := &v1alpha1.DynamoGraphDeploymentScalingAdapter{} + if err := fakeClient.Get(ctx, types.NamespacedName{Name: tt.adapter.Name, Namespace: tt.adapter.Namespace}, updatedAdapter); err != nil { + t.Fatalf("Failed to get updated adapter: %v", err) + } + + if updatedAdapter.Status.Replicas != tt.expectedStatusReplicas { + t.Errorf("Adapter status.replicas = %d, expected %d", updatedAdapter.Status.Replicas, tt.expectedStatusReplicas) + } + + // Verify selector is set + if updatedAdapter.Status.Selector == "" { + t.Errorf("Adapter status.selector is empty, expected non-empty") + } + }) + } +} + +func TestDynamoGraphDeploymentScalingAdapterReconciler_Reconcile_NotFound(t *testing.T) { + // Register custom types with the scheme + if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil { + t.Fatalf("Failed to add v1alpha1 to scheme: %v", err) + } + + // Create fake client with no objects + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + Build() + + r := &DynamoGraphDeploymentScalingAdapterReconciler{ + Client: fakeClient, + Scheme: scheme.Scheme, + Recorder: record.NewFakeRecorder(10), + } + + ctx := context.Background() + req := ctrl.Request{ + NamespacedName: types.NamespacedName{ + Name: "nonexistent", + Namespace: "default", + }, + } + + // Should return no error when adapter not found (client.IgnoreNotFound) + result, err := r.Reconcile(ctx, req) + if err != nil { + t.Errorf("Expected no error for not found adapter, got: %v", err) + } + if result.RequeueAfter != 0 { + t.Errorf("Expected no requeueAfter for not found adapter, got: %v", result.RequeueAfter) + } +} + +func TestDynamoGraphDeploymentScalingAdapterReconciler_Reconcile_DGDNotFound(t *testing.T) { + // Register custom types with the scheme + if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil { + t.Fatalf("Failed to add v1alpha1 to scheme: %v", err) + } + + adapter := &v1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-frontend", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: 5, + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "nonexistent-dgd", + ServiceName: "Frontend", + }, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects(adapter). + Build() + + r := &DynamoGraphDeploymentScalingAdapterReconciler{ + Client: fakeClient, + Scheme: scheme.Scheme, + Recorder: record.NewFakeRecorder(10), + } + + ctx := context.Background() + req := ctrl.Request{ + NamespacedName: types.NamespacedName{ + Name: adapter.Name, + Namespace: adapter.Namespace, + }, + } + + // Should return error when DGD not found + _, err := r.Reconcile(ctx, req) + if err == nil { + t.Errorf("Expected error when DGD not found, got none") + } +} + +func TestDynamoGraphDeploymentScalingAdapterReconciler_Reconcile_BeingDeleted(t *testing.T) { + // Register custom types with the scheme + if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil { + t.Fatalf("Failed to add v1alpha1 to scheme: %v", err) + } + + now := metav1.Now() + adapter := &v1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-frontend", + Namespace: "default", + DeletionTimestamp: &now, + Finalizers: []string{"test-finalizer"}, // Required for deletion timestamp to be set + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + Replicas: 5, + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "Frontend", + }, + }, + } + + dgd := &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + }, + Spec: v1alpha1.DynamoGraphDeploymentSpec{ + Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{ + "Frontend": { + Replicas: ptr.To(int32(2)), + }, + }, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects(adapter, dgd). + Build() + + r := &DynamoGraphDeploymentScalingAdapterReconciler{ + Client: fakeClient, + Scheme: scheme.Scheme, + Recorder: record.NewFakeRecorder(10), + } + + ctx := context.Background() + req := ctrl.Request{ + NamespacedName: types.NamespacedName{ + Name: adapter.Name, + Namespace: adapter.Namespace, + }, + } + + // Should return no error and skip reconciliation + result, err := r.Reconcile(ctx, req) + if err != nil { + t.Errorf("Expected no error for deleting adapter, got: %v", err) + } + if result.RequeueAfter != 0 { + t.Errorf("Expected no requeueAfter for deleting adapter, got: %v", result.RequeueAfter) + } + + // DGD replicas should NOT be updated (still 2) + updatedDGD := &v1alpha1.DynamoGraphDeployment{} + if err := fakeClient.Get(ctx, types.NamespacedName{Name: dgd.Name, Namespace: dgd.Namespace}, updatedDGD); err != nil { + t.Fatalf("Failed to get DGD: %v", err) + } + + if *updatedDGD.Spec.Services["Frontend"].Replicas != 2 { + t.Errorf("DGD replicas should remain unchanged, got %d", *updatedDGD.Spec.Services["Frontend"].Replicas) + } +} + +func TestDynamoGraphDeploymentScalingAdapterReconciler_findAdaptersForDGD(t *testing.T) { + // Register custom types with the scheme + if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil { + t.Fatalf("Failed to add v1alpha1 to scheme: %v", err) + } + + dgd := &v1alpha1.DynamoGraphDeployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd", + Namespace: "default", + }, + } + + // Adapters belonging to test-dgd + adapter1 := &v1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-frontend", + Namespace: "default", + Labels: map[string]string{ + consts.KubeLabelDynamoGraphDeploymentName: "test-dgd", + }, + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "Frontend", + }, + }, + } + + adapter2 := &v1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-dgd-decode", + Namespace: "default", + Labels: map[string]string{ + consts.KubeLabelDynamoGraphDeploymentName: "test-dgd", + }, + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "test-dgd", + ServiceName: "decode", + }, + }, + } + + // Adapter belonging to different DGD + adapterOther := &v1alpha1.DynamoGraphDeploymentScalingAdapter{ + ObjectMeta: metav1.ObjectMeta{ + Name: "other-dgd-frontend", + Namespace: "default", + Labels: map[string]string{ + consts.KubeLabelDynamoGraphDeploymentName: "other-dgd", + }, + }, + Spec: v1alpha1.DynamoGraphDeploymentScalingAdapterSpec{ + DGDRef: v1alpha1.DynamoGraphDeploymentServiceRef{ + Name: "other-dgd", + ServiceName: "Frontend", + }, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme.Scheme). + WithObjects(adapter1, adapter2, adapterOther). + Build() + + r := &DynamoGraphDeploymentScalingAdapterReconciler{ + Client: fakeClient, + } + + ctx := context.Background() + requests := r.findAdaptersForDGD(ctx, dgd) + + // Should return 2 requests (for test-dgd adapters only) + if len(requests) != 2 { + t.Errorf("findAdaptersForDGD() returned %d requests, expected 2", len(requests)) + } + + // Verify correct adapters are returned + expectedNames := map[string]bool{ + "test-dgd-frontend": true, + "test-dgd-decode": true, + } + + for _, req := range requests { + if !expectedNames[req.Name] { + t.Errorf("Unexpected adapter in results: %s", req.Name) + } + } +} diff --git a/deploy/cloud/operator/internal/controller_common/pod.go b/deploy/cloud/operator/internal/controller_common/pod.go deleted file mode 100644 index 48415b4451..0000000000 --- a/deploy/cloud/operator/internal/controller_common/pod.go +++ /dev/null @@ -1,294 +0,0 @@ -package controller_common - -import ( - "sort" - - corev1 "k8s.io/api/core/v1" -) - -// CanonicalizePodSpec sorts the pod spec in a way that is deterministic and easy to reason about. -// -//nolint:gocyclo -func CanonicalizePodSpec(podSpec *corev1.PodSpec) *corev1.PodSpec { - // Helper function to get EnvFromSource sort key - envFromKey := func(e corev1.EnvFromSource) string { - if e.ConfigMapRef != nil { - return "cm:" + e.ConfigMapRef.Name + ":" + e.Prefix - } - if e.SecretRef != nil { - return "sec:" + e.SecretRef.Name + ":" + e.Prefix - } - return "other:" + e.Prefix - } - - // Helper function to sort container-like fields (works for both Container and EphemeralContainer) - sortContainerFields := func(env []corev1.EnvVar, envFrom []corev1.EnvFromSource, ports []corev1.ContainerPort, volumeMounts []corev1.VolumeMount, securityContext *corev1.SecurityContext) { - // Sort env vars by name - if len(env) > 1 { - sort.Slice(env, func(i, j int) bool { return env[i].Name < env[j].Name }) - } - - // Sort envFrom by referenced source and prefix - if len(envFrom) > 1 { - sort.Slice(envFrom, func(i, j int) bool { - return envFromKey(envFrom[i]) < envFromKey(envFrom[j]) - }) - } - - // Sort ports by name then port number - if len(ports) > 1 { - sort.Slice(ports, func(i, j int) bool { - if ports[i].Name == ports[j].Name { - return ports[i].ContainerPort < ports[j].ContainerPort - } - return ports[i].Name < ports[j].Name - }) - } - - // Sort volume mounts by name then mount path - if len(volumeMounts) > 1 { - sort.Slice(volumeMounts, func(i, j int) bool { - if volumeMounts[i].Name == volumeMounts[j].Name { - return volumeMounts[i].MountPath < volumeMounts[j].MountPath - } - return volumeMounts[i].Name < volumeMounts[j].Name - }) - } - - // Sort security context capability lists - if securityContext != nil && securityContext.Capabilities != nil { - if caps := securityContext.Capabilities.Add; len(caps) > 1 { - sort.Slice(caps, func(i, j int) bool { return string(caps[i]) < string(caps[j]) }) - } - if caps := securityContext.Capabilities.Drop; len(caps) > 1 { - sort.Slice(caps, func(i, j int) bool { return string(caps[i]) < string(caps[j]) }) - } - } - } - - // Sort regular containers - for i := range podSpec.Containers { - c := &podSpec.Containers[i] - sortContainerFields(c.Env, c.EnvFrom, c.Ports, c.VolumeMounts, c.SecurityContext) - } - if len(podSpec.Containers) > 1 { - sort.Slice(podSpec.Containers, func(i, j int) bool { - return podSpec.Containers[i].Name < podSpec.Containers[j].Name - }) - } - - // Sort init containers - for i := range podSpec.InitContainers { - c := &podSpec.InitContainers[i] - sortContainerFields(c.Env, c.EnvFrom, c.Ports, c.VolumeMounts, c.SecurityContext) - } - if len(podSpec.InitContainers) > 1 { - sort.Slice(podSpec.InitContainers, func(i, j int) bool { - return podSpec.InitContainers[i].Name < podSpec.InitContainers[j].Name - }) - } - - // Sort ephemeral containers - for i := range podSpec.EphemeralContainers { - ec := &podSpec.EphemeralContainers[i] - sortContainerFields(ec.Env, ec.EnvFrom, ec.Ports, ec.VolumeMounts, ec.SecurityContext) - } - if len(podSpec.EphemeralContainers) > 1 { - sort.Slice(podSpec.EphemeralContainers, func(i, j int) bool { - return podSpec.EphemeralContainers[i].Name < podSpec.EphemeralContainers[j].Name - }) - } - - // Sort image pull secrets - if len(podSpec.ImagePullSecrets) > 1 { - uniqueSecrets := ensureUniqueImagePullSecrets(podSpec.ImagePullSecrets) - sort.Slice(uniqueSecrets, func(i, j int) bool { - return uniqueSecrets[i].Name < uniqueSecrets[j].Name - }) - podSpec.ImagePullSecrets = uniqueSecrets - } - - // Sort volumes and their nested items - sortKeyToPathItems := func(items []corev1.KeyToPath) { - if len(items) > 1 { - sort.Slice(items, func(i, j int) bool { - if items[i].Key == items[j].Key { - return items[i].Path < items[j].Path - } - return items[i].Key < items[j].Key - }) - } - } - - for i := range podSpec.Volumes { - v := &podSpec.Volumes[i] - - // ConfigMap items - if v.ConfigMap != nil { - sortKeyToPathItems(v.ConfigMap.Items) - } - - // Secret items - if v.Secret != nil { - sortKeyToPathItems(v.Secret.Items) - } - - // DownwardAPI items - if v.DownwardAPI != nil && len(v.DownwardAPI.Items) > 1 { - sort.Slice(v.DownwardAPI.Items, func(i, j int) bool { - return v.DownwardAPI.Items[i].Path < v.DownwardAPI.Items[j].Path - }) - } - - // Projected sources - if v.Projected != nil { - // Sort projected sources - if len(v.Projected.Sources) > 1 { - sort.Slice(v.Projected.Sources, func(i, j int) bool { - getProjectionKey := func(p corev1.VolumeProjection) string { - if p.ConfigMap != nil { - return "cm:" + p.ConfigMap.Name - } - if p.Secret != nil { - return "sec:" + p.Secret.Name - } - if p.DownwardAPI != nil { - return "downward:" - } - if p.ServiceAccountToken != nil { - return "sat:" + p.ServiceAccountToken.Audience - } - return "z:other" - } - return getProjectionKey(v.Projected.Sources[i]) < getProjectionKey(v.Projected.Sources[j]) - }) - } - - // Sort nested items for each projection - for j := range v.Projected.Sources { - p := &v.Projected.Sources[j] - if p.ConfigMap != nil { - sortKeyToPathItems(p.ConfigMap.Items) - } - if p.Secret != nil { - sortKeyToPathItems(p.Secret.Items) - } - if p.DownwardAPI != nil && len(p.DownwardAPI.Items) > 1 { - sort.Slice(p.DownwardAPI.Items, func(i, j int) bool { - return p.DownwardAPI.Items[i].Path < p.DownwardAPI.Items[j].Path - }) - } - } - } - } - - // Sort volumes by name - if len(podSpec.Volumes) > 1 { - sort.Slice(podSpec.Volumes, func(i, j int) bool { - return podSpec.Volumes[i].Name < podSpec.Volumes[j].Name - }) - } - - // Sort tolerations - if len(podSpec.Tolerations) > 1 { - sort.Slice(podSpec.Tolerations, func(i, j int) bool { - a, b := podSpec.Tolerations[i], podSpec.Tolerations[j] - - if a.Key != b.Key { - return a.Key < b.Key - } - if string(a.Operator) != string(b.Operator) { - return string(a.Operator) < string(b.Operator) - } - if a.Value != b.Value { - return a.Value < b.Value - } - if string(a.Effect) != string(b.Effect) { - return string(a.Effect) < string(b.Effect) - } - - // Handle TolerationSeconds (could be nil) - aSec, bSec := int64(0), int64(0) - if a.TolerationSeconds != nil { - aSec = *a.TolerationSeconds - } - if b.TolerationSeconds != nil { - bSec = *b.TolerationSeconds - } - return aSec < bSec - }) - } - - // Sort topology spread constraints - if len(podSpec.TopologySpreadConstraints) > 1 { - sort.Slice(podSpec.TopologySpreadConstraints, func(i, j int) bool { - a, b := podSpec.TopologySpreadConstraints[i], podSpec.TopologySpreadConstraints[j] - if a.TopologyKey != b.TopologyKey { - return a.TopologyKey < b.TopologyKey - } - if string(a.WhenUnsatisfiable) != string(b.WhenUnsatisfiable) { - return string(a.WhenUnsatisfiable) < string(b.WhenUnsatisfiable) - } - return a.MaxSkew < b.MaxSkew - }) - } - - // Sort host aliases - if len(podSpec.HostAliases) > 1 { - // First sort hostnames within each alias - for i := range podSpec.HostAliases { - if len(podSpec.HostAliases[i].Hostnames) > 1 { - sort.Strings(podSpec.HostAliases[i].Hostnames) - } - } - // Then sort aliases by IP - sort.Slice(podSpec.HostAliases, func(i, j int) bool { - return podSpec.HostAliases[i].IP < podSpec.HostAliases[j].IP - }) - } - - // Sort DNS config - if podSpec.DNSConfig != nil { - // Sort DNS options - if len(podSpec.DNSConfig.Options) > 1 { - sort.Slice(podSpec.DNSConfig.Options, func(i, j int) bool { - if podSpec.DNSConfig.Options[i].Name == podSpec.DNSConfig.Options[j].Name { - vi, vj := "", "" - if podSpec.DNSConfig.Options[i].Value != nil { - vi = *podSpec.DNSConfig.Options[i].Value - } - if podSpec.DNSConfig.Options[j].Value != nil { - vj = *podSpec.DNSConfig.Options[j].Value - } - return vi < vj - } - return podSpec.DNSConfig.Options[i].Name < podSpec.DNSConfig.Options[j].Name - }) - } - - // Sort nameservers and search domains - if len(podSpec.DNSConfig.Nameservers) > 1 { - sort.Strings(podSpec.DNSConfig.Nameservers) - } - if len(podSpec.DNSConfig.Searches) > 1 { - sort.Strings(podSpec.DNSConfig.Searches) - } - } - - return podSpec -} - -func ensureUniqueImagePullSecrets(secrets []corev1.LocalObjectReference) []corev1.LocalObjectReference { - if len(secrets) == 0 { - return nil - } - uniqueSecrets := make(map[string]corev1.LocalObjectReference) - for _, secret := range secrets { - uniqueSecrets[secret.Name] = secret - } - uniqueSecretsList := make([]corev1.LocalObjectReference, 0, len(uniqueSecrets)) - for secretName := range uniqueSecrets { - uniqueSecretsList = append(uniqueSecretsList, corev1.LocalObjectReference{Name: secretName}) - } - return uniqueSecretsList -} diff --git a/deploy/cloud/operator/internal/controller_common/pod_test.go b/deploy/cloud/operator/internal/controller_common/pod_test.go deleted file mode 100644 index 94a646cf67..0000000000 --- a/deploy/cloud/operator/internal/controller_common/pod_test.go +++ /dev/null @@ -1,891 +0,0 @@ -package controller_common - -import ( - "testing" - - "github.com/stretchr/testify/assert" - corev1 "k8s.io/api/core/v1" -) - -func TestCanonicalizePodSpec(t *testing.T) { - tests := []struct { - name string - input *corev1.PodSpec - expected *corev1.PodSpec - }{ - { - name: "sorts containers by name", - input: &corev1.PodSpec{ - Containers: []corev1.Container{ - {Name: "zebra"}, - {Name: "alpha"}, - {Name: "beta"}, - }, - }, - expected: &corev1.PodSpec{ - Containers: []corev1.Container{ - {Name: "alpha"}, - {Name: "beta"}, - {Name: "zebra"}, - }, - }, - }, - { - name: "sorts init containers by name", - input: &corev1.PodSpec{ - InitContainers: []corev1.Container{ - {Name: "init-zebra"}, - {Name: "init-alpha"}, - }, - }, - expected: &corev1.PodSpec{ - InitContainers: []corev1.Container{ - {Name: "init-alpha"}, - {Name: "init-zebra"}, - }, - }, - }, - { - name: "sorts ephemeral containers by name", - input: &corev1.PodSpec{ - EphemeralContainers: []corev1.EphemeralContainer{ - {EphemeralContainerCommon: corev1.EphemeralContainerCommon{Name: "debug-zebra"}}, - {EphemeralContainerCommon: corev1.EphemeralContainerCommon{Name: "debug-alpha"}}, - }, - }, - expected: &corev1.PodSpec{ - EphemeralContainers: []corev1.EphemeralContainer{ - {EphemeralContainerCommon: corev1.EphemeralContainerCommon{Name: "debug-alpha"}}, - {EphemeralContainerCommon: corev1.EphemeralContainerCommon{Name: "debug-zebra"}}, - }, - }, - }, - { - name: "sorts environment variables by name", - input: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - Env: []corev1.EnvVar{ - {Name: "ZOO", Value: "zebra"}, - {Name: "ALPHA", Value: "apple"}, - {Name: "BETA", Value: "banana"}, - }, - }, - }, - }, - expected: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - Env: []corev1.EnvVar{ - {Name: "ALPHA", Value: "apple"}, - {Name: "BETA", Value: "banana"}, - {Name: "ZOO", Value: "zebra"}, - }, - }, - }, - }, - }, - { - name: "sorts envFrom by source type and name", - input: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - EnvFrom: []corev1.EnvFromSource{ - {SecretRef: &corev1.SecretEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "secret-z"}}}, - {ConfigMapRef: &corev1.ConfigMapEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "config-a"}}}, - {SecretRef: &corev1.SecretEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "secret-a"}}}, - {ConfigMapRef: &corev1.ConfigMapEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "config-z"}}}, - }, - }, - }, - }, - expected: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - EnvFrom: []corev1.EnvFromSource{ - {ConfigMapRef: &corev1.ConfigMapEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "config-a"}}}, - {ConfigMapRef: &corev1.ConfigMapEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "config-z"}}}, - {SecretRef: &corev1.SecretEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "secret-a"}}}, - {SecretRef: &corev1.SecretEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "secret-z"}}}, - }, - }, - }, - }, - }, - { - name: "sorts container ports by name then port number", - input: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - Ports: []corev1.ContainerPort{ - {Name: "http", ContainerPort: 8080}, - {Name: "grpc", ContainerPort: 9090}, - {Name: "grpc", ContainerPort: 8080}, - {Name: "debug", ContainerPort: 8080}, - }, - }, - }, - }, - expected: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - Ports: []corev1.ContainerPort{ - {Name: "debug", ContainerPort: 8080}, - {Name: "grpc", ContainerPort: 8080}, - {Name: "grpc", ContainerPort: 9090}, - {Name: "http", ContainerPort: 8080}, - }, - }, - }, - }, - }, - { - name: "sorts volume mounts by name then mount path", - input: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - VolumeMounts: []corev1.VolumeMount{ - {Name: "vol1", MountPath: "/data2"}, - {Name: "vol2", MountPath: "/data1"}, - {Name: "vol1", MountPath: "/data1"}, - }, - }, - }, - }, - expected: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - VolumeMounts: []corev1.VolumeMount{ - {Name: "vol1", MountPath: "/data1"}, - {Name: "vol1", MountPath: "/data2"}, - {Name: "vol2", MountPath: "/data1"}, - }, - }, - }, - }, - }, - { - name: "sorts security context capabilities", - input: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - SecurityContext: &corev1.SecurityContext{ - Capabilities: &corev1.Capabilities{ - Add: []corev1.Capability{"SYS_ADMIN", "NET_ADMIN", "CHOWN"}, - Drop: []corev1.Capability{"ALL", "SETUID", "KILL"}, - }, - }, - }, - }, - }, - expected: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - SecurityContext: &corev1.SecurityContext{ - Capabilities: &corev1.Capabilities{ - Add: []corev1.Capability{"CHOWN", "NET_ADMIN", "SYS_ADMIN"}, - Drop: []corev1.Capability{"ALL", "KILL", "SETUID"}, - }, - }, - }, - }, - }, - }, - { - name: "sorts image pull secrets by name", - input: &corev1.PodSpec{ - ImagePullSecrets: []corev1.LocalObjectReference{ - {Name: "registry-z"}, - {Name: "registry-a"}, - {Name: "registry-b"}, - {Name: "registry-a"}, - }, - }, - expected: &corev1.PodSpec{ - ImagePullSecrets: []corev1.LocalObjectReference{ - {Name: "registry-a"}, - {Name: "registry-b"}, - {Name: "registry-z"}, - }, - }, - }, - { - name: "sorts nil image pull secrets", - input: &corev1.PodSpec{ - ImagePullSecrets: nil, - }, - expected: &corev1.PodSpec{ - ImagePullSecrets: nil, - }, - }, - { - name: "sorts volumes by name", - input: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - {Name: "vol-z"}, - {Name: "vol-a"}, - {Name: "vol-b"}, - }, - }, - expected: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - {Name: "vol-a"}, - {Name: "vol-b"}, - {Name: "vol-z"}, - }, - }, - }, - { - name: "sorts configmap volume items by key then path", - input: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - { - Name: "config", - VolumeSource: corev1.VolumeSource{ - ConfigMap: &corev1.ConfigMapVolumeSource{ - Items: []corev1.KeyToPath{ - {Key: "app.conf", Path: "config/app.conf"}, - {Key: "db.conf", Path: "config/db.conf"}, - {Key: "app.conf", Path: "backup/app.conf"}, - }, - }, - }, - }, - }, - }, - expected: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - { - Name: "config", - VolumeSource: corev1.VolumeSource{ - ConfigMap: &corev1.ConfigMapVolumeSource{ - Items: []corev1.KeyToPath{ - {Key: "app.conf", Path: "backup/app.conf"}, - {Key: "app.conf", Path: "config/app.conf"}, - {Key: "db.conf", Path: "config/db.conf"}, - }, - }, - }, - }, - }, - }, - }, - { - name: "sorts secret volume items by key then path", - input: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - { - Name: "secret", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - Items: []corev1.KeyToPath{ - {Key: "tls.key", Path: "tls/server.key"}, - {Key: "tls.crt", Path: "tls/server.crt"}, - {Key: "tls.key", Path: "backup/server.key"}, - }, - }, - }, - }, - }, - }, - expected: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - { - Name: "secret", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - Items: []corev1.KeyToPath{ - {Key: "tls.crt", Path: "tls/server.crt"}, - {Key: "tls.key", Path: "backup/server.key"}, - {Key: "tls.key", Path: "tls/server.key"}, - }, - }, - }, - }, - }, - }, - }, - { - name: "sorts downward API items by path", - input: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - { - Name: "downward", - VolumeSource: corev1.VolumeSource{ - DownwardAPI: &corev1.DownwardAPIVolumeSource{ - Items: []corev1.DownwardAPIVolumeFile{ - {Path: "metadata/name"}, - {Path: "metadata/annotations"}, - {Path: "limits/cpu"}, - }, - }, - }, - }, - }, - }, - expected: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - { - Name: "downward", - VolumeSource: corev1.VolumeSource{ - DownwardAPI: &corev1.DownwardAPIVolumeSource{ - Items: []corev1.DownwardAPIVolumeFile{ - {Path: "limits/cpu"}, - {Path: "metadata/annotations"}, - {Path: "metadata/name"}, - }, - }, - }, - }, - }, - }, - }, - { - name: "sorts projected volume sources and their items", - input: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - { - Name: "projected", - VolumeSource: corev1.VolumeSource{ - Projected: &corev1.ProjectedVolumeSource{ - Sources: []corev1.VolumeProjection{ - { - Secret: &corev1.SecretProjection{ - LocalObjectReference: corev1.LocalObjectReference{Name: "secret-z"}, - Items: []corev1.KeyToPath{ - {Key: "password", Path: "auth/password"}, - {Key: "username", Path: "auth/username"}, - }, - }, - }, - { - ConfigMap: &corev1.ConfigMapProjection{ - LocalObjectReference: corev1.LocalObjectReference{Name: "config-a"}, - Items: []corev1.KeyToPath{ - {Key: "db.conf", Path: "config/db.conf"}, - {Key: "app.conf", Path: "config/app.conf"}, - }, - }, - }, - { - DownwardAPI: &corev1.DownwardAPIProjection{ - Items: []corev1.DownwardAPIVolumeFile{ - {Path: "metadata/name"}, - {Path: "limits/cpu"}, - }, - }, - }, - { - ServiceAccountToken: &corev1.ServiceAccountTokenProjection{ - Audience: "api.example.com", - Path: "tokens/api", - }, - }, - }, - }, - }, - }, - }, - }, - expected: &corev1.PodSpec{ - Volumes: []corev1.Volume{ - { - Name: "projected", - VolumeSource: corev1.VolumeSource{ - Projected: &corev1.ProjectedVolumeSource{ - Sources: []corev1.VolumeProjection{ - { - ConfigMap: &corev1.ConfigMapProjection{ - LocalObjectReference: corev1.LocalObjectReference{Name: "config-a"}, - Items: []corev1.KeyToPath{ - {Key: "app.conf", Path: "config/app.conf"}, - {Key: "db.conf", Path: "config/db.conf"}, - }, - }, - }, - { - DownwardAPI: &corev1.DownwardAPIProjection{ - Items: []corev1.DownwardAPIVolumeFile{ - {Path: "limits/cpu"}, - {Path: "metadata/name"}, - }, - }, - }, - { - ServiceAccountToken: &corev1.ServiceAccountTokenProjection{ - Audience: "api.example.com", - Path: "tokens/api", - }, - }, - { - Secret: &corev1.SecretProjection{ - LocalObjectReference: corev1.LocalObjectReference{Name: "secret-z"}, - Items: []corev1.KeyToPath{ - {Key: "password", Path: "auth/password"}, - {Key: "username", Path: "auth/username"}, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - { - name: "sorts tolerations by key, operator, value, effect, seconds", - input: &corev1.PodSpec{ - Tolerations: []corev1.Toleration{ - { - Key: "node-type", - Operator: corev1.TolerationOpEqual, - Value: "gpu", - Effect: corev1.TaintEffectNoSchedule, - }, - { - Key: "node-role", - Operator: corev1.TolerationOpEqual, - Value: "master", - Effect: corev1.TaintEffectNoSchedule, - }, - { - Key: "node-role", - Operator: corev1.TolerationOpExists, - Effect: corev1.TaintEffectNoSchedule, - }, - }, - }, - expected: &corev1.PodSpec{ - Tolerations: []corev1.Toleration{ - { - Key: "node-role", - Operator: corev1.TolerationOpEqual, - Value: "master", - Effect: corev1.TaintEffectNoSchedule, - }, - { - Key: "node-role", - Operator: corev1.TolerationOpExists, - Effect: corev1.TaintEffectNoSchedule, - }, - { - Key: "node-type", - Operator: corev1.TolerationOpEqual, - Value: "gpu", - Effect: corev1.TaintEffectNoSchedule, - }, - }, - }, - }, - { - name: "sorts topology spread constraints by topology key, when unsatisfiable, max skew", - input: &corev1.PodSpec{ - TopologySpreadConstraints: []corev1.TopologySpreadConstraint{ - { - TopologyKey: "kubernetes.io/zone", - WhenUnsatisfiable: corev1.DoNotSchedule, - MaxSkew: 2, - }, - { - TopologyKey: "kubernetes.io/hostname", - WhenUnsatisfiable: corev1.DoNotSchedule, - MaxSkew: 1, - }, - { - TopologyKey: "kubernetes.io/hostname", - WhenUnsatisfiable: corev1.ScheduleAnyway, - MaxSkew: 1, - }, - }, - }, - expected: &corev1.PodSpec{ - TopologySpreadConstraints: []corev1.TopologySpreadConstraint{ - { - TopologyKey: "kubernetes.io/hostname", - WhenUnsatisfiable: corev1.DoNotSchedule, - MaxSkew: 1, - }, - { - TopologyKey: "kubernetes.io/hostname", - WhenUnsatisfiable: corev1.ScheduleAnyway, - MaxSkew: 1, - }, - { - TopologyKey: "kubernetes.io/zone", - WhenUnsatisfiable: corev1.DoNotSchedule, - MaxSkew: 2, - }, - }, - }, - }, - { - name: "sorts host aliases by IP and hostnames within each alias", - input: &corev1.PodSpec{ - HostAliases: []corev1.HostAlias{ - { - IP: "192.168.1.2", - Hostnames: []string{"web2.example.com", "api2.example.com"}, - }, - { - IP: "192.168.1.1", - Hostnames: []string{"web1.example.com", "api1.example.com", "admin1.example.com"}, - }, - }, - }, - expected: &corev1.PodSpec{ - HostAliases: []corev1.HostAlias{ - { - IP: "192.168.1.1", - Hostnames: []string{"admin1.example.com", "api1.example.com", "web1.example.com"}, - }, - { - IP: "192.168.1.2", - Hostnames: []string{"api2.example.com", "web2.example.com"}, - }, - }, - }, - }, - { - name: "sorts DNS config options, nameservers, and searches", - input: &corev1.PodSpec{ - DNSConfig: &corev1.PodDNSConfig{ - Options: []corev1.PodDNSConfigOption{ - {Name: "timeout", Value: func() *string { s := "5"; return &s }()}, - {Name: "attempts", Value: func() *string { s := "3"; return &s }()}, - {Name: "ndots", Value: func() *string { s := "2"; return &s }()}, - }, - Nameservers: []string{"8.8.8.8", "1.1.1.1", "8.8.4.4"}, - Searches: []string{"example.com", "cluster.local", "app.local"}, - }, - }, - expected: &corev1.PodSpec{ - DNSConfig: &corev1.PodDNSConfig{ - Options: []corev1.PodDNSConfigOption{ - {Name: "attempts", Value: func() *string { s := "3"; return &s }()}, - {Name: "ndots", Value: func() *string { s := "2"; return &s }()}, - {Name: "timeout", Value: func() *string { s := "5"; return &s }()}, - }, - Nameservers: []string{"1.1.1.1", "8.8.4.4", "8.8.8.8"}, - Searches: []string{"app.local", "cluster.local", "example.com"}, - }, - }, - }, - { - name: "handles nil pointer values gracefully", - input: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - SecurityContext: &corev1.SecurityContext{ - Capabilities: nil, - }, - }, - }, - DNSConfig: &corev1.PodDNSConfig{ - Options: []corev1.PodDNSConfigOption{ - {Name: "timeout", Value: nil}, - {Name: "attempts", Value: func() *string { s := "3"; return &s }()}, - }, - }, - Tolerations: []corev1.Toleration{ - { - Key: "test", - TolerationSeconds: nil, - }, - { - Key: "test2", - TolerationSeconds: func() *int64 { s := int64(300); return &s }(), - }, - }, - }, - expected: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - SecurityContext: &corev1.SecurityContext{ - Capabilities: nil, - }, - }, - }, - DNSConfig: &corev1.PodDNSConfig{ - Options: []corev1.PodDNSConfigOption{ - {Name: "attempts", Value: func() *string { s := "3"; return &s }()}, - {Name: "timeout", Value: nil}, - }, - }, - Tolerations: []corev1.Toleration{ - { - Key: "test", - TolerationSeconds: nil, - }, - { - Key: "test2", - TolerationSeconds: func() *int64 { s := int64(300); return &s }(), - }, - }, - }, - }, - { - name: "returns original podspec when already sorted", - input: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "alpha", - Env: []corev1.EnvVar{ - {Name: "A", Value: "1"}, - {Name: "B", Value: "2"}, - }, - }, - {Name: "beta"}, - }, - ImagePullSecrets: []corev1.LocalObjectReference{ - {Name: "secret-a"}, - {Name: "secret-b"}, - }, - }, - expected: &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "alpha", - Env: []corev1.EnvVar{ - {Name: "A", Value: "1"}, - {Name: "B", Value: "2"}, - }, - }, - {Name: "beta"}, - }, - ImagePullSecrets: []corev1.LocalObjectReference{ - {Name: "secret-a"}, - {Name: "secret-b"}, - }, - }, - }, - { - name: "handles empty slices gracefully", - input: &corev1.PodSpec{ - Containers: []corev1.Container{}, - InitContainers: []corev1.Container{}, - ImagePullSecrets: []corev1.LocalObjectReference{}, - Volumes: []corev1.Volume{}, - Tolerations: []corev1.Toleration{}, - }, - expected: &corev1.PodSpec{ - Containers: []corev1.Container{}, - InitContainers: []corev1.Container{}, - ImagePullSecrets: []corev1.LocalObjectReference{}, - Volumes: []corev1.Volume{}, - Tolerations: []corev1.Toleration{}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := CanonicalizePodSpec(tt.input) - - // Verify the function returns the same instance - assert.Same(t, tt.input, result, "function should return the same PodSpec instance") - - // Verify the sorting is correct - assert.Equal(t, tt.expected, result, "PodSpec should be sorted correctly") - }) - } -} - -func TestCanonicalizePodSpec_Idempotent(t *testing.T) { - // Create a complex, unsorted PodSpec - podSpec := &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "zebra", - Env: []corev1.EnvVar{ - {Name: "Z_VAR", Value: "z"}, - {Name: "A_VAR", Value: "a"}, - }, - Ports: []corev1.ContainerPort{ - {Name: "http", ContainerPort: 8080}, - {Name: "grpc", ContainerPort: 9090}, - }, - VolumeMounts: []corev1.VolumeMount{ - {Name: "vol2", MountPath: "/data2"}, - {Name: "vol1", MountPath: "/data1"}, - }, - }, - {Name: "alpha"}, - }, - InitContainers: []corev1.Container{ - {Name: "init-zebra"}, - {Name: "init-alpha"}, - }, - ImagePullSecrets: []corev1.LocalObjectReference{ - {Name: "secret-z"}, - {Name: "secret-a"}, - }, - Volumes: []corev1.Volume{ - {Name: "vol-z"}, - {Name: "vol-a"}, - }, - Tolerations: []corev1.Toleration{ - {Key: "node-z"}, - {Key: "node-a"}, - }, - } - - // First canonicalization - result1 := CanonicalizePodSpec(podSpec) - - // Second canonicalization on the same object - result2 := CanonicalizePodSpec(result1) - - // Should be identical after second canonicalization - assert.Equal(t, result1, result2, "CanonicalizePodSpec should be idempotent") - - // Verify containers are sorted - assert.Equal(t, "alpha", result2.Containers[0].Name) - assert.Equal(t, "zebra", result2.Containers[1].Name) - - // Verify env vars within containers are sorted - assert.Equal(t, "A_VAR", result2.Containers[1].Env[0].Name) - assert.Equal(t, "Z_VAR", result2.Containers[1].Env[1].Name) - - // Verify ports are sorted - assert.Equal(t, "grpc", result2.Containers[1].Ports[0].Name) - assert.Equal(t, "http", result2.Containers[1].Ports[1].Name) - - // Verify volume mounts are sorted - assert.Equal(t, "vol1", result2.Containers[1].VolumeMounts[0].Name) - assert.Equal(t, "vol2", result2.Containers[1].VolumeMounts[1].Name) -} - -func TestCanonicalizePodSpec_EnvFromSortPriority(t *testing.T) { - podSpec := &corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test", - EnvFrom: []corev1.EnvFromSource{ - {SecretRef: &corev1.SecretEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "secret-b"}}}, - {ConfigMapRef: &corev1.ConfigMapEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "config-b"}}}, - {SecretRef: &corev1.SecretEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "secret-a"}}}, - {ConfigMapRef: &corev1.ConfigMapEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "config-a"}}}, - // Test duplicate names for secondary sort - {ConfigMapRef: &corev1.ConfigMapEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "config-a"}}}, - {SecretRef: &corev1.SecretEnvSource{LocalObjectReference: corev1.LocalObjectReference{Name: "secret-a"}}}, - }, - }, - }, - } - - result := CanonicalizePodSpec(podSpec) - - // ConfigMaps should come before Secrets (cm: < sec:) - // Within each type, sorted by name - expected := []string{ - "cm:config-a:", // ConfigMap config-a - "cm:config-a:", // ConfigMap config-a (duplicate) - "cm:config-b:", // ConfigMap config-b - "sec:secret-a:", // Secret secret-a - "sec:secret-a:", // Secret secret-a (duplicate) - "sec:secret-b:", // Secret secret-b - } - - envFromKey := func(e corev1.EnvFromSource) string { - if e.ConfigMapRef != nil { - return "cm:" + e.ConfigMapRef.Name + ":" - } - if e.SecretRef != nil { - return "sec:" + e.SecretRef.Name + ":" - } - return "other:" - } - - for i, envFrom := range result.Containers[0].EnvFrom { - assert.Equal(t, expected[i], envFromKey(envFrom), "EnvFrom at index %d should match expected sort order", i) - } -} - -func TestCanonicalizePodSpec_TolerationSecondsHandling(t *testing.T) { - sec300 := int64(300) - sec600 := int64(600) - - podSpec := &corev1.PodSpec{ - Tolerations: []corev1.Toleration{ - {Key: "key1", TolerationSeconds: &sec600}, - {Key: "key1", TolerationSeconds: nil}, - {Key: "key1", TolerationSeconds: &sec300}, - }, - } - - result := CanonicalizePodSpec(podSpec) - - // Should be sorted by TolerationSeconds: nil (0) < 300 < 600 - assert.Nil(t, result.Tolerations[0].TolerationSeconds) - assert.Equal(t, int64(300), *result.Tolerations[1].TolerationSeconds) - assert.Equal(t, int64(600), *result.Tolerations[2].TolerationSeconds) -} - -func TestCanonicalizePodSpec_ProjectedVolumeSourcePriority(t *testing.T) { - podSpec := &corev1.PodSpec{ - Volumes: []corev1.Volume{ - { - Name: "projected", - VolumeSource: corev1.VolumeSource{ - Projected: &corev1.ProjectedVolumeSource{ - Sources: []corev1.VolumeProjection{ - {Secret: &corev1.SecretProjection{LocalObjectReference: corev1.LocalObjectReference{Name: "secret-a"}}}, - {ServiceAccountToken: &corev1.ServiceAccountTokenProjection{Audience: "zz.example.com"}}, - {DownwardAPI: &corev1.DownwardAPIProjection{}}, - {ConfigMap: &corev1.ConfigMapProjection{LocalObjectReference: corev1.LocalObjectReference{Name: "config-z"}}}, - {ServiceAccountToken: &corev1.ServiceAccountTokenProjection{Audience: "aa.example.com"}}, - {ConfigMap: &corev1.ConfigMapProjection{LocalObjectReference: corev1.LocalObjectReference{Name: "config-a"}}}, - }, - }, - }, - }, - }, - } - - result := CanonicalizePodSpec(podSpec) - - // Expected sort order: cm: < downward: < sat: < sec: - // Within same type, sorted by name/audience - getProjectionKey := func(p corev1.VolumeProjection) string { - if p.ConfigMap != nil { - return "cm:" + p.ConfigMap.Name - } - if p.Secret != nil { - return "sec:" + p.Secret.Name - } - if p.DownwardAPI != nil { - return "downward:" - } - if p.ServiceAccountToken != nil { - return "sat:" + p.ServiceAccountToken.Audience - } - return "z:other" - } - - expected := []string{ - "cm:config-a", - "cm:config-z", - "downward:", - "sat:aa.example.com", - "sat:zz.example.com", - "sec:secret-a", - } - - sources := result.Volumes[0].VolumeSource.Projected.Sources - for i, source := range sources { - assert.Equal(t, expected[i], getProjectionKey(source), "Projected source at index %d should match expected sort order", i) - } -} diff --git a/deploy/cloud/operator/internal/controller_common/podgangset.go b/deploy/cloud/operator/internal/controller_common/podgangset.go deleted file mode 100644 index 871fa502df..0000000000 --- a/deploy/cloud/operator/internal/controller_common/podgangset.go +++ /dev/null @@ -1,19 +0,0 @@ -package controller_common - -import ( - "sort" - - grovev1alpha1 "github.com/NVIDIA/grove/operator/api/core/v1alpha1" -) - -func CanonicalizePodCliqueSet(gangSet *grovev1alpha1.PodCliqueSet) *grovev1alpha1.PodCliqueSet { - // sort cliques by name - sort.Slice(gangSet.Spec.Template.Cliques, func(i, j int) bool { - return gangSet.Spec.Template.Cliques[i].Name < gangSet.Spec.Template.Cliques[j].Name - }) - // sort scaling groups by name - sort.Slice(gangSet.Spec.Template.PodCliqueScalingGroupConfigs, func(i, j int) bool { - return gangSet.Spec.Template.PodCliqueScalingGroupConfigs[i].Name < gangSet.Spec.Template.PodCliqueScalingGroupConfigs[j].Name - }) - return gangSet -} diff --git a/deploy/cloud/operator/internal/controller_common/resource.go b/deploy/cloud/operator/internal/controller_common/resource.go index 66f5bffac7..4088958c02 100644 --- a/deploy/cloud/operator/internal/controller_common/resource.go +++ b/deploy/cloud/operator/internal/controller_common/resource.go @@ -496,6 +496,24 @@ func getGPUResourceName(resourceItem *v1alpha1.ResourceItem) corev1.ResourceName return corev1.ResourceName(consts.KubeResourceGPUNvidia) } +// AppendUniqueImagePullSecrets appends secrets to existing, skipping any that already exist by name. +func AppendUniqueImagePullSecrets(existing, additional []corev1.LocalObjectReference) []corev1.LocalObjectReference { + if len(additional) == 0 { + return existing + } + seen := make(map[string]bool, len(existing)) + for _, s := range existing { + seen[s.Name] = true + } + for _, s := range additional { + if !seen[s.Name] { + existing = append(existing, s) + seen[s.Name] = true + } + } + return existing +} + type Resource struct { client.Object isReady func() (bool, string) diff --git a/deploy/cloud/operator/internal/controller_common/resource_test.go b/deploy/cloud/operator/internal/controller_common/resource_test.go index 98b7cd9ad9..d3f4e86bb9 100644 --- a/deploy/cloud/operator/internal/controller_common/resource_test.go +++ b/deploy/cloud/operator/internal/controller_common/resource_test.go @@ -532,3 +532,75 @@ func TestGetResourcesConfig(t *testing.T) { }) } } + +func TestAppendUniqueImagePullSecrets(t *testing.T) { + tests := []struct { + name string + existing []corev1.LocalObjectReference + additional []corev1.LocalObjectReference + expected []corev1.LocalObjectReference + }{ + { + name: "empty existing, empty additional", + existing: []corev1.LocalObjectReference{}, + additional: []corev1.LocalObjectReference{}, + expected: []corev1.LocalObjectReference{}, + }, + { + name: "empty existing, some additional", + existing: []corev1.LocalObjectReference{}, + additional: []corev1.LocalObjectReference{{Name: "secret-a"}, {Name: "secret-b"}}, + expected: []corev1.LocalObjectReference{{Name: "secret-a"}, {Name: "secret-b"}}, + }, + { + name: "some existing, empty additional", + existing: []corev1.LocalObjectReference{{Name: "secret-a"}}, + additional: []corev1.LocalObjectReference{}, + expected: []corev1.LocalObjectReference{{Name: "secret-a"}}, + }, + { + name: "no duplicates", + existing: []corev1.LocalObjectReference{{Name: "secret-a"}}, + additional: []corev1.LocalObjectReference{{Name: "secret-b"}, {Name: "secret-c"}}, + expected: []corev1.LocalObjectReference{{Name: "secret-a"}, {Name: "secret-b"}, {Name: "secret-c"}}, + }, + { + name: "all duplicates", + existing: []corev1.LocalObjectReference{{Name: "secret-a"}, {Name: "secret-b"}}, + additional: []corev1.LocalObjectReference{{Name: "secret-a"}, {Name: "secret-b"}}, + expected: []corev1.LocalObjectReference{{Name: "secret-a"}, {Name: "secret-b"}}, + }, + { + name: "some duplicates", + existing: []corev1.LocalObjectReference{{Name: "secret-a"}, {Name: "secret-b"}}, + additional: []corev1.LocalObjectReference{{Name: "secret-b"}, {Name: "secret-c"}}, + expected: []corev1.LocalObjectReference{{Name: "secret-a"}, {Name: "secret-b"}, {Name: "secret-c"}}, + }, + { + name: "duplicates within additional", + existing: []corev1.LocalObjectReference{{Name: "secret-a"}}, + additional: []corev1.LocalObjectReference{{Name: "secret-b"}, {Name: "secret-b"}, {Name: "secret-c"}}, + expected: []corev1.LocalObjectReference{{Name: "secret-a"}, {Name: "secret-b"}, {Name: "secret-c"}}, + }, + { + name: "nil existing", + existing: nil, + additional: []corev1.LocalObjectReference{{Name: "secret-a"}}, + expected: []corev1.LocalObjectReference{{Name: "secret-a"}}, + }, + { + name: "nil additional", + existing: []corev1.LocalObjectReference{{Name: "secret-a"}}, + additional: nil, + expected: []corev1.LocalObjectReference{{Name: "secret-a"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := gomega.NewGomegaWithT(t) + result := AppendUniqueImagePullSecrets(tt.existing, tt.additional) + g.Expect(result).To(gomega.Equal(tt.expected)) + }) + } +} diff --git a/deploy/cloud/operator/internal/dynamo/component_worker.go b/deploy/cloud/operator/internal/dynamo/component_worker.go index f0ad70eba0..a80ae8e256 100644 --- a/deploy/cloud/operator/internal/dynamo/component_worker.go +++ b/deploy/cloud/operator/internal/dynamo/component_worker.go @@ -86,6 +86,10 @@ func (w *WorkerDefaults) GetBaseContainer(context ComponentContext) (corev1.Cont Name: "DYN_SYSTEM_PORT", Value: fmt.Sprintf("%d", commonconsts.DynamoSystemPort), }, + { + Name: "DYN_HEALTH_CHECK_ENABLED", + Value: "true", + }, }...) return container, nil diff --git a/deploy/cloud/operator/internal/dynamo/graph.go b/deploy/cloud/operator/internal/dynamo/graph.go index 706dcec234..2feaa90d6f 100644 --- a/deploy/cloud/operator/internal/dynamo/graph.go +++ b/deploy/cloud/operator/internal/dynamo/graph.go @@ -903,10 +903,10 @@ func GenerateBasePodSpec( podSpec.Containers = append(podSpec.Containers, container) podSpec.Volumes = append(podSpec.Volumes, volumes...) - podSpec.ImagePullSecrets = append(podSpec.ImagePullSecrets, imagePullSecrets...) + podSpec.ImagePullSecrets = controller_common.AppendUniqueImagePullSecrets(podSpec.ImagePullSecrets, imagePullSecrets) backend.UpdatePodSpec(&podSpec, numberOfNodes, role, component, serviceName) - return controller_common.CanonicalizePodSpec(&podSpec), nil + return &podSpec, nil } func setMetricsLabels(labels map[string]string, dynamoGraphDeployment *v1alpha1.DynamoGraphDeployment) { @@ -1034,7 +1034,7 @@ func GenerateGrovePodCliqueSet( PodSpec: *podSpec, }, } - labels, err := generateLabels(component, dynamoDeployment, r.Name) + labels, err := generateLabels(component, dynamoDeployment, serviceName) if err != nil { return nil, fmt.Errorf("failed to generate labels: %w", err) } @@ -1068,13 +1068,14 @@ func GenerateGrovePodCliqueSet( gangSet.Spec.Template.PodCliqueScalingGroupConfigs = scalingGroups } - return controller_common.CanonicalizePodCliqueSet(gangSet), nil + return gangSet, nil } func generateLabels(component *v1alpha1.DynamoComponentDeploymentSharedSpec, dynamoDeployment *v1alpha1.DynamoGraphDeployment, componentName string) (map[string]string, error) { labels := make(map[string]string) labels[commonconsts.KubeLabelDynamoSelector] = GetDynamoComponentName(dynamoDeployment, componentName) labels[commonconsts.KubeLabelDynamoGraphDeploymentName] = dynamoDeployment.Name + labels[commonconsts.KubeLabelDynamoComponent] = componentName if component.DynamoNamespace != nil { labels[commonconsts.KubeLabelDynamoNamespace] = *component.DynamoNamespace } diff --git a/deploy/cloud/operator/internal/dynamo/graph_test.go b/deploy/cloud/operator/internal/dynamo/graph_test.go index d93a60459b..9218ba88ec 100644 --- a/deploy/cloud/operator/internal/dynamo/graph_test.go +++ b/deploy/cloud/operator/internal/dynamo/graph_test.go @@ -121,7 +121,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { commonconsts.KubeLabelDynamoNamespace: "default-test-dynamographdeployment", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamographdeployment", }, - Autoscaling: nil, }, }, }, @@ -153,7 +152,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { Custom: map[string]string{}, }, }, - Autoscaling: nil, }, }, }, @@ -229,7 +227,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { commonconsts.KubeLabelDynamoNamespace: "default-test-dynamographdeployment", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamographdeployment", }, - Autoscaling: nil, }, }, }, @@ -261,7 +258,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { Custom: map[string]string{}, }, }, - Autoscaling: nil, }, }, }, @@ -341,7 +337,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { commonconsts.KubeLabelDynamoNamespace: "default-test-dynamographdeployment", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamographdeployment", }, - Autoscaling: nil, Ingress: &v1alpha1.IngressSpec{ Enabled: true, Host: "test-dynamographdeployment", @@ -377,7 +372,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { Custom: map[string]string{}, }, }, - Autoscaling: nil, }, }, }, @@ -465,7 +459,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { commonconsts.KubeLabelDynamoNamespace: "default-test-dynamographdeployment", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamographdeployment", }, - Autoscaling: nil, Envs: []corev1.EnvVar{ { Name: "DYN_DEPLOYMENT_CONFIG", @@ -503,7 +496,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { Custom: map[string]string{}, }, }, - Autoscaling: nil, Envs: []corev1.EnvVar{ { Name: "DYN_DEPLOYMENT_CONFIG", @@ -599,7 +591,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { commonconsts.KubeLabelDynamoNamespace: "default-test-dynamographdeployment", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamographdeployment", }, - Autoscaling: nil, ExtraPodSpec: &v1alpha1.ExtraPodSpec{ MainContainer: &corev1.Container{ Command: []string{"sh", "-c"}, @@ -644,7 +635,6 @@ func TestGenerateDynamoComponentsDeployments(t *testing.T) { Custom: map[string]string{}, }, }, - Autoscaling: nil, Envs: []corev1.EnvVar{ { Name: "TEST_ENV", @@ -1307,6 +1297,7 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { Name: "frontend", Labels: map[string]string{ commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-frontend", + commonconsts.KubeLabelDynamoComponent: "Frontend", commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue, commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeFrontend, commonconsts.KubeLabelDynamoSubComponentType: "test-sub-component", @@ -1483,6 +1474,7 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { Labels: map[string]string{ commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue, commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-planner", + commonconsts.KubeLabelDynamoComponent: "Planner", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamo-graph-deployment", commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypePlanner, commonconsts.KubeLabelDynamoNamespace: "test-namespace-test-dynamo-graph-deployment", @@ -1884,8 +1876,9 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeWorker, commonconsts.KubeLabelDynamoSubComponentType: "test-sub-component", commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue, - commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker-ldr", + commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamo-graph-deployment", + commonconsts.KubeLabelDynamoComponent: "worker", commonconsts.KubeLabelDynamoNamespace: "test-namespace-test-dynamo-graph-deployment", "nvidia.com/label1": "label1", "nvidia.com/label2": "label2", @@ -1970,6 +1963,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { Name: commonconsts.DynamoComponentEnvVar, Value: commonconsts.ComponentTypeWorker, }, + { + Name: "DYN_HEALTH_CHECK_ENABLED", + Value: "true", + }, { Name: "DYN_PARENT_DGD_K8S_NAME", Value: "test-dynamo-graph-deployment", @@ -2059,8 +2056,9 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeWorker, commonconsts.KubeLabelDynamoSubComponentType: "test-sub-component", commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue, - commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker-wkr", + commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamo-graph-deployment", + commonconsts.KubeLabelDynamoComponent: "worker", commonconsts.KubeLabelDynamoNamespace: "test-namespace-test-dynamo-graph-deployment", "nvidia.com/label1": "label1", "nvidia.com/label2": "label2", @@ -2146,6 +2144,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { Name: commonconsts.DynamoComponentEnvVar, Value: commonconsts.ComponentTypeWorker, }, + { + Name: "DYN_HEALTH_CHECK_ENABLED", + Value: "true", + }, { Name: "DYN_PARENT_DGD_K8S_NAME", Value: "test-dynamo-graph-deployment", @@ -2200,6 +2202,7 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-frontend", commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeFrontend, commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamo-graph-deployment", + commonconsts.KubeLabelDynamoComponent: "Frontend", commonconsts.KubeLabelDynamoNamespace: "test-namespace-test-dynamo-graph-deployment", }, Annotations: map[string]string{}, @@ -2358,6 +2361,7 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { Name: "planner", Labels: map[string]string{ commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-planner", + commonconsts.KubeLabelDynamoComponent: "Planner", commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue, commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamo-graph-deployment", commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypePlanner, @@ -2779,7 +2783,8 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { { Name: "worker-ldr", Labels: map[string]string{ - commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker-ldr", + commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker", + commonconsts.KubeLabelDynamoComponent: "worker", commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue, commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeWorker, commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamo-graph-deployment", @@ -2867,6 +2872,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { Name: commonconsts.DynamoComponentEnvVar, Value: commonconsts.ComponentTypeWorker, }, + { + Name: "DYN_HEALTH_CHECK_ENABLED", + Value: "true", + }, { Name: "DYN_PARENT_DGD_K8S_NAME", Value: "test-dynamo-graph-deployment", @@ -2943,7 +2952,8 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { Labels: map[string]string{ commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypeWorker, commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue, - commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker-wkr", + commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-worker", + commonconsts.KubeLabelDynamoComponent: "worker", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamo-graph-deployment", commonconsts.KubeLabelDynamoNamespace: "test-namespace-test-dynamo-graph-deployment", "nvidia.com/label1": "label1", @@ -3030,6 +3040,10 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { Name: commonconsts.DynamoComponentEnvVar, Value: commonconsts.ComponentTypeWorker, }, + { + Name: "DYN_HEALTH_CHECK_ENABLED", + Value: "true", + }, { Name: "DYN_PARENT_DGD_K8S_NAME", Value: "test-dynamo-graph-deployment", @@ -3084,6 +3098,7 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue, commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-frontend", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamo-graph-deployment", + commonconsts.KubeLabelDynamoComponent: "Frontend", commonconsts.KubeLabelDynamoNamespace: "test-namespace-test-dynamo-graph-deployment", }, Annotations: map[string]string{}, @@ -3243,6 +3258,7 @@ func TestGenerateGrovePodCliqueSet(t *testing.T) { Labels: map[string]string{ commonconsts.KubeLabelMetricsEnabled: commonconsts.KubeLabelValueTrue, commonconsts.KubeLabelDynamoSelector: "test-dynamo-graph-deployment-planner", + commonconsts.KubeLabelDynamoComponent: "Planner", commonconsts.KubeLabelDynamoGraphDeploymentName: "test-dynamo-graph-deployment", commonconsts.KubeLabelDynamoComponentType: commonconsts.ComponentTypePlanner, commonconsts.KubeLabelDynamoNamespace: "test-namespace-test-dynamo-graph-deployment", @@ -4989,6 +5005,7 @@ func TestGenerateBasePodSpec_Worker(t *testing.T) { {Name: "ANOTHER_COMPONENTENV", Value: "true"}, {Name: "ANOTHER_CONTAINER_ENV", Value: "true"}, {Name: commonconsts.DynamoComponentEnvVar, Value: "worker"}, + {Name: "DYN_HEALTH_CHECK_ENABLED", Value: "true"}, {Name: commonconsts.DynamoNamespaceEnvVar, Value: ""}, {Name: "DYN_PARENT_DGD_K8S_NAME", Value: "test-deployment"}, {Name: "DYN_PARENT_DGD_K8S_NAMESPACE", Value: "default"}, diff --git a/deploy/cloud/operator/internal/webhook/common.go b/deploy/cloud/operator/internal/webhook/common.go index 6333738739..c18edd98f4 100644 --- a/deploy/cloud/operator/internal/webhook/common.go +++ b/deploy/cloud/operator/internal/webhook/common.go @@ -19,7 +19,9 @@ package webhook import ( "context" + "strings" + authenticationv1 "k8s.io/api/authentication/v1" "k8s.io/apimachinery/pkg/runtime" "sigs.k8s.io/controller-runtime/pkg/client" logf "sigs.k8s.io/controller-runtime/pkg/log" @@ -118,3 +120,54 @@ func (v *LeaseAwareValidator) shouldSkipValidation(obj runtime.Object) bool { return false } + +// DGDReplicasModifierSuffixes defines suffixes for service accounts that are authorized +// to modify DGD replicas when scaling adapter is enabled. +// Service accounts matching any of these suffixes are allowed regardless of namespace. +var DGDReplicasModifierSuffixes = []string{ + // Dynamo operator controller manager (handles DGDSA reconciliation) + // Example: "dynamo-platform-dynamo-operator-controller-manager" + "-dynamo-operator-controller-manager", + + // Planner service account (manages DGD replicas for autoscaling) + // Example: "planner-serviceaccount" + "planner-serviceaccount", +} + +// CanModifyDGDReplicas checks if the request comes from a service account authorized +// to modify DGD replicas when scaling adapter is enabled. +// Service accounts are identified by username format: system:serviceaccount:: +// +// Authorized service accounts (by suffix): +// - *-dynamo-operator-controller-manager (for DGDSA reconciliation) +// - *planner-serviceaccount (for Planner autoscaling) +func CanModifyDGDReplicas(userInfo authenticationv1.UserInfo) bool { + username := userInfo.Username + + // Service accounts have username format: system:serviceaccount:: + if !strings.HasPrefix(username, "system:serviceaccount:") { + return false + } + + // Parse: system:serviceaccount:: + parts := strings.Split(username, ":") + if len(parts) != 4 { + return false + } + + namespace := parts[2] + saName := parts[3] + + // Check against authorized suffixes + for _, suffix := range DGDReplicasModifierSuffixes { + if strings.HasSuffix(saName, suffix) { + webhookCommonLog.V(1).Info("allowing DGD replicas modification", + "serviceAccount", saName, + "namespace", namespace, + "matchedSuffix", suffix) + return true + } + } + + return false +} diff --git a/deploy/cloud/operator/internal/webhook/validation/dynamocomponentdeployment.go b/deploy/cloud/operator/internal/webhook/validation/dynamocomponentdeployment.go index c77303fde2..c0e0628834 100644 --- a/deploy/cloud/operator/internal/webhook/validation/dynamocomponentdeployment.go +++ b/deploy/cloud/operator/internal/webhook/validation/dynamocomponentdeployment.go @@ -42,13 +42,10 @@ func NewDynamoComponentDeploymentValidator(deployment *nvidiacomv1alpha1.DynamoC func (v *DynamoComponentDeploymentValidator) Validate() (admission.Warnings, error) { // Validate shared spec fields using SharedSpecValidator sharedValidator := NewSharedSpecValidator(&v.deployment.Spec.DynamoComponentDeploymentSharedSpec, "spec") - if err := sharedValidator.Validate(); err != nil { - return nil, err - } // DCD-specific validation would go here (currently none) - return nil, nil + return sharedValidator.Validate() } // ValidateUpdate performs stateful validation comparing old and new DynamoComponentDeployment. diff --git a/deploy/cloud/operator/internal/webhook/validation/dynamocomponentdeployment_test.go b/deploy/cloud/operator/internal/webhook/validation/dynamocomponentdeployment_test.go index 0324856dfd..f38240c8ee 100644 --- a/deploy/cloud/operator/internal/webhook/validation/dynamocomponentdeployment_test.go +++ b/deploy/cloud/operator/internal/webhook/validation/dynamocomponentdeployment_test.go @@ -47,11 +47,6 @@ func TestDynamoComponentDeploymentValidator_Validate(t *testing.T) { Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{ DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ Replicas: &validReplicas, - Autoscaling: &nvidiacomv1alpha1.Autoscaling{ - Enabled: true, - MinReplicas: 1, - MaxReplicas: 10, - }, }, BackendFramework: "sglang", }, @@ -74,26 +69,6 @@ func TestDynamoComponentDeploymentValidator_Validate(t *testing.T) { wantErr: true, errMsg: "spec.replicas must be non-negative", }, - { - name: "invalid autoscaling", - deployment: &nvidiacomv1alpha1.DynamoComponentDeployment{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-deployment", - Namespace: "default", - }, - Spec: nvidiacomv1alpha1.DynamoComponentDeploymentSpec{ - DynamoComponentDeploymentSharedSpec: nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ - Autoscaling: &nvidiacomv1alpha1.Autoscaling{ - Enabled: true, - MinReplicas: 5, - MaxReplicas: 3, - }, - }, - }, - }, - wantErr: true, - errMsg: "spec.autoscaling.maxReplicas must be > minReplicas", - }, { name: "invalid ingress", deployment: &nvidiacomv1alpha1.DynamoComponentDeployment{ diff --git a/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment.go b/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment.go index e6bf9e3893..00a1668806 100644 --- a/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment.go +++ b/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment.go @@ -22,6 +22,8 @@ import ( "fmt" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" + internalwebhook "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/webhook" + authenticationv1 "k8s.io/api/authentication/v1" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" ) @@ -51,30 +53,106 @@ func (v *DynamoGraphDeploymentValidator) Validate() (admission.Warnings, error) return nil, err } + var allWarnings admission.Warnings + // Validate each service for serviceName, service := range v.deployment.Spec.Services { - if err := v.validateService(serviceName, service); err != nil { + warnings, err := v.validateService(serviceName, service) + if err != nil { return nil, err } + allWarnings = append(allWarnings, warnings...) } - return nil, nil + return allWarnings, nil } // ValidateUpdate performs stateful validation comparing old and new DynamoGraphDeployment. +// userInfo is used for identity-based validation (replica protection). +// If userInfo is nil, replica changes for DGDSA-enabled services are rejected (fail closed). // Returns warnings and error. -func (v *DynamoGraphDeploymentValidator) ValidateUpdate(old *nvidiacomv1alpha1.DynamoGraphDeployment) (admission.Warnings, error) { - // Validate that BackendFramework is not changed (immutable) +func (v *DynamoGraphDeploymentValidator) ValidateUpdate(old *nvidiacomv1alpha1.DynamoGraphDeployment, userInfo *authenticationv1.UserInfo) (admission.Warnings, error) { + var warnings admission.Warnings + + // Validate immutable fields + if err := v.validateImmutableFields(old, &warnings); err != nil { + return warnings, err + } + + // Validate replicas changes for services with scaling adapter enabled + // Pass userInfo (may be nil - will fail closed for DGDSA-enabled services) + if err := v.validateReplicasChanges(old, userInfo); err != nil { + return warnings, err + } + + return warnings, nil +} + +// validateImmutableFields checks that immutable fields have not been changed. +// Appends warnings to the provided slice. +func (v *DynamoGraphDeploymentValidator) validateImmutableFields(old *nvidiacomv1alpha1.DynamoGraphDeployment, warnings *admission.Warnings) error { if v.deployment.Spec.BackendFramework != old.Spec.BackendFramework { - warning := "Changing spec.backendFramework may cause unexpected behavior" - return admission.Warnings{warning}, fmt.Errorf("spec.backendFramework is immutable and cannot be changed after creation") + *warnings = append(*warnings, "Changing spec.backendFramework may cause unexpected behavior") + return fmt.Errorf("spec.backendFramework is immutable and cannot be changed after creation") } + return nil +} - return nil, nil +// validateReplicasChanges checks if replicas were changed for services with scaling adapter enabled. +// Only authorized service accounts (operator controller, planner) can modify these fields. +// If userInfo is nil, all replica changes for DGDSA-enabled services are rejected (fail closed). +func (v *DynamoGraphDeploymentValidator) validateReplicasChanges(old *nvidiacomv1alpha1.DynamoGraphDeployment, userInfo *authenticationv1.UserInfo) error { + // If the request comes from an authorized service account, allow the change + if userInfo != nil && internalwebhook.CanModifyDGDReplicas(*userInfo) { + return nil + } + + var errs []error + + for serviceName, newService := range v.deployment.Spec.Services { + // Check if scaling adapter is enabled for this service (enabled by default) + scalingAdapterEnabled := true + if newService.ScalingAdapter != nil && newService.ScalingAdapter.Disable { + scalingAdapterEnabled = false + } + + if !scalingAdapterEnabled { + // Scaling adapter is disabled, users can modify replicas directly + continue + } + + // Get old service (if exists) + oldService, exists := old.Spec.Services[serviceName] + if !exists { + // New service, no comparison needed + continue + } + + // Check if replicas changed + oldReplicas := int32(1) // default + if oldService.Replicas != nil { + oldReplicas = *oldService.Replicas + } + + newReplicas := int32(1) // default + if newService.Replicas != nil { + newReplicas = *newService.Replicas + } + + if oldReplicas != newReplicas { + errs = append(errs, fmt.Errorf( + "spec.services[%s].replicas cannot be modified directly when scaling adapter is enabled; "+ + "scale or update the related DynamoGraphDeploymentScalingAdapter instead", + serviceName)) + } + } + + return errors.Join(errs...) } // validateService validates a single service configuration using SharedSpecValidator. -func (v *DynamoGraphDeploymentValidator) validateService(serviceName string, service *nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec) error { +// Returns warnings and error. +func (v *DynamoGraphDeploymentValidator) validateService(serviceName string, service *nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec) (admission.Warnings, error) { // Use SharedSpecValidator to validate service spec (which is a DynamoComponentDeploymentSharedSpec) fieldPath := fmt.Sprintf("spec.services[%s]", serviceName) sharedValidator := NewSharedSpecValidator(service, fieldPath) diff --git a/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment_handler.go b/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment_handler.go index 074a4c5cc2..e98bd03442 100644 --- a/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment_handler.go +++ b/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment_handler.go @@ -23,6 +23,7 @@ import ( nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" internalwebhook "github.com/ai-dynamo/dynamo/deploy/cloud/operator/internal/webhook" + authenticationv1 "k8s.io/api/authentication/v1" "k8s.io/apimachinery/pkg/runtime" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/manager" @@ -91,9 +92,24 @@ func (h *DynamoGraphDeploymentHandler) ValidateUpdate(ctx context.Context, oldOb return warnings, err } - // Validate stateful rules (immutability) - updateWarnings, err := validator.ValidateUpdate(oldDeployment) + // Get user info from admission request context for identity-based validation + var userInfo *authenticationv1.UserInfo + req, err := admission.RequestFromContext(ctx) if err != nil { + logger.Error(err, "failed to get admission request from context, replica changes for DGDSA-enabled services will be rejected") + // userInfo remains nil - validateReplicasChanges will fail closed + } else { + userInfo = &req.UserInfo + } + + // Validate stateful rules (immutability + replicas protection) + updateWarnings, err := validator.ValidateUpdate(oldDeployment, userInfo) + if err != nil { + username := "" + if userInfo != nil { + username = userInfo.Username + } + logger.Info("validation failed", "error", err.Error(), "user", username) return updateWarnings, err } diff --git a/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment_test.go b/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment_test.go index 75c18dd33f..71228327b6 100644 --- a/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment_test.go +++ b/deploy/cloud/operator/internal/webhook/validation/dynamographdeployment_test.go @@ -93,28 +93,6 @@ func TestDynamoGraphDeploymentValidator_Validate(t *testing.T) { wantErr: true, errMsg: "spec.services[main].replicas must be non-negative", }, - { - name: "service with invalid autoscaling", - deployment: &nvidiacomv1alpha1.DynamoGraphDeployment{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-graph", - Namespace: "default", - }, - Spec: nvidiacomv1alpha1.DynamoGraphDeploymentSpec{ - Services: map[string]*nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ - "prefill": { - Autoscaling: &nvidiacomv1alpha1.Autoscaling{ - Enabled: true, - MinReplicas: 10, - MaxReplicas: 5, - }, - }, - }, - }, - }, - wantErr: true, - errMsg: "spec.services[prefill].autoscaling.maxReplicas must be > minReplicas", - }, { name: "service with invalid ingress", deployment: &nvidiacomv1alpha1.DynamoGraphDeployment{ @@ -441,7 +419,8 @@ func TestDynamoGraphDeploymentValidator_ValidateUpdate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { validator := NewDynamoGraphDeploymentValidator(tt.newDeployment) - warnings, err := validator.ValidateUpdate(tt.oldDeployment) + // Pass nil userInfo - these tests don't modify replicas, so it's safe + warnings, err := validator.ValidateUpdate(tt.oldDeployment, nil) if (err != nil) != tt.wantErr { t.Errorf("DynamoGraphDeploymentValidator.ValidateUpdate() error = %v, wantErr %v", err, tt.wantErr) diff --git a/deploy/cloud/operator/internal/webhook/validation/shared.go b/deploy/cloud/operator/internal/webhook/validation/shared.go index 5348193f3f..30edb0500d 100644 --- a/deploy/cloud/operator/internal/webhook/validation/shared.go +++ b/deploy/cloud/operator/internal/webhook/validation/shared.go @@ -21,6 +21,7 @@ import ( "fmt" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/cloud/operator/api/v1alpha1" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" ) // SharedSpecValidator validates DynamoComponentDeploymentSharedSpec fields. @@ -41,61 +42,45 @@ func NewSharedSpecValidator(spec *nvidiacomv1alpha1.DynamoComponentDeploymentSha } // Validate performs validation on the shared spec fields. -// Returns an error if validation fails. -func (v *SharedSpecValidator) Validate() error { +// Returns warnings (e.g., deprecation notices) and error if validation fails. +func (v *SharedSpecValidator) Validate() (admission.Warnings, error) { // Validate replicas if specified if v.spec.Replicas != nil && *v.spec.Replicas < 0 { - return fmt.Errorf("%s.replicas must be non-negative", v.fieldPath) - } - - // Validate autoscaling configuration if specified - if v.spec.Autoscaling != nil { - if err := v.validateAutoscaling(); err != nil { - return err - } + return nil, fmt.Errorf("%s.replicas must be non-negative", v.fieldPath) } // Validate ingress configuration if enabled if v.spec.Ingress != nil && v.spec.Ingress.Enabled { if err := v.validateIngress(); err != nil { - return err + return nil, err } } // Validate volume mounts if err := v.validateVolumeMounts(); err != nil { - return err + return nil, err } // Validate shared memory if v.spec.SharedMemory != nil { if err := v.validateSharedMemory(); err != nil { - return err + return nil, err } } - return nil -} - -// validateAutoscaling validates the autoscaling configuration. -func (v *SharedSpecValidator) validateAutoscaling() error { - autoscaling := v.spec.Autoscaling - - if !autoscaling.Enabled { - return nil - } - - // Validate minReplicas - if autoscaling.MinReplicas < 1 { - return fmt.Errorf("%s.autoscaling.minReplicas must be >= 1", v.fieldPath) - } + // Collect warnings (e.g., deprecation notices) + var warnings admission.Warnings - // Validate maxReplicas - if autoscaling.MaxReplicas <= autoscaling.MinReplicas { - return fmt.Errorf("%s.autoscaling.maxReplicas must be > minReplicas", v.fieldPath) + // Check for deprecated autoscaling field + //nolint:staticcheck // SA1019: Intentionally checking deprecated field to warn users + if v.spec.Autoscaling != nil { + warnings = append(warnings, fmt.Sprintf( + "%s.autoscaling is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter "+ + "with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md", + v.fieldPath)) } - return nil + return warnings, nil } // validateIngress validates the ingress configuration. diff --git a/deploy/cloud/operator/internal/webhook/validation/shared_test.go b/deploy/cloud/operator/internal/webhook/validation/shared_test.go index 472bb7d990..b7a2687cbd 100644 --- a/deploy/cloud/operator/internal/webhook/validation/shared_test.go +++ b/deploy/cloud/operator/internal/webhook/validation/shared_test.go @@ -41,11 +41,6 @@ func TestSharedSpecValidator_Validate(t *testing.T) { name: "valid spec with all fields", spec: &nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ Replicas: &validReplicas, - Autoscaling: &nvidiacomv1alpha1.Autoscaling{ - Enabled: true, - MinReplicas: 1, - MaxReplicas: 10, - }, Ingress: &nvidiacomv1alpha1.IngressSpec{ Enabled: true, Host: "example.com", @@ -77,44 +72,6 @@ func TestSharedSpecValidator_Validate(t *testing.T) { wantErr: true, errMsg: "spec.replicas must be non-negative", }, - { - name: "autoscaling minReplicas too low", - spec: &nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ - Autoscaling: &nvidiacomv1alpha1.Autoscaling{ - Enabled: true, - MinReplicas: 0, - MaxReplicas: 10, - }, - }, - fieldPath: "spec", - wantErr: true, - errMsg: "spec.autoscaling.minReplicas must be >= 1", - }, - { - name: "autoscaling maxReplicas less than minReplicas", - spec: &nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ - Autoscaling: &nvidiacomv1alpha1.Autoscaling{ - Enabled: true, - MinReplicas: 5, - MaxReplicas: 3, - }, - }, - fieldPath: "spec", - wantErr: true, - errMsg: "spec.autoscaling.maxReplicas must be > minReplicas", - }, - { - name: "autoscaling disabled - no validation", - spec: &nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ - Autoscaling: &nvidiacomv1alpha1.Autoscaling{ - Enabled: false, - MinReplicas: 0, - MaxReplicas: 0, - }, - }, - fieldPath: "spec", - wantErr: false, - }, { name: "ingress enabled without host", spec: &nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ @@ -227,7 +184,7 @@ func TestSharedSpecValidator_Validate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { validator := NewSharedSpecValidator(tt.spec, tt.fieldPath) - err := validator.Validate() + _, err := validator.Validate() if (err != nil) != tt.wantErr { t.Errorf("SharedSpecValidator.Validate() error = %v, wantErr %v", err, tt.wantErr) @@ -240,3 +197,53 @@ func TestSharedSpecValidator_Validate(t *testing.T) { }) } } + +func TestSharedSpecValidator_Validate_Warnings(t *testing.T) { + validReplicas := int32(3) + + tests := []struct { + name string + spec *nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec + fieldPath string + wantWarnings int + }{ + { + name: "no warnings for spec without autoscaling", + spec: &nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ + Replicas: &validReplicas, + }, + fieldPath: "spec", + wantWarnings: 0, + }, + { + name: "warning for deprecated autoscaling field enabled", + spec: &nvidiacomv1alpha1.DynamoComponentDeploymentSharedSpec{ + Replicas: &validReplicas, + //nolint:staticcheck // SA1019: Intentionally testing deprecated field + Autoscaling: &nvidiacomv1alpha1.Autoscaling{ + Enabled: true, + MinReplicas: 1, + MaxReplicas: 10, + }, + }, + fieldPath: "spec", + wantWarnings: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewSharedSpecValidator(tt.spec, tt.fieldPath) + warnings, err := validator.Validate() + + if err != nil { + t.Errorf("SharedSpecValidator.Validate() unexpected error = %v", err) + return + } + + if len(warnings) != tt.wantWarnings { + t.Errorf("SharedSpecValidator.Validate() warnings count = %d, want %d", len(warnings), tt.wantWarnings) + } + }) + } +} diff --git a/deploy/observability/k8s/grafana-planner-dashboard-configmap.yaml b/deploy/observability/k8s/grafana-planner-dashboard-configmap.yaml new file mode 100644 index 0000000000..c536514332 --- /dev/null +++ b/deploy/observability/k8s/grafana-planner-dashboard-configmap.yaml @@ -0,0 +1,1525 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +apiVersion: v1 +kind: ConfigMap +metadata: + name: grafana-planner-dashboard + namespace: monitoring + labels: + grafana_dashboard: "1" +data: + planner-dashboard.json: |- + { + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "description": "Dynamo Planner metrics dashboard - Worker counts, observed/predicted metrics, and correction factors", + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 1, + "id": null, + "links": [], + "panels": [ + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 100, + "panels": [], + "title": "๐Ÿ–ฅ๏ธ Worker Counts & GPU Usage", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Current number of prefill workers", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "#6E40AA", + "value": null + } + ] + }, + "unit": "none" + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 4, + "x": 0, + "y": 1 + }, + "id": 1, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": ["lastNotNull"], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:num_p_workers{namespace=~\"$namespace\"}", + "legendFormat": "Prefill Workers", + "range": true, + "refId": "A" + } + ], + "title": "Prefill Workers", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Current number of decode workers", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "#1FA8C9", + "value": null + } + ] + }, + "unit": "none" + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 4, + "x": 4, + "y": 1 + }, + "id": 2, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": ["lastNotNull"], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:num_d_workers{namespace=~\"$namespace\"}", + "legendFormat": "Decode Workers", + "range": true, + "refId": "A" + } + ], + "title": "Decode Workers", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Cumulative GPU hours used since planner start", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 2, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "#76B900", + "value": null + } + ] + }, + "unit": "h" + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 4, + "x": 8, + "y": 1 + }, + "id": 3, + "options": { + "colorMode": "value", + "graphMode": "area", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": ["lastNotNull"], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:gpu_hours{namespace=~\"$namespace\"}", + "legendFormat": "GPU Hours", + "range": true, + "refId": "A" + } + ], + "title": "Cumulative GPU Hours", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Worker count history over time", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Workers", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 20, + "gradientMode": "opacity", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "stepAfter", + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 0, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Prefill Workers" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#6E40AA", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Decode Workers" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#1FA8C9", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 5, + "w": 12, + "x": 12, + "y": 1 + }, + "id": 4, + "options": { + "legend": { + "calcs": ["lastNotNull", "mean"], + "displayMode": "table", + "placement": "right", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:num_p_workers{namespace=~\"$namespace\"}", + "legendFormat": "Prefill Workers", + "range": true, + "refId": "A" + }, + { + "editorMode": "code", + "expr": "planner:num_d_workers{namespace=~\"$namespace\"}", + "legendFormat": "Decode Workers", + "range": true, + "refId": "B" + } + ], + "title": "Worker Count History", + "type": "timeseries" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 6 + }, + "id": 101, + "panels": [], + "title": "๐Ÿ“Š Observed Metrics", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Observed time to first token and inter-token latency", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Latency (ms)", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "opacity", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "ms" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "TTFT" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#FF6B6B", + "mode": "fixed" + } + }, + { + "id": "custom.axisPlacement", + "value": "left" + }, + { + "id": "custom.axisLabel", + "value": "TTFT (ms)" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "ITL" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#4ECDC4", + "mode": "fixed" + } + }, + { + "id": "custom.axisPlacement", + "value": "right" + }, + { + "id": "custom.axisLabel", + "value": "ITL (ms)" + } + ] + } + ] + }, + "gridPos": { + "h": 7, + "w": 8, + "x": 0, + "y": 7 + }, + "id": 10, + "options": { + "legend": { + "calcs": ["lastNotNull", "mean", "max"], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:observed_ttft{namespace=~\"$namespace\"}", + "legendFormat": "TTFT", + "range": true, + "refId": "A" + }, + { + "editorMode": "code", + "expr": "planner:observed_itl{namespace=~\"$namespace\"}", + "legendFormat": "ITL", + "range": true, + "refId": "B" + } + ], + "title": "Observed Latency (TTFT & ITL)", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Observed request rate and duration", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "opacity", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Request Rate" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#F9A825", + "mode": "fixed" + } + }, + { + "id": "unit", + "value": "reqps" + }, + { + "id": "custom.axisPlacement", + "value": "left" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Request Duration" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#AB47BC", + "mode": "fixed" + } + }, + { + "id": "unit", + "value": "s" + }, + { + "id": "custom.axisPlacement", + "value": "right" + } + ] + } + ] + }, + "gridPos": { + "h": 7, + "w": 8, + "x": 8, + "y": 7 + }, + "id": 11, + "options": { + "legend": { + "calcs": ["lastNotNull", "mean", "max"], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:observed_request_rate{namespace=~\"$namespace\"}", + "legendFormat": "Request Rate", + "range": true, + "refId": "A" + }, + { + "editorMode": "code", + "expr": "planner:observed_request_duration{namespace=~\"$namespace\"}", + "legendFormat": "Request Duration", + "range": true, + "refId": "B" + } + ], + "title": "Observed Request Rate & Duration", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Observed input and output sequence lengths", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Tokens", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "opacity", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "ISL" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#26A69A", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "OSL" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#5C6BC0", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 7, + "w": 8, + "x": 16, + "y": 7 + }, + "id": 12, + "options": { + "legend": { + "calcs": ["lastNotNull", "mean", "max"], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:observed_isl{namespace=~\"$namespace\"}", + "legendFormat": "ISL", + "range": true, + "refId": "A" + }, + { + "editorMode": "code", + "expr": "planner:observed_osl{namespace=~\"$namespace\"}", + "legendFormat": "OSL", + "range": true, + "refId": "B" + } + ], + "title": "Observed Sequence Lengths (ISL & OSL)", + "type": "timeseries" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 14 + }, + "id": 102, + "panels": [], + "title": "๐Ÿ”ฎ Predicted Metrics", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Predicted request rate", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Request Rate", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "opacity", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineStyle": { + "dash": [10, 10], + "fill": "dash" + }, + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "reqps" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Predicted Request Rate" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#FFB74D", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 7, + "w": 8, + "x": 0, + "y": 15 + }, + "id": 20, + "options": { + "legend": { + "calcs": ["lastNotNull", "mean"], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:predicted_request_rate{namespace=~\"$namespace\"}", + "legendFormat": "Predicted Request Rate", + "range": true, + "refId": "A" + } + ], + "title": "Predicted Request Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Predicted input and output sequence lengths", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Tokens", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "opacity", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineStyle": { + "dash": [10, 10], + "fill": "dash" + }, + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Predicted ISL" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#80CBC4", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Predicted OSL" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#9FA8DA", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 7, + "w": 8, + "x": 8, + "y": 15 + }, + "id": 22, + "options": { + "legend": { + "calcs": ["lastNotNull", "mean"], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:predicted_isl{namespace=~\"$namespace\"}", + "legendFormat": "Predicted ISL", + "range": true, + "refId": "A" + }, + { + "editorMode": "code", + "expr": "planner:predicted_osl{namespace=~\"$namespace\"}", + "legendFormat": "Predicted OSL", + "range": true, + "refId": "B" + } + ], + "title": "Predicted Sequence Lengths (ISL & OSL)", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Predicted number of prefill and decode replicas", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Replicas", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 20, + "gradientMode": "opacity", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "stepAfter", + "lineStyle": { + "dash": [10, 10], + "fill": "dash" + }, + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "never", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "decimals": 0, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Predicted Prefill" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#B388FF", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Predicted Decode" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#64B5F6", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 7, + "w": 8, + "x": 16, + "y": 15 + }, + "id": 21, + "options": { + "legend": { + "calcs": ["lastNotNull", "mean", "max"], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:predicted_num_p{namespace=~\"$namespace\"}", + "legendFormat": "Predicted Prefill", + "range": true, + "refId": "A" + }, + { + "editorMode": "code", + "expr": "planner:predicted_num_d{namespace=~\"$namespace\"}", + "legendFormat": "Predicted Decode", + "range": true, + "refId": "B" + } + ], + "title": "Predicted Replica Counts", + "type": "timeseries" + }, + { + "collapsed": false, + "gridPos": { + "h": 1, + "w": 24, + "x": 0, + "y": 22 + }, + "id": 103, + "panels": [], + "title": "โš™๏ธ Correction Factors", + "type": "row" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Current prefill correction factor (TTFT observed / TTFT expected)", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 3, + "mappings": [], + "max": 2, + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "#73BF69", + "value": null + }, + { + "color": "#FF9830", + "value": 1.2 + }, + { + "color": "#F2495C", + "value": 1.5 + } + ] + }, + "unit": "none" + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 4, + "x": 0, + "y": 23 + }, + "id": 30, + "options": { + "minVizHeight": 75, + "minVizWidth": 75, + "orientation": "auto", + "reduceOptions": { + "calcs": ["lastNotNull"], + "fields": "", + "values": false + }, + "showThresholdLabels": false, + "showThresholdMarkers": true, + "sizing": "auto" + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:p_correction_factor{namespace=~\"$namespace\"}", + "legendFormat": "Prefill CF", + "range": true, + "refId": "A" + } + ], + "title": "Prefill Correction Factor", + "type": "gauge" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Current decode correction factor (ITL observed / ITL expected)", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "decimals": 3, + "mappings": [], + "max": 2, + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "#73BF69", + "value": null + }, + { + "color": "#FF9830", + "value": 1.2 + }, + { + "color": "#F2495C", + "value": 1.5 + } + ] + }, + "unit": "none" + }, + "overrides": [] + }, + "gridPos": { + "h": 5, + "w": 4, + "x": 4, + "y": 23 + }, + "id": 31, + "options": { + "minVizHeight": 75, + "minVizWidth": 75, + "orientation": "auto", + "reduceOptions": { + "calcs": ["lastNotNull"], + "fields": "", + "values": false + }, + "showThresholdLabels": false, + "showThresholdMarkers": true, + "sizing": "auto" + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:d_correction_factor{namespace=~\"$namespace\"}", + "legendFormat": "Decode CF", + "range": true, + "refId": "A" + } + ], + "title": "Decode Correction Factor", + "type": "gauge" + }, + { + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "description": "Correction factor history over time. Values close to 1.0 indicate accurate predictions.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "Factor", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 10, + "gradientMode": "opacity", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "smooth", + "lineWidth": 2, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "line+area" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "transparent", + "value": null + }, + { + "color": "rgba(255, 152, 48, 0.1)", + "value": 1.2 + }, + { + "color": "rgba(242, 73, 92, 0.1)", + "value": 1.5 + } + ] + }, + "unit": "none" + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Prefill CF" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#CE93D8", + "mode": "fixed" + } + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Decode CF" + }, + "properties": [ + { + "id": "color", + "value": { + "fixedColor": "#81D4FA", + "mode": "fixed" + } + } + ] + } + ] + }, + "gridPos": { + "h": 5, + "w": 16, + "x": 8, + "y": 23 + }, + "id": 32, + "options": { + "legend": { + "calcs": ["lastNotNull", "mean", "min", "max"], + "displayMode": "table", + "placement": "right", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "none" + } + }, + "pluginVersion": "12.0.1", + "targets": [ + { + "editorMode": "code", + "expr": "planner:p_correction_factor{namespace=~\"$namespace\"}", + "legendFormat": "Prefill CF", + "range": true, + "refId": "A" + }, + { + "editorMode": "code", + "expr": "planner:d_correction_factor{namespace=~\"$namespace\"}", + "legendFormat": "Decode CF", + "range": true, + "refId": "B" + } + ], + "title": "Correction Factor History", + "type": "timeseries" + } + ], + "refresh": "", + "schemaVersion": 41, + "tags": ["dynamo", "planner"], + "templating": { + "list": [ + { + "current": { + "text": "default", + "value": "default" + }, + "label": "Data source", + "name": "datasource", + "options": [], + "query": "prometheus", + "refresh": 1, + "regex": "", + "type": "datasource" + }, + { + "current": { + "selected": true, + "text": ["All"], + "value": ["$__all"] + }, + "datasource": { + "type": "prometheus", + "uid": "${datasource}" + }, + "definition": "label_values(planner:num_p_workers, namespace)", + "hide": 0, + "includeAll": true, + "label": "Namespace", + "multi": true, + "name": "namespace", + "options": [], + "query": "label_values(planner:num_p_workers, namespace)", + "refresh": 2, + "regex": "", + "skipUrlSync": false, + "sort": 1, + "type": "query" + } + ] + }, + "time": { + "from": "now-30m", + "to": "now" + }, + "timepicker": { + "refresh_intervals": ["5s", "10s", "30s", "1m", "5m", "15m", "30m", "1h"] + }, + "timezone": "browser", + "title": "Dynamo Planner Dashboard", + "uid": "dynamo-planner-dashboard", + "version": 1 + } + diff --git a/deploy/observability/tempo.yaml b/deploy/observability/tempo.yaml index d5656245ee..a150aca64c 100644 --- a/deploy/observability/tempo.yaml +++ b/deploy/observability/tempo.yaml @@ -9,7 +9,7 @@ distributor: otlp: protocols: grpc: - endpoint: 0.0.0.0:4317 + endpoint: 0.0.0.0:4317 # Receives from OTEL collector http: endpoint: 0.0.0.0:4318 diff --git a/deploy/sanity_check.py b/deploy/sanity_check.py index 165eeab8d5..51fd1fd5f2 100755 --- a/deploy/sanity_check.py +++ b/deploy/sanity_check.py @@ -92,6 +92,7 @@ --thorough-check Enable thorough checking (file permissions, directory sizes, HuggingFace model details) --terse Enable terse output mode (show only essential info and errors) --runtime-check Skip compile-time dependency checks (Rust, Cargo, Maturin) for runtime containers + and validate ai-dynamo packages (ai-dynamo-runtime and ai-dynamo) """ import datetime @@ -299,10 +300,12 @@ def __init__( thorough_check: bool = False, terse: bool = False, runtime_check: bool = False, + no_gpu_check: bool = False, ): self.thorough_check = thorough_check self.terse = terse self.runtime_check = runtime_check + self.no_gpu_check = no_gpu_check if hostname is None: hostname = platform.node() @@ -325,9 +328,10 @@ def __init__( self.add_child(OSInfo()) self.add_child(UserInfo()) - # Add GPU info (always show, even if not found) - gpu_info = GPUInfo() - self.add_child(gpu_info) + # Add GPU info (always show, even if not found) unless --no-gpu-check + if not self.no_gpu_check: + gpu_info = GPUInfo() + self.add_child(gpu_info) # Add Framework info (vllm, sglang, tensorrt_llm) self.add_child(FrameworkInfo()) @@ -359,7 +363,11 @@ def __init__( self._add_error_only_components() # Add Dynamo workspace info (always show, even if not found) - self.add_child(DynamoInfo(thorough_check=self.thorough_check)) + self.add_child( + DynamoInfo( + thorough_check=self.thorough_check, runtime_check=self.runtime_check + ) + ) def _get_ip_address(self) -> Optional[str]: """Get the primary IP address of the system.""" @@ -1094,13 +1102,23 @@ def _check_dynamo_directory_permissions(self): dynamo_root = DynamoInfo.find_workspace() if not dynamo_root: - self.add_child( - NodeInfo( - label="Dynamo workspace", - desc="workspace not found", - status=NodeStatus.ERROR, + # In runtime check mode, workspace not being found is expected + if self.runtime_check: + self.add_child( + NodeInfo( + label="Dynamo workspace", + desc="not needed for runtime container", + status=NodeStatus.INFO, + ) + ) + else: + self.add_child( + NodeInfo( + label="Dynamo workspace", + desc="workspace not found", + status=NodeStatus.ERROR, + ) ) - ) return if not DynamoInfo.is_dynamo_workspace(dynamo_root): @@ -1840,25 +1858,78 @@ def __init__(self): ] frameworks_found = 0 + gpu_dependent_found = 0 for module_name, display_name in frameworks_to_check: - # Regular import for all frameworks - try: - module = __import__(module_name) - version = getattr(module, "__version__", "installed") - frameworks_found += 1 + # First check if module exists without importing (for GPU-dependent modules) + import importlib.metadata + import importlib.util - # Get module path - module_path = None - if hasattr(module, "__file__") and module.__file__: - module_path = self._replace_home_with_var(module.__file__) + spec = importlib.util.find_spec(module_name) + if not spec: + # Module not installed at all + continue - # Get executable path - exec_path = None - exec_path_raw = shutil.which(module_name) + # Module exists, try to get version from metadata (doesn't require import) + version = None + try: + version = importlib.metadata.version(module_name) + except Exception: + # Try alternative package names + alt_names = { + "tensorrt_llm": "tensorrt-llm", + "sglang": "sglang", + "vllm": "vllm", + } + if module_name in alt_names: + try: + version = importlib.metadata.version(alt_names[module_name]) + except Exception: + pass + + # Get module path from spec + module_path = None + if spec.origin: + module_path = self._replace_home_with_var(spec.origin) + + # Get executable path (special handling for each framework) + exec_path = None + exec_names = { + "vllm": "vllm", + "sglang": "sglang", + "tensorrt_llm": "trtllm-build", + } + if module_name in exec_names: + exec_path_raw = shutil.which(exec_names[module_name]) if exec_path_raw: exec_path = self._replace_home_with_var(exec_path_raw) + # Now try to import to get runtime version if needed + gpu_required = False + try: + module = __import__(module_name) + # Get version from module if not already found + if not version: + version = getattr(module, "__version__", "installed") + except ImportError as e: + # Check if it's a GPU-related error + error_msg = str(e).lower() + if "libcuda" in error_msg or "cuda" in error_msg: + gpu_required = True + gpu_dependent_found += 1 + except Exception: + pass + + # If we found the module (either importable or just installed) + if spec: + frameworks_found += 1 + if not version: + version = "installed" + + # Add status indicator to version for GPU-dependent modules + if gpu_required: + version = f"{version} (requires GPU)" + package_info = PythonPackageInfo( package_name=display_name, version=version, @@ -1868,9 +1939,6 @@ def __init__(self): is_installed=True, ) self.add_child(package_info) - except (ImportError, Exception): - # Framework not installed - don't add it - pass # If no frameworks found, set status to ERROR (X) and show what's missing if frameworks_found == 0: @@ -1881,6 +1949,9 @@ def __init__(self): missing_frameworks.append(f"no {module_name}") missing_text = ", ".join(missing_frameworks) self.desc = missing_text + elif gpu_dependent_found > 0: + # At least one framework needs GPU + self.status = NodeStatus.WARNING class PythonPackageInfo(NodeInfo): @@ -1962,8 +2033,14 @@ def __init__(self, pythonpath: str): class DynamoRuntimeInfo(NodeInfo): """Dynamo runtime components information""" - def __init__(self, workspace_dir: str, thorough_check: bool = False): + def __init__( + self, + workspace_dir: Optional[str], + thorough_check: bool = False, + runtime_check: bool = False, + ): self.thorough_check = thorough_check + self.runtime_check = runtime_check # Try to get package version import importlib.metadata @@ -1993,20 +2070,30 @@ def __init__(self, workspace_dir: str, thorough_check: bool = False): if pth_file: self.add_child(pth_file) - # Check for multiple _core*.so files - multiple_so_warning = self._check_multiple_core_so(workspace_dir) - if multiple_so_warning: - self.add_child(multiple_so_warning) + # Check for multiple _core*.so files (only if workspace exists) + if workspace_dir: + multiple_so_warning = self._check_multiple_core_so(workspace_dir) + if multiple_so_warning: + self.add_child(multiple_so_warning) # Discover runtime components from source components = self._discover_runtime_components(workspace_dir) + # For runtime check, always try to import the core modules + if self.runtime_check: + # Force check of essential runtime modules + essential_components = ["dynamo._core", "dynamo.runtime"] + for comp in essential_components: + if comp not in components: + components.append(comp) + # Find where each component actually is and add them if components: # Calculate max width for alignment max_len = max(len(comp) for comp in components) components_found = False + import_failures = [] for component in components: try: # Try to import to find actual location @@ -2041,16 +2128,31 @@ def __init__(self, workspace_dir: str, thorough_check: bool = False): label=padded_name, desc=error_msg, status=NodeStatus.ERROR ) self.add_child(module_node) + import_failures.append(component) # Don't set components_found to True for failed imports # Update status and value based on whether we found components if components_found: - self.status = NodeStatus.OK + # For runtime check, fail if any essential component failed to import + if self.runtime_check and import_failures: + essential_failed = any( + comp in import_failures + for comp in ["dynamo._core", "dynamo.runtime"] + ) + if essential_failed: + self.status = NodeStatus.ERROR + self.desc = "ai-dynamo-runtime - FAILED (essential modules not importable)" + else: + self.status = NodeStatus.OK + else: + self.status = NodeStatus.OK # If not installed but components work via PYTHONPATH, update the message - if not is_installed: + if not is_installed and self.status == NodeStatus.OK: self.desc = "ai-dynamo-runtime (via PYTHONPATH)" else: self.status = NodeStatus.ERROR + if self.runtime_check: + self.desc = "ai-dynamo-runtime - FAILED (no components found)" else: # No components discovered at all self.status = NodeStatus.ERROR @@ -2102,7 +2204,7 @@ def _check_multiple_core_so(self, workspace_dir: str) -> Optional[NodeInfo]: return None - def _discover_runtime_components(self, workspace_dir: str) -> list: + def _discover_runtime_components(self, workspace_dir: Optional[str]) -> list: """Discover ai-dynamo-runtime components from filesystem. Returns: @@ -2195,8 +2297,14 @@ def _find_pth_file(self) -> Optional[NodeInfo]: class DynamoFrameworkInfo(NodeInfo): """Dynamo framework components information""" - def __init__(self, workspace_dir: str, thorough_check: bool = False): + def __init__( + self, + workspace_dir: Optional[str], + thorough_check: bool = False, + runtime_check: bool = False, + ): self.thorough_check = thorough_check + self.runtime_check = runtime_check # Try to get package version import importlib.metadata @@ -2248,6 +2356,16 @@ def __init__(self, workspace_dir: str, thorough_check: bool = False): # Discover framework components from source components = self._discover_framework_components(workspace_dir) + # For runtime check, always try to import at least one framework component + if self.runtime_check and not components: + # Try common framework components even if not discovered + components = [ + "dynamo.frontend", + "dynamo.vllm", + "dynamo.sglang", + "dynamo.trtllm", + ] + # Find where each component actually is and add them if components: # Sort components for consistent output @@ -2257,6 +2375,7 @@ def __init__(self, workspace_dir: str, thorough_check: bool = False): max_len = max(len(comp) for comp in components) components_found = False + import_failures = [] for component in components: try: # Try to import to find actual location @@ -2281,21 +2400,29 @@ def __init__(self, workspace_dir: str, thorough_check: bool = False): label=padded_name, desc=error_msg, status=NodeStatus.ERROR ) self.add_child(component_node) + import_failures.append(component) # Don't set components_found to True for failed imports # Update status and value based on whether we found components if components_found: - self.status = NodeStatus.OK + # For runtime check, we need at least one component to work + if self.runtime_check and len(import_failures) == len(components): + self.status = NodeStatus.ERROR + self.desc = "ai-dynamo - FAILED (no components importable)" + else: + self.status = NodeStatus.OK # If not installed but components work via PYTHONPATH, update the message - if not is_installed: + if not is_installed and self.status == NodeStatus.OK: self.desc = "ai-dynamo (via PYTHONPATH)" else: self.status = NodeStatus.ERROR + if self.runtime_check: + self.desc = "ai-dynamo - FAILED (no components found)" else: # No components discovered at all self.status = NodeStatus.ERROR - def _discover_framework_components(self, workspace_dir: str) -> list: + def _discover_framework_components(self, workspace_dir: Optional[str]) -> list: """Discover ai-dynamo framework components from filesystem. Returns: @@ -2326,12 +2453,37 @@ def _discover_framework_components(self, workspace_dir: str) -> list: class DynamoInfo(NodeInfo): """Dynamo workspace information""" - def __init__(self, thorough_check: bool = False): + def __init__(self, thorough_check: bool = False, runtime_check: bool = False): self.thorough_check = thorough_check + self.runtime_check = runtime_check # Find workspace directory workspace_dir = DynamoInfo.find_workspace() + # For runtime check, we don't need a workspace - just check packages + if self.runtime_check and not workspace_dir: + super().__init__( + label="Dynamo", + desc="Runtime container - checking installed packages", + status=NodeStatus.INFO, + ) + # Check runtime components even without workspace + runtime_info = DynamoRuntimeInfo( + None, + thorough_check=self.thorough_check, + runtime_check=self.runtime_check, + ) + self.add_child(runtime_info) + + # Check framework components even without workspace + framework_info = DynamoFrameworkInfo( + None, + thorough_check=self.thorough_check, + runtime_check=self.runtime_check, + ) + self.add_child(framework_info) + return + if not workspace_dir: # Show error when workspace is not found super().__init__( @@ -2368,13 +2520,17 @@ def __init__(self, thorough_check: bool = False): # Always add runtime components runtime_info = DynamoRuntimeInfo( - workspace_dir, thorough_check=self.thorough_check + workspace_dir, + thorough_check=self.thorough_check, + runtime_check=self.runtime_check, ) self.add_child(runtime_info) # Always add framework components framework_info = DynamoFrameworkInfo( - workspace_dir, thorough_check=self.thorough_check + workspace_dir, + thorough_check=self.thorough_check, + runtime_check=self.runtime_check, ) self.add_child(framework_info) @@ -2513,8 +2669,14 @@ def main(): ) parser.add_argument( "--runtime-check", + "--runtime", + action="store_true", + help="Skip compile-time dependency checks (Rust, Cargo, Maturin) for runtime containers and validate ai-dynamo packages", + ) + parser.add_argument( + "--no-gpu-check", action="store_true", - help="Skip compile-time dependency checks (Rust, Cargo, Maturin) for runtime containers", + help="Skip GPU detection and information collection (useful for CI environments without GPU access)", ) args = parser.parse_args() @@ -2527,6 +2689,7 @@ def main(): thorough_check=args.thorough_check, terse=args.terse, runtime_check=args.runtime_check, + no_gpu_check=args.no_gpu_check, ) tree.print_tree() diff --git a/docs/_sections/k8s_deployment.rst b/docs/_sections/k8s_deployment.rst index 81d06513cb..cdd7d2029a 100644 --- a/docs/_sections/k8s_deployment.rst +++ b/docs/_sections/k8s_deployment.rst @@ -10,3 +10,4 @@ Deployment Guide Webhooks <../kubernetes/webhooks> Minikube Setup <../kubernetes/deployment/minikube> Managing Models with DynamoModel <../kubernetes/deployment/dynamomodel-guide> + Autoscaling <../kubernetes/autoscaling> diff --git a/docs/agents/tool-calling.md b/docs/agents/tool-calling.md index 0326d57bf2..dd0d116215 100644 --- a/docs/agents/tool-calling.md +++ b/docs/agents/tool-calling.md @@ -38,6 +38,7 @@ Parser to Model Mapping | deepseek_v3 | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-0528 | | deepseek_v3_1 | deepseek-ai/DeepSeek-V3.1 | | pythonic | meta-llama/Llama-4-* | +| jamba | ai21labs/AI21-Jamba-*-1.5, ai21labs/AI21-Jamba-*-1.6, ai21labs/AI21-Jamba-*-1.7, | ## Examples diff --git a/docs/api/nixl_connect/connector.md b/docs/api/nixl_connect/connector.md index 05db27be03..3c64d99eae 100644 --- a/docs/api/nixl_connect/connector.md +++ b/docs/api/nixl_connect/connector.md @@ -47,7 +47,6 @@ The metadata contains required information (identifiers, keys, etc.) which enabl @async_on_start async def async_init(self): self.connector = dynamo.nixl_connect.Connector() - await self.connector.initialize() ``` > [!Tip] @@ -109,7 +108,7 @@ Use [`.wait_for_completion()`](write_operation.md#wait_for_completion) to block ### `create_readable` ```python -def create_readable( +async def create_readable( self, local_descriptors: Descriptor | list[Descriptor], ) -> ReadableOperation: @@ -130,7 +129,7 @@ Use [`.wait_for_completion()`](readable_operation.md#wait_for_completion) to blo ### `create_writable` ```python -def create_writable( +async def create_writable( self, local_descriptors: Descriptor | list[Descriptor], ) -> WritableOperation: @@ -151,6 +150,15 @@ Use [`.wait_for_completion()`](writable_operation.md#wait_for_completion) to blo ## Properties +### `hostname` + +```python +@property +def hostname(self) -> str: +``` + +Gets the name of the current worker's host. + ### `is_cuda_available` ```python @@ -169,22 +177,6 @@ def name(self) -> str | None: Gets the Dynamo component name used by the connector. -### `namespace` - -```python -@property -def namespace(self) -> str: -``` - -Gets the Dynamo namespace used by the connector. - -### `runtime` - -```python -def runtime(self) -> dynamo.runtime.DistributedRuntime: -``` - -Gets the Dynamo distributed runtime instance associated with the connector. ## Related Classes diff --git a/docs/api/nixl_connect/read_operation.md b/docs/api/nixl_connect/read_operation.md index f01e925498..71b9e22fd9 100644 --- a/docs/api/nixl_connect/read_operation.md +++ b/docs/api/nixl_connect/read_operation.md @@ -38,7 +38,7 @@ therefore the operation should be awaited until completed unless cancellation is ) -> None: descriptor = dynamo.nixl_connect.Descriptor(local_tensor) - with self.connector.begin_read(descriptor, remote_metadata) as read_op: + with await self.connector.begin_read(remote_metadata, descriptor) as read_op: # Wait for the operation to complete writing data from the remote worker to local_tensor. await read_op.wait_for_completion() ``` diff --git a/docs/api/nixl_connect/readable_operation.md b/docs/api/nixl_connect/readable_operation.md index f112c77b3b..1e66a33b57 100644 --- a/docs/api/nixl_connect/readable_operation.md +++ b/docs/api/nixl_connect/readable_operation.md @@ -37,7 +37,7 @@ therefore the operation should be awaited until completed unless cancellation is ) -> None: descriptor = dynamo.nixl_connect.Descriptor(local_tensor) - with self.connector.create_readable(descriptor) as read_op: + with await self.connector.create_readable(descriptor) as read_op: op_metadata = read_op.metadata() # Send the metadata to the remote worker via sideband communication. diff --git a/docs/api/nixl_connect/writable_operation.md b/docs/api/nixl_connect/writable_operation.md index 4d57bb0808..d191f7d733 100644 --- a/docs/api/nixl_connect/writable_operation.md +++ b/docs/api/nixl_connect/writable_operation.md @@ -38,7 +38,7 @@ Cancellation is handled asynchronously. ) -> None: descriptor = dynamo.nixl_connect.Descriptor(local_tensor) - with self.connector.create_writable(descriptor) as write_op: + with await self.connector.create_writable(descriptor) as write_op: op_metadata = write_op.metadata() # Send the metadata to the remote worker via sideband communication. diff --git a/docs/api/nixl_connect/write_operation.md b/docs/api/nixl_connect/write_operation.md index 48c6729417..4740f13987 100644 --- a/docs/api/nixl_connect/write_operation.md +++ b/docs/api/nixl_connect/write_operation.md @@ -39,7 +39,7 @@ Cancellation is handled asynchronously. ) -> None: descriptor = dynamo.nixl_connect.Descriptor(local_tensor) - with self.connector.begin_write(descriptor, remote_metadata) as write_op: + with await self.connector.begin_write(descriptor, remote_metadata) as write_op: # Wait for the operation to complete writing local_tensor to the remote worker. await write_op.wait_for_completion() ``` diff --git a/docs/backends/sglang/multimodal_sglang_guide.md b/docs/backends/sglang/multimodal_sglang_guide.md new file mode 100644 index 0000000000..43848584cb --- /dev/null +++ b/docs/backends/sglang/multimodal_sglang_guide.md @@ -0,0 +1,324 @@ + + +# SGLang Multimodal Guide + +This document provides a comprehensive guide for multimodal inference using SGLang backend in Dynamo. For more details on the multimodal examples, see [Multimodal Examples Documentation](./multimodal_epd.md). + +## Multimodal Support Matrix + +| Modality | Input Format | Aggregated | Disaggregated | Notes | +|----------|--------------|------------|---------------|-------| +| **Image** | HTTP/HTTPS URL | โœ… Yes | โœ… Yes | Vision encoder generates embeddings | +| **Image** | Data URL (Base64) | โŒ No | โŒ No | Not supported | +| **Video** | HTTP/HTTPS URL | โŒ No | โŒ No | Not implemented | +| **Audio** | HTTP/HTTPS URL | โŒ No | โŒ No | Not implemented | + +## Architecture Comparison + +SGLang multimodal supports two deployment patterns: + +```text +AGGREGATED (E->PD): + Client โ†’ Frontend (Rust) โ†’ Processor โ†’ Encoder [NIXL] โ†’ PD Worker โ†’ Response + โ€ข 3 components โ€ข Vision encoder in Python โ€ข NIXL embeddings transfer + +DISAGGREGATED (E->P->D): + Client โ†’ Frontend โ†’ Processor โ†’ Encoder [NIXL] โ†’ Prefill [bootstrap] โ†’ Decode โ†’ Response + โ€ข 4 components โ€ข Vision encoder in Python โ€ข KV cache transfer via bootstrap mechanism +``` + +## Aggregated Mode (E->PD) + +In aggregated mode, encoding happens in a separate worker, but prefill and decode share the same engine. + +### Architecture + +```text +HTTP Frontend (Rust) + โ†“ +Processor (Python - ModelInput.Text - REGISTERED) + โ†“ tokenizes with chat template, extracts image URL +Encode Worker (Python - NOT registered) + โ†“ downloads image, runs vision encoder, generates embeddings, NIXL transfer +PD Worker (Python - NOT registered) + โ†“ receives embeddings via NIXL, prefill + decode +Response โ†’ Processor โ†’ Frontend +``` + +### Components + +| Component | Flag | ModelInput | Registered | Has SGLang Engine? | Purpose | +|-----------|------|-----------|------------|-------------------|---------| +| Processor | `--multimodal-processor` | Text | โœ… Yes | โŒ No | HTTP entry, OpenAIโ†’SGLang conversion | +| Encode Worker | `--multimodal-encode-worker` | N/A | โŒ No | โŒ No | Vision encoder, embeddings generation | +| PD Worker | `--multimodal-worker` | N/A | โŒ No | โœ… Yes | Prefill + Decode with embeddings | + +### Key Characteristics + +- **Vision Encoder in Python**: Encode worker loads vision model (AutoModel) and image processor (AutoImageProcessor) +- **Token Expansion**: Single `<|image_pad|>` token replaced with N tokens based on embedding shape +- **NIXL Transfer**: Embeddings transferred from Encoder โ†’ PD Worker using NIXL +- **No Rust Processing**: All tokenization and image handling happens in Python + +## Disaggregated Mode (E->P->D) + +In disaggregated mode, encoding, prefill, and decode are handled by separate workers using SGLang's bootstrap coordination. + +### Architecture + +```text +HTTP Frontend (Rust) + โ†“ +Processor (Python - ModelInput.Text - REGISTERED) + โ†“ tokenizes with chat template, extracts image URL +Encode Worker (Python - NOT registered) + โ†“ downloads image, runs vision encoder, generates embeddings, NIXL transfer +Prefill Worker (Python - NOT registered) + โ†“ receives embeddings via NIXL, prefill only, returns bootstrap info +Decode Worker (Python - NOT registered) + โ†“ uses bootstrap info, decode only, token generation +Response โ†’ Processor โ†’ Frontend +``` + +### Components + +| Component | Flag | ModelInput | Registered | Has SGLang Engine? | Purpose | +|-----------|------|-----------|------------|-------------------|---------| +| Processor | `--multimodal-processor` | Text | โœ… Yes | โŒ No | HTTP entry, OpenAIโ†’SGLang conversion | +| Encode Worker | `--multimodal-encode-worker` | N/A | โŒ No | โŒ No | Vision encoder, embeddings generation | +| Decode Worker | `--multimodal-worker --serving-mode=decode` | N/A | โŒ No | โœ… Yes | **Entry point for disaggregation**, calls Prefill | +| Prefill Worker | `--multimodal-worker --serving-mode=prefill` | N/A | โŒ No | โœ… Yes | Called by Decode, bootstrap coordination | + +### Bootstrap Coordination + +SGLang disaggregation uses a bootstrap mechanism for P->D coordination: + +**Request Flow (Important):** +```text +Client โ†’ Frontend โ†’ Processor โ†’ Encode โ†’ DECODE Worker โ†’ Prefill Worker + โ†‘ + Entry point for disaggregation! +``` + +**Bootstrap Process:** +1. **Decode Worker** receives request from Encode Worker +2. **Decode Worker** calls Prefill Worker via NATS to request bootstrap info +3. **Prefill Worker** generates `{host, port, room}` and returns immediately +4. **Both workers** connect to same "room" using bootstrap coordinates +5. **SGLang internally** transfers KV cache state via bootstrap connection (not NIXL) + +**Key Difference from vLLM:** +- vLLM: Frontend โ†’ Prefill โ†’ Decode (Prefill is entry point) +- SGLang: Frontend โ†’ Processor โ†’ Encode โ†’ **Decode โ†’ Prefill** (Decode is entry point) + +## ModelInput Types and Registration + +**Only the Processor registers with Dynamo Rust.** + +### Registration Pattern + +```python +# ONLY Processor registers with Dynamo Rust +await register_llm_with_readiness_gate( + None, # No engine for processor + generate_endpoint, + server_args, + dynamo_args, + input_type=ModelInput.Text, # Receives raw OpenAI format + readiness_gate=ready_event, +) + +# Workers do NOT register - they are internal components +# They communicate via NATS clients created in main.py +``` + +### Component Initialization + +```python +# Encode Worker - connects to downstream PD worker +pd_worker_client = ( + await runtime.namespace(dynamo_args.namespace) + .component("backend") + .endpoint("generate") + .client() +) + +# PD Worker (Decode mode) - connects to upstream Prefill worker +prefill_client = ( + await runtime.namespace(dynamo_args.namespace) + .component("prefill") + .endpoint("generate") + .client() +) +``` + +## Inter-Component Communication + +### Control Flow (NATS) + +All component-to-component communication happens via NATS: + +**Aggregated Mode (Eโ†’PD):** +```text +Processor โ†’ Encode Worker โ†’ PD Worker + (NATS) (NATS + NIXL embeddings) +``` + +**Disaggregated Mode (Eโ†’Pโ†’D):** +```text +Processor โ†’ Encode Worker โ†’ DECODE Worker โ†’ Prefill Worker + (NATS) (NATS) (NATS) + โ†“ + Decode requests bootstrap + โ†“ + Prefill returns {host, port, room} + โ†“ + Both connect via bootstrap + โ†“ + SGLang internal KV cache transfer +``` + +**Detailed Message Flow:** + +```text +Processor โ†’ Encode Worker: + - NATS round_robin with SglangMultimodalRequest + - Contains: tokenized input_ids, image URL, sampling params + +Encode Worker โ†’ Decode/PD Worker: + - NATS round_robin to "backend" component + - Contains: expanded token_ids, NIXL metadata, embeddings shape + - NIXL transfer: embeddings tensor + +Decode Worker โ†’ Prefill Worker (disagg only): + - NATS call to "prefill" component + - Decode requests bootstrap coordinates + - Prefill returns: {bootstrap_host, bootstrap_port, bootstrap_room} + +Prefill โ†” Decode (via bootstrap): + - SGLang internal connection (not NATS) + - KV cache state shared via bootstrap mechanism +``` + +### Data Transfer (NIXL) + +NIXL is used only for embedding transfer: + +```python +Encode Worker: + descriptor = connect.Descriptor(precomputed_embeddings) + with await connector.create_readable(descriptor) as readable: + request.serialized_request = readable.metadata() + # Send request with NIXL metadata + await pd_worker_client.round_robin(request) + await readable.wait_for_completion() + +PD Worker: + embeddings = torch.empty(request.embeddings_shape, dtype=torch.float16) + descriptor = connect.Descriptor(embeddings) + read_op = await connector.begin_read(request.serialized_request, descriptor) + await read_op.wait_for_completion() +``` + +## Vision Encoding Details + +### Encode Worker Components + +The encode worker loads and runs the vision model in Python: + +```python +# Vision components loaded in encode worker +self.image_processor = AutoImageProcessor.from_pretrained( + model_path, trust_remote_code=True +) +self.vision_model = AutoModel.from_pretrained( + model_path, + device_map="auto", + torch_dtype=torch.float16, + trust_remote_code=True +) +``` + +### Token Expansion Process + +1. Processor inserts single image token (e.g., `<|image_pad|>`) +2. Encode worker generates embeddings: `shape = (batch, num_patches, hidden_dim)` +3. Encode worker replaces single token with `num_patches` tokens +4. Downstream worker receives expanded token sequence + +Example: +```python +# Before: ["Hello", "<|image_pad|>", "world"] +# After: ["Hello", "<|image_pad|>", "<|image_pad|>", ...(576 tokens), "world"] +``` + +## Chat Template Processing + +SGLang uses its own chat template system: + +```python +from sglang.srt.parser.conversation import chat_templates + +conv = chat_templates["qwen2-vl"].copy() +conv.append_message(conv.roles[0], f"{conv.image_token} Describe this image") +processed = tokenizer(text=conv.get_prompt(), return_tensors="pt") +``` + +Supported templates: `qwen2-vl`, `llama-3`, `vicuna`, etc. + +## NIXL USE + +| Use Case | NIXL Used? | Data Transfer | Notes | +|----------|------------|---------------|-------| +| Eโ†’PD Aggregated | โœ… Yes | Encoder โ†’ PD (embeddings) | Vision encoder separate | +| Eโ†’Pโ†’D Disaggregated | โœ… Yes | Encoder โ†’ Prefill (embeddings) | KV cache via SGLang bootstrap | + +**Key Difference:** SGLang Pโ†’D uses bootstrap mechanism, not NIXL for KV cache like vLLM. + +## Known Limitations + +- **No Data URL support** - Only HTTP/HTTPS URLs supported; `data:image/...` base64 URLs not supported +- **No pre-computed embeddings** - Cannot use `.pt`, `.pth`, `.bin` embedding files; vision encoder runs for every request +- **No video support** - No video encoder implementation +- **No audio support** - No audio encoder implementation +- **Only Processor registers with Dynamo** - Workers are internal components, frontend routes to Processor only +- **Disaggregated routing** - Decode Worker is the entry point (calls Prefill), cannot route directly to Prefill workers +- **Limited model generalization** - Token expansion logic is model-specific; adding new models may require implementation updates + +## Supported Models + +SGLang multimodal **only supports image-based vision-language models**: + +### โœ… Supported (Images Only) +- **Qwen2-VL** / **Qwen2.5-VL** (primary support) +- Models with `AutoImageProcessor` and vision tower +- Models compatible with SGLang's image embedding format + + +## Key Files + +| File | Description | +|------|-------------| +| `components/src/dynamo/sglang/main.py` | Component initialization, only Processor registers | +| `components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py` | Processor implementation, OpenAIโ†’SGLang | +| `components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py` | Vision encoder, embeddings generation | +| `components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py` | PD/Prefill/Decode workers, NIXL read | +| `components/src/dynamo/sglang/multimodal_utils/multimodal_chat_processor.py` | Chat template processing | +| `components/src/dynamo/sglang/protocol.py` | Request/response data structures | +| `components/src/dynamo/sglang/register.py` | Registration logic (only called for Processor) | + diff --git a/docs/backends/sglang/profiling.md b/docs/backends/sglang/profiling.md new file mode 100644 index 0000000000..40a1c5ced1 --- /dev/null +++ b/docs/backends/sglang/profiling.md @@ -0,0 +1,44 @@ + + +# Profiling SGLang Workers in Dynamo + +Dynamo exposes profiling endpoints for SGLang workers via the system server's `/engine/*` routes. This allows you to start and stop PyTorch profiling on running inference workers without restarting them. + +These endpoints wrap SGLang's internal `TokenizerManager.start_profile()` and `stop_profile()` methods. See SGLang's documentation for the full list of supported parameters. + +## Quick Start + +1. **Start profiling:** + +```bash +curl -X POST http://localhost:9090/engine/start_profile \ + -H "Content-Type: application/json" \ + -d '{"output_dir": "/tmp/profiler_output"}' +``` + +2. **Run some inference requests to generate profiling data** + +3. **Stop profiling:** + +```bash +curl -X POST http://localhost:9090/engine/stop_profile +``` + +4. **View the traces:** + +The profiler outputs Chrome trace files in the specified `output_dir`. You can view them using: +- Chrome's `chrome://tracing` +- [Perfetto UI](https://ui.perfetto.dev/) +- TensorBoard with the PyTorch Profiler plugin + +## Test Script + +A test script is provided at [`examples/backends/sglang/test_sglang_profile.py`](../../../examples/backends/sglang/test_sglang_profile.py) that demonstrates the full profiling workflow: + +```bash +python examples/backends/sglang/test_sglang_profile.py +``` + diff --git a/docs/backends/trtllm/multimodal_epd.md b/docs/backends/trtllm/multimodal_epd.md deleted file mode 100644 index 17839e5826..0000000000 --- a/docs/backends/trtllm/multimodal_epd.md +++ /dev/null @@ -1,139 +0,0 @@ -# Encode-Prefill-Decode (EPD) Flow with NIXL - -For high-performance multimodal inference with large embeddings, Dynamo supports a specialized **Encode-Prefill-Decode (EPD)** flow using **NIXL (RDMA)** for zero-copy tensor transfer. - -## Enabling the Feature - -This is an experimental feature that requires using a specific TensorRT-LLM commit. -To enable it build the dynamo container with the `--tensorrtllm-commit` flag, followed by the commit hash: - -```bash -./container/build.sh --framework trtllm --tensorrtllm-git-url https://github.com/NVIDIA/TensorRT-LLM.git --tensorrtllm-commit v1.2.0rc2 -``` - -## Key Features - -- **High Performance**: Zero-copy RDMA transfer for embeddings -- **Dynamic Shape Allocation**: Automatically handles variable embedding shapes per image -- **Multi-Format Support**: Works with tensor files (`.pt`) and dictionary-based embeddings -- **Hybrid Transfer**: Large tensors via NIXL, small metadata via JSON - -## How to use - -```bash -cd $DYNAMO_HOME/examples/backends/trtllm - -# Launch 3-worker EPD flow with NIXL. -./launch/epd_disagg.sh -``` - -## Pre-requsites - -This script is specifically designed to work on 8 node H200 and `Llama-4-Maverick-17B-128E-Instruct` model with assumption that you already have a model specific embedding file ready. - -## Configuration - -The EPD flow uses a dedicated **Encode Worker** that runs separately from the Prefill and Decode workers. The `ENCODE_ENDPOINT` environment variable specifies how the Prefill worker communicates with the Encode worker: - -```bash -export ENCODE_ENDPOINT="dyn://dynamo.tensorrt_llm_encode.generate" -``` - -This endpoint follows Dynamo's standard format: `dyn://namespace.component.endpoint` where the Encode worker registers itself as `dynamo.tensorrt_llm_encode.generate`. - -For local embedding file access, use the `--allowed-local-media-path "$ALLOWED_LOCAL_MEDIA_PATH"` parameter to specify the secure directory path where embedding files can be loaded from (default: `/tmp`). This prevents path traversal attacks while allowing flexible file access within the designated directory. - -```bash -export ALLOWED_LOCAL_MEDIA_PATH="/tmp" -``` - -For tensor file size protection, use the `--max-file-size-mb "$MAX_FILE_SIZE_MB"` parameter to limit the maximum size of downloadable embedding files/Image URLs (default: `50MB`). This prevents Denial of Service (DoS) attacks from maliciously large files while accommodating typical embedding file sizes. - -```bash -export MAX_FILE_SIZE_MB=50 -``` - -## Architecture Overview - -The EPD flow implements a **3-worker architecture** for high-performance multimodal inference: - -- **Encode Worker**: Loads and processes multimodal embeddings -- **Prefill Worker**: Handles initial context processing and KV-cache generation -- **Decode Worker**: Performs streaming token generation - -## Request Flow Diagram - -```mermaid -sequenceDiagram - participant Client - participant Frontend - participant PrefillWorker as "Prefill Worker
(PrefillHandler)" - participant EncodeWorker as "Encode Worker
(EncodeHandler)" - participant DecodeWorker as "Decode Worker
(DecodeHandler)" - participant NIXL as "NIXL
(RDMA Transfer)" - - Note over Client,NIXL: Unified Frontend: Context processing followed by streaming generation - - Client->>Frontend: POST /v1/chat/completions
(multimodal request) - Frontend->>PrefillWorker: Route to prefill worker - - Note over PrefillWorker: Check for multimodal content - PrefillWorker->>EncodeWorker: Send request
(contains embedding paths) - - Note over EncodeWorker: Load embeddings from file/url
- EncodeWorker->>NIXL: Create readable operation
- EncodeWorker->>PrefillWorker: Send metadata + NIXL info
(JSON: shape, dtype, aux_data) - - Note over PrefillWorker: Allocate tensor with dynamic shape - PrefillWorker->>NIXL: Begin read operation - NIXL-->>PrefillWorker: Zero-copy transfer complete
- - Note over PrefillWorker: Reconstruct embeddings
(mm_embeddings + special_tokens + offsets) - Note over PrefillWorker: Process full context
(text + multimodal embeddings) - Note over PrefillWorker: Generate KV-cache
(max_tokens=1 in prefill mode) - - PrefillWorker->>Frontend: Return prefill response
(disaggregated_params) - - Frontend->>DecodeWorker: Route to decode worker
with disaggregated_params - - Note over DecodeWorker: Continue generation
(streaming tokens) - DecodeWorker->>Frontend: Stream response chunk 1 - Frontend->>Client: Response chunk 1 - DecodeWorker->>Frontend: Stream response chunk 2 - Frontend->>Client: Response chunk 2 - DecodeWorker->>Frontend: ... (continue streaming) - Frontend->>Client: ... (continue streaming) - DecodeWorker->>Frontend: Final response + [DONE] - Frontend->>Client: Final response + [DONE] -``` - -## How the System Works - -1. **Request Processing**: Multimodal requests containing embedding file paths or URLs are routed by the frontend to prefill workers -2. **Multimodal Loading**: EncodeWorker loads large embedding files and extracts auxiliary metadata -3. **NIXL Transfer**: Main tensors transferred via zero-copy RDMA, small metadata via JSON for efficiency -4. **Dynamic Allocation**: Consumer workers allocate tensors with exact shapes received from EncodeWorker -5. **Reconstruction**: Original embedding format (dictionary or tensor) is reconstructed for model processing - -## Example Request - -The request format is identical to regular multimodal requests: - -```bash -curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ - "model": "meta-llama/Llama-4-Maverick-17B-128E-Instruct", - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Describe the image"}, - { - "type": "image_url", - "image_url": {"url": "/path/to/embeddings.pt"} - } - ] - } - ], - "max_tokens": 160 -}' -``` diff --git a/docs/backends/trtllm/multimodal_support.md b/docs/backends/trtllm/multimodal_support.md index 7f90874be7..cc58f924b9 100644 --- a/docs/backends/trtllm/multimodal_support.md +++ b/docs/backends/trtllm/multimodal_support.md @@ -92,23 +92,41 @@ In general, disaggregated serving can run on a single node, provided the model f To deploy `Llama-4-Maverick-17B-128E-Instruct` in disaggregated mode, you will need to follow the multi-node setup instructions, which can be found [here](./multinode/multinode-multimodal-example.md). -## Using Pre-computed Embeddings (Experimental) +## Pre-computed Embeddings with EPD Flow -Dynamo with TensorRT-LLM supports providing pre-computed embeddings directly in an inference request. This bypasses the need for the model to process an image and generate embeddings itself, which is useful for performance optimization or when working with custom, pre-generated embeddings. +For high-performance multimodal inference, Dynamo supports pre-computed embeddings with an **Encode-Prefill-Decode (EPD)** flow using **NIXL (RDMA)** for zero-copy tensor transfer. -### How to Use +### Supported File Types -Once the container is built, you can send requests with paths to local embedding files. +- `.pt` - PyTorch tensor files +- `.pth` - PyTorch checkpoint files +- `.bin` - Binary tensor files -- **Format:** Provide the embedding as part of the `messages` array, using the `image_url` content type. -- **URL:** The `url` field should contain the absolute or relative path to your embedding file on the local filesystem. -- **File Types:** Supported embedding file extensions are `.pt`, `.pth`, and `.bin`. Dynamo will automatically detect these extensions. +### How to Launch -When a request with a supported embedding file is received, Dynamo will load the tensor from the file and pass it directly to the model for inference, skipping the image-to-embedding pipeline. +```bash +cd $DYNAMO_HOME/examples/backends/trtllm -### Example Request +# Launch 3-worker EPD flow with NIXL +./launch/epd_disagg.sh +``` -Here is an example of how to send a request with a pre-computed embedding file. +> **Note:** This script is designed for 8-node H200 with `Llama-4-Scout-17B-16E-Instruct` model and assumes you have a model-specific embedding file ready. + +### Configuration + +```bash +# Encode endpoint for Prefill โ†’ Encode communication +export ENCODE_ENDPOINT="dyn://dynamo.tensorrt_llm_encode.generate" + +# Security: Allowed directory for embedding files (default: /tmp) +export ALLOWED_LOCAL_MEDIA_PATH="/tmp" + +# Security: Max file size to prevent DoS attacks (default: 50MB) +export MAX_FILE_SIZE_MB=50 +``` + +### Example Request ```bash curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ @@ -117,27 +135,47 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d ' { "role": "user", "content": [ - { - "type": "text", - "text": "Describe the content represented by the embeddings" - }, - { - "type": "image_url", - "image_url": { - "url": "/path/to/your/embedding.pt" - } - } + {"type": "text", "text": "Describe the image"}, + {"type": "image_url", "image_url": {"url": "/path/to/embedding.pt"}} ] } ], - "stream": false, "max_tokens": 160 }' ``` -## Encode-Prefill-Decode (EPD) Flow with NIXL -Dynamo with the TensorRT-LLM backend supports multimodal models in Encode -> Decode -> Prefill fashion, enabling you to process embeddings seperately in a seperate worker. For detailed setup instructions, example requests, and best practices, see the [Multimodal EPD Support Guide](./multimodal_epd.md). +### Architecture + +The EPD flow implements a **3-worker architecture**: + +- **Encode Worker**: Loads pre-computed embeddings, transfers via NIXL +- **Prefill Worker**: Receives embeddings, handles context processing and KV-cache generation +- **Decode Worker**: Performs streaming token generation + +### Request Flow + +```mermaid +sequenceDiagram + participant Client + participant Frontend + participant PrefillWorker as "Prefill Worker" + participant EncodeWorker as "Encode Worker" + participant DecodeWorker as "Decode Worker" + participant NIXL as "NIXL (RDMA)" + + Client->>Frontend: POST /v1/chat/completions + Frontend->>PrefillWorker: Route to prefill worker + PrefillWorker->>EncodeWorker: Send request (embedding paths) + EncodeWorker->>NIXL: Create readable operation + EncodeWorker->>PrefillWorker: Send metadata + NIXL info + PrefillWorker->>NIXL: Begin read operation + NIXL-->>PrefillWorker: Zero-copy transfer complete + PrefillWorker->>Frontend: Return prefill response + Frontend->>DecodeWorker: Route to decode worker + DecodeWorker->>Frontend: Stream response chunks + Frontend->>Client: Stream response +``` ## Supported Multimodal Models -Multimodel models listed [here](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/inputs/utils.py#L221) are supported by dynamo. \ No newline at end of file +Multimodal models listed in [TensorRT-LLM supported models](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/models/supported-models.md) are supported by Dynamo. diff --git a/docs/backends/trtllm/multimodal_trtllm_guide.md b/docs/backends/trtllm/multimodal_trtllm_guide.md new file mode 100644 index 0000000000..3bd3b9b0fe --- /dev/null +++ b/docs/backends/trtllm/multimodal_trtllm_guide.md @@ -0,0 +1,270 @@ + + +# TRT-LLM Multimodal Guide + +This document provides a comprehensive guide for multimodal inference using TensorRT-LLM backend in Dynamo. For more details on the multimodal examples, see [Multimodal Examples Documentation](./multimodal_support.md). + +## Multimodal Support Matrix + +| Modality | Input Format | Aggregated | Disaggregated | Notes | +|----------|--------------|------------|---------------|-------| +| **Image** | HTTP/HTTPS URL | Yes | Yes | Full support for all image models | +| **Image** | Pre-computed Embeddings (.pt, .pth, .bin) | Yes | Yes | Direct embedding files | +| **Video** | HTTP/HTTPS URL | โŒ No | โŒ No | Not implemented | +| **Audio** | HTTP/HTTPS URL | โŒ No | โŒ No | Not implemented | + +## Architecture Comparison + +TRT-LLM multimodal supports three deployment patterns: + +```text +SIMPLE AGGREGATED (agg.sh): + Client โ†’ Frontend (Rust) โ†’ Worker [image load, encode, P+D] โ†’ Response + โ€ข 2 components โ€ข worker flag `--modality multimodal` โ€ข Easiest setup + +DISAGGREGATED P->D (disagg_multimodal.sh): + Client โ†’ Frontend โ†’ Prefill [image load, encode] โ†’ Decode โ†’ Response + โ€ข 3 components โ€ข worker flag `--disaggregation-mode prefill/decode` โ€ข Multi-GPU, KV transfer + +EPD DISAGGREGATED - WIP: + Client โ†’ Frontend โ†’ Encode [MultimodalEncoder] โ†’ Prefill [via params] โ†’ Decode โ†’ Response + โ€ข 4 components โ€ข worker flag `--disaggregation-mode encode/prefill/decode` โ€ข WIP PR #4668 +``` + +## Input Format Details + +### Supported URL Formats + +| Format | Example | Description | Support | +|--------|---------|-------------|---------| +| **HTTP/HTTPS** | `http://example.com/image.jpg` | Remote media files | โœ… | +| **Pre-computed Embeddings** | `/path/to/embedding.pt` | Local embedding files (.pt, .pth, .bin) | โœ… | + +## Simple Aggregated Mode (PD) + +In aggregated mode, all processing (image loading, encoding, prefill, decode) happens within a single worker. + +### Architecture + +```text +HTTP Frontend (Rust) + โ†“ +TRT-LLM Worker (Python - ModelInput.Tokens) + โ†“ downloads media, encodes, prefill + decode +Response +``` + +### Components + +| Component | Flag | ModelInput | Registered | Purpose | +|-----------|------|-----------|------------|---------| +| Worker | `--modality multimodal` | Tokens | Yes | Complete inference pipeline | + +### Launch Script + +Example: [`examples/backends/trtllm/launch/agg.sh`](../../../examples/backends/trtllm/launch/agg.sh) + +## Disaggregated Mode (P->D) + +In disaggregated mode, prefill and decode are handled by separate workers. The prefill worker handles image loading and encoding internally. + +### Architecture + +```text +HTTP Frontend (Rust) + โ†“ +Prefill Worker (Python - ModelInput.Tokens) + โ†“ downloads media, encodes, prefill, KV cache transfer +Decode Worker (Python - ModelInput.Tokens) + โ†“ decode only, token generation +Response +``` + +### Components + +| Component | Flag | ModelInput | Registered | Purpose | +|-----------|------|-----------|------------|---------| +| Prefill Worker | `--disaggregation-mode prefill` | Tokens | Yes | Image processing + Prefill | +| Decode Worker | `--disaggregation-mode decode` | Tokens | Yes | Decode only | + +### Launch Script + +Example: [`examples/backends/trtllm/launch/disagg_multimodal.sh`](../../../examples/backends/trtllm/launch/disagg_multimodal.sh) + +## Pre-computed Embeddings + +TRT-LLM supports providing pre-computed embeddings, bypassing image-to-embedding processing. + +### Supported File Types + +- `.pt` - PyTorch tensor files +- `.pth` - PyTorch checkpoint files +- `.bin` - Binary tensor files + +### Embedding File Formats + +TRT-LLM supports two formats for embedding files: + +#### 1. Simple Tensor Format + +- Direct tensor saved as `.pt` file +- Example: `llava_next_mm_embed_seashore.pt` +- Contains only the embedding tensor + +```python +# Example: Simple tensor format +embedding_tensor = torch.rand(1, 576, 4096) # [batch, seq_len, hidden_dim] +torch.save(embedding_tensor, "embedding.pt") +``` + +#### 2. Dictionary Format with Auxiliary Data + +- Dictionary containing multiple keys +- Used by models like Llama-4 that require additional metadata +- Must contain `mm_embeddings` key with the main tensor +- Can include auxiliary data like special tokens, offsets, etc. + +```python +# Example: Dictionary format (Llama-4 style) +embedding_dict = { + "mm_embeddings": torch.rand(1, 576, 4096), + "special_tokens": [128256, 128257], + "image_token_offsets": [[0, 576]], + # ... other model-specific metadata +} +torch.save(embedding_dict, "llama4_embedding.pt") +``` + +**How They're Used:** +- **Simple tensors**: Loaded directly and passed to `mm_embeddings` parameter +- **Dictionary format**: `mm_embeddings` key extracted as main tensor, other keys preserved as auxiliary data and transferred separately + +### Launch Script + +Example: [`examples/backends/trtllm/launch/epd_disagg.sh`](../../../examples/backends/trtllm/launch/epd_disagg.sh) + +### Security Considerations + +For EPD mode with local embedding files: + +- `--allowed-local-media-path` - Specify secure directory for embedding files (default: `/tmp`) +- `--max-file-size-mb` - Limit max file size to prevent DoS attacks (default: `50MB`) + +## EPD Disaggregated Mode (E->P->D) - WIP + +**Status:** Work In Progress (WIP PR #4668) - Full EPD flow with MultimodalEncoder + +In EPD mode, encoding, prefill, and decode are handled by separate workers. The encode worker uses TensorRT-LLM's `MultimodalEncoder` to process images and transfer embeddings via disaggregated parameters. + +### Architecture + +```text +HTTP Frontend (Rust) + โ†“ +Encode Worker (Python - NOT registered, uses MultimodalEncoder) + โ†“ downloads image, encodes with vision model, transfers via disaggregated_params +Prefill Worker (Python - ModelInput.Tokens) + โ†“ receives embeddings via disaggregated_params, prefill only, KV cache transfer +Decode Worker (Python - ModelInput.Tokens) + โ†“ decode only, token generation +Response +``` + +**Note (WIP):** The encode worker uses `MultimodalEncoder` from TensorRT-LLM to actually encode images, not just load pre-computed embeddings. This is a significant change from the legacy NIXL-based embedding transfer. + +### Components + +| Component | Flag | ModelInput | Registered | Purpose | +|-----------|------|-----------|------------|---------| +| Encode Worker | `--disaggregation-mode encode` | N/A | No | Image encoding with MultimodalEncoder | +| Prefill Worker | `--disaggregation-mode prefill --encode-endpoint` | Tokens | Yes | Prefill only | +| Decode Worker | `--disaggregation-mode decode` | Tokens | Yes | Decode only | + + +## ModelInput Types and Registration + +### Understanding ModelInput + +TRT-LLM workers register with Dynamo using: + +| ModelInput Type | Preprocessing | Use Case | +|-----------------|---------------|----------| +| `ModelInput.Tokens` | Rust SDK tokenizes text (bypassed for multimodal) | All TRT-LLM workers | + +### Component Registration Pattern + +```python +# TRT-LLM Worker - Register with Tokens +await register_llm( + ModelInput.Tokens, # Rust does minimal preprocessing + model_type, # ModelType.Chat or ModelType.Prefill + generate_endpoint, + model_name, + ... +) +``` + +## Inter-Component Communication + +| Transfer Stage | Message | NIXL Transfer | +|----------------|--------------|---------------| +| **Frontend โ†’ Prefill** | Request with image URL or embedding path | No | +| **Encode โ†’ Prefill (pre-computed embeddings)** | NIXL metadata (pre-computed embeddings) | Yes (Embeddings tensor) | +| **Encode โ†’ Prefill (Image URL) (WIP)** | Disaggregated params with multimodal handles | No (Handles via params) | +| **Prefill โ†’ Decode** | Disaggregated params | Configurable (KV cache: NIXL default, UCX optional) | + + +## **NIXL USE** + +| Use Case | Script | NIXL Used? | Data Transfer | +|----------|--------|------------|---------------| +| Simple Aggregated | [`examples/backends/trtllm/launch/agg.sh`](../../../examples/backends/trtllm/launch/agg.sh) | โŒ No | All in one worker | +| P->D Disaggregated | [`examples/backends/trtllm/launch/disagg_multimodal.sh`](../../../examples/backends/trtllm/launch/disagg_multimodal.sh) | โš™๏ธ Optional | Prefill โ†’ Decode (KV cache via UCX or NIXL) | +| E->P->D Disaggregated (pre-computed embeddings) | [`examples/backends/trtllm/launch/epd_disagg.sh`](../../../examples/backends/trtllm/launch/epd_disagg.sh) | โœ… Yes | Encoder โ†’ Prefill (pre-computed embeddings via NIXL) | +| E->P->D Disaggregated (WIP) | X | โŒ No | Encoder โ†’ Prefill (multimodal handles via disaggregated_params)
Prefill โ†’ Decode (KV cache via UCX/NIXL) | + +**Note:** NIXL for KV cache transfer is currently beta and only supported on AMD64 (x86_64) architecture. + + +## Key Files + +| File | Description | +|------|-------------| +| `components/src/dynamo/trtllm/main.py` | Worker initialization and setup | +| `components/src/dynamo/trtllm/utils/trtllm_utils.py` | Command-line argument parsing | +| `components/src/dynamo/trtllm/multimodal_processor.py` | Multimodal request processing | +| `components/src/dynamo/trtllm/request_handlers/handlers.py` | Request handler factory | +| `components/src/dynamo/trtllm/request_handlers/handler_base.py` | Base handler and disaggregation modes | + +## Known Limitations + +- **No Data URL support** - Only HTTP/HTTPS URLs supported; `...` | Base64-encoded inline data | โœ… | + +## Simple Aggregated Mode (PD) + +In simple aggregated mode, encoding, prefill, and decode happen within the same worker. + +### Architecture + +```text +HTTP Frontend with Rust processor + โ†“ +Worker (Python - ModelInput.Tokens) + โ†“ encode + prefill + decode +Response +``` + +## EPD Aggregated Mode (PD) + +In EPD aggregated mode, encoding happens in a separate worker and prefill and decode happen within the same pipeline. + +### Architecture + +```text +HTTP Frontend (Rust) + โ†“ +Processor (Python - ModelInput.Text) + โ†“ tokenizes, extracts media URL +Encode Worker (Python - not registered) + โ†“ downloads media, generates embeddings, NIXL transfer +PD Worker (Python - ModelInput.Tokens) + โ†“ prefill + decode +Response +``` + +### Components + +| Component | Flag | ModelInput | Registered | Purpose | +|-----------|------|-----------|------------|---------| +| Processor | `--multimodal-processor` | Text | Yes | HTTP entry, tokenization | +| Encode Worker | `--multimodal-encode-worker` | N/A | No | Media encoding | +| PD Worker | `--multimodal-worker` | Tokens | Yes | Prefill + Decode | + +## EPD Disaggregated Mode (E->P->D) + +In EPD disaggregated mode, encoding, prefill, and decode are handled by separate workers. + +### Architecture + +```text +HTTP Frontend (Rust) + โ†“ +Processor (Python - ModelInput.Text) + โ†“ tokenizes, extracts media URL +Encode Worker (Python - not registered) + โ†“ downloads media, generates embeddings, NIXL transfer +Prefill Worker (Python - ModelInput.Tokens) + โ†“ prefill only, KV cache NIXL transfer +Decode Worker (Python - ModelInput.Tokens) + โ†“ decode only, token generation +Response +``` + +### Components + +| Component | Flag | ModelInput | Registered | Purpose | +|-----------|------|-----------|------------|---------| +| Processor | `--multimodal-processor` | Text | Yes | HTTP entry, tokenization | +| Encode Worker | `--multimodal-encode-worker` | N/A | No | Media encoding | +| Prefill Worker | `--multimodal-worker --is-prefill-worker` | Tokens | Yes | Prefill only | +| Decode Worker | `--multimodal-decode-worker` | Tokens | Yes | Decode only | + +## Traditional Disagg (EP->D) + +Llama 4 models don't support pre-computed embeddings, so they use a combined Encode+Prefill worker. + +### Architecture + +```text +HTTP Frontend (Rust) + โ†“ +Processor (Python - ModelInput.Text) + โ†“ tokenizes, extracts media URL +Encode+Prefill Worker (Python - ModelInput.Tokens) + โ†“ downloads media, encodes inline, prefill, KV cache NIXL transfer +Decode Worker (Python - ModelInput.Tokens) + โ†“ decode only, token generation +Response +``` + +### Components + +| Component | Flag | ModelInput | Registered | Purpose | +|-----------|------|-----------|------------|---------| +| Processor | `--multimodal-processor` | Text | Yes | HTTP entry, tokenization | +| Encode+Prefill | `--multimodal-encode-prefill-worker --is-prefill-worker` | Tokens | Yes | Encode + Prefill | +| Decode Worker | `--multimodal-decode-worker` | Tokens | Yes | Decode only | + +### Launch Script + +Example: [`examples/backends/vllm/launch/disagg_multimodal_llama.sh`](../../../examples/backends/vllm/launch/disagg_multimodal_llama.sh) + +## ModelInput Types and Registration + +### Understanding ModelInput + +Dynamo's Rust SDK supports two input types that determine how the HTTP frontend preprocesses requests: + +| ModelInput Type | Preprocessing | Use Case | +|-----------------|---------------|----------| +| `ModelInput.Text` | None (raw text passed through) | Components that tokenize themselves | +| `ModelInput.Tokens` | Rust SDK would tokenize (but bypassed in multimodal) | Components expecting pre-tokenized input | + +### Component Registration Pattern + +```python +# Processor - Entry point from HTTP frontend +await register_llm( + ModelInput.Text, # Frontend sends raw text + ModelType.Chat, + generate_endpoint, + model_name, + ... +) + +# Workers - Internal components +await register_llm( + ModelInput.Tokens, # Expect pre-tokenized input + ModelType.Chat, # or ModelType.Prefill for prefill workers + generate_endpoint, + model_name, + ... +) +``` + +## **NIXL USE** + +| Use Case | Script | NIXL Used? | Data Transfer | +|----------|--------|------------|---------------| +| Simple Aggregated | [`examples/backends/vllm/launch/agg_multimodal.sh`](../../../examples/backends/vllm/launch/agg_multimodal.sh) | โŒ No | All in one worker | +| E->PD Aggregated | [`examples/backends/vllm/launch/agg_multimodal_epd.sh`](../../../examples/backends/vllm/launch/agg_multimodal_epd.sh) | โœ… Yes | Encoder โ†’ PD (embeddings) | +| E->P->D Disaggregated | [`examples/backends/vllm/launch/disagg_multimodal_epd.sh`](../../../examples/backends/vllm/launch/disagg_multimodal_epd.sh) | โœ… Yes | Encoder โ†’ Prefill (embeddings)
Prefill โ†’ Decode (KV cache) | +| EP->D Disaggregated (Llama 4) | [`examples/backends/vllm/launch/disagg_multimodal_llama.sh`](../../../examples/backends/vllm/launch/disagg_multimodal_llama.sh) | โœ… Yes | Prefill โ†’ Decode (KV cache) | + + +## Known Limitations + +- **Disaggregated flows require Python Processor** - All multimodal disaggregation requires the Python Processor component (`ModelInput.Text`). + +## Supported Models + +The following models have been tested with Dynamo's vLLM multimodal backend: + +- **Qwen2.5-VL** - `Qwen/Qwen2.5-VL-7B-Instruct` +- **Qwen3-VL** - `Qwen/Qwen3-VL-30B-A3B-Instruct-FP8` +- **LLaVA 1.5** - `llava-hf/llava-1.5-7b-hf` +- **Llama 4 Maverick** - `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8` +- **LLaVA Next Video** - `llava-hf/LLaVA-NeXT-Video-7B-hf` +- **Qwen2-Audio** - `Qwen/Qwen2-Audio-7B-Instruct` + +For a complete list of multimodal models supported by vLLM, see [vLLM Supported Multimodal Models](https://docs.vllm.ai/en/latest/models/supported_models/#list-of-multimodal-language-models). Models listed there should work with Simple Aggregated Mode but may not be explicitly tested. + +## Key Files + +| File | Description | +|------|-------------| +| `components/src/dynamo/vllm/main.py` | Worker initialization and setup | +| `components/src/dynamo/vllm/args.py` | Command-line argument parsing | +| `components/src/dynamo/vllm/multimodal_handlers/processor_handler.py` | Processor implementation | +| `components/src/dynamo/vllm/multimodal_handlers/encode_worker_handler.py` | Encode worker implementation | +| `components/src/dynamo/vllm/multimodal_handlers/worker_handler.py` | PD/Prefill/Decode worker implementation | + diff --git a/docs/backends/vllm/speculative_decoding.md b/docs/backends/vllm/speculative_decoding.md new file mode 100644 index 0000000000..b0cdf65d8b --- /dev/null +++ b/docs/backends/vllm/speculative_decoding.md @@ -0,0 +1,121 @@ + +# Running **Meta-Llama-3.1-8B-Instruct** with Speculative Decoding (Eagle3) + +This guide walks through how to deploy **Meta-Llama-3.1-8B-Instruct** using **aggregated speculative decoding** with **Eagle3** on a single node. +Since the model is only **8B parameters**, you can run it on **any GPU with at least 16GB VRAM**. + + + +## Step 1: Set Up Your Docker Environment + +First, weโ€™ll initialize a Docker container using the VLLM backend. +You can refer to the [VLLM Quickstart Guide](./README.md#vllm-quick-start) โ€” or follow the full steps below. + +### 1. Launch Docker Compose + +```bash +docker compose -f deploy/docker-compose.yml up -d +``` + +### 2. Build the Container + +```bash +./container/build.sh --framework VLLM +``` + +### 3. Run the Container + +```bash +./container/run.sh -it --framework VLLM --mount-workspace +``` + + + +## Step 2: Get Access to the Llama-3 Model + +The **Meta-Llama-3.1-8B-Instruct** model is gated, so youโ€™ll need to request access on Hugging Face. +Go to the official [Meta-Llama-3.1-8B-Instruct repository](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) and fill out the access form. +Approval usually takes around **5 minutes**. + +Once you have access, generate a **Hugging Face access token** with permission for gated repositories, then set it inside your container: + +```bash +export HUGGING_FACE_HUB_TOKEN="insert_your_token_here" +export HF_TOKEN=$HUGGING_FACE_HUB_TOKEN +``` + + + +## Step 3: Run Aggregated Speculative Decoding + +Now that your environment is ready, start the aggregated server with **speculative decoding**. + +```bash +# Requires only one GPU +cd examples/backends/vllm +bash launch/agg_spec_decoding.sh +``` + +Once the weights finish downloading and serving begins, youโ€™ll be ready to send inference requests to your model. + + + + +## Step 4: Example Request + +To verify your setup, try sending a simple prompt to your model: + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "messages": [ + {"role": "user", "content": "Write a poem about why Sakura trees are beautiful."} + ], + "max_tokens": 250 + }' +``` + +### Example Output + +```json +{ + "id": "cmpl-3e87ea5c-010e-4dd2-bcc4-3298ebd845a8", + "choices": [ + { + "text": "In cherry blossomโ€™s gentle breeze ... A delicate balance of life and death, as petals fade, and new life breathes.", + "index": 0, + "finish_reason": "stop" + } + ], + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "usage": { + "prompt_tokens": 16, + "completion_tokens": 250, + "total_tokens": 266 + } +} +``` + + + +## Additional Resources + +* [VLLM Quickstart](./README.md#vllm-quick-start) +* [Meta-Llama-3.1-8B-Instruct on Hugging Face](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 0b34264c38..b717f1ea3f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -9,7 +9,10 @@ project = "NVIDIA Dynamo" copyright = "2024-2025, NVIDIA CORPORATION & AFFILIATES" author = "NVIDIA" -release = "latest" + +# Version is set via DYNAMO_DOCS_VERSION env var during build (e.g., "0.3.0") +# Defaults to "dev" for main branch and PR builds +release = os.environ.get("DYNAMO_DOCS_VERSION", "dev") # -- General configuration --------------------------------------------------- @@ -112,7 +115,7 @@ # -- Options for HTML output ------------------------------------------------- html_theme = "nvidia_sphinx_theme" html_static_path = ["_static"] -html_extra_path = ["project.json", "versions1.json"] +html_extra_path = ["project.json"] html_theme_options = { "collapse_navigation": False, "icon_links": [ @@ -123,7 +126,9 @@ } ], "switcher": { - "json_url": "versions1.json", + # Use single shared URL so all versions see the same switcher list + # When a new version is added, all old docs automatically see it + "json_url": "https://docs.nvidia.com/dynamo/versions1.json", "version_match": release, }, "extra_head": { diff --git a/docs/design_docs/distributed_runtime.md b/docs/design_docs/distributed_runtime.md index f61cf4f762..31f56f34fc 100644 --- a/docs/design_docs/distributed_runtime.md +++ b/docs/design_docs/distributed_runtime.md @@ -53,7 +53,7 @@ The hierarchy and naming in etcd and NATS may change over time, and this documen For etcd, it also creates a primary lease and spin up a background task to keep the lease alive. All objects registered under this `DistributedRuntime` use this lease_id to maintain their life cycle. There is also a cancellation token that is tied to the primary lease. When the cancellation token is triggered or the background task failed, the primary lease is revoked or expired and the kv pairs stored with this lease_id is removed. - `Namespace`: `Namespace`s are primarily a logical grouping mechanism and is not registered in etcd. It provides the root path for all components under this `Namespace`. -- `Component`: When a `Component` object is created, similar to `Namespace`, it isn't be registered in etcd. When `create_service` is called, it creates a NATS service group using `{namespace_name}.{service_name}` for metrics and registers a service in the registry of the `Component`, where the registry is an internal data structure that tracks all services and endpoints within the `DistributedRuntime`. +- `Component`: When a `Component` object is created, similar to `Namespace`, it isn't be registered in etcd. When `create_service` is called, it creates a NATS service group using `{namespace_name}.{service_name}` as the service identifier and registers a service in the registry of the `Component`, where the registry is an internal data structure that tracks all services and endpoints within the `DistributedRuntime`. - `Endpoint`: When an Endpoint object is created and started, it performs two key registrations: - NATS Registration: The endpoint is registered with the NATS service group created during service creation. The endpoint is assigned a unique subject following the naming: `{namespace_name}.{service_name}.{endpoint_name}-{lease_id_hex}`. - etcd Registration: The endpoint information is stored in etcd at a path following the naming: `/services/{namespace}/{component}/{endpoint}-{lease_id}`. Note that the endpoints of different workers of the same type (i.e., two `VllmPrefillWorker`s in one deployment) share the same `Namespace`, `Component`, and `Endpoint` name. They are distinguished by their different primary `lease_id` of their `DistributedRuntime`. diff --git a/docs/generate_docs.py b/docs/generate_docs.py index 342fc2d1f2..a84379ae63 100755 --- a/docs/generate_docs.py +++ b/docs/generate_docs.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import os import re @@ -282,9 +283,23 @@ def change_directory(path): os.chdir(original_directory) +def update_project_json(): + """Update project.json with the current version from DYNAMO_DOCS_VERSION env var.""" + version = os.environ.get("DYNAMO_DOCS_VERSION", "dev") + project_json_path = os.path.join(dynamo_docs_abspath, "project.json") + + project_data = {"name": "NVIDIA Dynamo", "version": version} + + with open(project_json_path, "w") as f: + json.dump(project_data, f) + + log_message(f"Updated project.json with version: {version}") + + def main(): with change_directory(dynamo_docs_abspath): run_command("make clean") + update_project_json() preprocess_docs() run_command("make html") diff --git a/docs/hidden_toctree.rst b/docs/hidden_toctree.rst index 669ae0339c..9c8a50ad91 100644 --- a/docs/hidden_toctree.rst +++ b/docs/hidden_toctree.rst @@ -50,7 +50,7 @@ backends/trtllm/llama4_plus_eagle.md backends/trtllm/kv-cache-transfer.md backends/trtllm/multimodal_support.md - backends/trtllm/multimodal_epd.md + backends/trtllm/multimodal_trtllm_guide.md backends/trtllm/gemma3_sliding_window_attention.md backends/trtllm/gpt-oss.md backends/trtllm/prometheus.md @@ -61,6 +61,8 @@ backends/sglang/expert-distribution-eplb.md backends/sglang/gpt-oss.md backends/sglang/multimodal_epd.md + backends/sglang/multimodal_sglang_guide.md + backends/sglang/profiling.md backends/sglang/sgl-hicache-example.md backends/sglang/sglang-disaggregation.md backends/sglang/prometheus.md @@ -73,9 +75,12 @@ backends/vllm/deepseek-r1.md backends/vllm/gpt-oss.md + backends/vllm/LMCache_Integration.md backends/vllm/multi-node.md backends/vllm/multimodal.md + backends/vllm/multimodal_vllm_guide.md backends/vllm/prometheus.md + backends/vllm/speculative_decoding.md benchmarks/kv-router-ab-testing.md diff --git a/docs/kubernetes/api_reference.md b/docs/kubernetes/api_reference.md index 09e7415769..4ae3246155 100644 --- a/docs/kubernetes/api_reference.md +++ b/docs/kubernetes/api_reference.md @@ -37,6 +37,7 @@ Package v1alpha1 contains API Schema definitions for the nvidia.com v1alpha1 API - [DynamoComponentDeployment](#dynamocomponentdeployment) - [DynamoGraphDeployment](#dynamographdeployment) - [DynamoGraphDeploymentRequest](#dynamographdeploymentrequest) +- [DynamoGraphDeploymentScalingAdapter](#dynamographdeploymentscalingadapter) - [DynamoModel](#dynamomodel) @@ -45,7 +46,9 @@ Package v1alpha1 contains API Schema definitions for the nvidia.com v1alpha1 API - +Deprecated: This field is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter +with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md +for migration guidance. This field will be removed in a future API version. @@ -55,11 +58,11 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `enabled` _boolean_ | | | | -| `minReplicas` _integer_ | | | | -| `maxReplicas` _integer_ | | | | -| `behavior` _[HorizontalPodAutoscalerBehavior](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#horizontalpodautoscalerbehavior-v2-autoscaling)_ | | | | -| `metrics` _[MetricSpec](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#metricspec-v2-autoscaling) array_ | | | | +| `enabled` _boolean_ | Deprecated: This field is ignored. | | | +| `minReplicas` _integer_ | Deprecated: This field is ignored. | | | +| `maxReplicas` _integer_ | Deprecated: This field is ignored. | | | +| `behavior` _[HorizontalPodAutoscalerBehavior](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#horizontalpodautoscalerbehavior-v2-autoscaling)_ | Deprecated: This field is ignored. | | | +| `metrics` _[MetricSpec](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#metricspec-v2-autoscaling) array_ | Deprecated: This field is ignored. | | | @@ -165,7 +168,7 @@ _Appears in:_ | `dynamoNamespace` _string_ | DynamoNamespace is deprecated and will be removed in a future version.
The DGD Kubernetes namespace and DynamoGraphDeployment name are used to construct the Dynamo namespace for each component | | Optional: \{\}
| | `globalDynamoNamespace` _boolean_ | GlobalDynamoNamespace indicates that the Component will be placed in the global Dynamo namespace | | | | `resources` _[Resources](#resources)_ | Resources requested and limits for this component, including CPU, memory,
GPUs/devices, and any runtime-specific resources. | | | -| `autoscaling` _[Autoscaling](#autoscaling)_ | Autoscaling config for this component (replica range, target utilization, etc.). | | | +| `autoscaling` _[Autoscaling](#autoscaling)_ | Deprecated: This field is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter
with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md
for migration guidance. This field will be removed in a future API version. | | | | `envs` _[EnvVar](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#envvar-v1-core) array_ | Envs defines additional environment variables to inject into the component containers. | | | | `envFromSecret` _string_ | EnvFromSecret references a Secret whose key/value pairs will be exposed as
environment variables in the component containers. | | | | `volumeMounts` _[VolumeMount](#volumemount) array_ | VolumeMounts references PVCs defined at the top level for volumes to be mounted by the component. | | | @@ -176,8 +179,9 @@ _Appears in:_ | `extraPodSpec` _[ExtraPodSpec](#extrapodspec)_ | ExtraPodSpec allows to override the main pod spec configuration.
It is a k8s standard PodSpec. It also contains a MainContainer (standard k8s Container) field
that allows overriding the main container configuration. | | | | `livenessProbe` _[Probe](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#probe-v1-core)_ | LivenessProbe to detect and restart unhealthy containers. | | | | `readinessProbe` _[Probe](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#probe-v1-core)_ | ReadinessProbe to signal when the container is ready to receive traffic. | | | -| `replicas` _integer_ | Replicas is the desired number of Pods for this component when autoscaling is not used. | | | +| `replicas` _integer_ | Replicas is the desired number of Pods for this component.
When scalingAdapter is enabled (default), this field is managed by the
DynamoGraphDeploymentScalingAdapter and should not be modified directly. | | Minimum: 0
| | `multinode` _[MultinodeSpec](#multinodespec)_ | Multinode is the configuration for multinode components. | | | +| `scalingAdapter` _[ScalingAdapter](#scalingadapter)_ | ScalingAdapter configures whether this service uses the DynamoGraphDeploymentScalingAdapter.
When enabled (default), replicas are managed via DGDSA and external autoscalers can scale
the service using the Scale subresource. When disabled, replicas can be modified directly. | | | #### DynamoComponentDeploymentSpec @@ -202,7 +206,7 @@ _Appears in:_ | `dynamoNamespace` _string_ | DynamoNamespace is deprecated and will be removed in a future version.
The DGD Kubernetes namespace and DynamoGraphDeployment name are used to construct the Dynamo namespace for each component | | Optional: \{\}
| | `globalDynamoNamespace` _boolean_ | GlobalDynamoNamespace indicates that the Component will be placed in the global Dynamo namespace | | | | `resources` _[Resources](#resources)_ | Resources requested and limits for this component, including CPU, memory,
GPUs/devices, and any runtime-specific resources. | | | -| `autoscaling` _[Autoscaling](#autoscaling)_ | Autoscaling config for this component (replica range, target utilization, etc.). | | | +| `autoscaling` _[Autoscaling](#autoscaling)_ | Deprecated: This field is deprecated and ignored. Use DynamoGraphDeploymentScalingAdapter
with HPA, KEDA, or Planner for autoscaling instead. See docs/kubernetes/autoscaling.md
for migration guidance. This field will be removed in a future API version. | | | | `envs` _[EnvVar](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#envvar-v1-core) array_ | Envs defines additional environment variables to inject into the component containers. | | | | `envFromSecret` _string_ | EnvFromSecret references a Secret whose key/value pairs will be exposed as
environment variables in the component containers. | | | | `volumeMounts` _[VolumeMount](#volumemount) array_ | VolumeMounts references PVCs defined at the top level for volumes to be mounted by the component. | | | @@ -213,8 +217,9 @@ _Appears in:_ | `extraPodSpec` _[ExtraPodSpec](#extrapodspec)_ | ExtraPodSpec allows to override the main pod spec configuration.
It is a k8s standard PodSpec. It also contains a MainContainer (standard k8s Container) field
that allows overriding the main container configuration. | | | | `livenessProbe` _[Probe](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#probe-v1-core)_ | LivenessProbe to detect and restart unhealthy containers. | | | | `readinessProbe` _[Probe](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#probe-v1-core)_ | ReadinessProbe to signal when the container is ready to receive traffic. | | | -| `replicas` _integer_ | Replicas is the desired number of Pods for this component when autoscaling is not used. | | | +| `replicas` _integer_ | Replicas is the desired number of Pods for this component.
When scalingAdapter is enabled (default), this field is managed by the
DynamoGraphDeploymentScalingAdapter and should not be modified directly. | | Minimum: 0
| | `multinode` _[MultinodeSpec](#multinodespec)_ | Multinode is the configuration for multinode components. | | | +| `scalingAdapter` _[ScalingAdapter](#scalingadapter)_ | ScalingAdapter configures whether this service uses the DynamoGraphDeploymentScalingAdapter.
When enabled (default), replicas are managed via DGDSA and external autoscalers can scale
the service using the Scale subresource. When disabled, replicas can be modified directly. | | | #### DynamoGraphDeployment @@ -314,6 +319,83 @@ _Appears in:_ | `deployment` _[DeploymentStatus](#deploymentstatus)_ | Deployment tracks the auto-created DGD when AutoApply is true.
Contains name, namespace, state, and creation status of the managed DGD. | | Optional: \{\}
| +#### DynamoGraphDeploymentScalingAdapter + + + +DynamoGraphDeploymentScalingAdapter provides a scaling interface for individual services +within a DynamoGraphDeployment. It implements the Kubernetes scale +subresource, enabling integration with HPA, KEDA, and custom autoscalers. + +The adapter acts as an intermediary between autoscalers and the DGD, +ensuring that only the adapter controller modifies the DGD's service replicas. +This prevents conflicts when multiple autoscaling mechanisms are in play. + + + + + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `apiVersion` _string_ | `nvidia.com/v1alpha1` | | | +| `kind` _string_ | `DynamoGraphDeploymentScalingAdapter` | | | +| `metadata` _[ObjectMeta](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#objectmeta-v1-meta)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | +| `spec` _[DynamoGraphDeploymentScalingAdapterSpec](#dynamographdeploymentscalingadapterspec)_ | | | | +| `status` _[DynamoGraphDeploymentScalingAdapterStatus](#dynamographdeploymentscalingadapterstatus)_ | | | | + + +#### DynamoGraphDeploymentScalingAdapterSpec + + + +DynamoGraphDeploymentScalingAdapterSpec defines the desired state of DynamoGraphDeploymentScalingAdapter + + + +_Appears in:_ +- [DynamoGraphDeploymentScalingAdapter](#dynamographdeploymentscalingadapter) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `replicas` _integer_ | Replicas is the desired number of replicas for the target service.
This field is modified by external autoscalers (HPA/KEDA/Planner) or manually by users. | | Minimum: 0
Required: \{\}
| +| `dgdRef` _[DynamoGraphDeploymentServiceRef](#dynamographdeploymentserviceref)_ | DGDRef references the DynamoGraphDeployment and the specific service to scale. | | Required: \{\}
| + + +#### DynamoGraphDeploymentScalingAdapterStatus + + + +DynamoGraphDeploymentScalingAdapterStatus defines the observed state of DynamoGraphDeploymentScalingAdapter + + + +_Appears in:_ +- [DynamoGraphDeploymentScalingAdapter](#dynamographdeploymentscalingadapter) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `replicas` _integer_ | Replicas is the current number of replicas for the target service.
This is synced from the DGD's service replicas and is required for the scale subresource. | | | +| `selector` _string_ | Selector is a label selector string for the pods managed by this adapter.
Required for HPA compatibility via the scale subresource. | | | +| `lastScaleTime` _[Time](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#time-v1-meta)_ | LastScaleTime is the last time the adapter scaled the target service. | | | + + +#### DynamoGraphDeploymentServiceRef + + + +DynamoGraphDeploymentServiceRef identifies a specific service within a DynamoGraphDeployment + + + +_Appears in:_ +- [DynamoGraphDeploymentScalingAdapterSpec](#dynamographdeploymentscalingadapterspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `name` _string_ | Name of the DynamoGraphDeployment | | MinLength: 1
Required: \{\}
| +| `serviceName` _string_ | ServiceName is the key name of the service within the DGD's spec.services map to scale | | MinLength: 1
Required: \{\}
| + + #### DynamoGraphDeploymentSpec @@ -638,6 +720,25 @@ _Appears in:_ | `claims` _[ResourceClaim](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#resourceclaim-v1-core) array_ | Claims specifies resource claims for dynamic resource allocation | | | +#### ScalingAdapter + + + +ScalingAdapter configures whether a service uses the DynamoGraphDeploymentScalingAdapter +for replica management. When enabled (default), the DGDSA owns the replicas field and +external autoscalers (HPA, KEDA, Planner) can control scaling via the Scale subresource. + + + +_Appears in:_ +- [DynamoComponentDeploymentSharedSpec](#dynamocomponentdeploymentsharedspec) +- [DynamoComponentDeploymentSpec](#dynamocomponentdeploymentspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `disable` _boolean_ | Disable indicates whether the ScalingAdapter should be disabled for this service.
When false (default), a DGDSA is created and owns the replicas field.
When true, no DGDSA is created and replicas can be modified directly in the DGD. | false | | + + #### SharedMemorySpec diff --git a/docs/kubernetes/autoscaling.md b/docs/kubernetes/autoscaling.md new file mode 100644 index 0000000000..8adaf09107 --- /dev/null +++ b/docs/kubernetes/autoscaling.md @@ -0,0 +1,733 @@ +# Autoscaling + +This guide explains how to configure autoscaling for DynamoGraphDeployment (DGD) services using the `sglang-agg` example from `examples/backends/sglang/deploy/agg.yaml`. + +## Example DGD + +All examples in this guide use the following DGD: + +```yaml +# examples/backends/sglang/deploy/agg.yaml +apiVersion: nvidia.com/v1alpha1 +kind: DynamoGraphDeployment +metadata: + name: sglang-agg + namespace: default +spec: + services: + Frontend: + dynamoNamespace: sglang-agg + componentType: frontend + replicas: 1 + + decode: + dynamoNamespace: sglang-agg + componentType: worker + replicas: 1 + resources: + limits: + gpu: "1" +``` + +**Key identifiers:** +- **DGD name**: `sglang-agg` +- **Namespace**: `default` +- **Services**: `Frontend`, `decode` +- **dynamo_namespace label**: `default-sglang-agg` (used for metric filtering) + +## Overview + +Dynamo provides flexible autoscaling through the `DynamoGraphDeploymentScalingAdapter` (DGDSA) resource. When you deploy a DGD, the operator automatically creates one adapter per service (unless explicitly disabled). These adapters implement the Kubernetes [Scale subresource](https://kubernetes.io/docs/tasks/extend-kubernetes/custom-resources/custom-resource-definitions/#scale-subresource), enabling integration with: + +| Autoscaler | Description | Best For | +|------------|-------------|----------| +| **KEDA** | Event-driven autoscaling (recommended) | Most use cases | +| **Kubernetes HPA** | Native horizontal pod autoscaling | Simple CPU/memory-based scaling | +| **Dynamo Planner** | LLM-aware autoscaling with SLA optimization | Production LLM workloads | +| **Custom Controllers** | Any scale-subresource-compatible controller | Custom requirements | + +> **โš ๏ธ Deprecation Notice**: The `spec.services[X].autoscaling` field in DGD is **deprecated and ignored**. Use DGDSA with HPA, KEDA, or Planner instead. If you have existing DGDs with `autoscaling` configured, you'll see a warning. Remove the field to silence the warning. + +## Architecture + +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ DynamoGraphDeployment โ”‚ โ”‚ Scaling Adapters (auto-created) โ”‚ +โ”‚ "sglang-agg" โ”‚ โ”‚ (one per service) โ”‚ +โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค +โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ spec.services: โ”‚ โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ โ”‚ โ”‚ โ”‚ sglang-agg-frontend โ”‚โ—„โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”‚ Autoscalers โ”‚ +โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”โ—„โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”‚ spec.replicas: 1 โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ”‚ Frontend: 1 replica โ”‚ โ”‚ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ”‚ โ€ข KEDA โ”‚ +โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ”‚ โ”‚ โ”‚ โ€ข HPA โ”‚ +โ”‚ โ”‚ โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ โ”‚ โ€ข Planner โ”‚ +โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”โ—„โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”‚ sglang-agg-decode โ”‚โ—„โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”‚ โ€ข Custom โ”‚ +โ”‚ โ”‚ decode: 1 replica โ”‚ โ”‚ โ”‚ โ”‚ spec.replicas: 1 โ”‚ โ”‚ โ”‚ โ”‚ +โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +โ”‚ โ”‚ โ”‚ โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +**How it works:** + +1. You deploy a DGD with services (Frontend, decode) +2. The operator auto-creates one DGDSA per service +3. Autoscalers (KEDA, HPA, Planner) target the adapters via `/scale` subresource +4. Adapter controller syncs replica changes to the DGD +5. DGD controller reconciles the underlying pods + +## Viewing Scaling Adapters + +After deploying the `sglang-agg` DGD, verify the auto-created adapters: + +```bash +kubectl get dgdsa -n default + +# Example output: +# NAME DGD SERVICE REPLICAS AGE +# sglang-agg-frontend sglang-agg Frontend 1 5m +# sglang-agg-decode sglang-agg decode 1 5m +``` + +## Replica Ownership Model + +When DGDSA is enabled (the default), it becomes the **source of truth** for replica counts. This follows the same pattern as Kubernetes Deployments owning ReplicaSets. + +### How It Works + +1. **DGDSA owns replicas**: Autoscalers (HPA, KEDA, Planner) update the DGDSA's `spec.replicas` +2. **DGDSA syncs to DGD**: The DGDSA controller writes the replica count to the DGD's service +3. **Direct DGD edits blocked**: A validating webhook prevents users from directly editing `spec.services[X].replicas` in the DGD +4. **Controllers allowed**: Only authorized controllers (operator, Planner) can modify DGD replicas + +### Manual Scaling with DGDSA Enabled + +When DGDSA is enabled, use `kubectl scale` on the adapter (not the DGD): + +```bash +# โœ… Correct - scale via DGDSA +kubectl scale dgdsa sglang-agg-decode --replicas=3 + +# โŒ Blocked - direct DGD edit rejected by webhook +kubectl patch dgd sglang-agg --type=merge -p '{"spec":{"services":{"decode":{"replicas":3}}}}' +# Error: spec.services[decode].replicas cannot be modified directly when scaling adapter is enabled; +# use 'kubectl scale dgdsa/sglang-agg-decode --replicas=3' or update the DynamoGraphDeploymentScalingAdapter instead +``` + +## Disabling DGDSA for a Service + +If you want to manage replicas directly in the DGD (without autoscaling), you can disable the scaling adapter per service: + +```yaml +apiVersion: nvidia.com/v1alpha1 +kind: DynamoGraphDeployment +metadata: + name: sglang-agg +spec: + services: + Frontend: + replicas: 2 + scalingAdapter: + disable: true # โ† No DGDSA created, direct edits allowed + + decode: + replicas: 1 # โ† DGDSA created by default, managed via adapter +``` + +**When to disable DGDSA:** +- You want simple, manual replica management +- You don't need autoscaling for that service +- You prefer direct DGD edits over adapter-based scaling + +**When to keep DGDSA enabled (default):** +- You want to use HPA, KEDA, or Planner for autoscaling +- You want a clear separation between "desired scale" (adapter) and "deployment config" (DGD) +- You want protection against accidental direct replica edits + +## Autoscaling with Dynamo Planner + +The Dynamo Planner is an LLM-aware autoscaler that optimizes scaling decisions based on inference-specific metrics like Time To First Token (TTFT), Inter-Token Latency (ITL), and KV cache utilization. + +**When to use Planner:** +- You want LLM-optimized autoscaling out of the box +- You need coordinated scaling across prefill/decode services +- You want SLA-driven scaling (e.g., target TTFT < 500ms) + +**How Planner works:** + +Planner is deployed as a service component within your DGD. It: +1. Queries Prometheus for frontend metrics (request rate, latency, etc.) +2. Uses profiling data to predict optimal replica counts +3. Scales prefill/decode workers to meet SLA targets + +**Deployment:** + +The recommended way to deploy Planner is via `DynamoGraphDeploymentRequest` (DGDR). See the [SLA Planner Quick Start](../planner/sla_planner_quickstart.md) for complete instructions. + +Example configurations with Planner: +- `examples/backends/vllm/deploy/disagg_planner.yaml` +- `examples/backends/sglang/deploy/disagg_planner.yaml` +- `examples/backends/trtllm/deploy/disagg_planner.yaml` + +For more details, see the [SLA Planner documentation](../planner/sla_planner.md). + +## Autoscaling with Kubernetes HPA + +The Horizontal Pod Autoscaler (HPA) is Kubernetes' native autoscaling solution. + +**When to use HPA:** +- You have simple, predictable scaling requirements +- You want to use standard Kubernetes tooling +- You need CPU or memory-based scaling + +> **Note**: For custom metrics (like TTFT or queue depth), consider using [KEDA](#autoscaling-with-keda-recommended) instead - it's simpler to configure. + +### Basic HPA (CPU-based) + +```yaml +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: sglang-agg-frontend-hpa + namespace: default +spec: + scaleTargetRef: + apiVersion: nvidia.com/v1alpha1 + kind: DynamoGraphDeploymentScalingAdapter + name: sglang-agg-frontend + minReplicas: 1 + maxReplicas: 10 + metrics: + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: 70 + behavior: + scaleDown: + stabilizationWindowSeconds: 300 + scaleUp: + stabilizationWindowSeconds: 0 +``` + +### HPA with Dynamo Metrics + +Dynamo exports several metrics useful for autoscaling. These are available at the `/metrics` endpoint on each frontend pod. + +> **See also**: For a complete list of all Dynamo metrics, see the [Metrics Reference](../observability/metrics.md). For Prometheus and Grafana setup, see the [Prometheus and Grafana Setup Guide](../observability/prometheus-grafana.md). + +#### Available Dynamo Metrics + +| Metric | Type | Description | Good for scaling | +|--------|------|-------------|------------------| +| `dynamo_frontend_queued_requests` | Gauge | Requests waiting in HTTP queue | โœ… Workers | +| `dynamo_frontend_inflight_requests` | Gauge | Concurrent requests to engine | โœ… All services | +| `dynamo_frontend_time_to_first_token_seconds` | Histogram | TTFT latency | โœ… Workers | +| `dynamo_frontend_inter_token_latency_seconds` | Histogram | ITL latency | โœ… Decode | +| `dynamo_frontend_request_duration_seconds` | Histogram | Total request duration | โš ๏ธ General | +| `kvstats_gpu_cache_usage_percent` | Gauge | GPU KV cache usage (0-1) | โœ… Decode | + +#### Metric Labels + +Dynamo metrics include these labels for filtering: + +| Label | Description | Example | +|-------|-------------|---------| +| `dynamo_namespace` | Unique DGD identifier (`{k8s-namespace}-{dynamoNamespace}`) | `default-sglang-agg` | +| `model` | Model being served | `Qwen/Qwen3-0.6B` | + +> **Note**: When you have multiple DGDs in the same namespace, use `dynamo_namespace` to filter metrics for a specific DGD. + +#### Example: Scale Decode Service Based on TTFT + +Using HPA with Prometheus Adapter requires configuring external metrics. + +**Step 1: Configure Prometheus Adapter** + +Add this to your Helm values file (e.g., `prometheus-adapter-values.yaml`): + +```yaml +# prometheus-adapter-values.yaml +prometheus: + url: http://prometheus-kube-prometheus-prometheus.monitoring.svc + port: 9090 + +rules: + external: + # TTFT p95 from frontend - used to scale decode + - seriesQuery: 'dynamo_frontend_time_to_first_token_seconds_bucket{namespace!=""}' + resources: + overrides: + namespace: {resource: "namespace"} + name: + as: "dynamo_ttft_p95_seconds" + metricsQuery: | + histogram_quantile(0.95, + sum(rate(dynamo_frontend_time_to_first_token_seconds_bucket{<<.LabelMatchers>>}[5m])) + by (le, namespace, dynamo_namespace) + ) +``` + +**Step 2: Install Prometheus Adapter** + +```bash +helm repo add prometheus-community https://prometheus-community.github.io/helm-charts +helm repo update + +helm upgrade --install prometheus-adapter prometheus-community/prometheus-adapter \ + -n monitoring --create-namespace \ + -f prometheus-adapter-values.yaml +``` + +**Step 3: Verify the metric is available** + +```bash +kubectl get --raw "/apis/external.metrics.k8s.io/v1beta1/namespaces//dynamo_ttft_p95_seconds" | jq +``` + +**Step 4: Create the HPA** + +```yaml +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: sglang-agg-decode-hpa +spec: + scaleTargetRef: + apiVersion: nvidia.com/v1alpha1 + kind: DynamoGraphDeploymentScalingAdapter + name: sglang-agg-decode # โ† DGD name + service name (lowercase) + minReplicas: 1 + maxReplicas: 10 + metrics: + - type: External + external: + metric: + name: dynamo_ttft_p95_seconds + selector: + matchLabels: + dynamo_namespace: "default-sglang-agg" # โ† {namespace}-{dynamoNamespace} + target: + type: Value + value: "500m" # Scale up when TTFT p95 > 500ms + behavior: + scaleDown: + stabilizationWindowSeconds: 60 # Wait 1 min before scaling down + policies: + - type: Pods + value: 1 + periodSeconds: 30 + scaleUp: + stabilizationWindowSeconds: 0 # Scale up immediately + policies: + - type: Pods + value: 2 + periodSeconds: 30 +``` + +**How it works:** +1. Frontend pods export `dynamo_frontend_time_to_first_token_seconds` histogram +2. Prometheus Adapter calculates p95 TTFT per `dynamo_namespace` +3. HPA monitors this metric filtered by `dynamo_namespace: "default-sglang-agg"` +4. When TTFT p95 > 500ms, HPA scales up the `sglang-agg-decode` adapter +5. Adapter controller syncs the replica count to the DGD's `decode` service +6. More decode workers are created, reducing TTFT + +#### Example: Scale Based on Queue Depth + +Add this rule to your `prometheus-adapter-values.yaml` (alongside the TTFT rule): + +```yaml +# Add to rules.external in prometheus-adapter-values.yaml +- seriesQuery: 'dynamo_frontend_queued_requests{namespace!=""}' + resources: + overrides: + namespace: {resource: "namespace"} + name: + as: "dynamo_queued_requests" + metricsQuery: | + sum(<<.Series>>{<<.LabelMatchers>>}) by (namespace, dynamo_namespace) +``` + +Then create the HPA: + +```yaml +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: sglang-agg-decode-queue-hpa + namespace: default +spec: + scaleTargetRef: + apiVersion: nvidia.com/v1alpha1 + kind: DynamoGraphDeploymentScalingAdapter + name: sglang-agg-decode + minReplicas: 1 + maxReplicas: 10 + metrics: + - type: External + external: + metric: + name: dynamo_queued_requests + selector: + matchLabels: + dynamo_namespace: "default-sglang-agg" + target: + type: Value + value: "10" # Scale up when queue > 10 requests +``` + +## Autoscaling with KEDA (Recommended) + +KEDA (Kubernetes Event-driven Autoscaling) extends Kubernetes with event-driven autoscaling, supporting 50+ scalers including Prometheus. + +**Advantages over HPA + Prometheus Adapter:** +- No Prometheus Adapter configuration needed +- PromQL queries are defined in the ScaledObject itself (declarative, per-deployment) +- Easy to update - just `kubectl apply` the ScaledObject +- Can scale to zero when idle +- Supports multiple triggers per object + +**When to use KEDA:** +- You want simpler configuration (no Prometheus Adapter to manage) +- You need event-driven scaling (e.g., queue depth, Kafka, etc.) +- You want to scale to zero when idle + +### Installing KEDA + +```bash +# Add KEDA Helm repo +helm repo add kedacore https://kedacore.github.io/charts +helm repo update + +# Install KEDA +helm install keda kedacore/keda \ + --namespace keda \ + --create-namespace + +# Verify installation +kubectl get pods -n keda +``` + +> **Note**: If you have Prometheus Adapter installed, either uninstall it first (`helm uninstall prometheus-adapter -n monitoring`) or install KEDA with `--set metricsServer.enabled=false` to avoid API conflicts. + +### Example: Scale Decode Based on TTFT + +Using the `sglang-agg` DGD from `examples/backends/sglang/deploy/agg.yaml`: + +```yaml +apiVersion: keda.sh/v1alpha1 +kind: ScaledObject +metadata: + name: sglang-agg-decode-scaler + namespace: default +spec: + scaleTargetRef: + apiVersion: nvidia.com/v1alpha1 + kind: DynamoGraphDeploymentScalingAdapter + name: sglang-agg-decode + minReplicaCount: 1 + maxReplicaCount: 10 + pollingInterval: 15 # Check metrics every 15 seconds + cooldownPeriod: 60 # Wait 60s before scaling down + triggers: + - type: prometheus + metadata: + # Update this URL to match your Prometheus service + serverAddress: http://prometheus-kube-prometheus-prometheus.monitoring.svc:9090 + metricName: dynamo_ttft_p95 + query: | + histogram_quantile(0.95, + sum(rate(dynamo_frontend_time_to_first_token_seconds_bucket{dynamo_namespace="default-sglang-agg"}[5m])) + by (le) + ) + threshold: "0.5" # Scale up when TTFT p95 > 500ms (0.5 seconds) + activationThreshold: "0.1" # Start scaling when TTFT > 100ms +``` + +Apply it: + +```bash +kubectl apply -f sglang-agg-decode-scaler.yaml +``` + +### Verify KEDA Scaling + +```bash +# Check ScaledObject status +kubectl get scaledobject -n default + +# KEDA creates an HPA under the hood - you can see it +kubectl get hpa -n default + +# Example output: +# NAME REFERENCE TARGETS MINPODS MAXPODS REPLICAS +# keda-hpa-sglang-agg-decode-scaler DynamoGraphDeploymentScalingAdapter/sglang-agg-decode 45m/500m 1 10 1 + +# Get detailed status +kubectl describe scaledobject sglang-agg-decode-scaler -n default +``` + +### Example: Scale Based on Queue Depth + +```yaml +apiVersion: keda.sh/v1alpha1 +kind: ScaledObject +metadata: + name: sglang-agg-decode-queue-scaler + namespace: default +spec: + scaleTargetRef: + apiVersion: nvidia.com/v1alpha1 + kind: DynamoGraphDeploymentScalingAdapter + name: sglang-agg-decode + minReplicaCount: 1 + maxReplicaCount: 10 + pollingInterval: 15 + cooldownPeriod: 60 + triggers: + - type: prometheus + metadata: + serverAddress: http://prometheus-kube-prometheus-prometheus.monitoring.svc:9090 + metricName: dynamo_queued_requests + query: | + sum(dynamo_frontend_queued_requests{dynamo_namespace="default-sglang-agg"}) + threshold: "10" # Scale up when queue > 10 requests +``` + +### How KEDA Works + +KEDA creates and manages an HPA under the hood: + +``` +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ You create: ScaledObject โ”‚ +โ”‚ - scaleTargetRef: sglang-agg-decode โ”‚ +โ”‚ - triggers: prometheus query โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ–ผ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ KEDA Operator automatically creates: HPA โ”‚ +โ”‚ - name: keda-hpa-sglang-agg-decode-scaler โ”‚ +โ”‚ - scaleTargetRef: sglang-agg-decode โ”‚ +โ”‚ - metrics: External (from KEDA metrics server) โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ–ผ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ DynamoGraphDeploymentScalingAdapter: sglang-agg-decode โ”‚ +โ”‚ - spec.replicas: updated by HPA โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + โ”‚ + โ–ผ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ DynamoGraphDeployment: sglang-agg โ”‚ +โ”‚ - spec.services.decode.replicas: synced from adapter โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +## Mixed Autoscaling + +For disaggregated deployments (prefill + decode), you can use different autoscaling strategies for different services: + +```yaml +--- +# HPA for Frontend (CPU-based) +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: sglang-agg-frontend-hpa + namespace: default +spec: + scaleTargetRef: + apiVersion: nvidia.com/v1alpha1 + kind: DynamoGraphDeploymentScalingAdapter + name: sglang-agg-frontend + minReplicas: 1 + maxReplicas: 5 + metrics: + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: 70 + +--- +# KEDA for Decode (TTFT-based) +apiVersion: keda.sh/v1alpha1 +kind: ScaledObject +metadata: + name: sglang-agg-decode-scaler + namespace: default +spec: + scaleTargetRef: + apiVersion: nvidia.com/v1alpha1 + kind: DynamoGraphDeploymentScalingAdapter + name: sglang-agg-decode + minReplicaCount: 1 + maxReplicaCount: 10 + triggers: + - type: prometheus + metadata: + serverAddress: http://prometheus-kube-prometheus-prometheus.monitoring.svc:9090 + query: | + histogram_quantile(0.95, + sum(rate(dynamo_frontend_time_to_first_token_seconds_bucket{dynamo_namespace="default-sglang-agg"}[5m])) + by (le) + ) + threshold: "0.5" +``` + +## Manual Scaling + +### With DGDSA Enabled (Default) + +When DGDSA is enabled (the default), scale via the adapter: + +```bash +kubectl scale dgdsa sglang-agg-decode -n default --replicas=3 +``` + +Verify the scaling: + +```bash +kubectl get dgdsa sglang-agg-decode -n default + +# Output: +# NAME DGD SERVICE REPLICAS AGE +# sglang-agg-decode sglang-agg decode 3 10m +``` + +> **Note**: If an autoscaler (KEDA, HPA, Planner) is managing the adapter, your change will be overwritten on the next evaluation cycle. + +### With DGDSA Disabled + +If you've disabled the scaling adapter for a service, edit the DGD directly: + +```bash +kubectl patch dgd sglang-agg --type=merge -p '{"spec":{"services":{"decode":{"replicas":3}}}}' +``` + +Or edit the YAML: + +```yaml +spec: + services: + decode: + replicas: 3 + scalingAdapter: + disable: true +``` + +## Best Practices + +### 1. Choose One Autoscaler Per Service + +Avoid configuring multiple autoscalers for the same service: + +| Configuration | Status | +|---------------|--------| +| HPA for frontend, Planner for prefill/decode | โœ… Good | +| KEDA for all services | โœ… Good | +| Planner only (default) | โœ… Good | +| HPA + Planner both targeting decode | โŒ Bad - they will fight | + +### 2. Use Appropriate Metrics + +| Service Type | Recommended Metrics | Dynamo Metric | +|--------------|---------------------|---------------| +| Frontend | CPU utilization, request rate | `dynamo_frontend_requests_total` | +| Prefill | Queue depth, TTFT | `dynamo_frontend_queued_requests`, `dynamo_frontend_time_to_first_token_seconds` | +| Decode | KV cache utilization, ITL | `kvstats_gpu_cache_usage_percent`, `dynamo_frontend_inter_token_latency_seconds` | + +### 3. Configure Stabilization Windows + +Prevent thrashing with appropriate stabilization: + +```yaml +# HPA +behavior: + scaleDown: + stabilizationWindowSeconds: 300 # Wait 5 min before scaling down + scaleUp: + stabilizationWindowSeconds: 0 # Scale up immediately + +# KEDA +spec: + cooldownPeriod: 300 +``` + +### 4. Set Sensible Min/Max Replicas + +Always configure minimum and maximum replicas in your HPA/KEDA to prevent: +- Scaling to zero (unless intentional) +- Unbounded scaling that exhausts cluster resources + +## Troubleshooting + +### Adapters Not Created + +```bash +# Check DGD status +kubectl describe dgd sglang-agg -n default + +# Check operator logs +kubectl logs -n dynamo-system deployment/dynamo-operator +``` + +### Scaling Not Working + +```bash +# Check adapter status +kubectl describe dgdsa sglang-agg-decode -n default + +# Check HPA/KEDA status +kubectl describe hpa sglang-agg-decode-hpa -n default +kubectl describe scaledobject sglang-agg-decode-scaler -n default + +# Verify metrics are available in Kubernetes metrics API +kubectl get --raw /apis/external.metrics.k8s.io/v1beta1 +``` + +### Metrics Not Available + +If HPA/KEDA shows `` for metrics: + +```bash +# Check if Dynamo metrics are being scraped +kubectl port-forward -n default svc/sglang-agg-frontend 8000:8000 +curl http://localhost:8000/metrics | grep dynamo_frontend + +# Example output: +# dynamo_frontend_queued_requests{model="Qwen/Qwen3-0.6B"} 2 +# dynamo_frontend_inflight_requests{model="Qwen/Qwen3-0.6B"} 5 + +# Verify Prometheus is scraping the metrics +kubectl port-forward -n monitoring svc/prometheus-kube-prometheus-prometheus 9090:9090 +# Then query: dynamo_frontend_time_to_first_token_seconds_bucket + +# Check KEDA operator logs +kubectl logs -n keda deployment/keda-operator +``` + +### Rapid Scaling Up and Down + +If you see unstable scaling: + +1. Check if multiple autoscalers are targeting the same adapter +2. Increase `cooldownPeriod` in KEDA ScaledObject +3. Increase `stabilizationWindowSeconds` in HPA behavior + +## References + +- [Kubernetes HPA Documentation](https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/) +- [KEDA Documentation](https://keda.sh/) +- [Prometheus Adapter](https://github.com/kubernetes-sigs/prometheus-adapter) +- [Planner Documentation](../planner/sla_planner.md) +- [Dynamo Metrics Reference](../observability/metrics.md) +- [Prometheus and Grafana Setup](../observability/prometheus-grafana.md) + diff --git a/docs/kvbm/trtllm-setup.md b/docs/kvbm/trtllm-setup.md index 3884fad4c2..17975e05cf 100644 --- a/docs/kvbm/trtllm-setup.md +++ b/docs/kvbm/trtllm-setup.md @@ -23,7 +23,6 @@ To learn what KVBM is, please check [here](kvbm_architecture.md) > [!Note] > - Ensure that `etcd` and `nats` are running before starting. -> - KVBM does not currently support CUDA graphs in TensorRT-LLM. > - KVBM only supports TensorRT-LLMโ€™s PyTorch backend. > - Disable partial reuse `enable_partial_reuse: false` in the LLM API configโ€™s `kv_connector_config` to increase offloading cache hits. > - KVBM requires TensorRT-LLM v1.1.0rc5 or newer. diff --git a/docs/observability/health-checks.md b/docs/observability/health-checks.md index 07dacaf0b2..980401fd03 100644 --- a/docs/observability/health-checks.md +++ b/docs/observability/health-checks.md @@ -20,6 +20,9 @@ orchestration frameworks such as Kubernetes. | `DYN_SYSTEM_HEALTH_PATH` | Custom health endpoint path | `/health` | `/custom/health` | | `DYN_SYSTEM_LIVE_PATH` | Custom liveness endpoint path | `/live` | `/custom/live` | | `DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS` | Endpoints required for ready state | none | `["generate"]` | +| `DYN_HEALTH_CHECK_ENABLED` | Enable canary health checks | `false` (K8s: `true`) | `true`, `false` | +| `DYN_CANARY_WAIT_TIME` | Seconds before sending canary health check | `10` | `5`, `30` | +| `DYN_HEALTH_CHECK_REQUEST_TIMEOUT` | Health check request timeout in seconds | `3` | `5`, `10` | ## Getting Started Quickly @@ -213,6 +216,127 @@ date: Wed, 03 Sep 2025 13:42:45 GMT } ``` +## Canary Health Checks (Active Monitoring) + +In addition to the HTTP endpoints described above, Dynamo includes a **canary health check** system that actively monitors worker endpoints. + +### Overview + +The canary health check system: +- **Monitors endpoint health** by sending periodic test requests to worker endpoints +- **Only activates during idle periods** - if there's ongoing traffic, health checks are skipped to avoid overhead +- **Automatically enabled in Kubernetes** deployments via the operator +- **Disabled by default** in local/development environments + +### How It Works + +1. **Idle Detection**: After no activity on an endpoint for a configurable wait time (default: 10 seconds), a canary health check is triggered +2. **Health Check Request**: A lightweight test request is sent to the endpoint with a minimal payload (generates 1 token) +3. **Activity Resets Timer**: If normal requests arrive, the canary timer resets and no health check is sent +4. **Timeout Handling**: If a health check doesn't respond within the timeout (default: 3 seconds), the endpoint is marked as unhealthy + +### Configuration + +#### In Kubernetes (Enabled by Default) + +Health checks are automatically enabled by the Dynamo operator. No additional configuration is required. + +```yaml +apiVersion: nvidia.com/v1alpha1 +kind: DynamoGraphDeployment +metadata: + name: my-deployment +spec: + services: + VllmWorker: + componentType: worker + replicas: 2 + # Health checks automatically enabled by operator +``` + +#### In Local/Development Environments (Disabled by Default) + +To enable health checks locally: + +```bash +# Enable health checks +export DYN_HEALTH_CHECK_ENABLED=true + +# Optional: Customize timing +export DYN_CANARY_WAIT_TIME=5 # Wait 5 seconds before sending health check +export DYN_HEALTH_CHECK_REQUEST_TIMEOUT=5 # 5 second timeout + +# Start worker +python -m dynamo.vllm --model Qwen/Qwen3-0.6B +``` + +#### Configuration Options + +| Environment Variable | Description | Default | Notes | +|---------------------|-------------|---------|-------| +| `DYN_HEALTH_CHECK_ENABLED` | Enable/disable canary health checks | `false` (K8s: `true`) | Automatically set to `true` in K8s | +| `DYN_CANARY_WAIT_TIME` | Seconds to wait (during idle) before sending health check | `10` | Lower values = more frequent checks | +| `DYN_HEALTH_CHECK_REQUEST_TIMEOUT` | Max seconds to wait for health check response | `3` | Higher values = more tolerance for slow responses | + +### Health Check Payloads + +Each backend defines its own minimal health check payload: + +- **vLLM**: Single token generation with minimal sampling options +- **TensorRT-LLM**: Single token with BOS token ID +- **SGLang**: Single token generation request + +These payloads are designed to: +- Complete quickly (< 100ms typically) +- Minimize GPU overhead +- Verify the full inference stack is working + +### Observing Health Checks + +When health checks are enabled, you'll see logs like: + +``` +INFO Health check manager started (canary_wait_time: 10s, request_timeout: 3s) +INFO Spawned health check task for endpoint: generate +INFO Canary timer expired for generate, sending health check +INFO Health check successful for generate +``` + +If an endpoint fails: + +``` +WARN Health check timeout for generate +ERROR Health check request failed for generate: connection refused +``` + +### When to Use Canary Health Checks + +**Enable in production (Kubernetes):** +- โœ… Detect unhealthy workers before they affect user traffic +- โœ… Enable faster failure detection and recovery +- โœ… Monitor worker availability continuously + +**Disable in development:** +- โœ… Reduce log noise during debugging +- โœ… Avoid overhead when not needed +- โœ… Simplify local testing + +### Troubleshooting + +**Health checks timing out:** +- Increase `DYN_HEALTH_CHECK_REQUEST_TIMEOUT` +- Check worker logs for errors +- Verify network connectivity + +**Too many health check logs:** +- Increase `DYN_CANARY_WAIT_TIME` to reduce frequency +- Or disable with `DYN_HEALTH_CHECK_ENABLED=false` in dev + +**Health checks not running:** +- Verify `DYN_HEALTH_CHECK_ENABLED=true` is set +- Check that `DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS` includes the endpoint +- Ensure the worker is serving the endpoint + ## Related Documentation - [Distributed Runtime Architecture](../design_docs/distributed_runtime.md) diff --git a/docs/observability/metrics.md b/docs/observability/metrics.md index 4b4781f761..55d0d3e8d3 100644 --- a/docs/observability/metrics.md +++ b/docs/observability/metrics.md @@ -152,6 +152,7 @@ The Dynamo HTTP Frontend (`python -m dynamo.frontend`) exposes `dynamo_frontend_ - `dynamo_frontend_queued_requests`: Number of requests in HTTP processing queue (gauge) - `dynamo_frontend_disconnected_clients`: Number of disconnected clients (gauge) - `dynamo_frontend_input_sequence_tokens`: Input sequence length (histogram) +- `dynamo_frontend_cached_tokens`: Number of cached tokens (prefix cache hits) per request (histogram) - `dynamo_frontend_inter_token_latency_seconds`: Inter-token latency (histogram) - `dynamo_frontend_output_sequence_tokens`: Output sequence length (histogram) - `dynamo_frontend_output_tokens_total`: Total number of output tokens generated (counter) diff --git a/docs/planner/sla_planner_quickstart.md b/docs/planner/sla_planner_quickstart.md index 4d0c375f6e..c4029a2a2d 100644 --- a/docs/planner/sla_planner_quickstart.md +++ b/docs/planner/sla_planner_quickstart.md @@ -179,6 +179,25 @@ kubectl port-forward svc/trtllm-disagg-frontend 8000:8000 -n $NAMESPACE curl http://localhost:8000/v1/models ``` +### Step 5 (Optional): Access the Planner Grafana Dashboard + +If you want to monitor the SLA Planner's decision-making in real-time, you can deploy the Planner Grafana dashboard. + +```bash +kubectl apply -n monitoring -f deploy/observability/k8s/grafana-planner-dashboard-configmap.yaml +``` + +Follow the instructions in [Dynamo Metrics Collection on Kubernetes](../kubernetes/observability/metrics.md) to access the Grafana UI and select the **Dynamo Planner Dashboard**. + +The dashboard displays: +- **Worker Counts & GPU Usage**: Current prefill/decode worker counts and cumulative GPU hours +- **Observed Metrics**: Real-time TTFT, ITL, request rate, and sequence lengths from Prometheus +- **Predicted Metrics**: Planner's load predictions and recommended replica counts +- **Correction Factors**: How the planner adjusts predictions based on observed vs expected performance + +> [!TIP] +> Use the **Namespace** dropdown at the top of the dashboard to filter metrics for your specific deployment namespace. + ## DGDR Configuration Details ### Required Fields diff --git a/docs/project.json b/docs/project.json index 3b94839f5f..a951ef7e58 100644 --- a/docs/project.json +++ b/docs/project.json @@ -1 +1 @@ -{"name": "NVIDIA Dynamo", "version": "latest"} \ No newline at end of file +{"name": "NVIDIA Dynamo", "version": "dev"} diff --git a/docs/reference/support-matrix.md b/docs/reference/support-matrix.md index 2efb446874..e6c862b8b4 100644 --- a/docs/reference/support-matrix.md +++ b/docs/reference/support-matrix.md @@ -58,12 +58,12 @@ If you are using a **GPU**, the following GPU models and architectures are suppo ### Build Dependency -| **Build Dependency** | **Version as of Dynamo v0.7.0** | -| :------------------- | :------------------------------------------------------------------------------- | -| **SGLang** | 0.5.3.post4 | -| **TensorRT-LLM** | 1.2.0rc2 | -| **vLLM** | 0.11.0 | -| **NIXL** | 0.7.1 | +| **Build Dependency** | **Version as of Dynamo v0.7.0** | +| :------------------- | :------------------------------ | +| **SGLang** | 0.5.3.post4 | +| **TensorRT-LLM** | 1.2.0rc5 | +| **vLLM** | 0.11.0 | +| **NIXL** | 0.7.1 | > [!Important] diff --git a/docs/versions1.json b/docs/versions1.json deleted file mode 100644 index 856c91f1f9..0000000000 --- a/docs/versions1.json +++ /dev/null @@ -1,62 +0,0 @@ -[ - { - "preferred": true, - "version": "latest", - "url": "https://docs.nvidia.com/dynamo/latest/" - }, - { - "version": "0.7.0", - "url": "https://docs.nvidia.com/dynamo/archive/0.7.0/" - }, - { - "version": "0.6.1", - "url": "https://docs.nvidia.com/dynamo/archive/0.6.1/" - }, - { - "version": "0.6.0", - "url": "https://docs.nvidia.com/dynamo/archive/0.6.0/" - }, - { - "version": "0.5.1", - "url": "https://docs.nvidia.com/dynamo/archive/0.5.1/" - }, - { - "version": "0.5.0", - "url": "https://docs.nvidia.com/dynamo/archive/0.5.0/" - }, - { - "name": "0.4.1", - "version": "0.4.1", - "url": "https://docs.nvidia.com/dynamo/archive/0.4.1/" - }, - { - "name": "0.4.0", - "version": "0.4.0", - "url": "https://docs.nvidia.com/dynamo/archive/0.4.0/" - }, - { - "name": "0.3.2", - "version": "0.3.2", - "url": "https://docs.nvidia.com/dynamo/archive/0.3.2/" - }, - { - "name": "0.3.1", - "version": "0.3.1", - "url": "https://docs.nvidia.com/dynamo/archive/0.3.1/" - }, - { - "name": "0.3.0", - "version": "0.3.0", - "url": "https://docs.nvidia.com/dynamo/archive/0.3.0/" - }, - { - "name": "0.2.1", - "version": "0.2.1", - "url": "https://docs.nvidia.com/dynamo/archive/0.2.1/" - }, - { - "name": "0.2.0", - "version": "0.2.0", - "url": "https://docs.nvidia.com/dynamo/archive/0.2.0/" - } -] diff --git a/examples/backends/sglang/launch/agg.sh b/examples/backends/sglang/launch/agg.sh index 9ccb48f260..43e4f1f4af 100755 --- a/examples/backends/sglang/launch/agg.sh +++ b/examples/backends/sglang/launch/agg.sh @@ -46,10 +46,12 @@ while [[ $# -gt 0 ]]; do done # Enable tracing if requested +TRACE_ARGS=() if [ "$ENABLE_OTEL" = true ]; then export DYN_LOGGING_JSONL=true export OTEL_EXPORT_ENABLED=1 export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=${OTEL_EXPORTER_OTLP_TRACES_ENDPOINT:-http://localhost:4317} + TRACE_ARGS+=(--enable-trace --otlp-traces-endpoint localhost:4317) fi # run ingress @@ -59,7 +61,7 @@ python3 -m dynamo.frontend & DYNAMO_PID=$! # run worker with metrics enabled -DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \ +OTEL_SERVICE_NAME=dynamo-worker DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \ python3 -m dynamo.sglang \ --model-path "$MODEL" \ --served-model-name "$MODEL" \ @@ -68,4 +70,5 @@ python3 -m dynamo.sglang \ --trust-remote-code \ --skip-tokenizer-init \ --enable-metrics \ + "${TRACE_ARGS[@]}" \ "${EXTRA_ARGS[@]}" diff --git a/examples/backends/sglang/launch/agg_embed.sh b/examples/backends/sglang/launch/agg_embed.sh index 9064273f30..e78ebb2458 100755 --- a/examples/backends/sglang/launch/agg_embed.sh +++ b/examples/backends/sglang/launch/agg_embed.sh @@ -37,10 +37,12 @@ while [[ $# -gt 0 ]]; do done # Enable tracing if requested +TRACE_ARGS=() if [ "$ENABLE_OTEL" = true ]; then export DYN_LOGGING_JSONL=true export OTEL_EXPORT_ENABLED=1 export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=${OTEL_EXPORTER_OTLP_TRACES_ENDPOINT:-http://localhost:4317} + TRACE_ARGS+=(--enable-trace --otlp-traces-endpoint localhost:4317) fi # run ingress @@ -59,4 +61,5 @@ python3 -m dynamo.sglang \ --tp 1 \ --trust-remote-code \ --use-sglang-tokenizer \ - --enable-metrics + --enable-metrics \ + "${TRACE_ARGS[@]}" diff --git a/examples/backends/sglang/launch/agg_router.sh b/examples/backends/sglang/launch/agg_router.sh index 0b336f5f15..4cfca011f4 100755 --- a/examples/backends/sglang/launch/agg_router.sh +++ b/examples/backends/sglang/launch/agg_router.sh @@ -37,10 +37,12 @@ while [[ $# -gt 0 ]]; do done # Enable tracing if requested +TRACE_ARGS=() if [ "$ENABLE_OTEL" = true ]; then export DYN_LOGGING_JSONL=true export OTEL_EXPORT_ENABLED=1 export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=${OTEL_EXPORTER_OTLP_TRACES_ENDPOINT:-http://localhost:4317} + TRACE_ARGS+=(--enable-trace --otlp-traces-endpoint localhost:4317) fi # run ingress @@ -58,7 +60,8 @@ python3 -m dynamo.sglang \ --tp 1 \ --trust-remote-code \ --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5557"}' \ - --enable-metrics & + --enable-metrics \ + "${TRACE_ARGS[@]}" & WORKER_PID=$! OTEL_SERVICE_NAME=dynamo-worker-2 DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT_WORKER2:-8082} \ @@ -69,4 +72,5 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \ --tp 1 \ --trust-remote-code \ --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5558"}' \ - --enable-metrics + --enable-metrics \ + "${TRACE_ARGS[@]}" diff --git a/examples/backends/sglang/launch/disagg.sh b/examples/backends/sglang/launch/disagg.sh index 53e22fc723..9291ffb0c8 100755 --- a/examples/backends/sglang/launch/disagg.sh +++ b/examples/backends/sglang/launch/disagg.sh @@ -37,10 +37,12 @@ while [[ $# -gt 0 ]]; do done # Enable tracing if requested +TRACE_ARGS=() if [ "$ENABLE_OTEL" = true ]; then export DYN_LOGGING_JSONL=true export OTEL_EXPORT_ENABLED=1 export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=${OTEL_EXPORTER_OTLP_TRACES_ENDPOINT:-http://localhost:4317} + TRACE_ARGS+=(--enable-trace --otlp-traces-endpoint localhost:4317) fi # run ingress @@ -49,31 +51,38 @@ OTEL_SERVICE_NAME=dynamo-frontend \ python3 -m dynamo.frontend & DYNAMO_PID=$! +#AssertionError: Prefill round robin balance is required when dp size > 1. Please make sure that the prefill instance is launched with `--load-balance-method round_robin` and `--prefill-round-robin-balance` is set for decode server. + # run prefill worker OTEL_SERVICE_NAME=dynamo-worker-prefill DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT_PREFILL:-8081} \ python3 -m dynamo.sglang \ - --model-path Qwen/Qwen3-0.6B \ - --served-model-name Qwen/Qwen3-0.6B \ + --model-path silence09/DeepSeek-R1-Small-2layers \ + --served-model-name silence09/DeepSeek-R1-Small-2layers \ --page-size 16 \ - --tp 1 \ + --tp 2 --dp-size 2 --enable-dp-attention \ + --load-balance-method round_robin \ --trust-remote-code \ --disaggregation-mode prefill \ --disaggregation-bootstrap-port 12345 \ --host 0.0.0.0 \ + --port 40000 \ --disaggregation-transfer-backend nixl \ - --enable-metrics & + --enable-metrics \ + "${TRACE_ARGS[@]}" & PREFILL_PID=$! # run decode worker OTEL_SERVICE_NAME=dynamo-worker-decode DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT_DECODE:-8082} \ -CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \ - --model-path Qwen/Qwen3-0.6B \ - --served-model-name Qwen/Qwen3-0.6B \ +CUDA_VISIBLE_DEVICES=2,3 python3 -m dynamo.sglang \ + --model-path silence09/DeepSeek-R1-Small-2layers \ + --served-model-name silence09/DeepSeek-R1-Small-2layers \ --page-size 16 \ - --tp 1 \ + --prefill-round-robin-balance \ + --tp 2 --dp-size 2 --enable-dp-attention \ --trust-remote-code \ --disaggregation-mode decode \ --disaggregation-bootstrap-port 12345 \ --host 0.0.0.0 \ --disaggregation-transfer-backend nixl \ - --enable-metrics + --enable-metrics \ + "${TRACE_ARGS[@]}" diff --git a/examples/backends/sglang/launch/disagg_router.sh b/examples/backends/sglang/launch/disagg_router.sh index 916cbbf410..16a7db750e 100755 --- a/examples/backends/sglang/launch/disagg_router.sh +++ b/examples/backends/sglang/launch/disagg_router.sh @@ -38,10 +38,12 @@ while [[ $# -gt 0 ]]; do done # Enable tracing if requested +TRACE_ARGS=() if [ "$ENABLE_OTEL" = true ]; then export DYN_LOGGING_JSONL=true export OTEL_EXPORT_ENABLED=1 export OTEL_EXPORTER_OTLP_TRACES_ENDPOINT=${OTEL_EXPORTER_OTLP_TRACES_ENDPOINT:-http://localhost:4317} + TRACE_ARGS+=(--enable-trace --otlp-traces-endpoint localhost:4317) fi # run ingress @@ -74,7 +76,8 @@ python3 -m dynamo.sglang \ --host 0.0.0.0 \ --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5557"}' \ --disaggregation-transfer-backend nixl \ - --enable-metrics & + --enable-metrics \ + "${TRACE_ARGS[@]}" & PREFILL_PID=$! # run prefill worker @@ -89,7 +92,8 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \ --host 0.0.0.0 \ --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5558"}' \ --disaggregation-transfer-backend nixl \ - --enable-metrics & + --enable-metrics \ + "${TRACE_ARGS[@]}" & PREFILL_PID=$! # run decode worker @@ -104,7 +108,8 @@ CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.sglang \ --host 0.0.0.0 \ --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5560"}' \ --disaggregation-transfer-backend nixl \ - --enable-metrics & + --enable-metrics \ + "${TRACE_ARGS[@]}" & PREFILL_PID=$! # run decode worker @@ -119,4 +124,5 @@ CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \ --host 0.0.0.0 \ --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5559"}' \ --disaggregation-transfer-backend nixl \ - --enable-metrics + --enable-metrics \ + "${TRACE_ARGS[@]}" diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/1k1k-low-latency.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/1k1k-low-latency.sh new file mode 100755 index 0000000000..f128e5cb20 --- /dev/null +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/1k1k-low-latency.sh @@ -0,0 +1,179 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Function to print usage +print_usage() { + echo "Usage: $0 " + echo " mode: prefill or decode" + echo "" + echo "Examples:" + echo " $0 prefill" + echo " $0 decode" + exit 1 +} + +# Check if correct number of arguments provided +if [ $# -ne 1 ]; then + echo "Error: Expected 1 argument, got $#" + print_usage +fi + +# Parse arguments +mode=$1 + +# Validate mode argument +if [ "$mode" != "prefill" ] && [ "$mode" != "decode" ]; then + echo "Error: mode must be 'prefill' or 'decode', got '$mode'" + print_usage +fi + +echo "Mode: $mode" +echo "Command: dynamo" + +# Check if required environment variables are set +if [ -z "$HOST_IP_MACHINE" ]; then + echo "Error: HOST_IP_MACHINE environment variable is not set" + exit 1 +fi + +if [ -z "$PORT" ]; then + echo "Error: PORT environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_GPUS" ]; then + echo "Error: TOTAL_GPUS environment variable is not set" + exit 1 +fi + +if [ -z "$RANK" ]; then + echo "Error: RANK environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_NODES" ]; then + echo "Error: TOTAL_NODES environment variable is not set" + exit 1 +fi + +if [ -z "$USE_INIT_LOCATIONS" ]; then + echo "Error: USE_INIT_LOCATIONS environment variable is not set" + exit 1 +fi + +if [ -z "$RUN_IN_CI" ]; then + echo "Error: RUN_IN_CI environment variable is not set" + exit 1 +fi + +# Construct command based on mode +if [ "$mode" = "prefill" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_DECODE_BOOTSTRAP_TIMEOUT=1000 \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_ENABLE_JIT_DEEPGEMM=false \ + SGLANG_ENABLE_FLASHINFER_GEMM=true \ + python3 -m dynamo.sglang \ + --disaggregation-mode prefill \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --disable-radix-cache \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_trtllm \ + --stream-interval 10 \ + --watchdog-timeout 1000000 \ + --context-length 2200 \ + --mem-fraction-static 0.95 \ + --max-total-tokens 8192 \ + --chunked-prefill-size 8192 \ + --cuda-graph-max-bs 256 \ + --max-running-requests 512 \ + --scheduler-recv-interval 10 \ + --enable-symm-mem \ + --moe-dense-tp-size 1 \ + --load-balance-method round_robin \ + --disaggregation-bootstrap-port 30001 \ + --data-parallel-size 1 \ + --tensor-parallel-size "$TOTAL_GPUS" \ + --expert-parallel-size 1 \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} + +elif [ "$mode" = "decode" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_DECODE_BOOTSTRAP_TIMEOUT=1000 \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_ENABLE_JIT_DEEPGEMM=false \ + SGLANG_ENABLE_FLASHINFER_GEMM=true \ + python3 -m dynamo.sglang \ + --disaggregation-mode decode \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --prefill-round-robin-balance \ + --trust-remote-code \ + --disable-radix-cache \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_trtllm \ + --disaggregation-bootstrap-port 30001 \ + --stream-interval 10 \ + --watchdog-timeout 1000000 \ + --context-length 2200 \ + --mem-fraction-static 0.95 \ + --chunked-prefill-size 8192 \ + --cuda-graph-max-bs 256 \ + --scheduler-recv-interval 10 \ + --enable-symm-mem \ + --moe-dense-tp-size 1 \ + --tensor-parallel-size "$TOTAL_GPUS" \ + --expert-parallel-size 1 \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} +fi \ No newline at end of file diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/1k1k-max-tpt.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/1k1k-max-tpt.sh new file mode 100755 index 0000000000..f81aa51a6c --- /dev/null +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/1k1k-max-tpt.sh @@ -0,0 +1,200 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Function to print usage +print_usage() { + echo "Usage: $0 " + echo " mode: prefill or decode" + echo "" + echo "Examples:" + echo " $0 prefill" + echo " $0 decode" + exit 1 +} + +# Check if correct number of arguments provided +if [ $# -ne 1 ]; then + echo "Error: Expected 1 argument, got $#" + print_usage +fi + +# Parse arguments +mode=$1 + +# Validate mode argument +if [ "$mode" != "prefill" ] && [ "$mode" != "decode" ]; then + echo "Error: mode must be 'prefill' or 'decode', got '$mode'" + print_usage +fi + +echo "Mode: $mode" +echo "Command: dynamo" + +# Check if required environment variables are set +if [ -z "$HOST_IP_MACHINE" ]; then + echo "Error: HOST_IP_MACHINE environment variable is not set" + exit 1 +fi + +if [ -z "$PORT" ]; then + echo "Error: PORT environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_GPUS" ]; then + echo "Error: TOTAL_GPUS environment variable is not set" + exit 1 +fi + +if [ -z "$RANK" ]; then + echo "Error: RANK environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_NODES" ]; then + echo "Error: TOTAL_NODES environment variable is not set" + exit 1 +fi + +if [ -z "$USE_INIT_LOCATIONS" ]; then + echo "Error: USE_INIT_LOCATIONS environment variable is not set" + exit 1 +fi + +if [ -z "$RUN_IN_CI" ]; then + echo "Error: RUN_IN_CI environment variable is not set" + exit 1 +fi + +# Construct command based on mode +if [ "$mode" = "prefill" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 \ + SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ + MC_TE_METRIC=true \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_cutlass \ + --disable-radix-cache \ + --disable-chunked-prefix-cache \ + --stream-interval 50 \ + --decode-log-interval 1000 \ + --watchdog-timeout 1000000 \ + --context-length 2176 \ + --disable-shared-experts-fusion \ + --eplb-algorithm deepseek \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode prefill \ + --mem-fraction-static 0.84 \ + --max-total-tokens 131072 \ + --max-prefill-tokens 32768 \ + --chunked-prefill-size 65536 \ + --enable-single-batch-overlap \ + --max-running-requests 30000 \ + --load-balance-method round_robin \ + --disable-cuda-graph \ + --enable-dp-attention \ + --tp-size "$TOTAL_GPUS" \ + --dp-size "$TOTAL_GPUS" \ + --ep-size "$TOTAL_GPUS" \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} + +elif [ "$mode" = "decode" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 \ + SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ + MC_TE_METRIC=true \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=1024 \ + SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH=1 \ + SGLANG_FLASHINFER_FP4_GEMM_BACKEND=cutlass \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_cutedsl \ + --disable-radix-cache \ + --disable-chunked-prefix-cache \ + --stream-interval 50 \ + --decode-log-interval 1000 \ + --watchdog-timeout 1000000 \ + --context-length 2176 \ + --disable-shared-experts-fusion \ + --eplb-algorithm deepseek \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode decode \ + --mem-fraction-static 0.83 \ + --max-total-tokens 3122380 \ + --chunked-prefill-size 786432 \ + --max-running-requests 67584 \ + --moe-a2a-backend deepep \ + --deepep-mode low_latency \ + --ep-dispatch-algorithm static \ + --ep-num-redundant-experts 32 \ + --cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120 128 136 144 152 160 168 176 184 192 200 208 216 224 232 240 248 256 264 272 280 288 296 304 312 320 328 336 344 352 360 368 376 384 416 448 480 512 544 576 608 640 672 704 736 768 1024 \ + --num-reserved-decode-tokens 112 \ + --moe-dense-tp-size 1 \ + --enable-dp-lm-head \ + --prefill-round-robin-balance \ + --enable-dp-attention \ + --tp-size "$TOTAL_GPUS" \ + --dp-size "$TOTAL_GPUS" \ + --ep-size "$TOTAL_GPUS" \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} +fi \ No newline at end of file diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/default.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/1k1k-middle-curve.sh similarity index 73% rename from examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/default.sh rename to examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/1k1k-middle-curve.sh index de786d8e0e..43a435c95a 100755 --- a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/default.sh +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/1k1k-middle-curve.sh @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# This comes from https://github.com/sgl-project/sglang/issues/10903 and uses the low-prec decode setup - # Function to print usage print_usage() { echo "Usage: $0 " @@ -64,152 +62,140 @@ if [ -z "$USE_INIT_LOCATIONS" ]; then exit 1 fi +if [ -z "$RUN_IN_CI" ]; then + echo "Error: RUN_IN_CI environment variable is not set" + exit 1 +fi + # Construct command based on mode if [ "$mode" = "prefill" ]; then set -x - # no expert locations collected for fp4 yet + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + command_suffix="" - if [[ "${USE_INIT_LOCATIONS,,}" == "true" ]]; then command_suffix=" "; fi if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi - # we have to install pre-release cutedsl for a integer overflow fix - python3 -m pip install --no-cache-dir --upgrade --pre nvidia-cutlass-dsl - - # set your own cache variables here - export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 - export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" - export FLASHINFER_WORKSPACE_BASE="/configs/flashinfer-cache" - + PYTHONUNBUFFERED=1 \ DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 \ SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2=1 \ - SGL_JIT_DEEPGEMM_PRECOMPILE=0 \ SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ MC_TE_METRIC=true \ - SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ MC_FORCE_MNNVL=1 \ NCCL_MNNVL_ENABLE=1 \ NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ - PYTHONUNBUFFERED=1 \ python3 -m dynamo.sglang \ --served-model-name deepseek-ai/DeepSeek-R1 \ --model-path /model/ \ - --skip-tokenizer-init \ - --disaggregation-mode prefill \ + --trust-remote-code \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_cutlass \ + --disable-radix-cache \ + --disable-chunked-prefix-cache \ + --stream-interval 50 \ --decode-log-interval 1000 \ - --max-running-requests 5632 \ + --watchdog-timeout 1000000 \ --context-length 2176 \ - --disable-radix-cache \ --disable-shared-experts-fusion \ - --watchdog-timeout 1000000 \ - --disable-chunked-prefix-cache \ - --attention-backend trtllm_mla \ - --kv-cache-dtype fp8_e4m3 \ - --enable-single-batch-overlap \ - --chunked-prefill-size 65536 \ --eplb-algorithm deepseek \ - --trust-remote-code \ - --disable-cuda-graph \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode prefill \ --mem-fraction-static 0.84 \ --max-total-tokens 131072 \ - --max-prefill-tokens 16384 \ + --max-prefill-tokens 32768 \ + --chunked-prefill-size 65536 \ + --enable-single-batch-overlap \ + --max-running-requests 30000 \ --load-balance-method round_robin \ - --quantization modelopt_fp4 \ - --moe-runner-backend flashinfer_cutlass \ + --disable-cuda-graph \ + --enable-dp-attention \ + --tp-size "$TOTAL_GPUS" \ + --dp-size "$TOTAL_GPUS" \ + --ep-size "$TOTAL_GPUS" \ --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ - --disaggregation-bootstrap-port 30001 \ --nnodes "$TOTAL_NODES" \ --node-rank "$RANK" \ - --ep-size "$TOTAL_GPUS" \ - --tp-size "$TOTAL_GPUS" \ - --dp-size "$TOTAL_GPUS" \ - --enable-dp-attention \ - --host 0.0.0.0 \ - --stream-interval 50 \ - --log-level debug ${command_suffix} - -# For now we must keep SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK and cuda-graph-bs at 1024 until -# DeepEP merges in https://github.com/deepseek-ai/DeepEP/pull/440 -# the nvidia-cutlass-dsl install fixes https://github.com/flashinfer-ai/flashinfer/issues/1830#issuecomment-3380074018 -# which was previously limiting us to DISPATCH_TOKENS and cuda-graph-bs == 384 -# For now use 12 nodes for fp4 since flashinfer_cutedsl requires experts per gpu < 8 -# We have 288 (256 + 32 redundant) => 288/48 = 6 + --host 0.0.0.0 ${command_suffix} elif [ "$mode" = "decode" ]; then set -x - # no expert locations collected for fp4 yet - command_suffix="" - if [[ "${USE_INIT_LOCATIONS,,}" == "true" ]]; then command_suffix=" "; fi - if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi - - # set your own cache variables here + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 - export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" - export FLASHINFER_WORKSPACE_BASE="/configs/flashinfer-cache" - # we have to install pre-release cutedsl for a integer overflow fix - python3 -m pip install --no-cache-dir --upgrade --pre nvidia-cutlass-dsl + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 \ SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2=1 \ - SGL_JIT_DEEPGEMM_PRECOMPILE=0 \ - MC_TE_METRIC=true \ SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ + MC_TE_METRIC=true \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ - SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=384 \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=1024 \ SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH=1 \ - SGLANG_FP4_GEMM_BACKEND=cutlass \ - DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ - PYTHONUNBUFFERED=1 \ + SGLANG_FLASHINFER_FP4_GEMM_BACKEND=cutlass \ python3 -m dynamo.sglang \ --served-model-name deepseek-ai/DeepSeek-R1 \ --model-path /model/ \ - --skip-tokenizer-init \ --trust-remote-code \ - --disaggregation-mode decode \ - --host 0.0.0.0 \ - --decode-log-interval 1 \ - --max-running-requests 67584 \ - --context-length 2176 \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_cutedsl \ --disable-radix-cache \ - --disable-shared-experts-fusion \ - --watchdog-timeout 1000000 \ --disable-chunked-prefix-cache \ - --attention-backend trtllm_mla \ - --kv-cache-dtype fp8_e4m3 \ - --enable-dp-attention \ - --chunked-prefill-size 786432 \ + --stream-interval 50 \ + --decode-log-interval 1000 \ + --watchdog-timeout 1000000 \ + --context-length 2176 \ + --disable-shared-experts-fusion \ + --eplb-algorithm deepseek \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode decode \ --mem-fraction-static 0.83 \ + --max-total-tokens 3122380 \ + --chunked-prefill-size 786432 \ + --max-running-requests 67584 \ + --enable-single-batch-overlap \ --moe-a2a-backend deepep \ --deepep-mode low_latency \ --ep-dispatch-algorithm static \ - --cuda-graph-bs 384 \ - --num-reserved-decode-tokens 112 \ --ep-num-redundant-experts 32 \ - --eplb-algorithm deepseek \ + --cuda-graph-bs 1 2 4 8 16 24 32 40 48 56 64 72 80 88 96 104 112 120 128 136 144 152 160 168 176 184 192 200 208 216 224 232 240 248 256 264 272 280 288 296 304 312 320 328 336 344 352 360 368 376 384 416 448 480 512 544 576 608 640 672 704 736 768 1024 \ + --num-reserved-decode-tokens 112 \ --moe-dense-tp-size 1 \ --enable-dp-lm-head \ --prefill-round-robin-balance \ - --max-total-tokens 3122380 \ - --quantization modelopt_fp4 \ - --moe-runner-backend flashinfer_cutedsl \ + --enable-dp-attention \ + --tp-size "$TOTAL_GPUS" \ + --dp-size "$TOTAL_GPUS" \ + --ep-size "$TOTAL_GPUS" \ --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ - --disaggregation-bootstrap-port 30001 \ --nnodes "$TOTAL_NODES" \ --node-rank "$RANK" \ - --tp-size "$TOTAL_GPUS" \ - --ep-size "$TOTAL_GPUS" \ - --dp-size "$TOTAL_GPUS" \ - --enable-single-batch-overlap \ - --enable-dp-attention \ - --stream-interval 50 \ - --mem-fraction-static 0.82 ${command_suffix} -fi + --host 0.0.0.0 ${command_suffix} +fi \ No newline at end of file diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/8k1k-low-latency.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/8k1k-low-latency.sh new file mode 100755 index 0000000000..1d6007e13a --- /dev/null +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/8k1k-low-latency.sh @@ -0,0 +1,181 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Function to print usage +print_usage() { + echo "Usage: $0 " + echo " mode: prefill or decode" + echo "" + echo "Examples:" + echo " $0 prefill" + echo " $0 decode" + exit 1 +} + +# Check if correct number of arguments provided +if [ $# -ne 1 ]; then + echo "Error: Expected 1 argument, got $#" + print_usage +fi + +# Parse arguments +mode=$1 + +# Validate mode argument +if [ "$mode" != "prefill" ] && [ "$mode" != "decode" ]; then + echo "Error: mode must be 'prefill' or 'decode', got '$mode'" + print_usage +fi + +echo "Mode: $mode" +echo "Command: dynamo" + +# Check if required environment variables are set +if [ -z "$HOST_IP_MACHINE" ]; then + echo "Error: HOST_IP_MACHINE environment variable is not set" + exit 1 +fi + +if [ -z "$PORT" ]; then + echo "Error: PORT environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_GPUS" ]; then + echo "Error: TOTAL_GPUS environment variable is not set" + exit 1 +fi + +if [ -z "$RANK" ]; then + echo "Error: RANK environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_NODES" ]; then + echo "Error: TOTAL_NODES environment variable is not set" + exit 1 +fi + +if [ -z "$USE_INIT_LOCATIONS" ]; then + echo "Error: USE_INIT_LOCATIONS environment variable is not set" + exit 1 +fi + +if [ -z "$RUN_IN_CI" ]; then + echo "Error: RUN_IN_CI environment variable is not set" + exit 1 +fi + +# Construct command based on mode +if [ "$mode" = "prefill" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_DECODE_BOOTSTRAP_TIMEOUT=1000 \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_ENABLE_JIT_DEEPGEMM=false \ + SGLANG_ENABLE_FLASHINFER_GEMM=true \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --disable-radix-cache \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_trtllm \ + --stream-interval 50 \ + --watchdog-timeout 1000000 \ + --context-length 9600 \ + --mem-fraction-static 0.95 \ + --max-total-tokens 32768 \ + --chunked-prefill-size 24576 \ + --cuda-graph-max-bs 256 \ + --max-running-requests 512 \ + --scheduler-recv-interval 10 \ + --enable-symm-mem \ + --moe-dense-tp-size 1 \ + --load-balance-method round_robin \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode prefill \ + --dp-size 1 \ + --tp-size "$TOTAL_GPUS" \ + --ep-size 1 \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} + +elif [ "$mode" = "decode" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_DECODE_BOOTSTRAP_TIMEOUT=1000 \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_ENABLE_JIT_DEEPGEMM=false \ + SGLANG_ENABLE_FLASHINFER_GEMM=true \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --prefill-round-robin-balance \ + --trust-remote-code \ + --disable-radix-cache \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_trtllm \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode decode \ + --stream-interval 50 \ + --watchdog-timeout 1000000 \ + --context-length 9600 \ + --mem-fraction-static 0.95 \ + --chunked-prefill-size 8192 \ + --cuda-graph-max-bs 256 \ + --scheduler-recv-interval 10 \ + --enable-symm-mem \ + --moe-dense-tp-size 1 \ + --dp-size 1 \ + --tp-size "$TOTAL_GPUS" \ + --ep-size 1 \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} +fi + diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/8k1k-max-tpt.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/8k1k-max-tpt.sh new file mode 100755 index 0000000000..e7447fc5a4 --- /dev/null +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/8k1k-max-tpt.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Function to print usage +print_usage() { + echo "Usage: $0 " + echo " mode: prefill or decode" + echo "" + echo "Examples:" + echo " $0 prefill" + echo " $0 decode" + exit 1 +} + +# Check if correct number of arguments provided +if [ $# -ne 1 ]; then + echo "Error: Expected 1 argument, got $#" + print_usage +fi + +# Parse arguments +mode=$1 + +# Validate mode argument +if [ "$mode" != "prefill" ] && [ "$mode" != "decode" ]; then + echo "Error: mode must be 'prefill' or 'decode', got '$mode'" + print_usage +fi + +echo "Mode: $mode" +echo "Command: dynamo" + +# Check if required environment variables are set +if [ -z "$HOST_IP_MACHINE" ]; then + echo "Error: HOST_IP_MACHINE environment variable is not set" + exit 1 +fi + +if [ -z "$PORT" ]; then + echo "Error: PORT environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_GPUS" ]; then + echo "Error: TOTAL_GPUS environment variable is not set" + exit 1 +fi + +if [ -z "$RANK" ]; then + echo "Error: RANK environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_NODES" ]; then + echo "Error: TOTAL_NODES environment variable is not set" + exit 1 +fi + +if [ -z "$USE_INIT_LOCATIONS" ]; then + echo "Error: USE_INIT_LOCATIONS environment variable is not set" + exit 1 +fi + +if [ -z "$RUN_IN_CI" ]; then + echo "Error: RUN_IN_CI environment variable is not set" + exit 1 +fi + +# Construct command based on mode +if [ "$mode" = "prefill" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 \ + SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ + MC_TE_METRIC=true \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_trtllm \ + --disable-radix-cache \ + --disable-chunked-prefix-cache \ + --stream-interval 50 \ + --decode-log-interval 1000 \ + --watchdog-timeout 1000000 \ + --context-length 9600 \ + --disable-shared-experts-fusion \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode prefill \ + --mem-fraction-static 0.95 \ + --max-total-tokens 131072 \ + --max-prefill-tokens 524288 \ + --chunked-prefill-size 131072 \ + --max-running-requests 30000 \ + --load-balance-method round_robin \ + --disable-cuda-graph \ + --tp-size "$TOTAL_GPUS" \ + --dp-size 1 \ + --ep-size 1 \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} + +elif [ "$mode" = "decode" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 \ + SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ + MC_TE_METRIC=true \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=512 \ + SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH=1 \ + SGLANG_FLASHINFER_FP4_GEMM_BACKEND=cutlass \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_cutedsl \ + --disable-radix-cache \ + --disable-chunked-prefix-cache \ + --stream-interval 50 \ + --decode-log-interval 1000 \ + --watchdog-timeout 1000000 \ + --context-length 9600 \ + --disable-shared-experts-fusion \ + --eplb-algorithm deepseek \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode decode \ + --mem-fraction-static 0.83 \ + --max-total-tokens 524288 \ + --chunked-prefill-size 24576 \ + --max-running-requests 16384 \ + --moe-a2a-backend deepep \ + --deepep-mode low_latency \ + --ep-dispatch-algorithm static \ + --ep-num-redundant-experts 32 \ + --cuda-graph-max-bs 512 \ + --num-reserved-decode-tokens 112 \ + --moe-dense-tp-size 1 \ + --enable-dp-lm-head \ + --prefill-round-robin-balance \ + --enable-dp-attention \ + --tp-size "$TOTAL_GPUS" \ + --dp-size "$TOTAL_GPUS" \ + --ep-size "$TOTAL_GPUS" \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} +fi + diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/8k1k-middle-curve.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/8k1k-middle-curve.sh new file mode 100755 index 0000000000..e7447fc5a4 --- /dev/null +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp4/disagg/8k1k-middle-curve.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Function to print usage +print_usage() { + echo "Usage: $0 " + echo " mode: prefill or decode" + echo "" + echo "Examples:" + echo " $0 prefill" + echo " $0 decode" + exit 1 +} + +# Check if correct number of arguments provided +if [ $# -ne 1 ]; then + echo "Error: Expected 1 argument, got $#" + print_usage +fi + +# Parse arguments +mode=$1 + +# Validate mode argument +if [ "$mode" != "prefill" ] && [ "$mode" != "decode" ]; then + echo "Error: mode must be 'prefill' or 'decode', got '$mode'" + print_usage +fi + +echo "Mode: $mode" +echo "Command: dynamo" + +# Check if required environment variables are set +if [ -z "$HOST_IP_MACHINE" ]; then + echo "Error: HOST_IP_MACHINE environment variable is not set" + exit 1 +fi + +if [ -z "$PORT" ]; then + echo "Error: PORT environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_GPUS" ]; then + echo "Error: TOTAL_GPUS environment variable is not set" + exit 1 +fi + +if [ -z "$RANK" ]; then + echo "Error: RANK environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_NODES" ]; then + echo "Error: TOTAL_NODES environment variable is not set" + exit 1 +fi + +if [ -z "$USE_INIT_LOCATIONS" ]; then + echo "Error: USE_INIT_LOCATIONS environment variable is not set" + exit 1 +fi + +if [ -z "$RUN_IN_CI" ]; then + echo "Error: RUN_IN_CI environment variable is not set" + exit 1 +fi + +# Construct command based on mode +if [ "$mode" = "prefill" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 \ + SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ + MC_TE_METRIC=true \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_trtllm \ + --disable-radix-cache \ + --disable-chunked-prefix-cache \ + --stream-interval 50 \ + --decode-log-interval 1000 \ + --watchdog-timeout 1000000 \ + --context-length 9600 \ + --disable-shared-experts-fusion \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode prefill \ + --mem-fraction-static 0.95 \ + --max-total-tokens 131072 \ + --max-prefill-tokens 524288 \ + --chunked-prefill-size 131072 \ + --max-running-requests 30000 \ + --load-balance-method round_robin \ + --disable-cuda-graph \ + --tp-size "$TOTAL_GPUS" \ + --dp-size 1 \ + --ep-size 1 \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} + +elif [ "$mode" = "decode" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN=1 \ + SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ + MC_TE_METRIC=true \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=512 \ + SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH=1 \ + SGLANG_FLASHINFER_FP4_GEMM_BACKEND=cutlass \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization modelopt_fp4 \ + --moe-runner-backend flashinfer_cutedsl \ + --disable-radix-cache \ + --disable-chunked-prefix-cache \ + --stream-interval 50 \ + --decode-log-interval 1000 \ + --watchdog-timeout 1000000 \ + --context-length 9600 \ + --disable-shared-experts-fusion \ + --eplb-algorithm deepseek \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode decode \ + --mem-fraction-static 0.83 \ + --max-total-tokens 524288 \ + --chunked-prefill-size 24576 \ + --max-running-requests 16384 \ + --moe-a2a-backend deepep \ + --deepep-mode low_latency \ + --ep-dispatch-algorithm static \ + --ep-num-redundant-experts 32 \ + --cuda-graph-max-bs 512 \ + --num-reserved-decode-tokens 112 \ + --moe-dense-tp-size 1 \ + --enable-dp-lm-head \ + --prefill-round-robin-balance \ + --enable-dp-attention \ + --tp-size "$TOTAL_GPUS" \ + --dp-size "$TOTAL_GPUS" \ + --ep-size "$TOTAL_GPUS" \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} +fi + diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/agg/default.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/agg/default.sh deleted file mode 100755 index 84c06870b5..0000000000 --- a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/agg/default.sh +++ /dev/null @@ -1,94 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -# Simple agg script (not an optimized config) - -print_usage() { - echo "Usage: $0" - echo "" - echo "This script runs aggregated mode (single dynamo.sglang instance)" - exit 1 -} - -echo "Mode: aggregated" -echo "Command: dynamo" - -# Check if required environment variables are set -if [ -z "$HOST_IP_MACHINE" ]; then - echo "Error: HOST_IP_MACHINE environment variable is not set" - exit 1 -fi - -if [ -z "$PORT" ]; then - echo "Error: PORT environment variable is not set" - exit 1 -fi - -if [ -z "$TOTAL_GPUS" ]; then - echo "Error: TOTAL_GPUS environment variable is not set" - exit 1 -fi - -if [ -z "$RANK" ]; then - echo "Error: RANK environment variable is not set" - exit 1 -fi - -if [ -z "$TOTAL_NODES" ]; then - echo "Error: TOTAL_NODES environment variable is not set" - exit 1 -fi - -# Construct command suffix for config dump -command_suffix="" -if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="--dump-config-to ${DUMP_CONFIG_PATH}"; fi - -set -x -export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 -export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" -export FLASHINFER_WORKSPACE_BASE="/configs/flashinfer-cache" - -DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ -MC_TE_METRIC=true \ -SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ -MC_FORCE_MNNVL=1 \ -NCCL_MNNVL_ENABLE=1 \ -NCCL_CUMEM_ENABLE=1 \ -SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ -PYTHONUNBUFFERED=1 \ -python3 -m dynamo.sglang \ - --served-model-name deepseek-ai/DeepSeek-R1 \ - --model-path /model/ \ - --skip-tokenizer-init \ - --trust-remote-code \ - --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ - --nnodes "$TOTAL_NODES" \ - --node-rank "$RANK" \ - --tp-size "$TOTAL_GPUS" \ - --dp-size "$TOTAL_GPUS" \ - --enable-dp-attention \ - --host 0.0.0.0 \ - --max-running-requests 30000 \ - --context-length 2200 \ - --disable-radix-cache \ - --moe-a2a-backend deepep \ - --load-balance-method round_robin \ - --deepep-mode normal \ - --ep-dispatch-algorithm dynamic \ - --moe-dense-tp-size 1 \ - --enable-dp-lm-head \ - --disable-shared-experts-fusion \ - --ep-num-redundant-experts 32 \ - --eplb-algorithm deepseek \ - --attention-backend trtllm_mla \ - --kv-cache-dtype fp8_e4m3 \ - --watchdog-timeout 1000000 \ - --disable-cuda-graph \ - --chunked-prefill-size 131072 \ - --max-total-tokens 524288 \ - --deepep-config /configs/deepep_config.json \ - --stream-interval 50 \ - --mem-fraction-static 0.75 ${command_suffix} - - diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/1p_4d.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/1k1k-low-latency.sh similarity index 95% rename from examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/1p_4d.sh rename to examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/1k1k-low-latency.sh index 3f193c273e..090e238d70 100755 --- a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/1p_4d.sh +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/1k1k-low-latency.sh @@ -73,8 +73,8 @@ fi if [ "$mode" = "prefill" ]; then set -x if [[ "${RUN_IN_CI,,}" == "true" ]]; then - python3 -m pip install /configs/ai_dynamo_runtime-0.6.1-cp310-abi3-manylinux_2_28_aarch64.whl - python3 -m pip install /configs/ai_dynamo-0.6.1-py3-none-any.whl + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl fi export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" @@ -131,8 +131,8 @@ if [ "$mode" = "prefill" ]; then elif [ "$mode" = "decode" ]; then set -x if [[ "${RUN_IN_CI,,}" == "true" ]]; then - python3 -m pip install /configs/ai_dynamo_runtime-0.6.1-cp310-abi3-manylinux_2_28_aarch64.whl - python3 -m pip install /configs/ai_dynamo-0.6.1-py3-none-any.whl + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl fi export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/default.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/1k1k-max-tpt.sh similarity index 96% rename from examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/default.sh rename to examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/1k1k-max-tpt.sh index 7b0b0215fe..604e5c3a03 100755 --- a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/default.sh +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/1k1k-max-tpt.sh @@ -71,8 +71,8 @@ fi if [ "$mode" = "prefill" ]; then set -x if [[ "${RUN_IN_CI,,}" == "true" ]]; then - python3 -m pip install /configs/ai_dynamo_runtime-0.6.1-cp310-abi3-manylinux_2_28_aarch64.whl - python3 -m pip install /configs/ai_dynamo-0.6.1-py3-none-any.whl + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl fi export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" @@ -132,8 +132,8 @@ if [ "$mode" = "prefill" ]; then elif [ "$mode" = "decode" ]; then set -x if [[ "${RUN_IN_CI,,}" == "true" ]]; then - python3 -m pip install /configs/ai_dynamo_runtime-0.6.1-cp310-abi3-manylinux_2_28_aarch64.whl - python3 -m pip install /configs/ai_dynamo-0.6.1-py3-none-any.whl + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl fi export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/8k1k-low-latency.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/8k1k-low-latency.sh new file mode 100755 index 0000000000..93d3a68d92 --- /dev/null +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/8k1k-low-latency.sh @@ -0,0 +1,184 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Function to print usage +print_usage() { + echo "Usage: $0 " + echo " mode: prefill or decode" + echo "" + echo "Examples:" + echo " $0 prefill" + echo " $0 decode" + exit 1 +} + +# Check if correct number of arguments provided +if [ $# -ne 1 ]; then + echo "Error: Expected 1 argument, got $#" + print_usage +fi + +# Parse arguments +mode=$1 + +# Validate mode argument +if [ "$mode" != "prefill" ] && [ "$mode" != "decode" ]; then + echo "Error: mode must be 'prefill' or 'decode', got '$mode'" + print_usage +fi + +echo "Mode: $mode" +echo "Command: dynamo" + +# Check if required environment variables are set +if [ -z "$HOST_IP_MACHINE" ]; then + echo "Error: HOST_IP_MACHINE environment variable is not set" + exit 1 +fi + +if [ -z "$PORT" ]; then + echo "Error: PORT environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_GPUS" ]; then + echo "Error: TOTAL_GPUS environment variable is not set" + exit 1 +fi + +if [ -z "$RANK" ]; then + echo "Error: RANK environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_NODES" ]; then + echo "Error: TOTAL_NODES environment variable is not set" + exit 1 +fi + +if [ -z "$USE_INIT_LOCATIONS" ]; then + echo "Error: USE_INIT_LOCATIONS environment variable is not set" + exit 1 +fi + +if [ -z "$RUN_IN_CI" ]; then + echo "Error: RUN_IN_CI environment variable is not set" + exit 1 +fi + +# Construct command based on mode +if [ "$mode" = "prefill" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_ENABLE_JIT_DEEPGEMM=false \ + SGLANG_ENABLE_FLASHINFER_GEMM=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + MC_TE_METRIC=true \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization fp8 \ + --moe-runner-backend flashinfer_trtllm \ + --disable-radix-cache \ + --watchdog-timeout 1000000 \ + --context-length 9600 \ + --disaggregation-mode prefill \ + --mem-fraction-static 0.95 \ + --max-total-tokens 32768 \ + --chunked-prefill-size 24576 \ + --cuda-graph-max-bs 512 \ + --max-running-requests 512 \ + --load-balance-method round_robin \ + --scheduler-recv-interval 10 \ + --enable-flashinfer-allreduce-fusion \ + --moe-dense-tp-size 1 \ + --tensor-parallel-size "$TOTAL_GPUS" \ + --data-parallel-size 1 \ + --expert-parallel-size 1 \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --disaggregation-bootstrap-port 30001 \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} + +elif [ "$mode" = "decode" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_ENABLE_JIT_DEEPGEMM=false \ + SGLANG_ENABLE_FLASHINFER_GEMM=1 \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_DECODE_BOOTSTRAP_TIMEOUT=1000 \ + SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + MC_TE_METRIC=true \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --kv-cache-dtype fp8_e4m3 \ + --attention-backend trtllm_mla \ + --quantization fp8 \ + --moe-runner-backend flashinfer_trtllm \ + --disable-radix-cache \ + --watchdog-timeout 1000000 \ + --context-length 9600 \ + --disaggregation-mode decode \ + --mem-fraction-static 0.95 \ + --chunked-prefill-size 8192 \ + --cuda-graph-max-bs 512 \ + --max-running-requests 512 \ + --scheduler-recv-interval 10 \ + --enable-flashinfer-allreduce-fusion \ + --enable-symm-mem \ + --moe-dense-tp-size 1 \ + --prefill-round-robin-balance \ + --tensor-parallel-size "$TOTAL_GPUS" \ + --data-parallel-size 1 \ + --expert-parallel-size 1 \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --disaggregation-bootstrap-port 30001 \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} +fi diff --git a/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/8k1k-max-tpt.sh b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/8k1k-max-tpt.sh new file mode 100755 index 0000000000..4a4c01493e --- /dev/null +++ b/examples/backends/sglang/slurm_jobs/scripts/gb200-fp8/disagg/8k1k-max-tpt.sh @@ -0,0 +1,194 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Function to print usage +print_usage() { + echo "Usage: $0 " + echo " mode: prefill or decode" + echo "" + echo "Examples:" + echo " $0 prefill" + echo " $0 decode" + exit 1 +} + +# Check if correct number of arguments provided +if [ $# -ne 1 ]; then + echo "Error: Expected 1 argument, got $#" + print_usage +fi + +# Parse arguments +mode=$1 + +# Validate mode argument +if [ "$mode" != "prefill" ] && [ "$mode" != "decode" ]; then + echo "Error: mode must be 'prefill' or 'decode', got '$mode'" + print_usage +fi + +echo "Mode: $mode" +echo "Command: dynamo" + +# Check if required environment variables are set +if [ -z "$HOST_IP_MACHINE" ]; then + echo "Error: HOST_IP_MACHINE environment variable is not set" + exit 1 +fi + +if [ -z "$PORT" ]; then + echo "Error: PORT environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_GPUS" ]; then + echo "Error: TOTAL_GPUS environment variable is not set" + exit 1 +fi + +if [ -z "$RANK" ]; then + echo "Error: RANK environment variable is not set" + exit 1 +fi + +if [ -z "$TOTAL_NODES" ]; then + echo "Error: TOTAL_NODES environment variable is not set" + exit 1 +fi + +if [ -z "$USE_INIT_LOCATIONS" ]; then + echo "Error: USE_INIT_LOCATIONS environment variable is not set" + exit 1 +fi + +if [ -z "$RUN_IN_CI" ]; then + echo "Error: RUN_IN_CI environment variable is not set" + exit 1 +fi + +# Construct command based on mode +if [ "$mode" = "prefill" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + MC_TE_METRIC=true \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + MC_FORCE_MNNVL=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --trust-remote-code \ + --tp-size "$TOTAL_GPUS" \ + --dp-size "$TOTAL_GPUS" \ + --ep-size "$TOTAL_GPUS" \ + --enable-dp-attention \ + --attention-backend trtllm_mla \ + --kv-cache-dtype fp8_e4m3 \ + --disable-radix-cache \ + --stream-interval 50 \ + --max-running-requests 30000 \ + --context-length 9300 \ + --watchdog-timeout 1000000 \ + --disable-shared-experts-fusion \ + --eplb-algorithm deepseek \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode prefill \ + --mem-fraction-static 0.80 \ + --max-total-tokens 524288 \ + --chunked-prefill-size 131072 \ + --load-balance-method round_robin \ + --disable-cuda-graph \ + --moe-a2a-backend deepep \ + --deepep-mode normal \ + --ep-dispatch-algorithm dynamic \ + --moe-dense-tp-size 1 \ + --enable-dp-lm-head \ + --ep-num-redundant-experts 32 \ + --deepep-config /configs/deepep_config.json \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} + +elif [ "$mode" = "decode" ]; then + set -x + if [[ "${RUN_IN_CI,,}" == "true" ]]; then + python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl + python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl + fi + export TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=1800 + export SGLANG_DG_CACHE_DIR="/configs/dg-10212025" + + command_suffix="" + if [[ -n "${DUMP_CONFIG_PATH}" ]]; then command_suffix="${command_suffix} --dump-config-to ${DUMP_CONFIG_PATH}"; fi + + PYTHONUNBUFFERED=1 \ + DYN_SKIP_SGLANG_LOG_FORMATTING=1 \ + SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=256 \ + MC_TE_METRIC=true \ + SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE=100000 \ + SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=100000 \ + SGLANG_DISAGGREGATION_WAITING_TIMEOUT=100000 \ + SGLANG_DECODE_BOOTSTRAP_TIMEOUT=1000 \ + SGLANG_HACK_SEQ_BOOTSTRAP_ROOM=1 \ + SGLANG_MOONCAKE_CUSTOM_MEM_POOL=True \ + MC_FORCE_MNNVL=1 \ + NCCL_MNNVL_ENABLE=1 \ + NCCL_CUMEM_ENABLE=1 \ + SGLANG_USE_MESSAGE_QUEUE_BROADCASTER=0 \ + SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK=1 \ + python3 -m dynamo.sglang \ + --served-model-name deepseek-ai/DeepSeek-R1 \ + --model-path /model/ \ + --skip-tokenizer-init \ + --trust-remote-code \ + --tp-size "$TOTAL_GPUS" \ + --dp-size "$TOTAL_GPUS" \ + --ep-size "$TOTAL_GPUS" \ + --enable-dp-attention \ + --attention-backend trtllm_mla \ + --kv-cache-dtype fp8_e4m3 \ + --disable-radix-cache \ + --stream-interval 50 \ + --decode-log-interval 1000 \ + --max-running-requests 8192 \ + --context-length 9300 \ + --watchdog-timeout 1000000 \ + --disable-shared-experts-fusion \ + --eplb-algorithm deepseek \ + --disaggregation-bootstrap-port 30001 \ + --disaggregation-mode decode \ + --mem-fraction-static 0.82 \ + --chunked-prefill-size 36864 \ + --moe-a2a-backend deepep \ + --deepep-mode low_latency \ + --ep-dispatch-algorithm static \ + --moe-dense-tp-size 1 \ + --enable-dp-lm-head \ + --prefill-round-robin-balance \ + --ep-num-redundant-experts 32 \ + --deepep-config /configs/deepep_config.json \ + --cuda-graph-max-bs 256 \ + --dist-init-addr "$HOST_IP_MACHINE:$PORT" \ + --nnodes "$TOTAL_NODES" \ + --node-rank "$RANK" \ + --host 0.0.0.0 ${command_suffix} +fi diff --git a/examples/backends/sglang/slurm_jobs/scripts/vllm/benchmark_serving.py b/examples/backends/sglang/slurm_jobs/scripts/vllm/benchmark_serving.py index f9c67be7bc..a5962afe17 100644 --- a/examples/backends/sglang/slurm_jobs/scripts/vllm/benchmark_serving.py +++ b/examples/backends/sglang/slurm_jobs/scripts/vllm/benchmark_serving.py @@ -8,7 +8,6 @@ vLLM OpenAI API server vllm serve \ --swap-space 16 \ - --disable-log-requests (TGI backend) ./launch_tgi_server.sh diff --git a/examples/backends/sglang/slurm_jobs/scripts/worker_setup.py b/examples/backends/sglang/slurm_jobs/scripts/worker_setup.py index 2260713e5d..59fd8f3f17 100644 --- a/examples/backends/sglang/slurm_jobs/scripts/worker_setup.py +++ b/examples/backends/sglang/slurm_jobs/scripts/worker_setup.py @@ -373,7 +373,7 @@ def setup_frontend_worker( # All frontends run the ingress server frontend_cmd = "python3 -m dynamo.frontend --http-port=8000" if run_in_ci: - frontend_cmd = "python3 -m pip install /configs/ai_dynamo_runtime-0.6.1-cp310-abi3-manylinux_2_28_aarch64.whl && python3 -m pip install /configs/ai_dynamo-0.6.1-py3-none-any.whl && python3 -m dynamo.frontend --http-port=8000" + frontend_cmd = "python3 -m pip install /configs/ai_dynamo_runtime-0.7.0-cp310-abi3-manylinux_2_28_aarch64.whl && python3 -m pip install /configs/ai_dynamo-0.7.0-py3-none-any.whl && python3 -m dynamo.frontend --http-port=8000" return run_command(frontend_cmd) diff --git a/examples/backends/sglang/slurm_jobs/submit_disagg.sh b/examples/backends/sglang/slurm_jobs/submit_disagg.sh index 62e4221e96..47501ba426 100755 --- a/examples/backends/sglang/slurm_jobs/submit_disagg.sh +++ b/examples/backends/sglang/slurm_jobs/submit_disagg.sh @@ -48,7 +48,6 @@ check_env MODEL_PATH check_env CONFIG_DIR check_env CONTAINER_IMAGE -GPU_TYPE="gb200-fp8" GPUS_PER_NODE=4 : "${NETWORK_INTERFACE:=enP6p9s0np0}" @@ -62,7 +61,8 @@ ISL=$6 OSL=$7 CONCURRENCIES=$8 REQUEST_RATE=$9 -SCRIPT_VARIANT=${10} +GPU_TYPE=${10} +SCRIPT_VARIANT=${11} RETRIES=1 # defaults to retry the job 1 time to avoid transient errors @@ -86,7 +86,7 @@ command=( --model-dir $MODEL_PATH --config-dir $CONFIG_DIR --container-image $CONTAINER_IMAGE - --gpu-type $GPU_TYPE --gpus-per-node $GPUS_PER_NODE --network-interface $NETWORK_INTERFACE + --gpus-per-node $GPUS_PER_NODE --network-interface $NETWORK_INTERFACE --prefill-nodes $PREFILL_NODES --prefill-workers $PREFILL_WORKERS --decode-nodes $DECODE_NODES --decode-workers $DECODE_WORKERS @@ -96,6 +96,8 @@ command=( --retries $RETRIES + --gpu-type $GPU_TYPE + --run-in-ci ${SCRIPT_VARIANT_ARGS[@]} ) diff --git a/examples/backends/sglang/test_sglang_profile.py b/examples/backends/sglang/test_sglang_profile.py new file mode 100644 index 0000000000..a2d75d491b --- /dev/null +++ b/examples/backends/sglang/test_sglang_profile.py @@ -0,0 +1,296 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test script for /engine/start_profile and /engine/stop_profile routes. + +This script demonstrates the new custom engine route registration feature. +It starts a simple sglang server with dynamo and tests the profiling endpoints. + +Usage: + python test_sglang_profile.py +""" + +import os +import signal +import subprocess +import sys +import time +from pathlib import Path + +import requests + +# Configuration +MODEL = "Qwen/Qwen3-0.6B" # Small model for quick testing +HOST = "127.0.0.1" +PORT = 30000 +SYSTEM_PORT = 9090 +PROFILER_OUTPUT_DIR = "/tmp/dynamo_profiler_test" + + +def cleanup_output_dir(): + """Clean up the profiler output directory""" + import shutil + + if os.path.exists(PROFILER_OUTPUT_DIR): + shutil.rmtree(PROFILER_OUTPUT_DIR) + os.makedirs(PROFILER_OUTPUT_DIR, exist_ok=True) + + +def start_frontend(): + """Start the Dynamo frontend (HTTP server)""" + print("\nStarting Dynamo frontend...") + print(f" - Frontend HTTP: http://{HOST}:{PORT}") + + cmd = [ + "python", + "-m", + "dynamo.frontend", + "--http-port", + str(PORT), + ] + + print(f"Command: {' '.join(cmd)}") + print("(Output will appear below)\n") + + process = subprocess.Popen(cmd) + + # Wait for frontend to be ready + max_wait = 30 + start_time = time.time() + frontend_ready = False + + while time.time() - start_time < max_wait: + try: + # Check /health endpoint first + response = requests.get(f"http://{HOST}:{PORT}/health", timeout=1) + if response.status_code == 200: + print("โœ“ Frontend is ready!") + frontend_ready = True + break + except requests.exceptions.RequestException: + pass + + if process.poll() is not None: + print("โœ— Frontend process died!") + sys.exit(1) + + time.sleep(1) + + if not frontend_ready: + print("โœ— Frontend failed to start in time!") + process.kill() + sys.exit(1) + + return process + + +def start_sglang_backend(): + """Start the sglang backend (inference engine)""" + print("\nStarting SGLang backend...") + print(f" - Model: {MODEL}") + print(f" - System server: http://{HOST}:{SYSTEM_PORT}") + + # Set environment variables + env = os.environ.copy() + env["SGLANG_TORCH_PROFILER_DIR"] = PROFILER_OUTPUT_DIR + env["DYN_SYSTEM_PORT"] = str(SYSTEM_PORT) + + cmd = [ + "python", + "-m", + "dynamo.sglang", + "--model-path", + MODEL, + "--tp", + "1", + "--mem-fraction-static", + "0.8", + ] + + print(f"Command: {' '.join(cmd)}") + print("(Output will appear below)") + print("\nWaiting for backend to start...\n") + + process = subprocess.Popen(cmd, env=env) + + # Wait for backend to be ready (check system server health) + max_wait = 120 # 2 minutes + start_time = time.time() + backend_ready = False + + while time.time() - start_time < max_wait: + try: + # Check system server health endpoint + response = requests.get(f"http://{HOST}:{SYSTEM_PORT}/health", timeout=1) + if response.status_code == 200: + print("โœ“ Backend is ready!") + backend_ready = True + break + except requests.exceptions.RequestException: + pass + + # Check if process has died + if process.poll() is not None: + print("โœ— Backend process died!") + sys.exit(1) + + time.sleep(2) + + if not backend_ready: + print("โœ— Backend failed to start in time!") + process.kill() + sys.exit(1) + + return process + + +def test_profiling_endpoints(): + """Test the /engine/start_profile and /engine/stop_profile endpoints""" + base_url = f"http://{HOST}:{SYSTEM_PORT}" + + print("\n" + "=" * 60) + print("Testing /engine/start_profile and /engine/stop_profile") + print("=" * 60) + + # Test 1: Start profiling with parameters (no num_steps so we control stop manually) + print("\n1. Starting profiling with parameters...") + response = requests.post( + f"{base_url}/engine/start_profile", + json={ + "output_dir": PROFILER_OUTPUT_DIR, + "activities": ["CPU", "GPU"], + "with_stack": True, + "record_shapes": True, + }, + ) + print(f" Status: {response.status_code}") + print(f" Response: {response.json()}") + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + assert response.json()["status"] == "ok", "Expected status 'ok'" + + # Check available models + print("\n2. Checking available models...") + response = requests.get(f"http://{HOST}:{PORT}/v1/models") + if response.status_code == 200: + models = response.json() + print(f" Available models: {models}") + + # Make a few inference requests to generate profiling data + print("\n3. Making inference requests...") + inference_url = f"http://{HOST}:{PORT}/v1/completions" + for i in range(3): + response = requests.post( + inference_url, + json={ + "model": MODEL, + "prompt": f"Hello, this is test request {i+1}. ", + "max_tokens": 10, + "temperature": 0.8, + }, + ) + print(f" Request {i+1}: {response.status_code}") + if response.status_code != 200: + print(f" Response: {response.text[:200]}") + time.sleep(0.5) + + # Test 2: Stop profiling + print("\n4. Stopping profiling...") + response = requests.post(f"{base_url}/engine/stop_profile") + print(f" Status: {response.status_code}") + print(f" Response: {response.json()}") + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + assert response.json()["status"] == "ok", "Expected status 'ok'" + + # Test 3: Test with empty body (GET-like POST) + print("\n5. Starting profiling with empty body...") + response = requests.post(f"{base_url}/engine/start_profile") + print(f" Status: {response.status_code}") + print(f" Response: {response.json()}") + assert response.status_code == 200, f"Expected 200, got {response.status_code}" + + # Test 4: Test invalid route + print("\n6. Testing invalid route...") + response = requests.post(f"{base_url}/engine/nonexistent_route") + print(f" Status: {response.status_code}") + print(f" Response: {response.json()}") + assert response.status_code == 404, f"Expected 404, got {response.status_code}" + + # Stop profiling again + response = requests.post(f"{base_url}/engine/stop_profile") + + print("\n" + "=" * 60) + print("โœ“ All tests passed!") + print("=" * 60) + + # Check if profiling files were created + print(f"\nChecking profiler output directory: {PROFILER_OUTPUT_DIR}") + if os.path.exists(PROFILER_OUTPUT_DIR): + files = list(Path(PROFILER_OUTPUT_DIR).rglob("*")) + if files: + print(f"โœ“ Found {len(files)} files in output directory") + for f in files[:5]: # Show first 5 files + print(f" - {f}") + else: + print("โš  No files found (profiling may not have run long enough)") + else: + print("โš  Output directory not created") + + +def main(): + """Main test function""" + frontend_process = None + backend_process = None + try: + # Clean up output directory + cleanup_output_dir() + + # Start frontend first + frontend_process = start_frontend() + + # Start backend + backend_process = start_sglang_backend() + + # Run tests + print("\n" + "=" * 60) + print("Both frontend and backend are ready!") + print("=" * 60) + time.sleep(2) # Give everything a moment to fully settle + test_profiling_endpoints() + + print("\nโœ“ Test completed successfully!") + + except KeyboardInterrupt: + print("\nโš  Interrupted by user") + except Exception as e: + print(f"\nโœ— Test failed: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + finally: + # Cleanup + print("\nShutting down servers...") + if backend_process: + print(" Stopping backend...") + backend_process.send_signal(signal.SIGTERM) + try: + backend_process.wait(timeout=10) + except subprocess.TimeoutExpired: + print(" Force killing backend...") + backend_process.kill() + + if frontend_process: + print(" Stopping frontend...") + frontend_process.send_signal(signal.SIGTERM) + try: + frontend_process.wait(timeout=10) + except subprocess.TimeoutExpired: + print(" Force killing frontend...") + frontend_process.kill() + + print("โœ“ Servers stopped") + + +if __name__ == "__main__": + main() diff --git a/examples/backends/vllm/deploy/agg_kvbm.yaml b/examples/backends/vllm/deploy/agg_kvbm.yaml index 62e28386aa..5d84638769 100644 --- a/examples/backends/vllm/deploy/agg_kvbm.yaml +++ b/examples/backends/vllm/deploy/agg_kvbm.yaml @@ -40,9 +40,6 @@ spec: args: - --model - Qwen/Qwen3-8B - - --gpu-memory-utilization - - "0.45" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager diff --git a/examples/backends/vllm/deploy/disagg_kvbm.yaml b/examples/backends/vllm/deploy/disagg_kvbm.yaml index f4315a13cd..45b62f5617 100644 --- a/examples/backends/vllm/deploy/disagg_kvbm.yaml +++ b/examples/backends/vllm/deploy/disagg_kvbm.yaml @@ -33,9 +33,6 @@ spec: args: - --model - Qwen/Qwen3-8B - - --gpu-memory-utilization - - "0.3" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager @@ -66,9 +63,6 @@ spec: - --model - Qwen/Qwen3-8B - --is-prefill-worker - - --gpu-memory-utilization - - "0.3" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager diff --git a/examples/backends/vllm/deploy/disagg_kvbm_2p2d.yaml b/examples/backends/vllm/deploy/disagg_kvbm_2p2d.yaml index 1aa5281d09..d4203aafea 100644 --- a/examples/backends/vllm/deploy/disagg_kvbm_2p2d.yaml +++ b/examples/backends/vllm/deploy/disagg_kvbm_2p2d.yaml @@ -33,9 +33,6 @@ spec: args: - --model - Qwen/Qwen3-8B - - --gpu-memory-utilization - - "0.3" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager @@ -66,9 +63,6 @@ spec: - --model - Qwen/Qwen3-8B - --is-prefill-worker - - --gpu-memory-utilization - - "0.3" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager diff --git a/examples/backends/vllm/deploy/disagg_kvbm_tp2.yaml b/examples/backends/vllm/deploy/disagg_kvbm_tp2.yaml index 439b17a91f..141ca375fa 100644 --- a/examples/backends/vllm/deploy/disagg_kvbm_tp2.yaml +++ b/examples/backends/vllm/deploy/disagg_kvbm_tp2.yaml @@ -37,7 +37,6 @@ spec: - Qwen/Qwen3-8B - --gpu-memory-utilization - "0.23" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager @@ -72,7 +71,6 @@ spec: - --is-prefill-worker - --gpu-memory-utilization - "0.23" - - --disable-log-requests - --max-model-len - "32000" - --enforce-eager diff --git a/examples/backends/vllm/deploy/lora/README.md b/examples/backends/vllm/deploy/lora/README.md new file mode 100644 index 0000000000..425a2434a1 --- /dev/null +++ b/examples/backends/vllm/deploy/lora/README.md @@ -0,0 +1,297 @@ +# LoRA Deployment with MinIO on Kubernetes + +This guide explains how to deploy LoRA-enabled vLLM inference with S3-compatible storage backend on Kubernetes. + +## Overview + +This deployment pattern enables dynamic LoRA adapter loading from S3-compatible storage (MinIO) in a Kubernetes environment: + +## Prerequisites + +- Kubernetes cluster with GPU support +- Helm 3.x installed +- `kubectl` configured to access your cluster +- Dynamo Cloud Platform installed ([Installation Guide](../../../../../docs/kubernetes/installation_guide.md)) +- HuggingFace token for downloading Base and LoRA adapters + +## Files in This Directory + +| File | Description | +|------|-------------| +| `agg_lora.yaml` | DynamoGraphDeployment for vLLM with LoRA support | +| `minio-secret.yaml` | Kubernetes secret for MinIO credentials | +| `sync-lora-job.yaml` | Job to download LoRA from HuggingFace and upload to MinIO | +| `lora-model.yaml` | DynamoModel CRD for registering LoRA adapters | + +--- + +## Step 1: Set Up Environment Variables + +```bash +export NAMESPACE=dynamo # Your Dynamo namespace +export HF_TOKEN=your_hf_token # Your HuggingFace token +``` + +--- + +## Step 2: Create Secrets + +### Create HuggingFace Token Secret + +```bash +kubectl create secret generic hf-token-secret \ + --from-literal=HF_TOKEN=${HF_TOKEN} \ + -n ${NAMESPACE} +``` + +### Create MinIO Credentials Secret + +in this example, we are using the default credentials for MinIO. +You can change the credentials to point to your own S3 compatible storage. + +```bash +kubectl apply -f minio-secret.yaml -n ${NAMESPACE} +``` + +--- + +## Step 3: Install MinIO + +### Add MinIO Helm Repository + +```bash +helm repo add minio https://charts.min.io/ +helm repo update +``` + +### Deploy MinIO + +```bash +helm install minio minio/minio \ + --namespace ${NAMESPACE} \ + --set rootUser=minioadmin \ + --set rootPassword=minioadmin \ + --set mode=standalone \ + --set replicas=1 \ + --set persistence.enabled=true \ + --set persistence.size=10Gi \ + --set resources.requests.memory=512Mi \ + --set service.type=ClusterIP \ + --set consoleService.type=ClusterIP +``` + +### Verify MinIO Installation + +```bash +kubectl get pods -n ${NAMESPACE} | grep minio +kubectl get svc -n ${NAMESPACE} | grep minio +``` + +Expected output: +``` +minio-xxxx-xxxx 1/1 Running 0 1m +``` + +### (Optional) Access MinIO Console + +```bash +kubectl port-forward svc/minio-console -n ${NAMESPACE} 9001:9001 9000:9000 +``` + +Open http://localhost:9001 in your browser: +- Username: `minioadmin` +- Password: `minioadmin` + +--- + +## Step 4: Upload LoRA Adapters to MinIO + +Use the provided Kubernetes Job to download a LoRA adapter from HuggingFace and upload it to MinIO: + +```bash +kubectl apply -f sync-lora-job.yaml -n ${NAMESPACE} +``` + +### Monitor the Job + +```bash +# Watch job progress +kubectl get jobs -n ${NAMESPACE} -w + +# Check job logs +kubectl logs job/sync-hf-lora-to-minio -n ${NAMESPACE} -f +``` + +Wait for the job to complete successfully. + +### Verify Upload (Optional) + +```bash +# Port-forward MinIO API +kubectl port-forward svc/minio -n ${NAMESPACE} 9000:9000 & + +# Check uploaded files +export AWS_ACCESS_KEY_ID=minioadmin +export AWS_SECRET_ACCESS_KEY=minioadmin +export AWS_ENDPOINT_URL=http://localhost:9000 +aws s3 ls s3://my-loras/ --recursive +``` + +### Customizing the LoRA Adapter + +To upload a different LoRA adapter, edit `sync-lora-job.yaml` and change the `MODEL_NAME` environment variable: + +```yaml +env: +- name: MODEL_NAME + value: your-org/your-lora-adapter +``` + +--- + +## Step 5: Deploy vLLM with LoRA Support + +### Update the Image (if needed) + +Edit `agg_lora.yaml` to use your container image: + +```bash +# Using yq to update the image +export FRAMEWORK_RUNTIME_IMAGE=your-registry/your-image:tag +yq '.spec.services.[].extraPodSpec.mainContainer.image = env(FRAMEWORK_RUNTIME_IMAGE)' agg_lora.yaml > agg_lora_updated.yaml +``` + +### Deploy the LoRA-enabled vLLM Graph + +```bash +kubectl apply -f agg_lora.yaml -n ${NAMESPACE} +``` + +### Verify Deployment + +```bash +# Check pods +kubectl get pods -n ${NAMESPACE} + +# Watch worker logs +kubectl logs -f deployment/vllm-agg-lora-vllmdecode-worker -n ${NAMESPACE} +``` + +Wait for the worker to show "Application startup complete". + + +## Step 6: Using DynamoModel CRD + +The `lora-model.yaml` file demonstrates how to register a LoRA adapter using the DynamoModel Custom Resource: + +```bash +kubectl apply -f lora-model.yaml -n ${NAMESPACE} +``` + +This creates a declarative way to manage LoRA adapters in your cluster. + +--- + +## Configuration Reference + +### Environment Variables + +| Variable | Description | Default | +|----------|-------------|---------| +| `AWS_ENDPOINT` | MinIO/S3 endpoint URL | `http://minio:9000` | +| `AWS_ACCESS_KEY_ID` | MinIO access key | From secret | +| `AWS_SECRET_ACCESS_KEY` | MinIO secret key | From secret | +| `AWS_REGION` | AWS region (required for S3 SDK) | `us-east-1` | +| `AWS_ALLOW_HTTP` | Allow HTTP connections | `true` | +| `DYN_LORA_ENABLED` | Enable LoRA support | `true` | +| `DYN_LORA_PATH` | Local cache path for LoRA files | `/tmp/dynamo_loras_minio` | +| `BUCKET_NAME` | MinIO bucket name | `my-loras` | + +### vLLM LoRA Arguments + +| Argument | Description | +|----------|-------------| +| `--enable-lora` | Enable LoRA adapter support | +| `--max-lora-rank` | Maximum LoRA rank (must be >= your LoRA's rank) | +| `--max-loras` | Maximum number of LoRAs to load simultaneously | + +--- + +## Cleanup + +### Remove vLLM Deployment + +```bash +kubectl delete -f agg_lora.yaml -n ${NAMESPACE} +``` + +### Remove Sync Job + +```bash +kubectl delete -f sync-lora-job.yaml -n ${NAMESPACE} +``` + +### Remove MinIO + +```bash +helm uninstall minio -n ${NAMESPACE} +``` + +### Remove Secrets + +```bash +kubectl delete -f minio-secret.yaml -n ${NAMESPACE} +kubectl delete secret hf-token-secret -n ${NAMESPACE} +``` + +--- + +## Troubleshooting + +### LoRA Fails to Load + +1. **Check MinIO connectivity from worker**: + ```bash + kubectl exec -it deployment/vllm-agg-lora-vllmdecode-worker -n ${NAMESPACE} -- \ + curl http://minio:9000/minio/health/live + ``` + +2. **Verify LoRA exists in MinIO**: + ```bash + kubectl port-forward svc/minio -n ${NAMESPACE} 9000:9000 & + aws --endpoint-url=http://localhost:9000 s3 ls s3://my-loras/ --recursive + ``` + +3. **Check worker logs**: + ```bash + kubectl logs deployment/vllm-agg-lora-vllmdecode-worker -n ${NAMESPACE} + ``` + +### Sync Job Fails + +1. **Check job logs**: + ```bash + kubectl logs job/sync-hf-lora-to-minio -n ${NAMESPACE} + ``` + +2. **Verify HuggingFace token**: + ```bash + kubectl get secret hf-token-secret -n ${NAMESPACE} -o yaml + ``` + +3. **Check MinIO is accessible**: + ```bash + kubectl get svc minio -n ${NAMESPACE} + ``` + +### MinIO Connection Refused + +- Ensure MinIO pods are running: `kubectl get pods -n ${NAMESPACE} | grep minio` +- Check MinIO service: `kubectl get svc minio -n ${NAMESPACE}` +- Verify the `AWS_ENDPOINT` URL matches the service name + +## Further Reading + +- [vLLM Deployment Guide](../README.md) - Other deployment patterns +- [Dynamo Kubernetes Guide](../../../../../docs/kubernetes/README.md) - Platform setup +- [Installation Guide](../../../../../docs/kubernetes/installation_guide.md) - Platform installation diff --git a/examples/backends/vllm/deploy/lora/agg_lora.yaml b/examples/backends/vllm/deploy/lora/agg_lora.yaml new file mode 100644 index 0000000000..8c446beb69 --- /dev/null +++ b/examples/backends/vllm/deploy/lora/agg_lora.yaml @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +apiVersion: nvidia.com/v1alpha1 +kind: DynamoGraphDeployment +metadata: + name: vllm-agg-lora +spec: + services: + Frontend: + dynamoNamespace: vllm-agg-lora + componentType: frontend + replicas: 1 + extraPodSpec: + mainContainer: + image: nvcr.io/nvidian/dynamo-dev/biswa:7e499b5c460f1883a9945d221123e0760051210f-39500608-vllm-amd64 + VllmDecodeWorker: + envFromSecret: hf-token-secret + dynamoNamespace: vllm-agg-lora + componentType: worker + subComponentType: decode + replicas: 1 + resources: + limits: + gpu: "1" + modelRef: + name: Qwen/Qwen3-0.6B + extraPodSpec: + mainContainer: + image: nvcr.io/nvidian/dynamo-dev/biswa:7e499b5c460f1883a9945d221123e0760051210f-39500608-vllm-amd64 + workingDir: /workspace/examples/backends/vllm + env: + - name: DYN_LORA_ENABLED + value: "true" + - name: DYN_LORA_PATH + value: "/tmp/dynamo_loras_minio" + - name: DYN_SYSTEM_ENABLED + value: "true" + - name: DYN_SYSTEM_PORT + value: "9090" + - name: AWS_ENDPOINT + value: "http://minio:9000" + - name: AWS_ACCESS_KEY_ID + valueFrom: + secretKeyRef: + name: minio-secret + key: AWS_ACCESS_KEY_ID + - name: AWS_SECRET_ACCESS_KEY + valueFrom: + secretKeyRef: + name: minio-secret + key: AWS_SECRET_ACCESS_KEY + - name: AWS_REGION + value: "us-east-1" + - name: AWS_ALLOW_HTTP + value: "true" + - name: BUCKET_NAME + value: "my-loras" + command: + - python3 + - -m + - dynamo.vllm + args: + - --model + - Qwen/Qwen3-0.6B + - --connector + - none + - --enable-lora + - --max-lora-rank + - "64" + - --enforce-eager diff --git a/examples/backends/vllm/deploy/lora/lora-model.yaml b/examples/backends/vllm/deploy/lora/lora-model.yaml new file mode 100644 index 0000000000..f8c7a48f2f --- /dev/null +++ b/examples/backends/vllm/deploy/lora/lora-model.yaml @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +apiVersion: nvidia.com/v1alpha1 +kind: DynamoModel +metadata: + name: codelion-recovery-lora +spec: + modelName: codelion/Qwen3-0.6B-accuracy-recovery-lora + baseModelName: Qwen/Qwen3-0.6B + modelType: lora + source: + uri: s3://my-loras/codelion/Qwen3-0.6B-accuracy-recovery-lora \ No newline at end of file diff --git a/examples/backends/vllm/deploy/lora/minio-secret.yaml b/examples/backends/vllm/deploy/lora/minio-secret.yaml new file mode 100644 index 0000000000..7b14fc5574 --- /dev/null +++ b/examples/backends/vllm/deploy/lora/minio-secret.yaml @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +apiVersion: v1 +kind: Secret +type: Opaque +metadata: + name: minio-secret +stringData: + AWS_ACCESS_KEY_ID: minioadmin + AWS_SECRET_ACCESS_KEY: minioadmin diff --git a/examples/backends/vllm/deploy/lora/sync-lora-job.yaml b/examples/backends/vllm/deploy/lora/sync-lora-job.yaml new file mode 100644 index 0000000000..37779fff0c --- /dev/null +++ b/examples/backends/vllm/deploy/lora/sync-lora-job.yaml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +apiVersion: batch/v1 +kind: Job +metadata: + name: sync-hf-lora-to-minio +spec: + template: + spec: + containers: + - name: uploader + image: python:3.10-slim + command: + - /bin/sh + - -c + - | + set -eux + pip install --no-cache-dir huggingface-hub awscli + hf download $MODEL_NAME --local-dir /tmp/lora + rm -rf /tmp/lora/.cache + aws --endpoint-url=http://minio:9000 s3 mb s3://$LORA_ROOT_PATH || true + aws --endpoint-url=http://minio:9000 s3 sync /tmp/lora s3://$LORA_ROOT_PATH/$MODEL_NAME + envFrom: + - secretRef: + name: hf-token-secret + - secretRef: + name: minio-secret + env: + - name: AWS_REGION # set this to your aws region + value: us-east-1 + - name: AWS_ALLOW_HTTP # remove/disable this if you are using a S3 endpoint or secure MinIO + value: "true" + - name: LORA_ROOT_PATH + value: "my-loras" + - name: MODEL_NAME + value: codelion/Qwen3-0.6B-accuracy-recovery-lora + restartPolicy: Never + backoffLimit: 3 \ No newline at end of file diff --git a/examples/backends/vllm/launch/agg_multimodal.sh b/examples/backends/vllm/launch/agg_multimodal.sh index d016980331..0bcf5edfcf 100755 --- a/examples/backends/vllm/launch/agg_multimodal.sh +++ b/examples/backends/vllm/launch/agg_multimodal.sh @@ -18,6 +18,8 @@ trap 'echo Cleaning up...; kill 0' EXIT MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct" # Parse command line arguments +# Extra arguments are passed through to the vLLM worker +EXTRA_ARGS=() while [[ $# -gt 0 ]]; do case $1 in --model) @@ -25,16 +27,18 @@ while [[ $# -gt 0 ]]; do shift 2 ;; -h|--help) - echo "Usage: $0 [OPTIONS]" + echo "Usage: $0 [OPTIONS] [-- EXTRA_VLLM_ARGS]" echo "Options:" - echo " --model Specify the VLM model to use (default: $MODEL_NAME)" - echo " -h, --help Show this help message" + echo " --model Specify the VLM model to use (default: $MODEL_NAME)" + echo " -h, --help Show this help message" + echo "" + echo "Any additional arguments are passed through to the vLLM worker." + echo "Example: $0 --model Qwen/Qwen3-VL-30B-A3B-Instruct-FP8 --dyn-tool-call-parser hermes" exit 0 ;; *) - echo "Unknown option: $1" - echo "Use --help for usage information" - exit 1 + EXTRA_ARGS+=("$1") + shift ;; esac done @@ -48,20 +52,21 @@ export DYN_REQUEST_PLANE=tcp # dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000) python -m dynamo.frontend & -# Configure GPU memory optimization for specific models -EXTRA_ARGS="" +# Configure GPU memory optimization for specific models (if no extra args override) +MODEL_SPECIFIC_ARGS="" if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then - EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 4096" + MODEL_SPECIFIC_ARGS="--gpu-memory-utilization 0.85 --max-model-len 4096" elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then - EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048" + MODEL_SPECIFIC_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048" fi # Start vLLM worker with vision model # Multimodal data (images) are decoded in the backend worker using ImageLoader # --enforce-eager: Quick deployment (remove for production) # --connector none: No KV transfer needed for aggregated serving +# Extra args from command line come last to allow overrides DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \ - python -m dynamo.vllm --enable-multimodal --model $MODEL_NAME --enforce-eager --connector none $EXTRA_ARGS + python -m dynamo.vllm --enable-multimodal --model $MODEL_NAME --enforce-eager --connector none $MODEL_SPECIFIC_ARGS "${EXTRA_ARGS[@]}" # Wait for all background processes to complete wait diff --git a/examples/backends/vllm/launch/agg_multimodal_epd.sh b/examples/backends/vllm/launch/agg_multimodal_epd.sh index a94ab3c1f4..faf26ff1ea 100755 --- a/examples/backends/vllm/launch/agg_multimodal_epd.sh +++ b/examples/backends/vllm/launch/agg_multimodal_epd.sh @@ -80,7 +80,7 @@ python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_ # run E/P/D workers CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME & -CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS & +CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS & # Wait for all background processes to complete wait diff --git a/examples/backends/vllm/launch/agg_spec_decoding.sh b/examples/backends/vllm/launch/agg_spec_decoding.sh new file mode 100755 index 0000000000..7a30e69342 --- /dev/null +++ b/examples/backends/vllm/launch/agg_spec_decoding.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +set -e +trap 'echo Cleaning up...; kill 0' EXIT + + +# --------------------------- +# 1. Frontend (Ingress) +# --------------------------- +python -m dynamo.frontend --http-port=8000 & + + +# --------------------------- +# 2. Speculative Main Worker +# --------------------------- +# This runs the main model with EAGLE as the draft model for speculative decoding +DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=8081 \ +CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --enforce-eager \ + --speculative_config '{ + "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + "draft_tensor_parallel_size": 1, + "num_speculative_tokens": 2, + "method": "eagle" + }' \ + --connector none \ + --gpu-memory-utilization 0.8 \ No newline at end of file diff --git a/examples/backends/vllm/launch/disagg_multimodal_epd.sh b/examples/backends/vllm/launch/disagg_multimodal_epd.sh index 75b30abb8e..0e253c12be 100755 --- a/examples/backends/vllm/launch/disagg_multimodal_epd.sh +++ b/examples/backends/vllm/launch/disagg_multimodal_epd.sh @@ -81,23 +81,20 @@ python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_ # Configure GPU memory optimization for specific models EXTRA_ARGS="" -if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then - EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len 2048" -fi # Start encode worker -echo "Starting encode worker on GPU 1..." -VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' & +echo "Starting encode worker on GPU 0..." +VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' & # Start prefill worker -echo "Starting prefill worker on GPU 2..." +echo "Starting prefill worker on GPU 1..." VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \ -CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' & +CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' & # Start decode worker -echo "Starting decode worker on GPU 3..." +echo "Starting decode worker on GPU 2..." VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \ -CUDA_VISIBLE_DEVICES=3 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' & +CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' & echo "==================================================" echo "All components started. Waiting for initialization..." diff --git a/examples/backends/vllm/launch/lora/agg_lora_s3.sh b/examples/backends/vllm/launch/lora/agg_lora.sh similarity index 57% rename from examples/backends/vllm/launch/lora/agg_lora_s3.sh rename to examples/backends/vllm/launch/lora/agg_lora.sh index f2444abf51..4bd578613d 100755 --- a/examples/backends/vllm/launch/lora/agg_lora_s3.sh +++ b/examples/backends/vllm/launch/lora/agg_lora.sh @@ -4,14 +4,6 @@ set -e trap 'echo Cleaning up...; kill 0' EXIT -# Follow the README.md instructions to setup MinIO or upload the LoRA to s3/minio -# Adjust these values to match your local MinIO or S3 setup - - -# load math lora to minio -# LORA_NAME=Neural-Hacker/Qwen3-Math-Reasoning-LoRA HF_LORA_REPO=Neural-Hacker/Qwen3-Math-Reasoning-LoRA ./setup_minio.sh - - export AWS_ENDPOINT=http://localhost:9000 export AWS_ACCESS_KEY_ID=minioadmin export AWS_SECRET_ACCESS_KEY=minioadmin @@ -21,8 +13,6 @@ export AWS_ALLOW_HTTP=true # Dynamo LoRA Configuration export DYN_LORA_ENABLED=true export DYN_LORA_PATH=/tmp/dynamo_loras_minio -export DYN_LOG=debug -# export DYN_LOG_LEVEL=debug mkdir -p $DYN_LORA_PATH @@ -35,7 +25,7 @@ DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=8081 \ python -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager \ --connector none \ --enable-lora \ - --max-lora-rank 32 + --max-lora-rank 64 ################################## Example Usage ################################## @@ -43,35 +33,30 @@ DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=8081 \ curl http://localhost:8000/v1/models | jq . # Load LoRA using s3 uri -curl -X POST http://localhost:8081/v1/loras \ - -H "Content-Type: application/json" \ - -d '{ - "lora_name": "Neural-Hacker/Qwen3-Math-Reasoning-LoRA", - "source": { - "uri": "s3://my-loras/Neural-Hacker/Qwen3-Math-Reasoning-LoRA" - } - }' +curl -s -X POST http://localhost:8081/v1/loras \ + -H "Content-Type: application/json" \ + -d '{"lora_name": "codelion/Qwen3-0.6B-accuracy-recovery-lora", + "source": {"uri": "s3://my-loras/codelion/Qwen3-0.6B-accuracy-recovery-lora"}}' | jq . # Test LoRA inference curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "Neural-Hacker/Qwen3-Math-Reasoning-LoRA", - "messages": [{"role": "user", "content": "Solve (x*x - x + 1 = 0) for x"}], + "model": "codelion/Qwen3-0.6B-accuracy-recovery-lora", + "messages": [{"role": "user", "content": "What is deep learning?"}], "max_tokens": 300, "temperature": 0.0 }' -# Find the minimum possible value of \( x^2 + y^2 \) given that \( x \) and \( y \) are real numbers satisfying \( xy(x^2 - y^2) = x^2 + y^2 \) and \( x \neq 0 \) # Test base model inference (for comparison) curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Qwen/Qwen3-0.6B", - "messages": [{"role": "user", "content": "Solve (x*x - x + 1 = 0) for x"}], + "messages": [{"role": "user", "content": "What is deep learning?"}], "max_tokens": 300, "temperature": 0.0 }' # Unload LoRA -curl -X DELETE http://localhost:8081/v1/loras/Neural-Hacker/Qwen3-Math-Reasoning-LoRA +curl -X DELETE http://localhost:8081/v1/loras/codelion/Qwen3-0.6B-accuracy-recovery-lora diff --git a/examples/backends/vllm/launch/lora/agg_lora_router.sh b/examples/backends/vllm/launch/lora/agg_lora_router.sh new file mode 100755 index 0000000000..370301f7be --- /dev/null +++ b/examples/backends/vllm/launch/lora/agg_lora_router.sh @@ -0,0 +1,122 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +set -e +trap 'echo Cleaning up...; kill 0' EXIT + +export AWS_ENDPOINT=http://localhost:9000 +export AWS_ACCESS_KEY_ID=minioadmin +export AWS_SECRET_ACCESS_KEY=minioadmin +export AWS_REGION=us-east-1 +export AWS_ALLOW_HTTP=true + +# Dynamo LoRA Configuration +export DYN_LORA_ENABLED=true +export DYN_LORA_PATH=/tmp/dynamo_loras_minio + +mkdir -p $DYN_LORA_PATH + +# Set deterministic hash for KV event IDs +export PYTHONHASHSEED=0 + +# Common configuration +MODEL="Qwen/Qwen3-0.6B" +BLOCK_SIZE=64 + +# run frontend + KV router +# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000) +python -m dynamo.frontend \ + --router-mode kv \ + --router-reset-states & + +# run workers +# --enforce-eager is added for quick deployment. for production use, need to remove this flag +DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=8082 \ +CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \ + --model $MODEL \ + --block-size $BLOCK_SIZE \ + --enforce-eager \ + --connector none \ + --enable-lora \ + --max-lora-rank 64 \ + --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080","enable_kv_cache_events":true}' & + +DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=8081 \ +VLLM_NIXL_SIDE_CHANNEL_PORT=20097 \ +CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \ + --model $MODEL \ + --block-size $BLOCK_SIZE \ + --enforce-eager \ + --connector none \ + --enable-lora \ + --max-lora-rank 64 \ + --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081","enable_kv_cache_events":true}' + +# below commands are not executed automatically in the script because previous backend launch command is blocking. + +################################## Example Usage ################################## + +# Check available models +curl http://localhost:8000/v1/models | jq . + +# Load LoRA to instances using s3 uri +curl -s -X POST http://localhost:8081/v1/loras \ + -H "Content-Type: application/json" \ + -d '{"lora_name": "codelion/Qwen3-0.6B-accuracy-recovery-lora", + "source": {"uri": "s3://my-loras/codelion/Qwen3-0.6B-accuracy-recovery-lora"}}' | jq . + +curl -s -X POST http://localhost:8082/v1/loras \ + -H "Content-Type: application/json" \ + -d '{"lora_name": "codelion/Qwen3-0.6B-accuracy-recovery-lora", + "source": {"uri": "s3://my-loras/codelion/Qwen3-0.6B-accuracy-recovery-lora"}}' | jq . + + # Test LoRA inference +curl localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "codelion/Qwen3-0.6B-accuracy-recovery-lora", + "messages": [ + { + "role": "user", + "content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden." + } + ], + "stream": false, + "max_tokens": 30 + }' | jq . + + + # Sample output after running above curl request twice. + # usage.prompt_tokens_details.cached_tokens is the number of tokens that were cached from the previous request. +{ + "id": "chatcmpl-0cf880c2-fe98-45c4-9c76-84c3ad1a56cc", + "choices": [ + { + "index": 0, + "message": { + "content": "\nOkay, so I need to develop a character background for a character named Elara. Let me start by understanding the requirements. The user wants", + "role": "assistant", + "reasoning_content": null + }, + "finish_reason": "length" + } + ], + "created": 1765230243, + "model": "codelion/Qwen3-0.6B-accuracy-recovery-lora", + "object": "chat.completion", + "usage": { + "prompt_tokens": 196, + "completion_tokens": 30, + "total_tokens": 226, + "prompt_tokens_details": { + "audio_tokens": null, + "cached_tokens": 192 # tokens that were cached from the previous request. + } + }, + "nvext": { + "worker_id": { + "prefill_worker_id": 7587891281668871552, + "decode_worker_id": 7587891281668871552 + } + } +} \ No newline at end of file diff --git a/examples/backends/vllm/launch/lora/setup_minio.sh b/examples/backends/vllm/launch/lora/setup_minio.sh index fded31795d..0b1668f231 100755 --- a/examples/backends/vllm/launch/lora/setup_minio.sh +++ b/examples/backends/vllm/launch/lora/setup_minio.sh @@ -20,8 +20,8 @@ MINIO_SECRET_KEY="minioadmin" BUCKET_NAME="my-loras" # Default LoRA to download (can be overridden) -HF_LORA_REPO="${HF_LORA_REPO:-Neural-Hacker/Qwen3-Math-Reasoning-LoRA}" -LORA_NAME="${LORA_NAME:-Neural-Hacker/Qwen3-Math-Reasoning-LoRA}" +HF_LORA_REPO="${HF_LORA_REPO:-codelion/Qwen3-0.6B-accuracy-recovery-lora}" +LORA_NAME="${LORA_NAME:-codelion/Qwen3-0.6B-accuracy-recovery-lora}" # TEMP_DIR will be created using mktemp when needed TEMP_DIR="" @@ -63,8 +63,8 @@ show_help() { echo " --help, -h Show this help message" echo "" echo "Environment Variables:" - echo " HF_LORA_REPO Hugging Face repository (default: ${HF_LORA_REPO:-Neural-Hacker/Qwen3-Math-Reasoning-LoRA})" - echo " LORA_NAME Local name for the LoRA (default: ${LORA_NAME:-Neural-Hacker/Qwen3-Math-Reasoning-LoRA})" + echo " HF_LORA_REPO Hugging Face repository (default: ${HF_LORA_REPO:-codelion/Qwen3-0.6B-accuracy-recovery-lora})" + echo " LORA_NAME Local name for the LoRA (default: ${LORA_NAME:-codelion/Qwen3-0.6B-accuracy-recovery-lora})" echo "" echo "Examples:" echo " $0 # Full setup" @@ -173,6 +173,7 @@ download_lora_from_hf() { print_success "LoRA downloaded to ${TEMP_DIR}" + rm -rf "${TEMP_DIR}/.cache" # List downloaded files echo "Downloaded files:" ls -lh "${TEMP_DIR}" diff --git a/examples/multimodal/components/audio_encode_worker.py b/examples/multimodal/components/audio_encode_worker.py index 29a80f6d89..4384ec2e9c 100644 --- a/examples/multimodal/components/audio_encode_worker.py +++ b/examples/multimodal/components/audio_encode_worker.py @@ -25,7 +25,7 @@ import uvloop from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser import dynamo.nixl_connect as connect from dynamo.runtime import Client, DistributedRuntime, dynamo_worker @@ -168,7 +168,7 @@ async def generate( with torch.no_grad(): audio_embeddings = self.get_audio_embeddings(audio_features) descriptor = connect.Descriptor(audio_embeddings) - with self._connector.create_readable(descriptor) as readable: + with await self._connector.create_readable(descriptor) as readable: request.serialized_request = readable.metadata() # Clear the audio URL as hint that the audio is passed as embeddings. request.multimodal_input.audio_url = None @@ -201,7 +201,6 @@ async def async_init(self, runtime: DistributedRuntime): # Create and initialize a dynamo connector for this worker. # We'll needs this to move data between this worker and remote workers efficiently. self._connector = connect.Connector() - await self._connector.initialize() logger.info("Startup completed.") diff --git a/examples/multimodal/components/encode_worker.py b/examples/multimodal/components/encode_worker.py index 42f8c7263e..282e785037 100644 --- a/examples/multimodal/components/encode_worker.py +++ b/examples/multimodal/components/encode_worker.py @@ -12,7 +12,7 @@ import uvloop from transformers import AutoImageProcessor from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser import dynamo.nixl_connect as connect from dynamo.runtime import Client, DistributedRuntime, dynamo_worker @@ -125,7 +125,7 @@ async def generate( request.embeddings_shape = tuple(embeddings.shape) descriptor = connect.Descriptor(embeddings) - with self._connector.create_readable(descriptor) as readable: + with await self._connector.create_readable(descriptor) as readable: request.serialized_request = readable.metadata() # Clear the image URL as hint that the image is passed as embeddings. request.multimodal_input.image_url = None @@ -158,7 +158,6 @@ async def async_init(self, runtime: DistributedRuntime): # Create and initialize a dynamo connector for this worker. # We'll needs this to move data between this worker and remote workers efficiently. self._connector = connect.Connector() - await self._connector.initialize() logger.info("Startup completed.") diff --git a/examples/multimodal/components/processor.py b/examples/multimodal/components/processor.py index 7bc1be7b25..ede65cc975 100644 --- a/examples/multimodal/components/processor.py +++ b/examples/multimodal/components/processor.py @@ -17,8 +17,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest from vllm.outputs import RequestOutput -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import FlexibleArgumentParser +from vllm.tokenizers import TokenizerLike as AnyTokenizer +from vllm.utils.argparse_utils import FlexibleArgumentParser from dynamo.llm import ModelInput, ModelType, register_llm from dynamo.runtime import Client, DistributedRuntime, dynamo_worker diff --git a/examples/multimodal/components/publisher.py b/examples/multimodal/components/publisher.py index c1937fd6c6..19fe18ccff 100644 --- a/examples/multimodal/components/publisher.py +++ b/examples/multimodal/components/publisher.py @@ -38,6 +38,8 @@ def record( scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], engine_idx: int = 0, + *args, + **kwargs, ): pass @@ -74,6 +76,8 @@ def record( scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats], engine_idx: int = 0, + *args, + **kwargs, ): # request_total_slots and kv_total_blocks are properties of model + gpu # we should only publish them once, not every metric update diff --git a/examples/multimodal/components/video_encode_worker.py b/examples/multimodal/components/video_encode_worker.py index 58f6700019..9602c6ed39 100644 --- a/examples/multimodal/components/video_encode_worker.py +++ b/examples/multimodal/components/video_encode_worker.py @@ -16,7 +16,7 @@ import torch import uvloop from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser import dynamo.nixl_connect as connect from dynamo.runtime import Client, DistributedRuntime, dynamo_worker @@ -153,7 +153,7 @@ async def generate( request.embeddings_shape = tuple(tensor_for_descriptor.shape) descriptor = connect.Descriptor(tensor_for_descriptor) - with self._connector.create_readable(descriptor) as readable: + with await self._connector.create_readable(descriptor) as readable: request.serialized_request = readable.metadata() # Clear the image URL as hint that the image is passed as embeddings. request.multimodal_input.video_url = None @@ -199,7 +199,6 @@ async def async_init(self, runtime: DistributedRuntime): # Create and initialize a dynamo connector for this worker. # We'll needs this to move data between this worker and remote workers efficiently. self._connector = connect.Connector() - await self._connector.initialize() logger.info("Startup completed.") diff --git a/examples/multimodal/components/worker.py b/examples/multimodal/components/worker.py index 4e3b7ba43e..d5efa22a85 100644 --- a/examples/multimodal/components/worker.py +++ b/examples/multimodal/components/worker.py @@ -15,7 +15,7 @@ from vllm.distributed.kv_events import ZmqEventPublisher from vllm.inputs.data import TokensPrompt from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.engine.async_llm import AsyncLLM import dynamo.nixl_connect as connect @@ -251,7 +251,6 @@ async def async_init(self, runtime: DistributedRuntime): # We'll needs this to move data between this worker and remote workers efficiently. parsed_namespace, _, _ = parse_endpoint(self.endpoint) self._connector = connect.Connector() - await self._connector.initialize() self.image_loader = ImageLoader() diff --git a/examples/multimodal/launch/audio_agg.sh b/examples/multimodal/launch/audio_agg.sh index 3f1af408b1..0ea01066f0 100755 --- a/examples/multimodal/launch/audio_agg.sh +++ b/examples/multimodal/launch/audio_agg.sh @@ -91,7 +91,7 @@ python3 components/processor.py --model $MODEL_NAME --prompt-template "$PROMPT_T # run E/P/D workers CUDA_VISIBLE_DEVICES=0 python3 components/audio_encode_worker.py --model $MODEL_NAME & -VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=1 python3 components/worker.py --model $MODEL_NAME --worker-type prefill & +VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=0 python3 components/worker.py --model $MODEL_NAME --worker-type prefill & # Wait for all background processes to complete wait diff --git a/examples/multimodal/utils/args.py b/examples/multimodal/utils/args.py index 3fe10ee0b1..df6ce698da 100644 --- a/examples/multimodal/utils/args.py +++ b/examples/multimodal/utils/args.py @@ -159,6 +159,8 @@ def overwrite_args(config): "enable_prefix_caching": True, # KV routing relies on logging KV metrics "disable_log_stats": False, + # Enable multimodal embeddings input + "enable_mm_embeds": True, # Always setting up kv transfer for disagg "kv_transfer_config": KVTransferConfig( kv_connector="NixlConnector", kv_role="kv_both" diff --git a/examples/multimodal/utils/chat_processor.py b/examples/multimodal/utils/chat_processor.py index fe8d95dc81..3a693131d9 100644 --- a/examples/multimodal/utils/chat_processor.py +++ b/examples/multimodal/utils/chat_processor.py @@ -28,9 +28,22 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_engine import RequestPrompt +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.inputs.data import TokensPrompt from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.tokenizers import TokenizerLike as AnyTokenizer + + +class StubEngineClient: + """ + Stub EngineClient for preprocessing-only use of OpenAIServingChat/Completion. + Provides the minimal attributes required by OpenAIServingModels. + """ + + def __init__(self, model_config: ModelConfig): + self.model_config = model_config + self.input_processor = None + self.io_processor = None @runtime_checkable @@ -120,12 +133,19 @@ class ChatProcessor: def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig): self.tokenizer = tokenizer self.model_config = model_config + # Create stub engine client and models for preprocessing-only usage + stub_engine = StubEngineClient(model_config) + serving_models = OpenAIServingModels( + engine_client=stub_engine, + base_model_paths=[ + BaseModelPath(name=model_config.model, model_path=model_config.model) + ], + ) self.openai_serving = OpenAIServingChat( - engine_client=None, - model_config=model_config, - models=None, - request_logger=None, + engine_client=stub_engine, + models=serving_models, response_role="assistant", + request_logger=None, chat_template=None, chat_template_content_format="auto", ) @@ -186,7 +206,6 @@ async def stream_response( conversation, self.tokenizer, request_metadata, - enable_force_include_usage=False, ): if raw_response.startswith("data: [DONE]"): yield raw_response @@ -220,7 +239,6 @@ async def stream_response( conversation, self.tokenizer, request_metadata, - enable_force_include_usage=False, ): if raw_response.startswith("data: [DONE]"): break @@ -267,10 +285,17 @@ class CompletionsProcessor: def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig): self.tokenizer = tokenizer self.model_config = model_config + # Create stub engine client and models for preprocessing-only usage + stub_engine = StubEngineClient(model_config) + serving_models = OpenAIServingModels( + engine_client=stub_engine, + base_model_paths=[ + BaseModelPath(name=model_config.model, model_path=model_config.model) + ], + ) self.openai_serving = OpenAIServingCompletion( - engine_client=None, - model_config=model_config, - models=None, + engine_client=stub_engine, + models=serving_models, request_logger=None, ) diff --git a/examples/multimodal/utils/protocol.py b/examples/multimodal/utils/protocol.py index c31dd82799..a724b8720d 100644 --- a/examples/multimodal/utils/protocol.py +++ b/examples/multimodal/utils/protocol.py @@ -26,7 +26,7 @@ from vllm.multimodal.inputs import MultiModalUUIDDict # noqa: F401 from vllm.outputs import CompletionOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import RequestMetrics +from vllm.v1.metrics.stats import RequestStateStats import dynamo.nixl_connect as connect @@ -166,7 +166,7 @@ class MyRequestOutput(BaseModel): https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85 This class is used to serialize the RequestOutput and any recursively defined types - We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses + We can do this because PromptLogprobs, RequestStateStats, and CompletionOutput are all serializable dataclasses """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -177,7 +177,7 @@ class MyRequestOutput(BaseModel): prompt_logprobs: Optional[PromptLogprobs] = None outputs: List[CompletionOutput] finished: bool - metrics: Optional[RequestMetrics] = None + metrics: Optional[RequestStateStats] = None kv_transfer_params: Optional[dict[str, Any]] = None # lora_request: Optional[LoRARequest] = None # encoder_prompt: Optional[str] = None diff --git a/lib/bindings/c/src/lib.rs b/lib/bindings/c/src/lib.rs index 1eee40eb69..498b21dd39 100644 --- a/lib/bindings/c/src/lib.rs +++ b/lib/bindings/c/src/lib.rs @@ -1031,7 +1031,7 @@ pub async fn create_worker_selection_pipeline_chat( // Create worker monitor if busy_threshold is set // Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this - let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(Arc::new(client.clone()), t)); + let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(client.clone(), t)); let engine = build_routed_pipeline::< NvCreateChatCompletionRequest, diff --git a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/dynamo_connector.py b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/dynamo_connector.py index 8d06db7055..bfe371b41e 100644 --- a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/dynamo_connector.py +++ b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/dynamo_connector.py @@ -23,6 +23,7 @@ from vllm.config import VllmConfig from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request @@ -40,8 +41,15 @@ def __init__(self, metadata: bytes): class DynamoConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None @@ -90,13 +98,19 @@ def request_finished( def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._worker.register_kv_caches(kv_caches) + @override def bind_connector_metadata( self, connector_metadata: DynamoConnectorMetadata ) -> None: + # Must call super() to set _connector_metadata so has_connector_metadata() returns True + # This is required for save_kv_layer to be called during the forward pass + super().bind_connector_metadata(connector_metadata) assert isinstance(connector_metadata.metadata, bytes) self._worker.bind_connector_metadata(connector_metadata.metadata) + @override def clear_connector_metadata(self) -> None: + super().clear_connector_metadata() self._worker.clear_connector_metadata() @override diff --git a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/pd_connector.py b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/pd_connector.py index ceea2917ba..461815a7e7 100644 --- a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/pd_connector.py +++ b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector/pd_connector.py @@ -4,7 +4,10 @@ from typing import TYPE_CHECKING, Optional, Type from kvbm.vllm_integration.connector.dynamo_connector import DynamoConnector -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorHandshakeMetadata, + KVConnectorRole, +) from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( MultiConnector, MultiKVConnectorMetadata, @@ -29,6 +32,7 @@ LMCacheConnectorV1, ) from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request @@ -46,8 +50,15 @@ class PdConnector(MultiConnector): - The second connector must be NIXL and will be used by decode worker to get KV blocks from prefill worker. """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) if len(self._connectors) != 2: raise ValueError( f"PdConnector requires exactly two connectors (got {len(self._connectors)})" @@ -76,6 +87,18 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # Worker-side methods # ============================== + def set_xfer_handshake_metadata( + self, metadata: dict[int, KVConnectorHandshakeMetadata] + ) -> None: + """ + Propagate handshake metadata to child connectors. + + This is required for NIXL connector to start its handshake listener + which decode workers connect to for KV transfer coordination. + """ + for c in self._connectors: + c.set_xfer_handshake_metadata(metadata) + def bind_connector_metadata(self, connector_metadata: PdConnectorMetadata) -> None: assert isinstance(connector_metadata, PdConnectorMetadata) if connector_metadata.extra_async_saves: diff --git a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector_worker.py b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector_worker.py index 3d2532602d..ef791d36ed 100644 --- a/lib/bindings/kvbm/python/kvbm/vllm_integration/connector_worker.py +++ b/lib/bindings/kvbm/python/kvbm/vllm_integration/connector_worker.py @@ -14,7 +14,7 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.model_executor.models.utils import extract_layer_index -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata diff --git a/lib/bindings/kvbm/src/block_manager/vllm/connector/leader.rs b/lib/bindings/kvbm/src/block_manager/vllm/connector/leader.rs index 3b6151f7b9..2523a67a68 100644 --- a/lib/bindings/kvbm/src/block_manager/vllm/connector/leader.rs +++ b/lib/bindings/kvbm/src/block_manager/vllm/connector/leader.rs @@ -526,22 +526,33 @@ impl Leader for KvConnectorLeader { // remove the request from the inflight requests self.inflight_requests.remove(&request_id); - // if the slot has finished, we can return false to vllm, indicating all gpu blocks are free to be reused - // otherwise, we return true, which means there are still outstanding operations on gpu blocks which - // must be awaited before the gpu blocks can be reused. if we return true, then it is the worker side - // of the connector api which will be used to inform vllm that the request is finished. + // Return value semantics: + // - `false`: Tells vLLM all GPU blocks are free and the request can be fully cleaned up. + // vLLM will immediately remove the request from its internal hash table. + // - `true`: Tells vLLM there are outstanding async operations on GPU blocks. + // The worker side of the connector API will later call `finish_requests()` + // to notify vLLM when the request is truly complete. + // + // TODO(jthomson04): This is a temporary fix to ensure vLLM 0.11.2 compatibility. + // IMPORTANT: We must ALWAYS return `true` here, even when the slot is already Finished. + // + // Why? If we return `false`, vLLM removes the request from `self.requests` immediately. + // However, our worker connector may still report completion later via `finish_requests()`. + // When that happens, vLLM's scheduler.py has an assertion `req_id in self.requests` + // that will fail because the request was already removed from the hash table. + // + // By always returning `true`, we ensure vLLM keeps the request in its hash table until + // our worker explicitly signals completion, avoiding the race condition. + // + // If the slot is already Finished (no pending operations), we clean it up from our side + // but still return `true` so vLLM waits for the worker's completion signal. if let SlotState::Finished = slot.state() { - // All operations complete - safe to remove slot and tell vLLM blocks are free self.slot_manager().remove_slot(&request_id)?; - Ok(false) } else { debug_assert!(matches!(slot.state(), SlotState::Finishing)); - // Still has pending operations - keep slot alive for worker to process - // Don't remove slot here. Worker needs it to process the finish event. - // Worker will remove it after verifying all operations are complete. - // The lock on the slot prevents new operations from being created in offload_blocks() - Ok(true) } + + Ok(true) } fn has_slot(&self, request_id: String) -> bool { diff --git a/lib/bindings/kvbm/src/block_manager/vllm/connector/worker.rs b/lib/bindings/kvbm/src/block_manager/vllm/connector/worker.rs index 1b12d28cad..a80760973c 100644 --- a/lib/bindings/kvbm/src/block_manager/vllm/connector/worker.rs +++ b/lib/bindings/kvbm/src/block_manager/vllm/connector/worker.rs @@ -278,11 +278,6 @@ impl Worker for KvConnectorWorker { self.maybe_finished_onboarding.insert(request_id); } - // delay offloading operations until the end of the forward pass - debug_assert!( - self.offloading_operations.is_empty(), - "offloading operations should be empty" - ); self.offloading_operations = offloading_operations; Ok(()) @@ -304,15 +299,34 @@ impl Worker for KvConnectorWorker { /// Trigger block-wise completion signals afer last layer. fn save_kv_layer(&mut self, _layer_name: String) -> anyhow::Result<()> { self.layers_complete += 1; + tracing::debug!( + iteration = self.iteration, + layers_complete = self.layers_complete, + total_layers = self.kv_cache_layers.len(), + pending_offload_ops = self.offloading_operations.len(), + "save_kv_layer called" + ); if self.layers_complete == self.kv_cache_layers.len() { let offloading_operations = std::mem::take(&mut self.offloading_operations); + tracing::info!( + iteration = self.iteration, + num_operations = offloading_operations.len(), + "All layers complete, enqueuing {} offload operations", + offloading_operations.len() + ); + // block on the the completion of the last layer // todo(ryan): capture the context, pass this to the scheduler to do the await on another thread // or put the event on a stream and use stream waits to keep it all on device. event_sync_blocking(self.layer_events[self.layers_complete - 1]); - for operation in offloading_operations { - self.connector.enqueue_request(operation); + for operation in &offloading_operations { + tracing::debug!( + request_id = %operation.request_id, + operation_id = %operation.uuid, + "Enqueuing offload operation to scheduler" + ); + self.connector.enqueue_request(operation.clone()); } } Ok(()) diff --git a/lib/bindings/python/pyproject.toml b/lib/bindings/python/pyproject.toml index a27b20bdda..69ef03ffec 100644 --- a/lib/bindings/python/pyproject.toml +++ b/lib/bindings/python/pyproject.toml @@ -26,7 +26,7 @@ license = { text = "Apache-2.0" } license-files = ["LICENSE"] requires-python = ">=3.10" dependencies = [ - "pydantic>=2.10.6,<=2.11.7", + "pydantic>=2.10.6,<=2.13", "uvloop>=0.21.0", ] classifiers = [ diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index d8f2e785c3..78e04725f6 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -276,6 +276,8 @@ fn register_llm<'p>( ModelInput::Tensor => llm_rs::model_type::ModelInput::Tensor, }; + let is_tensor_based = model_type.inner.supports_tensor(); + let model_type_obj = model_type.inner; let inner_path = model_path.to_string(); @@ -323,7 +325,33 @@ fn register_llm<'p>( .or_else(|| Some(source_path.clone())); pyo3_async_runtimes::tokio::future_into_py(py, async move { - // Resolve the model path (local or fetch from HuggingFace) + // For TensorBased models, skip HuggingFace downloads and register directly + if is_tensor_based { + let model_name = model_name.unwrap_or_else(|| source_path.clone()); + let mut card = llm_rs::model_card::ModelDeploymentCard::with_name_only(&model_name); + card.model_type = model_type_obj; + card.model_input = model_input; + card.user_data = user_data_json; + + if let Some(cfg) = runtime_config { + card.runtime_config = cfg.inner; + } + + // Register the Model Deployment Card via discovery interface + let discovery = endpoint.inner.drt().discovery(); + let spec = rs::discovery::DiscoverySpec::from_model( + endpoint.inner.component().namespace().name().to_string(), + endpoint.inner.component().name().to_string(), + endpoint.inner.name().to_string(), + &card, + ) + .map_err(to_pyerr)?; + discovery.register(spec).await.map_err(to_pyerr)?; + + return Ok(()); + } + + // For non-TensorBased models, resolve the model path (local or fetch from HuggingFace) let model_path = if fs::exists(&source_path)? { PathBuf::from(&source_path) } else { @@ -596,6 +624,84 @@ impl DistributedRuntime { CancellationToken { inner } } + /// Register an async Python callback for /engine/{route_name} + /// + /// Args: + /// route_name: Route path (e.g., "start_profile" โ†’ /engine/start_profile) + /// callback: Async function with signature: async def(body: dict) -> dict + /// + /// Example: + /// ```python + /// async def start_profile(body: dict) -> dict: + /// await engine.start_profile(**body) + /// return {"status": "ok"} + /// + /// runtime.register_engine_route("start_profile", start_profile) + /// ``` + #[pyo3(signature = (route_name, callback))] + fn register_engine_route( + &self, + py: Python<'_>, + route_name: String, + callback: PyObject, + ) -> PyResult<()> { + // Capture TaskLocals at registration time when Python's event loop is running. + // This is needed because later, when the callback is invoked from an HTTP request, + // we'll be on a Rust thread without a running Python event loop. + let locals = + Arc::new(pyo3_async_runtimes::tokio::get_current_locals(py).map_err(to_pyerr)?); + let callback = Arc::new(callback); + + // Wrap Python async callback in Rust async closure + let rust_callback: rs::engine_routes::EngineRouteCallback = + Arc::new(move |body: serde_json::Value| { + let callback = callback.clone(); + let locals = locals.clone(); + + // Return a boxed future + Box::pin(async move { + // Acquire GIL to call Python callback and convert coroutine to future + let py_future = Python::with_gil(|py| { + // Convert body to Python dict + let py_body = pythonize::pythonize(py, &body).map_err(|e| { + anyhow::anyhow!("Failed to convert request body to Python: {}", e) + })?; + + // Call Python async function to get a coroutine + let coroutine = callback.call1(py, (py_body,)).map_err(|e| { + anyhow::anyhow!("Failed to call Python callback: {}", e) + })?; + + // Use the TaskLocals captured at registration time + pyo3_async_runtimes::into_future_with_locals( + &locals, + coroutine.into_bound(py), + ) + .map_err(|e| { + anyhow::anyhow!("Failed to convert coroutine to future: {}", e) + }) + })?; + + // Await the Python coroutine (GIL is released during await) + let py_result = py_future + .await + .map_err(|e| anyhow::anyhow!("Python callback failed: {}", e))?; + + // Convert result back to serde_json::Value + Python::with_gil(|py| { + pythonize::depythonize::(py_result.bind(py)) + .map_err(|e| anyhow::anyhow!("Failed to serialize response: {}", e)) + }) + }) + }); + + self.inner + .engine_routes() + .register(&route_name, rust_callback); + tracing::debug!("Registered engine route: /engine/{}", route_name); + Ok(()) + } + // This is used to pass the DistributedRuntime from the dynamo-runtime bindings // to the KVBM bindings, since KVBM cannot directly use the struct from this cdylib. // TODO: Create a separate crate "dynamo-python" so that all binding crates can import diff --git a/lib/bindings/python/rust/llm/kv.rs b/lib/bindings/python/rust/llm/kv.rs index 986a95464d..e4802083ba 100644 --- a/lib/bindings/python/rust/llm/kv.rs +++ b/lib/bindings/python/rust/llm/kv.rs @@ -21,7 +21,7 @@ use rs::traits::events::EventSubscriber; use tracing; use llm_rs::kv_router::protocols::*; -use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks}; +use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener}; use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; #[pyfunction] @@ -106,6 +106,9 @@ pub struct ZmqKvEventPublisherConfig { pub zmq_endpoint: String, #[pyo3(get, set)] pub zmq_topic: String, + #[pyo3(get, set)] + pub enable_local_indexer: bool, // whether the underlying KvEventPublisher publishes to + // both global and worker-local KvIndexers } #[pymethods] @@ -115,19 +118,22 @@ impl ZmqKvEventPublisherConfig { worker_id, kv_block_size, zmq_endpoint = "tcp://127.0.0.1:5557".to_string(), - zmq_topic = "".to_string() + zmq_topic = "".to_string(), + enable_local_indexer = false ))] pub fn new( worker_id: WorkerId, kv_block_size: usize, zmq_endpoint: String, zmq_topic: String, + enable_local_indexer: bool, ) -> Self { Self { worker_id, kv_block_size, zmq_endpoint, zmq_topic, + enable_local_indexer, } } } @@ -141,13 +147,14 @@ pub(crate) struct ZmqKvEventPublisher { impl ZmqKvEventPublisher { #[new] fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult { - let inner = llm_rs::kv_router::publisher::KvEventPublisher::new( + let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer( component.inner, config.kv_block_size as u32, Some(KvEventSourceConfig::Zmq { endpoint: config.zmq_endpoint, topic: config.zmq_topic, }), + config.enable_local_indexer, ) .map_err(to_pyerr)?; Ok(Self { inner }) @@ -179,7 +186,7 @@ impl ZmqKvEventListener { let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); let shutdown_token = tokio_util::sync::CancellationToken::new(); - tokio::spawn(llm_rs::kv_router::publisher::start_zmq_listener( + tokio::spawn(start_zmq_listener( zmq_endpoint, zmq_topic, tx, diff --git a/lib/bindings/python/rust/llm/local_model.rs b/lib/bindings/python/rust/llm/local_model.rs index 15fb24f373..3917c7a089 100644 --- a/lib/bindings/python/rust/llm/local_model.rs +++ b/lib/bindings/python/rust/llm/local_model.rs @@ -49,6 +49,11 @@ impl ModelRuntimeConfig { self.inner.data_parallel_size = data_parallel_size; } + #[setter] + fn set_enable_local_indexer(&mut self, enable_local_indexer: bool) { + self.inner.enable_local_indexer = enable_local_indexer; + } + fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> { let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?; self.inner @@ -103,6 +108,11 @@ impl ModelRuntimeConfig { self.inner.reasoning_parser.clone() } + #[getter] + fn enable_local_indexer(&self) -> bool { + self.inner.enable_local_indexer + } + #[getter] fn runtime_data(&self, py: Python<'_>) -> PyResult { let dict = PyDict::new(py); diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index 22841ef6e4..1a0d1913aa 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -5,6 +5,7 @@ from typing import ( Any, AsyncGenerator, AsyncIterator, + Awaitable, Callable, Dict, List, @@ -57,6 +58,32 @@ class DistributedRuntime: """ ... + def register_engine_route( + self, + route_name: str, + callback: Callable[[dict], Awaitable[dict]], + ) -> None: + """ + Register an async callback for /engine/{route_name} on the system status server. + + Args: + route_name: The route path (e.g., "start_profile" creates /engine/start_profile) + callback: Async function with signature: async def(body: dict) -> dict + + Example: + async def start_profile(body: dict) -> dict: + await engine.start_profile(**body) + return {"status": "ok", "message": "Profiling started"} + + runtime.register_engine_route("start_profile", start_profile) + + The callback receives the JSON request body as a dict and should return + a dict that will be serialized as the JSON response. + + For GET requests or empty bodies, an empty dict {} is passed. + """ + ... + class CancellationToken: def cancel(self) -> None: """ @@ -433,6 +460,7 @@ class ModelRuntimeConfig: max_num_batched_tokens: int | None tool_call_parser: str | None reasoning_parser: str | None + enable_local_indexer: bool runtime_data: dict[str, Any] tensor_model_config: Any | None @@ -816,7 +844,8 @@ class ZmqKvEventPublisherConfig: worker_id: int, kv_block_size: int, zmq_endpoint: str = "tcp://127.0.0.1:5557", - zmq_topic: str = "" + zmq_topic: str = "", + enable_local_indexer: bool = False ) -> None: """ Configuration for the ZmqKvEventPublisher. @@ -825,6 +854,7 @@ class ZmqKvEventPublisherConfig: :param kv_block_size: The block size for the key-value store. :param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557". :param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string. + :param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to False. """ ... @@ -1077,6 +1107,10 @@ async def register_llm( Providing only one of these parameters will raise a ValueError. - `lora_name`: The served model name for the LoRA model - `base_model_path`: Path to the base model that the LoRA extends + + For TensorBased models (using ModelInput.Tensor), HuggingFace downloads are skipped + and a minimal model card is registered directly. Use model_path as the display name + for these models. """ ... diff --git a/lib/bindings/python/src/dynamo/nixl_connect/__init__.py b/lib/bindings/python/src/dynamo/nixl_connect/__init__.py index 59c7f31e48..6b7678ffbb 100644 --- a/lib/bindings/python/src/dynamo/nixl_connect/__init__.py +++ b/lib/bindings/python/src/dynamo/nixl_connect/__init__.py @@ -69,15 +69,15 @@ class AbstractOperation(ABC): def __init__( self, - connector: Connector, + connection: Connection, operation_kind: OperationKind, local_descriptors: Descriptor | list[Descriptor], remote_descriptors: Optional[Descriptor | list[Descriptor]], notification_key: Optional[str], ) -> None: - if not isinstance(connector, Connector): + if not isinstance(connection, Connection): raise TypeError( - "Argument `connector` must be `dynamo.nixl_connect.Connector`." + "Argument `connection` must be `dynamo.nixl_connect.Connection`." ) if ( operation_kind is not OperationKind.READ @@ -126,7 +126,7 @@ def __init__( self._notification_key: str = ( "" if notification_key is None else notification_key ) - self._connector: Connector = connector + self._connection: Connection = connection self._operation_kind: OperationKind = operation_kind self._local_desc_list: Descriptor | list[Descriptor] = local_descriptors self._local_desc_tlist: Optional[list[tuple[int, int, int]]] = None @@ -141,9 +141,15 @@ def __init__( # Note: Only local descriptors should be registered with NIXL, if isinstance(local_descriptors, list): for d in local_descriptors: - d.register_memory(self._connector) + d.register_with_connector(self._connection) + logger.debug( + f"dynamo.nixl_connect.{self.__class__.__name__}: Registered descriptor {d} with connector {self._connection}." + ) else: - local_descriptors.register_memory(self._connector) + local_descriptors.register_with_connector(self._connection) + logger.debug( + f"dynamo.nixl_connect.{self.__class__.__name__}: Registered descriptor {local_descriptors} with connector {self._connection}." + ) # Record local descriptors. device_kind, desc_tlist = self._create_desc_tlist(local_descriptors) @@ -166,14 +172,32 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self._release() def _release(self) -> None: - pass + """ + Private method to release resources. + """ + # Deregister local descriptors from NIXL, allowing them to reused by a future operation. + if isinstance(self._local_desc_list, list): + for d in self._local_desc_list: + if d.is_registered: + d.deregister_with_connector(self._connection) + else: + logger.debug( + f"dynamo.nixl_connect.{self.__class__.__name__}: Descriptor {d} was not registered, skipping deregistration." + ) + else: + if self._local_desc_list.is_registered: + self._local_desc_list.deregister_with_connector(self._connection) + else: + logger.debug( + f"dynamo.nixl_connect.{self.__class__.__name__}: Descriptor {self._local_desc_list} was not registered, skipping deregistration." + ) @property - def connector(self) -> Connector: + def connection(self) -> Connection: """ - Gets the local associated with this operation. + Gets the local connection associated with this operation. """ - return self._connector + return self._connection @property def operation_kind(self) -> OperationKind: @@ -230,7 +254,7 @@ def __init__( remote_descriptors: Descriptor | list[Descriptor], notification_key: str, ) -> None: - if not isinstance(remote, Remote) or remote._connector is None: + if not isinstance(remote, Remote) or remote._connection is None: raise TypeError( "Argument `remote` must be valid `dynamo.nixl_connect.Remote`." ) @@ -303,7 +327,7 @@ def __init__( self._status = OperationStatus.UNINITIALIZED super().__init__( - remote.connector, + remote.connection, operation_kind, local_descriptors, remote_descriptors, @@ -317,21 +341,21 @@ def __init__( self._remote_xfer_descs: Optional[nixl_bindings.nixlXferDList] = None self._xfer_hndl: Optional[nixl_api.nixl_xfer_handle] = None - self._local_xfer_descs = self._connector._nixl.get_xfer_descs( + self._local_xfer_descs = self._connection._nixl.get_xfer_descs( descs=self._local_desc_tlist, mem_type=str(self._local_device_kind), ) logger.debug( f"dynamo.nixl_connect.{self.__class__.__name__}: Created local NIXL transfer descriptors: {self._local_xfer_descs}" ) - self._remote_xfer_descs = self._connector._nixl.get_xfer_descs( + self._remote_xfer_descs = self._connection._nixl.get_xfer_descs( descs=self._remote_desc_tlist, mem_type=str(self._remote_device_kind), ) logger.debug( f"dynamo.nixl_connect.{self.__class__.__name__}: Created remote NIXL transfer descriptors: {self._remote_xfer_descs}" ) - self._xfer_hndl = self._connector._nixl.initialize_xfer( + self._xfer_hndl = self._connection._nixl.initialize_xfer( operation=str(operation_kind), local_descs=self._local_xfer_descs, remote_descs=self._remote_xfer_descs, @@ -380,7 +404,7 @@ def _release(self) -> None: logger.debug( f"dynamo.nixl_connect.{self.__class__.__name__}: NIXL transfer handle {self._xfer_hndl} released." ) - self._connector._nixl.release_xfer_handle(self._xfer_hndl) + self._connection._nixl.release_xfer_handle(self._xfer_hndl) except Exception as e: logger.error( f"dynamo.nixl_connect.{self.__class__.__name__}: Failed to release resources: {e}" @@ -413,7 +437,7 @@ def _cancel_(self) -> None: ) # NIXL will cancel the transfer if it is in progress when the handle is released. - self._connector._nixl.release_xfer_handle(self._xfer_hndl) + self._connection._nixl.release_xfer_handle(self._xfer_hndl) self._status = OperationStatus.CANCELLED self._xfer_hndl = None @@ -467,7 +491,7 @@ def status(self) -> OperationStatus: old_status = self._status if self._status == OperationStatus.UNINITIALIZED: - state = self._connector._nixl.transfer( + state = self._connection._nixl.transfer( self._xfer_hndl, self._notification_key.encode("utf-8"), ) @@ -481,7 +505,7 @@ def status(self) -> OperationStatus: else: self._status = OperationStatus.INITIALIZED else: - state = self._connector._nixl.check_xfer_state(self._xfer_hndl) + state = self._connection._nixl.check_xfer_state(self._xfer_hndl) logger.debug( f"dynamo.nixl_connect.{self.__class__.__name__}: NIXL reported transfer state: {state}" ) @@ -500,6 +524,90 @@ def status(self) -> OperationStatus: return self._status +class Connection: + def __init__(self, connector: Connector, number: int): + """ + Creates a new Connection instance. + + Parameters + ---------- + connector : Connector + The connector associated with this connection. + number : int + The connection number. + Used to create a unique name for the connection. + + Raises + ------ + TypeError + When `connector` is provided and not of type `dynamo.nixl_connect.Connector`. + TypeError + When `number` is provided and not of type `int`. + ValueError + When `number` is provided and not greater than 0. + """ + if not isinstance(connector, Connector): + raise TypeError( + "Argument `connector` must be `dynamo.nixl_connect.Connector`." + ) + if not isinstance(number, int): + raise TypeError("Argument `number` must be of type `int`.") + if number <= 0: + raise ValueError("Argument `number` must be greater than 0.") + + self._connector: Connector = connector + self._is_initialized = False + self._name = f"{connector.name}-{number}" + self._nixl = nixl_api.nixl_agent(self._name) + + logger.debug( + f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}." + ) + + def __repr__(self) -> str: + return str( + f"{self.__class__.__name__}(" + f"is_initialized={self._is_initialized}, " + f"name='{self._name}'" + ")" + ) + + def __str__(self) -> str: + return self._name + + @property + def connector(self) -> Connector: + """ + Get the connector associated with this connection. + """ + return self._connector + + @property + def metadata(self) -> bytes: + """ + Get the metadata of the connection. + """ + return self._nixl.get_agent_metadata() + + @property + def name(self) -> str | None: + """ + Get the name of the connection. + """ + return self._name + + async def initialize(self) -> None: + # Only initialize the connection once. + if self._is_initialized: + return + + self._is_initialized = True + # This method is a no-op for now, in the future it may be used to initialize the connection. + logger.debug( + f"dynamo.nixl_connect.{self.__class__.__name__}: Initialized {{ name: '{self._name}' }} completed." + ) + + class Connector: """ Core class for managing the connection between workers in a distributed environment. @@ -529,28 +637,42 @@ def __init__( if not isinstance(worker_id, str) or len(worker_id) == 0: raise TypeError("Argument `worker_id` must be a non-empty `str` or `None`.") + self._connection_count: int = 0 self._worker_id = worker_id - self._is_initialized = False - self._nixl = nixl_api.nixl_agent(self._worker_id) self._hostname = socket.gethostname() - self._agent_metadata: Optional[bytes] = None logger.debug( f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}." ) + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Connector): + return False + return self._worker_id == other._worker_id + + def __ne__(self, value: object) -> bool: + if not isinstance(value, Connector): + return True + return self._worker_id != value._worker_id + def __repr__(self) -> str: return str( f"{self.__class__.__name__}(" f"worker_id='{self._worker_id}', " - f"hostname={self._hostname}, " - f"metadata=<{0 if self._agent_metadata is None else len(self._agent_metadata)} bytes>" + f"hostname={self._hostname}" ")" ) def __str__(self) -> str: return self._worker_id + @property + def hostname(self) -> str: + """ + Get the name of the current worker's host. + """ + return self._hostname + @cached_property def is_cuda_available(self) -> bool: # Note: `cuda.is_available` initializes CUDA @@ -562,13 +684,6 @@ def is_cuda_available(self) -> bool: except CUDARuntimeError: return False - @property - def metadata(self) -> bytes: - """ - Get the metadata of the worker. - """ - return self._nixl.get_agent_metadata() - @property def name(self) -> str | None: """ @@ -620,12 +735,8 @@ async def begin_read( "Cannot create a `dynamo.nixl_connect.ReadOperation` to read from a remote `dynamo.nixl_connect.WritableOperation`." ) - if not self._is_initialized: - raise RuntimeError( - "Connector not initialized. Call `initialize()` before calling this method." - ) - - op = ReadOperation(self, remote_metadata, local_descriptors) + conn = await self._create_connection() + op = ReadOperation(conn, remote_metadata, local_descriptors) return op async def begin_write( @@ -655,22 +766,18 @@ async def begin_write( raise TypeError( "Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`." ) - if remote_metadata.operation_kind != OperationKind.WRITE: + if remote_metadata.operation_kind != OperationKind.WRITE.value: raise RuntimeError( "Cannot create a `WriteOperation` to write to a remote `ReadableOperation`." ) if not isinstance(remote_metadata.nixl_metadata, str): raise TypeError("Argument `remote_metadata.nixl_metadata` must be `str`.") - if not self._is_initialized: - raise RuntimeError( - "Connector not initialized. Call `initialize()` before calling this method." - ) - - op = WriteOperation(self, local_descriptors, remote_metadata) + conn = await self._create_connection() + op = WriteOperation(conn, local_descriptors, remote_metadata) return op - def create_readable( + async def create_readable( self, local_descriptors: Descriptor | list[Descriptor], ) -> ReadableOperation: @@ -682,15 +789,11 @@ def create_readable( ReadableOperation A readable operation that can be used to transfer data from a remote worker. """ - if not self._is_initialized: - raise RuntimeError( - "Connector not initialized. Call `initialize()` before calling this method." - ) - - op = ReadableOperation(self, local_descriptors) + conn = await self._create_connection() + op = ReadableOperation(conn, local_descriptors) return op - def create_writable( + async def create_writable( self, local_descriptors: Descriptor | list[Descriptor], ) -> WritableOperation: @@ -702,25 +805,27 @@ def create_writable( WritableOperation A writable operation that can be used to transfer data to a remote worker. """ - if not self._is_initialized: - raise RuntimeError( - "Connector not initialized. Call `initialize()` before calling this method." - ) - - op = WritableOperation(self, local_descriptors) + conn = await self._create_connection() + op = WritableOperation(conn, local_descriptors) return op async def initialize(self) -> None: - # Only initialize the connector once. - if self._is_initialized: - return - - self._is_initialized = True - # This method is a no-op for now, in the future it may be used to initialize the connector. + """ + Deprecated method. + """ logger.debug( - f"dynamo.nixl_connect.{self.__class__.__name__}: Initialized {{ name: '{self._worker_id}' }} completed." + f"dynamo.nixl_connect.{self.__class__.__name__}: Initialized {{ name: '{self._worker_id}' }} (This method is deprecated)." ) + async def _create_connection(self) -> Connection: + """ + Private method to create a new connection. + """ + self._connection_count += 1 + conn = Connection(self, self._connection_count) + await conn.initialize() + return conn + class Descriptor: """ @@ -784,7 +889,8 @@ def __init__( # Member fields for managing NIXL memory registration. # Note: ONLY local descriptors should be registered with NIXL, # remote descriptors do not have a valid memory address and registration will fault. - self._connector: Optional[Connector] = None + + self._connection: Optional[Connection] = None self._nixl_hndl: Optional[nixl_bindings.nixlRegDList] = None # Initially `None` cached serialized descriptor reference, populated when `get_metadata()` is called. @@ -865,10 +971,11 @@ def __init__( raise TypeError(TYPE_ERROR_MESSAGE) def __del__(self) -> None: - if self._nixl_hndl is not None and self._connector is not None: - # Unregister the memory with NIXL. - self._connector._nixl.deregister_memory(self._nixl_hndl) + if not (self._nixl_hndl is None or self._connection is None): + # Deregister the memory with NIXL. + self._connection._nixl.deregister_memory(self._nixl_hndl) self._nixl_hndl = None + self._connection = None if self._data_ref is not None: # Release the reference to the data. @@ -891,6 +998,13 @@ def device(self) -> Device: """ return self._data_device + @property + def is_registered(self) -> bool: + """ + Gets whether the descriptor is registered with NIXL. + """ + return self._connection is not None and self._nixl_hndl is not None + @property def ptr(self) -> int: """ @@ -927,6 +1041,7 @@ def from_serialized( return serialized.to_descriptor() + @property def metadata(self) -> SerializedDescriptor: """ Serializes the descriptor into a `SerializedDescriptor` object. @@ -936,37 +1051,75 @@ def metadata(self) -> SerializedDescriptor: device=f"{self._data_device}", ptr=self._data_ptr, size=self._data_size, - ) + ) # type: ignore[operator] return self._serialized - def register_memory( + def deregister_with_connector(self, connection: Connection) -> None: + """ + Deregisters the memory of the descriptor with NIXL. + """ + if not isinstance(connection, Connection): + raise TypeError( + "Argument `connection` must be `dynamo.nixl_connect.Connection`." + ) + if connection != self._connection: + raise RuntimeError( + "Descriptor can only be deregistered from the connection it was registered with. " + f"Existing connection: {self._connection.name if self._connection is not None else None}, requested connection: {connection.name}." + ) + return + + if self._nixl_hndl is None: + logger.warning( + f"dynamo.nixl_connect.{self.__class__.__name__}: Request to deregister Descriptor {self.__repr__()} cannot be completed because the Descriptor is not registered." + ) + return + + connection._nixl.deregister_memory(self._nixl_hndl) + self._nixl_hndl = None + self._connection = None + logger.debug( + f"dynamo.nixl_connect.{self.__class__.__name__}: Deregistered {self.__repr__()} with NIXL." + ) + + def register_with_connector( self, - connector: Connector, + connection: Connection, ) -> None: """ Registers the memory of the descriptor with NIXL. """ - if not isinstance(connector, Connector): + if not isinstance(connection, Connection): raise TypeError( - "Argument `connector` must be `dynamo.nixl_connect.Connector`." + "Argument `connection` must be `dynamo.nixl_connect.Connection`." ) if self._data_ptr == 0: raise ValueError("Cannot register memory with a null pointer.") + if self._connection is not None: + if self._connection != connection: + raise RuntimeError( + "Descriptor cannot be registered with more than one connection. " + f"Existing connection: {self._connection.name}, new connection: {connection.name}." + ) + # Descriptor is already registered with this connection. + return - if not (self._nixl_hndl is None and self._connector is None): + # When the descriptor is already registered with NIXL, just return. + if self._nixl_hndl is not None: return # Register the memory with NIXL. - self._connector = connector + self._connection = connection + if isinstance(self._data_ref, torch.Tensor): - self._nixl_hndl = connector._nixl.register_memory(self._data_ref) + self._nixl_hndl = connection._nixl.register_memory(self._data_ref) else: mem_type = str(self._data_device.kind) reg_list = [ (self._data_ptr, self._data_size, self._data_device.id, mem_type) ] - self._nixl_hndl = connector._nixl.register_memory(reg_list, mem_type) + self._nixl_hndl = connection._nixl.register_memory(reg_list, mem_type) logger.debug( f"dynamo.nixl_connect.{self.__class__.__name__}: Registered {self.__repr__()} with NIXL." @@ -1173,7 +1326,7 @@ class PassiveOperation(AbstractOperation): def __init__( self, - connector: Connector, + connection: Connection, operation_kind: OperationKind, local_descriptors: Descriptor | list[Descriptor], ) -> None: @@ -1188,7 +1341,7 @@ def __init__( self._status = OperationStatus.UNINITIALIZED super().__init__( - connector, + connection, operation_kind, local_descriptors, None, @@ -1240,12 +1393,12 @@ def metadata(self, hex_encode: bool = False) -> RdmaMetadata: # When we've not yet cached the serialized request, we need to generate one before returning it. # Handle both cases: multiple and single descriptors. if isinstance(self._local_desc_list, list): - descriptors = [desc.metadata() for desc in self._local_desc_list] + descriptors = [desc.metadata for desc in self._local_desc_list] else: - descriptors = [self._local_desc_list.metadata()] + descriptors = [self._local_desc_list.metadata] - original_len = len(self._connector.metadata) - nixl_metadata = self._connector.metadata + original_len = len(self._connection.metadata) + nixl_metadata = self._connection.metadata nixl_metadata = zlib.compress(nixl_metadata, level=6) compressed_len = len(nixl_metadata) logger.debug( @@ -1283,7 +1436,7 @@ def status(self) -> OperationStatus: old_status = self._status # Query NIXL for any notifications. - notifications = self._connector._nixl.update_notifs() + notifications = self._connection._nixl.update_notifs() if isinstance(notifications, dict): remote_state = OperationStatus.IN_PROGRESS @@ -1309,7 +1462,7 @@ def status(self) -> OperationStatus: if remote_state == OperationStatus.COMPLETE: self._status = remote_state logger.debug( - f"dynamo.nixl_connect.{self.__class__.__name__}: {{ remote: '{self._connector.name}' status: '{old_status}' => '{self._status}' }}." + f"dynamo.nixl_connect.{self.__class__.__name__}: {{ remote: '{self._connection.name}' status: '{old_status}' => '{self._status}' }}." ) return self._status @@ -1330,7 +1483,7 @@ class ReadOperation(ActiveOperation): def __init__( self, - connector: Connector, + connection: Connection, remote_metadata: RdmaMetadata, local_descriptors: Descriptor | list[Descriptor], ) -> None: @@ -1341,16 +1494,16 @@ def __init__( Parameters ---------- - connector : Connector - Connector instance to use for the operation. + connection : Connection + Connection instance to use for the operation. remote_metadata : RdmaMetadata Serialized request from the remote worker. local_descriptors : Descriptor | list[Descriptor] Local descriptor(s) to to receive the data from the remote worker. """ - if not isinstance(connector, Connector): + if not isinstance(connection, Connection): raise TypeError( - "Argument `connector` must be `dynamo.nixl_connect.Connector`." + "Argument `connection` must be `dynamo.nixl_connect.Connection`." ) if not isinstance(remote_metadata, RdmaMetadata): raise TypeError( @@ -1359,7 +1512,7 @@ def __init__( if remote_metadata.operation_kind != OperationKind.READ.value: raise ValueError("Argument `remote_metadata` must be of kind `READ`.") - remote = Remote(connector, remote_metadata.nixl_metadata) + remote = Remote(connection, remote_metadata.nixl_metadata) remote_descriptors = remote_metadata.to_descriptors() if not ( @@ -1435,10 +1588,10 @@ class ReadableOperation(PassiveOperation): def __init__( self, - connector: Connector, + connection: Connection, local_descriptors: Descriptor | list[Descriptor], ) -> None: - super().__init__(connector, OperationKind.READ, local_descriptors) + super().__init__(connection, OperationKind.READ, local_descriptors) logger.debug( f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}" ) @@ -1510,17 +1663,19 @@ class Remote: def __init__( self, - connector: Connector, + connection: Connection, nixl_metadata: bytes | str, ) -> None: - if not isinstance(connector, Connector): - raise TypeError("Argument `local` must be `dynamo.nixl_connect.Connector`.") + if not isinstance(connection, Connection): + raise TypeError( + "Argument `connection` must be `dynamo.nixl_connect.Connection`." + ) if not (isinstance(nixl_metadata, bytes) or isinstance(nixl_metadata, str)): raise TypeError("Argument `nixl_metadata` must be `bytes` or `str`.") if len(nixl_metadata) == 0: raise ValueError("Argument `nixl_metadata` cannot be empty.") - self._connector = connector + self._connection = connection # When `nixl_metadata` is a string, it is assumed to have come from a remote worker # via a `RdmaMetadata` object and therefore can assumed be a b64-encoded, compressed @@ -1535,7 +1690,7 @@ def __init__( # Decompress the NIXL metadata. nixl_metadata = zlib.decompress(nixl_metadata) - self._name = connector._nixl.add_remote_agent(nixl_metadata) + self._name = connection._nixl.add_remote_agent(nixl_metadata) if isinstance(self._name, bytes): self._name = self._name.decode("utf-8") @@ -1559,7 +1714,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self._release() def __repr__(self) -> str: - return f"Remote(name={self._name}, connector={self._connector.name})" + return f"Remote(name={self._name}, connection={self._connection.name})" def __str__(self) -> str: return self._name @@ -1568,19 +1723,19 @@ def _release(self) -> None: """ Private method for releasing NIXL resources. Not intended for public use. """ - # We have to unregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and + # We have to deregister the remote agent from NIXL because we cannot know if the remote worker has updated its descriptors or not, and # NIXL will return an error if we attempt to register a remote agent with the same name but different descriptors (aka conn_info). - self._connector._nixl.remove_remote_agent(self._name) + self._connection._nixl.remove_remote_agent(self._name) logger.debug( - f'dynamo.nixl_connect.{self.__class__.__name__}: Unregistered NIXL remote {{ name: "{self._name}" }}.' + f'dynamo.nixl_connect.{self.__class__.__name__}: Deregistered NIXL remote {{ name: "{self._name}" }}.' ) @property - def connector(self) -> Connector: + def connection(self) -> Connection: """ - Gets the local connector associated with this remote worker. + Gets the local connection associated with this remote worker. """ - return self._connector + return self._connection @property def name(self) -> str: @@ -1647,7 +1802,7 @@ class WritableOperation(PassiveOperation): def __init__( self, - connector: Connector, + connection: Connection, local_descriptors: Descriptor | list[Descriptor], ) -> None: """ @@ -1656,18 +1811,18 @@ def __init__( Parameters ---------- - connector : Connector - Connector instance to use for the operation. + connection : Connection + Connection instance to use for the operation. local_descriptors : Descriptor | list[Descriptor] Descriptors to receive data from a remote worker. Raises TypeError - When `local` is not a `dynamo.nixl_connect.Connector`. + When `connection` is not a `dynamo.nixl_connect.Connection`. TypeError When `local_descriptors` is not a `dynamo.nixl_connect.Descriptor` or `list[dynamo.nixl_connect.Descriptor]`. """ - super().__init__(connector, OperationKind.WRITE, local_descriptors) + super().__init__(connection, OperationKind.WRITE, local_descriptors) logger.debug( f"dynamo.nixl_connect.{self.__class__.__name__}: Created {self.__repr__()}" ) @@ -1703,7 +1858,7 @@ class WriteOperation(ActiveOperation): def __init__( self, - connector: Connector, + connection: Connection, local_descriptors: Descriptor | list[Descriptor], remote_metadata: RdmaMetadata, ) -> None: @@ -1714,8 +1869,8 @@ def __init__( Parameters ---------- - connector : Connector - Connector instance to use for the operation. + connection : Connection + Connection instance to use for the operation. local_descriptors : Descriptor | list[Descriptor] Local descriptor(s) to send from, to the remote worker. remote_metadata : RdmaMetadata @@ -1733,9 +1888,9 @@ def __init__( TypeError When `local_descriptors` is not a `dynamo.nixl_connect.Descriptor` or `list[dynamo.nixl_connect.Descriptor]`. """ - if not isinstance(connector, Connector): + if not isinstance(connection, Connection): raise TypeError( - "Argument `connector` must be `dynamo.nixl_connect.Connector`." + "Argument `connection` must be `dynamo.nixl_connect.Connection`." ) if not isinstance(remote_metadata, RdmaMetadata): raise TypeError( @@ -1744,7 +1899,7 @@ def __init__( if remote_metadata.operation_kind != OperationKind.WRITE.value: raise ValueError("Argument `remote_metadata` must be of kind `WRITE`.") - remote = Remote(connector, remote_metadata.nixl_metadata) + remote = Remote(connection, remote_metadata.nixl_metadata) remote_descriptors = remote_metadata.to_descriptors() super().__init__( diff --git a/lib/bindings/python/src/dynamo/prometheus_names.py b/lib/bindings/python/src/dynamo/prometheus_names.py index 615edad127..c17985c33a 100644 --- a/lib/bindings/python/src/dynamo/prometheus_names.py +++ b/lib/bindings/python/src/dynamo/prometheus_names.py @@ -55,6 +55,8 @@ class frontend_service: INPUT_SEQUENCE_TOKENS = "input_sequence_tokens" # Output sequence length in tokens OUTPUT_SEQUENCE_TOKENS = "output_sequence_tokens" + # Number of cached tokens (prefix cache hits) per request + CACHED_TOKENS = "cached_tokens" # Total number of output tokens generated (counter that updates in real-time) OUTPUT_TOKENS_TOTAL = "output_tokens_total" # Time to first token in seconds @@ -93,6 +95,10 @@ class kvbm: ONBOARD_BLOCKS_D2D = "onboard_blocks_d2d" # The number of matched tokens MATCHED_TOKENS = "matched_tokens" + # Host cache hit rate (0.0-1.0) from the sliding window + HOST_CACHE_HIT_RATE = "host_cache_hit_rate" + # Disk cache hit rate (0.0-1.0) from the sliding window + DISK_CACHE_HIT_RATE = "disk_cache_hit_rate" class kvrouter: diff --git a/lib/bindings/python/tests/cancellation/test_cancellation.py b/lib/bindings/python/tests/cancellation/test_cancellation.py index 42d29d8930..1aff5e10ae 100644 --- a/lib/bindings/python/tests/cancellation/test_cancellation.py +++ b/lib/bindings/python/tests/cancellation/test_cancellation.py @@ -165,6 +165,7 @@ async def client(runtime, namespace): @pytest.mark.forked @pytest.mark.asyncio +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) async def test_client_context_cancel(temp_file_store, server, client): _, handler = server context = Context() @@ -198,6 +199,7 @@ async def test_client_context_cancel(temp_file_store, server, client): @pytest.mark.forked @pytest.mark.asyncio +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) async def test_client_loop_break(temp_file_store, server, client): _, handler = server stream = await client.generate("_generate_until_context_cancelled") @@ -230,6 +232,7 @@ async def test_client_loop_break(temp_file_store, server, client): @pytest.mark.forked @pytest.mark.asyncio +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) async def test_server_context_cancel(temp_file_store, server, client): _, handler = server stream = await client.generate("_generate_and_cancel_context") @@ -254,6 +257,7 @@ async def test_server_context_cancel(temp_file_store, server, client): @pytest.mark.forked @pytest.mark.asyncio +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) async def test_server_raise_cancelled(temp_file_store, server, client): _, handler = server stream = await client.generate("_generate_and_raise_cancelled") @@ -282,6 +286,7 @@ async def test_server_raise_cancelled(temp_file_store, server, client): @pytest.mark.forked @pytest.mark.asyncio +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) async def test_client_context_already_cancelled(temp_file_store, server, client): _, handler = server context = Context() @@ -304,6 +309,7 @@ async def test_client_context_already_cancelled(temp_file_store, server, client) @pytest.mark.forked @pytest.mark.asyncio +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) async def test_client_context_cancel_before_await_request( temp_file_store, server, client ): diff --git a/lib/bindings/python/tests/conftest.py b/lib/bindings/python/tests/conftest.py index 9d5f33a932..b234e405bb 100644 --- a/lib/bindings/python/tests/conftest.py +++ b/lib/bindings/python/tests/conftest.py @@ -402,8 +402,34 @@ def temp_file_store(): yield tmpdir +@pytest.fixture +def store_kv(request): + """ + KV store for runtime. Defaults to "file". + + To iterate over multiple stores in a test: + @pytest.mark.parametrize("store_kv", ["file", "etcd"], indirect=True) + async def test_example(runtime): + ... + """ + return getattr(request, "param", "file") + + +@pytest.fixture +def request_plane(request): + """ + Request plane for runtime. Defaults to "nats". + + To iterate over multiple transports in a test: + @pytest.mark.parametrize("request_plane", ["tcp", "nats"], indirect=True) + async def test_example(runtime): + ... + """ + return getattr(request, "param", "nats") + + @pytest.fixture(scope="function", autouse=False) -async def runtime(request): +async def runtime(request, store_kv, request_plane): """ Create a DistributedRuntime for testing. @@ -413,6 +439,14 @@ async def runtime(request): Without @pytest.mark.forked in isolated mode, you will get "Worker already initialized" errors when multiple tests try to create runtimes in the same process. + + The store_kv and request_plane can be customized by overriding their fixtures + or using @pytest.mark.parametrize with indirect=True: + + @pytest.mark.forked + @pytest.mark.parametrize("store_kv", ["etcd"], indirect=True) + async def test_with_etcd(runtime): + ... """ # Check if the test is marked with @pytest.mark.forked (only in isolated mode) if ENABLE_ISOLATED_ETCD_AND_NATS: @@ -435,6 +469,6 @@ async def test_my_test(runtime): ) loop = asyncio.get_running_loop() - runtime = DistributedRuntime(loop, "file", "nats") + runtime = DistributedRuntime(loop, store_kv, request_plane) yield runtime runtime.shutdown() diff --git a/lib/bindings/python/tests/test_kv_bindings.py b/lib/bindings/python/tests/test_kv_bindings.py index 1ff9245b9e..c3f24ff4ed 100644 --- a/lib/bindings/python/tests/test_kv_bindings.py +++ b/lib/bindings/python/tests/test_kv_bindings.py @@ -36,8 +36,8 @@ async def distributed_runtime(): runtime.shutdown() -@pytest.mark.asyncio -async def test_radix_tree_binding(distributed_runtime): +@pytest.mark.timeout(5) # Expected: ~1s, timeout set to 5x for safety +def test_radix_tree_binding(): """Test RadixTree binding directly with store event and find matches""" import json @@ -102,13 +102,12 @@ async def test_radix_tree_binding(distributed_runtime): ) -@pytest.mark.asyncio +@pytest.mark.timeout(5) # Expected: ~1s, timeout set to 5x for safety @pytest.mark.parametrize("num_threads", [2, 3, 5, 128]) @pytest.mark.parametrize("prepopulate_worker_ids", [True, False]) @pytest.mark.parametrize("expiration_duration_secs", [None]) @pytest.mark.parametrize("is_threaded", [True, False]) -async def test_radix_tree_thread_safety( - distributed_runtime, +def test_radix_tree_thread_safety( num_threads, prepopulate_worker_ids, expiration_duration_secs, @@ -205,6 +204,7 @@ def worker(worker_id, prepopulate_worker_ids: bool = False): @pytest.mark.asyncio +@pytest.mark.timeout(5) # Expected: ~1s, timeout set to 5x for safety async def test_event_handler(distributed_runtime): kv_block_size = 32 namespace = "kv_test" @@ -247,6 +247,7 @@ async def test_event_handler(distributed_runtime): @pytest.mark.asyncio +@pytest.mark.timeout(5) # Expected: ~1s, timeout set to 5x for safety async def test_approx_kv_indexer(distributed_runtime): """Test ApproxKvIndexer with TTL-based block tracking""" kv_block_size = 32 diff --git a/lib/bindings/python/tests/test_lora_utils.py b/lib/bindings/python/tests/test_lora_utils.py index fdce6ff4c5..b33b7a43bb 100644 --- a/lib/bindings/python/tests/test_lora_utils.py +++ b/lib/bindings/python/tests/test_lora_utils.py @@ -1,26 +1,32 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import pytest + from dynamo.llm import lora_name_to_id max_int32 = 0x7FFFFFFF class TestLoraNameToId: + @pytest.mark.timeout(5) def test_import_function(self): assert callable(lora_name_to_id) + @pytest.mark.timeout(5) def test_returns_positive_integer_for_different_names(self): for i in range(100): result = lora_name_to_id(f"test_lora_{i}") assert isinstance(result, int) assert 1 <= result <= max_int32 + @pytest.mark.timeout(5) def test_different_names_produce_different_ids(self): id1 = lora_name_to_id("lora_adapter_1") id2 = lora_name_to_id("lora_adapter_2") assert id1 != id2 + @pytest.mark.timeout(5) def test_consistency_across_multiple_calls(self): test_names = [f"lora_{i}" for i in range(100)] results_first = [lora_name_to_id(name) for name in test_names] diff --git a/lib/bindings/python/tests/test_tensor.py b/lib/bindings/python/tests/test_tensor.py index 30b1fde01c..e48de90f31 100644 --- a/lib/bindings/python/tests/test_tensor.py +++ b/lib/bindings/python/tests/test_tensor.py @@ -34,15 +34,12 @@ async def test_register(runtime: DistributedRuntime): assert model_config == runtime_config.get_tensor_model_config() - # [gluo FIXME] register_llm will attempt to load a LLM model, - # which is not well-defined for Tensor yet. Currently provide - # a valid model name to pass the registration. + # Use register_llm for tensor-based backends (skips HuggingFace downloads) await register_llm( ModelInput.Tensor, ModelType.TensorBased, endpoint, - "Qwen/Qwen3-0.6B", - "tensor", + "tensor", # model_path (used as display name for tensor-based models) runtime_config=runtime_config, ) diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 8bb5533e02..97da4b9c59 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -22,6 +22,7 @@ testing-cuda = ["dep:cudarc"] testing-nixl = ["dep:nixl-sys"] testing-etcd = [] block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec"] +block-manager-bench = ["block-manager", "testing-full", "dep:clap", "dep:indicatif"] cuda = ["dep:cudarc"] integration = ["dynamo-runtime/integration"] media-nixl = ["dep:nixl-sys", "dep:dynamo-memory"] @@ -105,6 +106,10 @@ nixl-sys = { version = "=0.7.1", optional = true } cudarc = { workspace = true, optional = true } nix = { version = "0.26", optional = true } +# block_manager_bench +clap = { version = "4.5.49", features = ["derive"], optional = true } +indicatif = { version = "0.18.0", optional = true } + # protocols unicode-segmentation = "1.12" @@ -188,3 +193,8 @@ mockito = "1.7.0" [build-dependencies] tonic-build = { version = "0.13.1" } + +[[bin]] +name = "bench_local_transfer_v2" +path = "bin/bench_local_transfer_v2.rs" +required-features = ["block-manager-bench"] diff --git a/lib/llm/bin/bench_local_transfer_v2.rs b/lib/llm/bin/bench_local_transfer_v2.rs new file mode 100644 index 0000000000..b82a5b89b8 --- /dev/null +++ b/lib/llm/bin/bench_local_transfer_v2.rs @@ -0,0 +1,196 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Result; +use clap::Parser; + +use core::time::Duration; +use indicatif::ProgressIterator; +use std::time::Instant; + +use dynamo_llm::block_manager::v2::physical::{ + layout::LayoutConfig, + transfer::{ + BounceBufferSpec, NixlAgent, PhysicalLayout, StorageKind, TransferOptions, + TransportManager, executor::execute_transfer, + }, +}; + +use std::sync::Arc; + +#[derive(Parser)] +struct Args { + /// Amount of layers + #[clap(long, default_value_t = 24)] + num_layers: usize, + + /// Inner dimension + #[clap(long, default_value_t = 4096)] + inner_dim: usize, + + /// Block size + #[clap(long, default_value_t = 32)] + block_size: usize, + + /// Amount of blocks per pool + #[clap(long, default_value_t = 16)] + num_blocks: usize, + + /// Amount of blocks per transferred batch + #[clap(long, default_value_t = 4)] + blocks_per_batch: usize, + + /// Amount of pinned bounce buffer blocks + #[clap(long, default_value_t = 2)] + num_bounce_blocks: usize, + + /// Amount of iterations + #[clap(long, default_value_t = 100)] + iterations: usize, +} + +struct DummyBounceBufferSpec { + pub layout: PhysicalLayout, + pub block_ids: Vec, +} + +impl BounceBufferSpec for DummyBounceBufferSpec { + fn layout(&self) -> &PhysicalLayout { + &self.layout + } + fn block_ids(&self) -> &[usize] { + &self.block_ids + } +} + +#[tokio::main] +pub async fn main() -> Result<()> { + let args = Args::parse(); + + // let manager = build_manager(&args).await?; + + benchmark(&args).await?; + + Ok(()) +} + +fn build_layout( + agent: NixlAgent, + config: LayoutConfig, + storage_kind: StorageKind, +) -> PhysicalLayout { + let builder = PhysicalLayout::builder(agent) + .with_config(config) + .fully_contiguous(); + + match storage_kind { + StorageKind::System => builder.allocate_system().build().unwrap(), + StorageKind::Pinned => builder.allocate_pinned(false).build().unwrap(), + StorageKind::Device(device_id) => builder.allocate_device(device_id).build().unwrap(), + StorageKind::Disk(_) => builder.allocate_disk(None).build().unwrap(), + } +} + +fn get_bandwidth_gbs(latencies: Vec, args: &Args) -> f64 { + let total_bytes = + args.num_layers * args.inner_dim * args.block_size * args.blocks_per_batch * 2; + let mean = latencies.iter().sum::() / latencies.len() as u32; + + total_bytes as f64 / mean.as_nanos() as f64 +} + +async fn benchmark(args: &Args) -> Result<()> { + let agent = NixlAgent::require_backends("test_agent", &["POSIX", "GDS_MT"])?; + let src_dst_config = LayoutConfig::builder() + .num_blocks(args.num_blocks) + .num_layers(args.num_layers) + .outer_dim(2) + .page_size(args.block_size) + .inner_dim(args.inner_dim) + .dtype_width_bytes(2) + .build()?; + + let disk_layout = build_layout(agent.clone(), src_dst_config.clone(), StorageKind::Disk(0)); + let device_layout = build_layout( + agent.clone(), + src_dst_config.clone(), + StorageKind::Device(0), + ); + + let bounce_config = LayoutConfig::builder() + .num_blocks(args.num_bounce_blocks) + .num_layers(args.num_layers) + .outer_dim(2) + .page_size(args.block_size) + .inner_dim(args.inner_dim) + .dtype_width_bytes(2) + .build()?; + + let bounce_layout = build_layout(agent.clone(), bounce_config.clone(), StorageKind::Pinned); + + let ctx = TransportManager::builder() + .worker_id(0) + .nixl_agent(agent) + .cuda_device_id(0) + .build()?; + + let bounce_buffer_spec: Arc = Arc::new(DummyBounceBufferSpec { + layout: bounce_layout, + block_ids: (0..args.num_bounce_blocks).collect(), + }); + + let options = TransferOptions::builder() + .bounce_buffer(bounce_buffer_spec) + .build()?; + + anyhow::ensure!( + args.blocks_per_batch <= args.num_blocks, + "blocks_per_batch must be less than or equal to num_blocks" + ); + let blocks = (0..args.blocks_per_batch).collect::>(); + + for (src, dst, name) in vec![ + (disk_layout.clone(), device_layout.clone(), "disk_to_device"), + (device_layout, disk_layout, "device_to_disk"), + ] { + println!("Starting {} benchmark...", name); + + let mut latencies = Vec::new(); + for _ in (0..args.iterations).progress() { + let options_clone = options.clone(); + let start = Instant::now(); + execute_transfer( + &src, + &dst, + blocks.as_slice(), + blocks.as_slice(), + options_clone, + ctx.context(), + )? + .await?; + let end = Instant::now(); + let duration = end.duration_since(start); + latencies.push(duration); + } + + println!( + "{} bandwidth: {:?} GB/s", + name, + get_bandwidth_gbs(latencies, args) + ); + } + + Ok(()) +} diff --git a/lib/llm/src/block_manager/block/transfer/kernels/vectorized_copy.fatbin b/lib/llm/src/block_manager/block/transfer/kernels/vectorized_copy.fatbin index 558ba11ff2..d1a3c05fb1 100644 Binary files a/lib/llm/src/block_manager/block/transfer/kernels/vectorized_copy.fatbin and b/lib/llm/src/block_manager/block/transfer/kernels/vectorized_copy.fatbin differ diff --git a/lib/llm/src/block_manager/v2/physical/transfer/executor/mod.rs b/lib/llm/src/block_manager/v2/physical/transfer/executor/mod.rs index a3eeb36379..896956f20b 100644 --- a/lib/llm/src/block_manager/v2/physical/transfer/executor/mod.rs +++ b/lib/llm/src/block_manager/v2/physical/transfer/executor/mod.rs @@ -17,6 +17,7 @@ use anyhow::Result; use std::ops::Range; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; +use tokio::sync::Mutex; // Re-export the NIXL transfer builder for public use pub use nixl::NixlTransferBuilder; @@ -181,6 +182,64 @@ struct TwoHopTransferParams<'a> { ctx: &'a TransferContext, } +#[allow(clippy::too_many_arguments)] +async fn handle_buffered_transfer( + src: &PhysicalLayout, + bounce_layout: &PhysicalLayout, + dst: &PhysicalLayout, + src_block_ids: &[usize], + bounce_block_ids: &[usize], + dst_block_ids: &[usize], + first_strategy: TransferStrategy, + second_strategy: TransferStrategy, + layer_range: &Option>, + ctx: &TransferContext, +) -> Result<()> { + let bounce_groups = + &bounce_block_ids[0..std::cmp::min(src_block_ids.len(), bounce_block_ids.len())]; + let (bounce_group_0, bounce_group_1) = bounce_groups.split_at(bounce_groups.len() / 2); + let bounce_group_0 = bounce_group_0.to_vec(); + let bounce_group_1 = bounce_group_1.to_vec(); + + let src_dst_iter = Arc::new(Mutex::new(src_block_ids.iter().zip(dst_block_ids.iter()))); + + let transfer_task = async move |bounce_group: &[usize]| -> Result<()> { + loop { + let (src_ids, dst_ids): (Vec, Vec); + { + let mut x = src_dst_iter.lock().await; + (src_ids, dst_ids) = x.by_ref().take(bounce_group.len()).unzip(); + if src_ids.is_empty() { + break; + } + } + + execute_two_hop_transfer_chunk( + src, + bounce_layout, + dst, + &src_ids, + &bounce_group[0..src_ids.len()], + &dst_ids, + first_strategy, + second_strategy, + layer_range, + ctx, + ) + .await?; + } + + Ok(()) + }; + + let transfer_0 = transfer_task(&bounce_group_0); + let transfer_1 = transfer_task(&bounce_group_1); + + futures::future::try_join(transfer_0, transfer_1).await?; + + Ok(()) +} + fn execute_two_hop_transfer(params: TwoHopTransferParams) -> Result { let TwoHopTransferParams { src, @@ -223,22 +282,26 @@ fn execute_two_hop_transfer(params: TwoHopTransferParams) -> Result Result, + client: Client, threshold: f64, ) -> KvWorkerMonitor { let mut monitors = self.worker_monitors.write(); diff --git a/lib/llm/src/discovery/watcher.rs b/lib/llm/src/discovery/watcher.rs index 690560c1d9..059cd8eb48 100644 --- a/lib/llm/src/discovery/watcher.rs +++ b/lib/llm/src/discovery/watcher.rs @@ -405,11 +405,8 @@ impl ModelWatcher { // Get or create the worker monitor for this model // This allows dynamic threshold updates via the ModelManager let worker_monitor = self.router_config.busy_threshold.map(|threshold| { - self.manager.get_or_create_worker_monitor( - card.name(), - Arc::new(client.clone()), - threshold, - ) + self.manager + .get_or_create_worker_monitor(card.name(), client.clone(), threshold) }); // Add chat engine only if the model supports chat diff --git a/lib/llm/src/discovery/worker_monitor.rs b/lib/llm/src/discovery/worker_monitor.rs index fda3985715..d6ef5a97d8 100644 --- a/lib/llm/src/discovery/worker_monitor.rs +++ b/lib/llm/src/discovery/worker_monitor.rs @@ -55,11 +55,11 @@ impl WorkerLoadState { /// Worker monitor for tracking KV cache usage and busy states. /// -/// All fields are `Arc`, so cloning shares state. This allows multiple pipelines +/// Cloning shares state via internal Arc-wrapped fields. This allows multiple pipelines /// (e.g., chat and completions) to share the same monitor instance. #[derive(Clone)] pub struct KvWorkerMonitor { - client: Arc, + client: Client, worker_load_states: Arc>>, /// Threshold stored as parts-per-10000 (e.g., 8500 = 0.85) busy_threshold: Arc, @@ -72,7 +72,7 @@ impl KvWorkerMonitor { /// /// The threshold (0.0-1.0) controls when workers are considered busy based on /// KV cache utilization. It can be dynamically updated via `set_threshold()`. - pub fn new(client: Arc, threshold: f64) -> Self { + pub fn new(client: Client, threshold: f64) -> Self { Self { client, worker_load_states: Arc::new(RwLock::new(HashMap::new())), diff --git a/lib/llm/src/entrypoint/input/common.rs b/lib/llm/src/entrypoint/input/common.rs index 2fef30b37a..beb3939927 100644 --- a/lib/llm/src/entrypoint/input/common.rs +++ b/lib/llm/src/entrypoint/input/common.rs @@ -271,13 +271,13 @@ where // Link with prefill chooser including backward edge for response flow let engine = frontend .link(preprocessor_op.forward_edge())? - .link(backend.forward_edge())? .link(migration.forward_edge())? + .link(backend.forward_edge())? .link(prefill_op.forward_edge())? .link(service_backend)? .link(prefill_op.backward_edge())? - .link(migration.backward_edge())? .link(backend.backward_edge())? + .link(migration.backward_edge())? .link(preprocessor_op.backward_edge())? .link(frontend)?; diff --git a/lib/llm/src/grpc/protos/kserve.proto b/lib/llm/src/grpc/protos/kserve.proto index b9efb9cefd..4d7fefd6a2 100644 --- a/lib/llm/src/grpc/protos/kserve.proto +++ b/lib/llm/src/grpc/protos/kserve.proto @@ -16,6 +16,27 @@ import "model_config.proto"; //@@ service GRPCInferenceService { + //@@ .. cpp:var:: rpc ServerLive(ServerLiveRequest) returns + //@@ (ServerLiveResponse) + //@@ + //@@ Check liveness of the inference server. + //@@ + rpc ServerLive(ServerLiveRequest) returns (ServerLiveResponse) {} + + //@@ .. cpp:var:: rpc ServerReady(ServerReadyRequest) returns + //@@ (ServerReadyResponse) + //@@ + //@@ Check readiness of the inference server. + //@@ + rpc ServerReady(ServerReadyRequest) returns (ServerReadyResponse) {} + + //@@ .. cpp:var:: rpc ModelReady(ModelReadyRequest) returns + //@@ (ModelReadyResponse) + //@@ + //@@ Check readiness of a model in the inference server. + //@@ + rpc ModelReady(ModelReadyRequest) returns (ModelReadyResponse) {} + //@@ .. cpp:var:: rpc ModelMetadata(ModelMetadataRequest) returns //@@ (ModelMetadataResponse) //@@ @@ -45,6 +66,89 @@ service GRPCInferenceService rpc ModelConfig(ModelConfigRequest) returns (ModelConfigResponse) {} } +//@@ +//@@.. cpp:var:: message ServerLiveRequest +//@@ +//@@ Request message for ServerLive. +//@@ +message ServerLiveRequest {} + +//@@ +//@@.. cpp:var:: message ServerLiveResponse +//@@ +//@@ Response message for ServerLive. +//@@ +message ServerLiveResponse +{ + //@@ + //@@ .. cpp:var:: bool live + //@@ + //@@ True if the inference server is live, false if not live. + //@@ + bool live = 1; +} + +//@@ +//@@.. cpp:var:: message ServerReadyRequest +//@@ +//@@ Request message for ServerReady. +//@@ +message ServerReadyRequest {} + +//@@ +//@@.. cpp:var:: message ServerReadyResponse +//@@ +//@@ Response message for ServerReady. +//@@ +message ServerReadyResponse +{ + //@@ + //@@ .. cpp:var:: bool ready + //@@ + //@@ True if the inference server is ready, false if not ready. The server + //@@ is considered ready if it has any registered models, since models + //@@ can freely be registered and unregistered at runtime. + //@@ + bool ready = 1; +} + +//@@ +//@@.. cpp:var:: message ModelReadyRequest +//@@ +//@@ Request message for ModelReady. +//@@ +message ModelReadyRequest +{ + //@@ + //@@ .. cpp:var:: string name + //@@ + //@@ The name of the model to check for readiness. + //@@ + string name = 1; + + //@@ .. cpp:var:: string version + //@@ + //@@ The version of the model to check for readiness. If not given the + //@@ server will choose a version based on the model and internal policy. + //@@ + string version = 2; +} + +//@@ +//@@.. cpp:var:: message ModelReadyResponse +//@@ +//@@ Response message for ModelReady. +//@@ +message ModelReadyResponse +{ + //@@ + //@@ .. cpp:var:: bool ready + //@@ + //@@ True if the model is ready, false if not ready. + //@@ + bool ready = 1; +} + //@@ //@@.. cpp:var:: message ModelMetadataRequest //@@ diff --git a/lib/llm/src/grpc/service/kserve.rs b/lib/llm/src/grpc/service/kserve.rs index 6fa2518942..1d950a1bbb 100644 --- a/lib/llm/src/grpc/service/kserve.rs +++ b/lib/llm/src/grpc/service/kserve.rs @@ -675,4 +675,38 @@ impl GrpcInferenceService for KserveService { request_model_name ))) } + + async fn server_live( + &self, + _request: Request, + ) -> Result, Status> { + // server is live if we can respond + Ok(Response::new(inference::ServerLiveResponse { live: true })) + } + + async fn server_ready( + &self, + _request: Request, + ) -> Result, Status> { + let has_models = !self.state.manager().get_model_cards().is_empty(); + Ok(Response::new(inference::ServerReadyResponse { + ready: has_models, + })) + } + + async fn model_ready( + &self, + request: Request, + ) -> Result, Status> { + let request_model_name = &request.into_inner().name; + let is_ready = self + .state + .manager() + .get_model_cards() + .into_iter() + .any(|card| request_model_name == &card.display_name); + Ok(Response::new(inference::ModelReadyResponse { + ready: is_ready, + })) + } } diff --git a/lib/llm/src/http/service/metrics.rs b/lib/llm/src/http/service/metrics.rs index 65f3867f39..58eeb20944 100644 --- a/lib/llm/src/http/service/metrics.rs +++ b/lib/llm/src/http/service/metrics.rs @@ -165,6 +165,7 @@ pub struct Metrics { request_duration: HistogramVec, input_sequence_length: HistogramVec, output_sequence_length: HistogramVec, + cached_tokens: HistogramVec, output_tokens_counter: IntCounterVec, time_to_first_token: HistogramVec, inter_token_latency: HistogramVec, @@ -252,6 +253,8 @@ pub struct ResponseMetricCollector { // be computed. last_response_time: Option, osl: usize, + // we track if cached_tokens has been observed to ensure we only increment once per request + cached_tokens_observed: bool, } impl Default for Metrics { @@ -378,7 +381,7 @@ impl Metrics { frontend_metric_name(frontend_service::INPUT_SEQUENCE_TOKENS), "Input sequence length in tokens", ) - .buckets(input_sequence_buckets), + .buckets(input_sequence_buckets.clone()), &["model"], ) .unwrap(); @@ -436,6 +439,16 @@ impl Metrics { ) .unwrap(); + let cached_tokens = HistogramVec::new( + HistogramOpts::new( + frontend_metric_name(frontend_service::CACHED_TOKENS), + "Number of cached tokens (prefix cache hits) per request", + ) + .buckets(input_sequence_buckets.clone()), + &["model"], + ) + .unwrap(); + // Runtime configuration metrics // Note: Some of these metrics represent counter-like values from source systems, // but are implemented as gauges because they are copied/synchronized from upstream @@ -502,6 +515,7 @@ impl Metrics { request_duration, input_sequence_length, output_sequence_length, + cached_tokens, output_tokens_counter, time_to_first_token, inter_token_latency, @@ -597,6 +611,7 @@ impl Metrics { registry.register(Box::new(self.request_duration.clone()))?; registry.register(Box::new(self.input_sequence_length.clone()))?; registry.register(Box::new(self.output_sequence_length.clone()))?; + registry.register(Box::new(self.cached_tokens.clone()))?; registry.register(Box::new(self.output_tokens_counter.clone()))?; registry.register(Box::new(self.time_to_first_token.clone()))?; registry.register(Box::new(self.inter_token_latency.clone()))?; @@ -830,6 +845,7 @@ impl ResponseMetricCollector { last_response_time: None, start_time: Instant::now(), osl: 0, + cached_tokens_observed: false, } } @@ -843,6 +859,19 @@ impl ResponseMetricCollector { self.is_first_token } + /// Observe cached tokens (prefix cache hits), observing only once per request when value is available + pub fn observe_cached_tokens(&mut self, cached_tokens: Option) { + if let Some(tokens) = cached_tokens + && !self.cached_tokens_observed + { + self.cached_tokens_observed = true; + self.metrics + .cached_tokens + .with_label_values(&[&self.model]) + .observe(tokens as f64); + } + } + /// Observe a response with input sequence length and number of new tokens pub fn observe_response(&mut self, isl: usize, num_tokens: usize) { if num_tokens == 0 { @@ -943,11 +972,13 @@ impl From> for EventConverter { /// /// This function handles metrics collection, http_queue_guard management, and converts /// annotated responses to SSE events for streaming responses. +/// +/// Returns None for metrics annotation events (events without SSE data payload). pub fn process_response_using_event_converter_and_observe_metrics( annotated: EventConverter, response_collector: &mut ResponseMetricCollector, http_queue_guard: &mut Option, -) -> Result { +) -> Result, axum::Error> { use crate::preprocessor::LLMMetricAnnotation; let mut annotated = annotated.0; @@ -955,6 +986,7 @@ pub fn process_response_using_event_converter_and_observe_metrics( // update metrics if let Ok(Some(metrics)) = LLMMetricAnnotation::from_annotation(&annotated) { response_collector.observe_current_osl(metrics.output_tokens); + response_collector.observe_cached_tokens(metrics.cached_tokens); // Drop http_queue_guard on first token for streaming if response_collector.is_first_token() @@ -976,11 +1008,11 @@ pub fn process_response_using_event_converter_and_observe_metrics( let mut event = Event::default(); - if let Some(data) = annotated.data { + if let Some(ref data) = annotated.data { event = event.json_data(data)?; } - if let Some(msg) = annotated.event { + if let Some(ref msg) = annotated.event { if msg == "error" { let msgs = annotated .comment @@ -996,7 +1028,12 @@ pub fn process_response_using_event_converter_and_observe_metrics( } } - Ok(event) + // Filter out metrics annotation events (events without SSE data payload) + if annotated.data.is_none() && annotated.event.is_none() { + Ok(None) + } else { + Ok(Some(event)) + } } /// Create a new router with optional custom backend metrics support @@ -1357,4 +1394,120 @@ mod tests { 20 ); } + + #[test] + fn test_cached_tokens_once_per_request() { + let metrics = Arc::new(Metrics::new()); + let registry = prometheus::Registry::new(); + metrics.register(®istry).unwrap(); + + let model = "test-model"; + let expected_metric_name = "dynamo_frontend_cached_tokens"; + let mut collector = metrics.clone().create_response_collector(model); + + // Create histogram handle first + let _histogram = metrics.cached_tokens.with_label_values(&[model]); + + // First call should observe and record 1 sample + collector.observe_cached_tokens(Some(100)); + let metric_families = registry.gather(); + let histogram_family = metric_families + .iter() + .find(|mf| mf.name() == expected_metric_name) + .expect("histogram should be registered"); + assert_eq!( + histogram_family.get_metric()[0] + .get_histogram() + .get_sample_count(), + 1 + ); + + // Second call with same collector should not observe again (idempotent) + collector.observe_cached_tokens(Some(50)); + let metric_families = registry.gather(); + let histogram_family = metric_families + .iter() + .find(|mf| mf.name() == expected_metric_name) + .expect("histogram should be registered"); + assert_eq!( + histogram_family.get_metric()[0] + .get_histogram() + .get_sample_count(), + 1 + ); + + // Third call with different value should still be idempotent + collector.observe_cached_tokens(Some(75)); + let metric_families = registry.gather(); + let histogram_family = metric_families + .iter() + .find(|mf| mf.name() == expected_metric_name) + .expect("histogram should be registered"); + assert_eq!( + histogram_family.get_metric()[0] + .get_histogram() + .get_sample_count(), + 1 + ); + } + + #[test] + fn test_metrics_annotation_event_handling() { + use crate::preprocessor::LLMMetricAnnotation; + use crate::types::Annotated; + + let metrics = Arc::new(Metrics::new()); + let registry = prometheus::Registry::new(); + metrics.register(®istry).unwrap(); + + let model = "test-model"; + let expected_metric_name = "dynamo_frontend_cached_tokens"; + let mut collector = metrics.clone().create_response_collector(model); + + // Create a metrics annotation event (event without SSE data payload) + let mut annotated = Annotated::< + crate::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + > { + id: None, + data: None, + event: Some(crate::preprocessor::ANNOTATION_LLM_METRICS.to_string()), + comment: None, + }; + + // Add metrics annotation with cached_tokens + let llm_metrics = LLMMetricAnnotation { + input_tokens: 10, + output_tokens: 20, + chunk_tokens: 5, + cached_tokens: Some(15), + }; + + let annotation = llm_metrics.to_annotation::<()>().unwrap(); + annotated.event = annotation.event; + annotated.comment = annotation.comment; + + // Process the event + let mut http_queue_guard = None; + let result = process_response_using_event_converter_and_observe_metrics( + EventConverter::from(annotated), + &mut collector, + &mut http_queue_guard, + ); + + // Should return Ok(None) for metrics annotation events + assert!(matches!(result, Ok(None))); + + // Should have observed the cached tokens from the metrics annotation event + let metric_families = registry.gather(); + let histogram_family = metric_families + .iter() + .find(|mf| mf.name() == expected_metric_name) + .expect("histogram should be registered"); + assert_eq!( + histogram_family.get_metric()[0] + .get_histogram() + .get_sample_count(), + 1 + ); + } } diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index 4f65f16c10..145c253d2f 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -41,7 +41,10 @@ use super::{ use crate::engines::ValidateRequest; use crate::protocols::openai::chat_completions::aggregator::ChatCompletionAggregator; use crate::protocols::openai::{ - chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionResponse}, + chat_completions::{ + NvCreateChatCompletionRequest, NvCreateChatCompletionResponse, + NvCreateChatCompletionStreamResponse, + }, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, responses::{NvCreateResponse, NvResponse}, @@ -408,14 +411,20 @@ async fn completions_single( if streaming { // For streaming, we'll drop the http_queue_guard on the first token let mut http_queue_guard = Some(http_queue_guard); - let stream = stream.map(move |response| { - // Calls observe_response() on each token - process_response_using_event_converter_and_observe_metrics( - EventConverter::from(response), - &mut response_collector, - &mut http_queue_guard, - ) - }); + let stream = stream + .map(move |response| { + // Calls observe_response() on each token + process_response_using_event_converter_and_observe_metrics( + EventConverter::from(response), + &mut response_collector, + &mut http_queue_guard, + ) + }) + .filter_map(|result| { + use futures::future; + // Transpose Result> -> Option> + future::ready(result.transpose()) + }); let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle); let mut sse_stream = Sse::new(stream); @@ -564,14 +573,20 @@ async fn completions_batch( if streaming { // For streaming, we'll drop the http_queue_guard on the first token let mut http_queue_guard = Some(http_queue_guard); - let stream = merged_stream.map(move |response| { - // Calls observe_response() on each token - process_response_using_event_converter_and_observe_metrics( - EventConverter::from(response), - &mut response_collector, - &mut http_queue_guard, - ) - }); + let stream = merged_stream + .map(move |response| { + // Calls observe_response() on each token + process_response_using_event_converter_and_observe_metrics( + EventConverter::from(response), + &mut response_collector, + &mut http_queue_guard, + ) + }) + .filter_map(|result| { + use futures::future; + // Transpose Result> -> Option> + future::ready(result.transpose()) + }); let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle); let mut sse_stream = Sse::new(stream); @@ -717,6 +732,116 @@ async fn handler_chat_completions( response } +/// Checks if an Annotated event represents a backend error and extracts error information. +/// Returns Some((message, status_code)) if it's an error, None otherwise. +fn extract_backend_error_if_present( + event: &Annotated, +) -> Option<(String, StatusCode)> { + #[derive(serde::Deserialize)] + struct ErrorPayload { + message: Option, + code: Option, + } + + // Check if event type is "error" (from postprocessor when FinishReason::Error is encountered) + if let Some(event_type) = &event.event + && event_type == "error" + { + let comment_str = event + .comment + .as_ref() + .map(|c| c.join(", ")) + .unwrap_or_else(|| "Unknown error".to_string()); + + // Try to parse comment as error JSON to extract status code + if let Ok(error_payload) = serde_json::from_str::(&comment_str) { + let code = error_payload + .code + .and_then(|c| StatusCode::from_u16(c).ok()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let message = error_payload.message.unwrap_or(comment_str); + return Some((message, code)); + } + + return Some((comment_str, StatusCode::INTERNAL_SERVER_ERROR)); + } + + // Check if the data payload itself contains an error structure with code >= 400 + if let Some(data) = &event.data + && let Ok(json_value) = serde_json::to_value(data) + && let Ok(error_payload) = serde_json::from_value::(json_value.clone()) + && let Some(code_num) = error_payload.code + && code_num >= 400 + { + let code = StatusCode::from_u16(code_num).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let message = error_payload + .message + .unwrap_or_else(|| json_value.to_string()); + return Some((message, code)); + } + + // Check if comment contains error information (without event: error) + if let Some(comments) = &event.comment + && !comments.is_empty() + { + let comment_str = comments.join(", "); + + // Try to parse comment as error JSON with code >= 400 + if let Ok(error_payload) = serde_json::from_str::(&comment_str) + && let Some(code_num) = error_payload.code + && code_num >= 400 + { + let code = StatusCode::from_u16(code_num).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let message = error_payload.message.unwrap_or(comment_str); + return Some((message, code)); + } + + // Comments present with no data AND no event type indicates error + // (events with event types like "request_id" or "event.dynamo.test.sentinel" are annotations) + if event.data.is_none() && event.event.is_none() { + return Some((comment_str, StatusCode::INTERNAL_SERVER_ERROR)); + } + } + + None +} + +/// Checks if the first event in the stream is a backend error. +/// Returns Err(ErrorResponse) if error detected, Ok(stream) otherwise. +async fn check_for_backend_error( + mut stream: impl futures::Stream> + + Send + + Unpin + + 'static, +) -> Result< + impl futures::Stream> + Send, + ErrorResponse, +> { + use futures::stream::StreamExt; + + // Peek at the first event + if let Some(first_event) = stream.next().await { + // Check if it's an error event + if let Some((error_msg, status_code)) = extract_backend_error_if_present(&first_event) { + return Err(( + status_code, + Json(ErrorMessage { + message: error_msg, + error_type: map_error_code_to_error_type(status_code), + code: status_code.as_u16(), + }), + )); + } + + // Not an error - reconstruct stream with first event + let reconstructed_stream = futures::stream::iter(vec![first_event]).chain(stream); + Ok(reconstructed_stream) + } else { + // Empty stream - this shouldn't happen but handle gracefully + Ok(futures::stream::iter(vec![]).chain(stream)) + } +} + /// OpenAI Chat Completions Request Handler /// /// This method will handle the incoming request for the /v1/chat/completions endpoint. The endpoint is a "source" @@ -822,17 +947,28 @@ async fn chat_completions( // note - we might do this as part of the post processing set to make it more generic if streaming { + // For streaming responses, we return HTTP 200 immediately without checking for errors. + // Once HTTP 200 OK is sent, we cannot change the status code, so any backend errors + // must be delivered as SSE events with `event: error` in the stream (handled by + // EventConverter and monitor_for_disconnects). This is standard SSE behavior. stream_handle.arm(); // allows the system to detect client disconnects and cancel the LLM generation let mut http_queue_guard = Some(http_queue_guard); - let stream = stream.map(move |response| { - // Calls observe_response() on each token - process_response_using_event_converter_and_observe_metrics( - EventConverter::from(response), - &mut response_collector, - &mut http_queue_guard, - ) - }); + let stream = stream + .map(move |response| { + // Calls observe_response() on each token + // EventConverter will detect `event: "error"` and convert to SSE error events + process_response_using_event_converter_and_observe_metrics( + EventConverter::from(response), + &mut response_collector, + &mut http_queue_guard, + ) + }) + .filter_map(|result| { + use futures::future; + // Transpose Result> -> Option> + future::ready(result.transpose()) + }); let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle); let mut sse_stream = Sse::new(stream); @@ -843,8 +979,17 @@ async fn chat_completions( Ok(sse_stream.into_response()) } else { + // Check first event for backend errors before aggregating (non-streaming only) + let stream_with_check = + check_for_backend_error(stream) + .await + .map_err(|error_response| { + tracing::error!(request_id, "Backend error detected: {:?}", error_response); + error_response + })?; + let mut http_queue_guard = Some(http_queue_guard); - let stream = stream.inspect(move |response| { + let stream = stream_with_check.inspect(move |response| { // Calls observe_response() on each token - drops http_queue_guard on first token process_response_and_observe_metrics( response, @@ -859,11 +1004,11 @@ async fn chat_completions( .map_err(|e| { tracing::error!( request_id, - "Failed to fold chat completions stream for: {:?}", + "Failed to parse chat completion response: {:?}", e ); ErrorMessage::internal_server_error(&format!( - "Failed to fold chat completions stream: {}", + "Failed to parse chat completion response: {}", e )) })?; @@ -2055,4 +2200,136 @@ mod tests { assert!(msg.contains("response_format")); } } + + #[tokio::test] + async fn test_check_for_backend_error_with_error_event() { + use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse; + use futures::stream; + + // Create an error event + let error_event = Annotated:: { + data: None, + id: None, + event: Some("error".to_string()), + comment: Some(vec!["Backend service unavailable".to_string()]), + }; + + let test_stream = stream::iter(vec![error_event]); + let result = check_for_backend_error(test_stream).await; + + // Should return an error + assert!(result.is_err()); + if let Err(error_response) = result { + assert_eq!(error_response.0, StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(error_response.1.message, "Backend service unavailable"); + } + } + + #[tokio::test] + async fn test_check_for_backend_error_with_json_error_and_code() { + use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse; + use futures::stream; + + // Create an error event with JSON payload containing error code in comment + let error_json = + r#"{"message":"prompt > max_seq_len","type":"Internal Server Error","code":500}"#; + let error_event = Annotated:: { + data: None, + id: None, + event: Some("error".to_string()), + comment: Some(vec![error_json.to_string()]), + }; + + let test_stream = stream::iter(vec![error_event]); + let result = check_for_backend_error(test_stream).await; + + // Should return an error with correct status code extracted from JSON + assert!(result.is_err()); + if let Err(error_response) = result { + assert_eq!(error_response.0, StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(error_response.1.message, "prompt > max_seq_len"); + assert_eq!(error_response.1.code, 500); + } + } + + #[tokio::test] + async fn test_check_for_backend_error_with_normal_event() { + use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse; + use dynamo_async_openai::types::CreateChatCompletionStreamResponse; + use futures::stream::{self, StreamExt}; + + // Create a normal data event + let normal_event = Annotated:: { + data: Some(CreateChatCompletionStreamResponse { + id: "test-id".to_string(), + choices: vec![], + created: 0, + model: "test-model".to_string(), + system_fingerprint: None, + object: "chat.completion.chunk".to_string(), + service_tier: None, + usage: None, + nvext: None, + }), + id: Some("msg-1".to_string()), + event: None, + comment: None, + }; + + let test_stream = stream::iter(vec![normal_event.clone()]); + let result = check_for_backend_error(test_stream).await; + + // Should return Ok with the stream + assert!(result.is_ok()); + let mut returned_stream = result.unwrap(); + + // Verify we can read the event back from the stream + let first = returned_stream.next().await; + assert!(first.is_some()); + let first_event = first.unwrap(); + assert_eq!(first_event.id, Some("msg-1".to_string())); + } + + #[tokio::test] + async fn test_check_for_backend_error_with_empty_stream() { + use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse; + use futures::stream::{self, StreamExt}; + + // Create an empty stream + let test_stream = + stream::iter::>>(vec![]); + let result = check_for_backend_error(test_stream).await; + + // Should return Ok with an empty stream + assert!(result.is_ok()); + let mut returned_stream = result.unwrap(); + + // Verify stream is empty + let first = returned_stream.next().await; + assert!(first.is_none()); + } + + #[tokio::test] + async fn test_check_for_backend_error_with_comment_but_no_event_type() { + use crate::types::openai::chat_completions::NvCreateChatCompletionStreamResponse; + use futures::stream; + + // Create an event with comment but no event type and no data (error indicator) + let error_event = Annotated:: { + data: None, + id: None, + event: None, + comment: Some(vec!["Connection timeout".to_string()]), + }; + + let test_stream = stream::iter(vec![error_event]); + let result = check_for_backend_error(test_stream).await; + + // Should return an error based on is_backend_error_event logic + assert!(result.is_err()); + if let Err(error_response) = result { + assert_eq!(error_response.0, StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(error_response.1.message, "Connection timeout"); + } + } } diff --git a/lib/llm/src/kv_router.rs b/lib/llm/src/kv_router.rs index 7b90429ce3..0cf3cfa099 100644 --- a/lib/llm/src/kv_router.rs +++ b/lib/llm/src/kv_router.rs @@ -22,6 +22,8 @@ use futures::stream::{self, StreamExt}; use serde::{Deserialize, Serialize}; use serde_json::json; +use crate::protocols::openai::nvext::WorkerIdInfo; + pub mod approx; pub mod indexer; pub mod prefill_router; @@ -51,7 +53,7 @@ use crate::{ }, scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest}, sequence::SequenceError, - subscriber::start_kv_router_background, + subscriber::{recover_from_all_workers, start_kv_router_background}, }, local_model::runtime_config::ModelRuntimeConfig, model_card::ModelDeploymentCard, @@ -81,6 +83,7 @@ pub const RADIX_STATE_FILE: &str = "radix-state"; // for worker-local kvindexer query pub const WORKER_KV_INDEXER_QUERY_SUBJECT: &str = "worker_kv_indexer_query"; +pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer // for router discovery registration pub const KV_ROUTER_COMPONENT: &str = "kv-router"; @@ -303,13 +306,23 @@ impl KvRouter { endpoint: endpoint_id.name.clone(), }; let discovery_stream = discovery - .list_and_watch(discovery_key, Some(cancellation_token.clone())) + .list_and_watch(discovery_key.clone(), Some(cancellation_token.clone())) .await?; let runtime_configs_rx = watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| { card.runtime_config }); + // Watch for local indexer states via discovery interface (separate stream needed + // because streams are consumed by watch_and_extract_field) + let discovery_stream_local_indexer = discovery + .list_and_watch(discovery_key, Some(cancellation_token.clone())) + .await?; + let local_indexer_rx = watch_and_extract_field( + discovery_stream_local_indexer, + |card: ModelDeploymentCard| card.runtime_config.enable_local_indexer, + ); + let indexer = if kv_router_config.overlap_score_weight == 0.0 { // When overlap_score_weight is zero, we don't need to track prefixes Indexer::None @@ -347,6 +360,12 @@ impl KvRouter { ) .await?; + // Initialize worker query client using namespace abstraction + // (created before background task so we can use it for startup recovery) + let worker_query_client = + worker_query::WorkerQueryClient::new(component.clone(), local_indexer_rx); + tracing::info!("Worker query client initialized"); + // Start KV event subscriber background process (only when use_kv_events is enabled) if kv_router_config.use_kv_events && let Indexer::KvIndexer(ref kv_indexer) = indexer @@ -367,11 +386,48 @@ impl KvRouter { kv_router_config.router_reset_states, ) .await?; - } - // Initialize worker query client using the namespace abstraction - // NATS client is managed by DRT and accessed through namespace.drt() - let worker_query_client = Some(WorkerQueryClient::new(component.namespace().clone())); + // Perform startup recovery from workers with local indexers + // This catches up on any events missed while the router was offline + let last_event_ids = kv_indexer + .get_last_received_event_ids() + .await + .unwrap_or_default(); + let instances = client.instance_source.as_ref().borrow().clone(); + let worker_ids: Vec = instances.iter().map(|i| i.instance_id).collect(); + + if !worker_ids.is_empty() { + tracing::info!( + worker_count = worker_ids.len(), + "Starting recovery from workers with local indexers" + ); + + // NOTE: recover_from_all_workers() is a no-op if + // Worker with worker_id is not associated with a + // local indexer instance. + let recovered = recover_from_all_workers( + &worker_query_client, + &last_event_ids, + &worker_ids, + &kv_indexer.event_sender(), + ) + .await; + + if recovered > 0 { + tracing::info!( + recovered_events = recovered, + "KV Router startup: Recovered {} KV events from workers {:?}", + recovered, + worker_ids + ); + } else { + tracing::info!( + "KV Router startup: No KV events recovered from workers {:?}", + worker_ids + ); + } + } + } tracing::info!("KV Routing initialized"); Ok(Self { @@ -381,7 +437,7 @@ impl KvRouter { kv_router_config, cancellation_token, client, - worker_query_client, + worker_query_client: Some(worker_query_client), }) } @@ -515,17 +571,60 @@ impl KvRouter { self.indexer.dump_events().await } - /// Query a specific worker's local KV indexer for its buffered events + /// Query a specific worker's local KV indexer for its events + /// (See docstring for `WorkerQueryClient.query_worker()`) pub async fn query_worker_local_kv( &self, worker_id: WorkerId, + start_event_id: Option, + end_event_id: Option, ) -> Result { let query_client = self .worker_query_client .as_ref() .ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?; - query_client.query_worker(worker_id).await + query_client + .query_worker(worker_id, start_event_id, end_event_id) + .await + } + + /// Recover missed KV events from a specific worker. + /// + /// Queries the worker's local KV indexer for events starting from + /// `start_event_id` and applies them to the router's indexer. + /// + /// # Arguments + /// + /// * `worker_id` - The worker to recover from + /// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning + /// * `end_event_id` - Last event ID to fetch (inclusive), or None for all + pub async fn recover_from_worker( + &self, + worker_id: WorkerId, + start_event_id: Option, + end_event_id: Option, + ) -> Result { + let query_client = self + .worker_query_client + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?; + + let event_tx = match &self.indexer { + Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(), + Indexer::None => { + anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)") + } + }; + + subscriber::recover_from_worker( + query_client, + worker_id, + start_event_id, + end_event_id, + &event_tx, + ) + .await } } @@ -673,13 +772,19 @@ impl AsyncEngine, ManyOut("prefill_worker_id") - .ok() - .map(|arc| *arc) + let prefill_worker_id = backend_input + .prefill_result + .as_ref() + .and_then(|prefill_result| { + prefill_result + .disaggregated_params + .get("worker_id") + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .and_then(|info| info.prefill_worker_id) + }) .or(Some(decode_worker_id)); // Use decode_worker_id if no separate prefill worker let updated_request = context.map(|_| backend_input); @@ -726,12 +831,14 @@ impl AsyncEngine, ManyOut, + pub end_event_id: Option, } /// Response from a worker's local KV indexer. @@ -803,6 +806,13 @@ pub struct GetWorkersRequest { pub resp: oneshot::Sender>, } +/// A request to get the last received event ID per worker. +/// Used for fault tolerance recovery to determine which events to request from workers. +pub struct GetLastReceivedEventIdsRequest { + /// Channel to send the last received event IDs per worker + pub resp: oneshot::Sender>, +} + #[async_trait] pub trait KvIndexerInterface { /// Find matches for a given sequence of `LocalBlockHash`es. @@ -907,6 +917,8 @@ pub struct KvIndexer { dump_tx: mpsc::Sender, /// A sender for routing decision requests. routing_tx: mpsc::Sender, + /// A sender for getting last received event IDs (for fault tolerance recovery). + last_event_ids_tx: mpsc::Sender, /// A handle to the background task managing the KV store. task: OnceLock>, /// The size of the KV block this indexer can handle. @@ -940,6 +952,9 @@ impl KvIndexer { let (dump_tx, dump_rx) = mpsc::channel::(16); let (routing_tx, mut routing_rx) = mpsc::channel::(2048); let (prune_tx, mut prune_rx) = mpsc::channel::<()>(1); + let (last_event_ids_tx, mut last_event_ids_rx) = + mpsc::channel::(16); + let cancel_clone = token.clone(); let task = std::thread::spawn(move || { @@ -964,6 +979,10 @@ impl KvIndexer { }); let mut event_id_counter = 0u64; + // Track last received event ID per worker (for fault tolerance recovery) + // Only used when enable_event_tracking is true + let mut last_received_event_id: HashMap = HashMap::new(); + loop { // Create a future that sleeps until the next expiration time let expiry_fut = if let Some(ref pm) = prune_manager @@ -990,6 +1009,10 @@ impl KvIndexer { let _ = get_workers_req.resp.send(workers); } + Some(req) = last_event_ids_rx.recv() => { + let _ = req.resp.send(last_received_event_id.clone()); + } + Some(_) = prune_rx.recv() => { // Tree size-based pruning triggered let Some(ref mut pm) = prune_manager else { continue }; @@ -1012,6 +1035,33 @@ impl KvIndexer { } Some(event) = event_rx.recv() => { + // Track last received event ID per worker + // Check for gaps before updating the last received ID + // TODO should this trigger a recovery event? + let last_id = *last_received_event_id.get(&event.worker_id).unwrap_or(&0); + let incoming_id = event.event.event_id; + + // Detect gap: if incoming ID is more than 1 greater than last received + if incoming_id > last_id + 1 && last_id > 0 { + let gap_start = last_id + 1; + let gap_end = incoming_id - 1; + tracing::warn!( + worker_id = event.worker_id, + gap_start, + gap_end, + gap_size = gap_end - gap_start + 1, + "Event ID gap detected! Missed events [{}, {}]. \ + If this is a global KvIndexer, within a KvRouter context, + consider calling KvRouter::query_worker_local_kv() to potentially recover worker-stored events.", + gap_start, + gap_end, + ); + } + + // Update last received event ID (use max to handle out-of-order events) + let entry = last_received_event_id.entry(event.worker_id).or_insert(0); + *entry = (*entry).max(event.event.event_id); + let event_type = KvIndexerMetrics::get_event_type(&event.event.data); let result = trie.apply_event(event.clone()); let result_is_ok = result.is_ok(); @@ -1143,6 +1193,7 @@ impl KvIndexer { get_workers_tx, dump_tx, routing_tx, + last_event_ids_tx, task: once, kv_block_size, } @@ -1195,6 +1246,48 @@ impl KvIndexer { pub fn get_workers_sender(&self) -> mpsc::Sender { self.get_workers_tx.clone() } + + /// Get a sender for last received event IDs requests. + /// + /// ### Returns + /// + /// A `mpsc::Sender` for `GetLastReceivedEventIdsRequest`s. + pub fn last_event_ids_sender(&self) -> mpsc::Sender { + self.last_event_ids_tx.clone() + } + + /// Get the last received event ID for each worker. + /// + /// This method is used for **fault tolerance recovery** when the router needs to + /// catch up on missed events after a disconnect. By tracking the last event ID + /// received from each worker, the router can query workers for events starting + /// from `last_id + 1` to recover missed state. + /// + /// **Note**: This method is intdned for the global `KvIndexer` used by routers, + /// not on `LocalKvIndexer` (worker-side) or `KvIndexerSharded`. + /// + /// ### Returns + /// + /// A `HashMap` mapping worker IDs to their last received event ID. + /// + pub async fn get_last_received_event_ids( + &self, + ) -> Result, KvRouterError> { + let (resp_tx, resp_rx) = oneshot::channel(); + let req = GetLastReceivedEventIdsRequest { resp: resp_tx }; + + if let Err(e) = self.last_event_ids_tx.send(req).await { + tracing::error!( + "Failed to send last event IDs request: {:?}; the indexer maybe offline", + e + ); + return Err(KvRouterError::IndexerOffline); + } + + resp_rx + .await + .map_err(|_| KvRouterError::IndexerDroppedRequest) + } } #[async_trait] @@ -1320,7 +1413,7 @@ pub struct LocalKvIndexer { /// Circular buffer of recent events event_buffer: Mutex>, /// Maximum number of events to keep in buffer - max_buffer_size: usize, + max_buffer_size: usize, // Router sets this to WORKER_KV_INDEXER_BUFFER_SIZE } impl LocalKvIndexer { @@ -1338,45 +1431,138 @@ impl LocalKvIndexer { } } - /// get the N most recent events (returned in oldest->newest order) - pub fn get_recent_events(&self, n: usize) -> Vec { - // TODO what if n > buffer size + /// Get all buffered events (oldest first). + pub fn get_all_events_in_buffer(&self) -> Vec { let buffer = self.event_buffer.lock().unwrap(); - buffer.iter().rev().take(n).cloned().rev().collect() + buffer.iter().cloned().collect() } - /// get all buffered events (oldest first) - pub fn get_all_buffered_events(&self) -> Vec { - let buffer = self.event_buffer.lock().unwrap(); - buffer.iter().cloned().collect() + /// Query events by ID range, returning events in `[start_id, end_id]` (both inclusive). + /// + /// This method attempts to serve the request from the in-memory event buffer when possible. + /// If the requested range extends beyond what's available in the buffer, a full tree dump + /// is performed instead. + /// + /// ### Arguments + /// + /// * `start_id` - Starting event ID (inclusive). If `None`, returns from oldest available. + /// * `end_id` - Ending event ID (inclusive). If `None`, returns up to newest available. + /// + /// ### Behavior + /// + /// - **Buffer path**: If `start_id >= first_buffered_id`, events are retrieved directly + /// from the buffer with their original event IDs. + /// + /// - **Tree dump path**: If the range extends before the buffer or no range is specified, + /// a full tree dump is performed. **Note**: Tree dumps generate synthetic 0-indexed + /// event IDs that do NOT correspond to the original event IDs. The entire tree state + /// is returned regardless of the requested range. + /// + /// ### Returns + /// + /// A vector of `RouterEvent`s. When served from buffer, events have their original IDs. + /// When served from tree dump, events have synthetic sequential IDs starting from 0. + pub async fn get_events_in_id_range( + &self, + start_id: Option, + end_id: Option, + ) -> Vec { + // Validate range if both specified + if let (Some(s), Some(e)) = (start_id, end_id) + && s > e + { + tracing::warn!( + start_id = s, + end_id = e, + "Requested start_id > end_id; returning empty result." + ); + return Vec::new(); + } + + // Check if we can serve from buffer + let buffer_range = { + let buffer = self.event_buffer.lock().unwrap(); + if buffer.is_empty() { + None + } else { + Some(( + buffer.front().unwrap().event.event_id, + buffer.back().unwrap().event.event_id, + )) + } + }; + + // Determine if request can be served from buffer + let can_use_buffer = match (start_id, buffer_range) { + // No start specified means we need everything from the beginning -> tree dump + (None, _) => false, + // Buffer is empty -> tree dump + (_, None) => false, + // start_id is within or after buffer range -> can use buffer + (Some(s), Some((first_buffered, _))) => s >= first_buffered, + }; + + if can_use_buffer { + // Serve from buffer - these have real event IDs + self.get_buffer_events_in_id_range(start_id, end_id) + } else { + // Must dump entire tree + if let (Some(s), Some(e)) = (start_id, end_id) { + tracing::warn!( + requested_start_id = s, + requested_end_id = e, + buffer_range = ?buffer_range, + "Requested event ID range extends before buffer; dumping entire tree. \ + Note: Tree dump returns synthetic 0-indexed event IDs, not original IDs." + ); + } else if start_id.is_some() || end_id.is_some() { + tracing::warn!( + requested_start_id = ?start_id, + requested_end_id = ?end_id, + buffer_range = ?buffer_range, + "Partial range specified but cannot serve from buffer; dumping entire tree. \ + Note: Tree dump returns synthetic 0-indexed event IDs, not original IDs." + ); + } + // Return full tree dump - no filtering since IDs are synthetic + self.dump_events().await.unwrap_or_default() + } } - /// Returns events in [start_id, end_id) - pub fn get_events_in_id_range(&self, start_id: u64, end_id: u64) -> Vec { + /// Get events from the buffer in the range `[start_id, end_id]` (both inclusive). + pub fn get_buffer_events_in_id_range( + &self, + start_id: Option, + end_id: Option, + ) -> Vec { let buffer = self.event_buffer.lock().unwrap(); if buffer.is_empty() { tracing::warn!("No events in buffer yet; returning empty result."); return Vec::new(); } - if start_id >= end_id { + + let first_id = buffer.front().map(|e| e.event.event_id).unwrap(); + let last_id = buffer.back().map(|e| e.event.event_id).unwrap(); + + let start_id = start_id.unwrap_or(first_id); + let end_id = end_id.unwrap_or(last_id); + + if start_id > end_id { tracing::warn!( start_id, end_id, - "Requested start id is greater than or equal to end id; returning empty result." + "Requested start_id > end_id; returning empty result." ); return Vec::new(); } - let first_id = buffer.front().map(|e| e.event.event_id).unwrap(); - let last_id = buffer.back().map(|e| e.event.event_id).unwrap(); - let start_idx = match buffer.binary_search_by_key(&start_id, |e| e.event.event_id) { Ok(idx) => idx, Err(_) if start_id < first_id => { tracing::warn!( start_id, first_id, - "Requested start id precedes buffered range; clamping to oldest. TODO: implement logic to pull older events into buffer." + "Requested start_id precedes buffer; clamping to oldest." ); 0 } @@ -1384,15 +1570,16 @@ impl LocalKvIndexer { tracing::error!( start_id, last_id, - "Requested start id is newer than any buffered event; returning empty result." + "Requested start_id is newer than buffer; returning empty." ); return Vec::new(); } Err(insertion_point) => insertion_point, }; + // For inclusive end, we need idx + 1 when we find an exact match let end_idx = match buffer.binary_search_by_key(&end_id, |e| e.event.event_id) { - Ok(idx) => idx, + Ok(idx) => idx + 1, // Include the matched element Err(_) if end_id < first_id => { return Vec::new(); } @@ -1400,7 +1587,7 @@ impl LocalKvIndexer { tracing::warn!( end_id, last_id, - "Requested end id exceeds buffered range; clamping to newest. TODO: maybe just error if requesting events which do not exist yet." + "Requested end_id exceeds buffer; clamping to newest." ); buffer.len() } @@ -1431,6 +1618,11 @@ impl LocalKvIndexer { "Non-consecutive KV event id; buffer may have gaps" ); } + tracing::info!( + "Recorded event {:?} in buffer, now size is {}", + event, + buffer.len() + ); // Add to back buffer.push_back(event); @@ -1525,33 +1717,192 @@ mod local_kv_indexer_tests { #[test] fn returns_slice_within_range() { let indexer = make_indexer_with_events(&[1, 2, 3, 4, 5]); - let mut result = indexer.get_events_in_id_range(2, 4); + + // Test get_buffer_events_in_id_range (buffer-only queries) + // Range is [start, end] inclusive + let mut result = indexer.get_buffer_events_in_id_range(Some(2), Some(4)); let mut ids: Vec = result .iter() .map(|router_event| router_event.event.event_id) .collect(); - assert_eq!(ids, vec![2, 3]); // return slice within range + assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4] + + result = indexer.get_buffer_events_in_id_range(Some(2), Some(6)); + ids = result + .iter() + .map(|router_event| router_event.event.event_id) + .collect(); + assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max - result = indexer.get_events_in_id_range(2, 6); + result = indexer.get_buffer_events_in_id_range(Some(0), Some(4)); ids = result .iter() .map(|router_event| router_event.event.event_id) .collect(); - assert_eq!(ids, vec![2, 3, 4, 5]); // clamp max (TODO error instead?) + assert_eq!(ids, vec![1, 2, 3, 4]); // clamp start to buffer min, inclusive end - result = indexer.get_events_in_id_range(0, 4); + result = indexer.get_buffer_events_in_id_range(Some(3), Some(3)); ids = result .iter() .map(|router_event| router_event.event.event_id) .collect(); - assert_eq!(ids, vec![1, 2, 3]); // clamp min (TODO error instead?) + assert_eq!(ids, vec![3]); // single element when start == end - result = indexer.get_events_in_id_range(0, 0); + result = indexer.get_buffer_events_in_id_range(Some(5), Some(2)); ids = result .iter() .map(|router_event| router_event.event.event_id) .collect(); - assert!(ids.is_empty()); // return empty when start is before buffer + assert!(ids.is_empty()); // return empty when start > end + } + + #[tokio::test] + async fn test_get_events_in_id_range_all_cases() { + use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash}; + + // Create indexer with small buffer (5 events max) + // This way older events will only be in the tree, not the buffer + let indexer = LocalKvIndexer::new( + CancellationToken::new(), + 4, // block_size + Arc::new(KvIndexerMetrics::new_unregistered()), + 5, // max_buffer_size - only keeps 5 most recent events + ); + + // Helper to create a test event + let make_event = |id: u64| { + RouterEvent::new( + 0, // worker_id + KvCacheEvent { + event_id: id, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: None, + blocks: vec![KvCacheStoredBlockData { + block_hash: ExternalSequenceBlockHash(id * 100), + tokens_hash: LocalBlockHash(id * 200), + }], + }), + dp_rank: 0, + }, + ) + }; + + // Add 10 events (IDs 5-14) + // Buffer will only keep the last 5: events 10-14 + // Tree will have all blocks + for id in 5..15 { + indexer + .apply_event_with_buffer(make_event(id)) + .await + .unwrap(); + } + + // Wait for events to be processed by the tree + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Helper to extract event IDs from result + let get_ids = |events: Vec| -> Vec { + events.iter().map(|e| e.event.event_id).collect() + }; + + // Verify buffer state: should have events 10-14 (last 5) + let buffer_events = indexer.get_all_events_in_buffer(); + assert_eq!( + get_ids(buffer_events), + vec![10, 11, 12, 13, 14], + "Buffer should have events 10-14" + ); + + // ========== BUFFER PATH TESTS (start_id >= first_buffered) ========== + // Range is [start, end] inclusive + + // Test: start_id within buffer, no end + let result = indexer.get_events_in_id_range(Some(11), None).await; + assert_eq!( + get_ids(result), + vec![11, 12, 13, 14], + "start_id=11 (in buffer) should return [11, 14]" + ); + + // Test: start_id at buffer boundary + let result = indexer.get_events_in_id_range(Some(10), None).await; + assert_eq!( + get_ids(result), + vec![10, 11, 12, 13, 14], + "start_id=10 (buffer start) should return [10, 14]" + ); + + // Test: both start and end within buffer (inclusive) + let result = indexer.get_events_in_id_range(Some(11), Some(13)).await; + assert_eq!( + get_ids(result), + vec![11, 12, 13], + "range [11, 13] inclusive should return 3 events" + ); + + let result = indexer.get_events_in_id_range(Some(10), Some(14)).await; + assert_eq!( + get_ids(result), + vec![10, 11, 12, 13, 14], + "range [10, 14] should return all buffer events" + ); + + // ========== TREE DUMP PATH TESTS (range extends before buffer) ========== + // Note: Tree dumps return synthetic 0-indexed event IDs, so we just check + // that we get events back (the IDs won't match original IDs) + + // Test: (None, None) dumps entire tree + let result = indexer.get_events_in_id_range(None, None).await; + assert_eq!( + result.len(), + 10, + "(None, None) should dump entire tree (10 events)" + ); + + // Test: (None, Some(_)) dumps entire tree + let result = indexer.get_events_in_id_range(None, Some(8)).await; + assert_eq!( + result.len(), + 10, + "(None, Some(_)) dumps entire tree - end_id is ignored for tree dumps" + ); + + // Test: start_id before buffer triggers tree dump + let result = indexer.get_events_in_id_range(Some(7), None).await; + assert_eq!( + result.len(), + 10, + "start_id=7 (before buffer) should dump entire tree" + ); + + let result = indexer.get_events_in_id_range(Some(5), Some(12)).await; + assert_eq!( + result.len(), + 10, + "range [5, 12] extending before buffer should dump entire tree" + ); + + // ========== EDGE CASES ========== + + // Single element when start == end (inclusive range) + let result = indexer.get_events_in_id_range(Some(12), Some(12)).await; + assert_eq!( + get_ids(result), + vec![12], + "start == end should return single event" + ); + + // Empty when start > end + let result = indexer.get_events_in_id_range(Some(15), Some(10)).await; + assert!(result.is_empty(), "start > end should return empty"); + + // Request beyond buffer but valid range -> buffer returns what it has + let result = indexer.get_events_in_id_range(Some(12), Some(100)).await; + assert_eq!( + get_ids(result), + vec![12, 13, 14], + "range with end beyond buffer should return available buffer events" + ); } } @@ -3312,11 +3663,45 @@ mod tests { } #[cfg(test)] -mod tests_local_indexer_query { +mod tests_local_indexer { use super::*; use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash}; + use tokio::time; use tokio_util::sync::CancellationToken; + fn setup() { + dynamo_runtime::logging::init(); + } + + fn make_blocks(hashes: Vec) -> Vec { + hashes + .iter() + .map(|i| KvCacheStoredBlockData { + tokens_hash: LocalBlockHash(*i), + block_hash: ExternalSequenceBlockHash(*i * 100), + }) + .collect() + } + + fn create_store_event( + worker_id: WorkerId, + event_id: u64, + hashes: Vec, + parent: Option, + ) -> RouterEvent { + RouterEvent { + worker_id, + event: KvCacheEvent { + event_id, + data: KvCacheEventData::Stored(KvCacheStoreData { + parent_hash: parent, + blocks: make_blocks(hashes), + }), + dp_rank: 0, + }, + } + } + #[tokio::test] async fn test_local_indexer_buffer_and_serialization() { // Tests components of the LocalKvIndexer query without using nats @@ -3354,7 +3739,7 @@ mod tests_local_indexer_query { tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; // Get buffered events (what the query service would return) - let buffered_events = local_indexer.get_all_buffered_events(); + let buffered_events = local_indexer.get_all_events_in_buffer(); // Verify buffer contents assert_eq!(buffered_events.len(), 1, "Buffer should have 1 event"); @@ -3385,4 +3770,49 @@ mod tests_local_indexer_query { _ => panic!("Expected Stored event"), } } + + #[tokio::test] + async fn test_gap_detection_per_worker() { + setup(); + + let token = CancellationToken::new(); + let metrics = Arc::new(KvIndexerMetrics::new_unregistered()); + let indexer = KvIndexer::new(token.clone(), 4, metrics); + + let worker_a: WorkerId = 100; + let worker_b: WorkerId = 200; + let event_tx = indexer.event_sender(); + + // Worker A: events 1, 2, 3 (no gap) + for id in 1..=3 { + let event = create_store_event(worker_a, id, vec![id], None); + event_tx.send(event).await.unwrap(); + } + + // Worker B: events 1, then 5 (gap of 2, 3, 4) + let event_b1 = create_store_event(worker_b, 1, vec![10], None); + event_tx.send(event_b1).await.unwrap(); + + let event_b5 = create_store_event(worker_b, 5, vec![50], None); + event_tx.send(event_b5).await.unwrap(); + + // Give time for events to be processed + time::sleep(Duration::from_millis(20)).await; + + // Verify each worker has correct last_received_event_id + let last_ids = indexer.get_last_received_event_ids().await.unwrap(); + assert_eq!( + last_ids.get(&worker_a), + Some(&3), + "Worker A should have last_id = 3 (no gap)" + ); + assert_eq!( + last_ids.get(&worker_b), + Some(&5), + "Worker B should have last_id = 5 (despite gap)" + ); + + // Cleanup + token.cancel(); + } } diff --git a/lib/llm/src/kv_router/prefill_router.rs b/lib/llm/src/kv_router/prefill_router.rs index 6619bbd991..b2f16adb0f 100644 --- a/lib/llm/src/kv_router/prefill_router.rs +++ b/lib/llm/src/kv_router/prefill_router.rs @@ -176,11 +176,11 @@ impl PrefillRouter { Ok(()) } - /// Call the prefill router and extract structured prefill result and worker ID + /// Call the prefill router and extract structured prefill result async fn call_prefill( &self, request: SingleIn, - ) -> Result<(PrefillResult, Option), PrefillError> { + ) -> Result { // Get the prefill router, error if not activated let Some(prefill_router) = self.prefill_router.get() else { return Err(PrefillError::NotActivated); @@ -239,21 +239,10 @@ impl PrefillRouter { )); }; - // Extract prefill worker ID from disaggregated_params - let prefill_worker_id = disaggregated_params - .get("worker_id") - .and_then(|worker_id_json| { - worker_id_json - .get("prefill_worker_id") - .and_then(|v| v.as_u64()) - }); - Ok(( - PrefillResult { - disaggregated_params, - prompt_tokens_details, - }, - prefill_worker_id, - )) + Ok(PrefillResult { + disaggregated_params, + prompt_tokens_details, + }) } } @@ -310,7 +299,7 @@ impl // Handle prefill result match prefill_result { - Ok((prefill_result, prefill_worker_id)) => { + Ok(prefill_result) => { tracing::debug!("Prefill succeeded, using disaggregated params for decode"); let mut decode_req = req; @@ -326,14 +315,8 @@ impl ..existing_override.unwrap_or_default() }); - // Store prefill worker ID in context if available - let mut decode_context = context; - if let Some(worker_id) = prefill_worker_id { - decode_context.insert("prefill_worker_id", worker_id); - } - // Map the modified request through with preserved context - let decode_request = decode_context.map(|_| decode_req); + let decode_request = context.map(|_| decode_req); next.generate(decode_request).await } Err(PrefillError::NotActivated) => { diff --git a/lib/llm/src/kv_router/publisher.rs b/lib/llm/src/kv_router/publisher.rs index 572955e026..d36fcf2c18 100644 --- a/lib/llm/src/kv_router/publisher.rs +++ b/lib/llm/src/kv_router/publisher.rs @@ -26,10 +26,11 @@ use dynamo_runtime::{ use futures::StreamExt; use crate::kv_router::{ - KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_QUERY_SUBJECT, + KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE, + WORKER_KV_INDEXER_QUERY_SUBJECT, indexer::{ - KvIndexerMetrics, LocalKvIndexer, RouterEvent, WorkerKvQueryRequest, WorkerKvQueryResponse, - compute_block_hash_for_seq, + KvIndexerInterface, KvIndexerMetrics, LocalKvIndexer, RouterEvent, WorkerKvQueryRequest, + WorkerKvQueryResponse, compute_block_hash_for_seq, }, protocols::*, scoring::LoadEvent, @@ -101,11 +102,6 @@ pub struct KvEventPublisher { cancellation_token: CancellationToken, /// The channel to send events to. tx: mpsc::UnboundedSender, - /// Optional worker-local indexer for tracking this worker's own KV cache. - /// When present, events are applied to this indexer before being published to NATS. - local_indexer: Option>, - /// Optional runtime for router->local indexer comm (TODO might be refactored) - local_indexer_query_handle: Option>, } impl KvEventPublisher { @@ -114,7 +110,7 @@ impl KvEventPublisher { kv_block_size: u32, source_config: Option, ) -> Result { - Self::new_with_local_indexer(component, kv_block_size, source_config, true) + Self::new_with_local_indexer(component, kv_block_size, source_config, false) } pub fn new_with_local_indexer( @@ -161,14 +157,14 @@ impl KvEventPublisher { cancellation_token.clone(), kv_block_size, metrics, - 100, // TODO make this a parameter available for user change? + WORKER_KV_INDEXER_BUFFER_SIZE, ))) } else { None }; // Spawn runtime for router->local indexer comm if requested - let local_indexer_query_handle = local_indexer.as_ref().map(|local_indexer_ref| { + let _local_indexer_query_handle = local_indexer.as_ref().map(|local_indexer_ref| { let component = component.clone(); let local_indexer = local_indexer_ref.clone(); @@ -180,6 +176,7 @@ impl KvEventPublisher { component, worker_id, local_indexer, + cancellation_token.clone(), )) }); @@ -218,8 +215,6 @@ impl KvEventPublisher { source, cancellation_token, tx, - local_indexer, - local_indexer_query_handle, }) } @@ -231,11 +226,6 @@ impl KvEventPublisher { self.kv_block_size } - /// Get reference to local indexer if enabled. - pub fn local_indexer(&self) -> Option> { - self.local_indexer.clone() - } - pub fn shutdown(&mut self) { if !self.cancellation_token.is_cancelled() { self.cancellation_token.cancel(); @@ -244,10 +234,6 @@ impl KvEventPublisher { if let Some(source) = self.source.take() { source.shutdown(); } - - if let Some(handle) = self.local_indexer_query_handle.take() { - handle.abort(); - } } } @@ -307,29 +293,24 @@ async fn start_worker_kv_query_service( component: Component, worker_id: u64, local_indexer: Arc, + cancellation_token: CancellationToken, ) { - // NOTE: referenced discover/worker_monitor.rs for pub/sub pattern - let cancellation_token = component.drt().child_token(); - // Create NATS subscriber on a subject specific to worker's id let subject = format!("{}.{}", WORKER_KV_INDEXER_QUERY_SUBJECT, worker_id); - let full_subject = format!("namespace.{}.{}", component.namespace().name(), subject); // TODO make a helper like subject() - let mut subscriber = match component.namespace().subscribe(&subject).await { + let mut subscriber = match component.subscribe(&subject).await { Ok(sub) => sub, Err(e) => { tracing::error!("Failed to subscribe to {}: {}", subject, e); return; // No ? because function doesn't return Result } }; - tracing::info!( + tracing::debug!( "Query service on worker {} listening on NATS subject: {}", worker_id, - full_subject + subject ); // Receive query request from router, retrieve event(s) from LocalKvIndexer, return response - // TODO: currently just dumps all events from LocalKvIndexer; need to implement - // event selection logic from buffer loop { tokio::select! { _ = cancellation_token.cancelled() => { @@ -355,8 +336,25 @@ async fn start_worker_kv_query_service( // TODO extract request event id range. For now, just debug print tracing::debug!("Received WorkerKvQueryRequest: {:?}", request); - // Get events from local indexer (TODO for now, dump all events) - let events = local_indexer.get_all_buffered_events(); + // Resolve which events to return based on optional start/end ids + let events = match (request.start_event_id, request.end_event_id) { + (None, None) => { + match local_indexer.dump_events().await { + Ok(events) => events, + Err(err) => { + tracing::error!( + error = %err, + worker_id, + "Failed to dump events for WorkerKvQueryRequest; returning buffered events instead" + ); + local_indexer.get_all_events_in_buffer() + } + } + } + _ => { + local_indexer.get_events_in_id_range(request.start_event_id, request.end_event_id).await + } + }; // Build WorkerKvQueryResponse let response = WorkerKvQueryResponse { events }; @@ -1760,7 +1758,7 @@ mod tests_startup_helpers { ); // assert: Worker's local indexer buffered event - let buffered = local_indexer_1.get_all_buffered_events(); + let buffered = local_indexer_1.get_all_events_in_buffer(); assert_eq!(buffered.len(), 1, "Local indexer should buffer 1 event"); // === STEP 2 & 3: Simulate Outage - Stop forwarding to router === @@ -1796,7 +1794,7 @@ mod tests_startup_helpers { } // assert: Worker's local indexer has both events - let buffered = local_indexer_1.get_all_buffered_events(); + let buffered = local_indexer_1.get_all_events_in_buffer(); assert_eq!( buffered.len(), 2, @@ -1819,17 +1817,32 @@ mod tests_startup_helpers { "Router should only see 1 shared block (not the new block from event_2)" ); - // === STEP 4 & 5: Recovery - Apply buffered events to router === - // This simulates: router.query_worker_local_kv(worker_1_id) - // followed by applying the returned events - // TODO be able to identify which event id is the last that Router has, - // and query worker(s) for buffer starting after it - for router_event in buffered { + // === STEP 4 & 5: Recovery - Query last received event IDs and fetch missed events === + // Step 4a: Router queries its last received event ID per worker + let last_ids = router_indexer.get_last_received_event_ids().await.unwrap(); + let last_known_id = last_ids.get(&worker_1_id).copied().unwrap_or(0); + assert_eq!( + last_known_id, 1, + "Router should have last_received_event_id = 1 for worker (only event_1 was forwarded)" + ); + + // Step 4b: Query worker's local indexer for events after last_known_id + let missed_events = local_indexer_1 + .get_events_in_id_range(Some(last_known_id + 1), None) + .await; + assert_eq!( + missed_events.len(), + 1, + "Should get 1 missed event (event_2 with id=2)" + ); + + // Step 5: Apply missed events to router + for router_event in missed_events { router_indexer .event_sender() .send(router_event) .await - .unwrap(); // TODO use apply_event() instead? + .unwrap(); } tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -1846,6 +1859,14 @@ mod tests_startup_helpers { "Router should now see both blocks after recovery" ); + // assert: Router's last_received_event_id is updated after recovery + let last_ids_after = router_indexer.get_last_received_event_ids().await.unwrap(); + assert_eq!( + last_ids_after.get(&worker_1_id), + Some(&2), + "Router should have last_received_event_id = 2 after recovery" + ); + token.cancel(); } } @@ -1903,10 +1924,7 @@ mod test_exponential_backoff { mod test_integration_publisher { use super::*; use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats}; - use dynamo_runtime::distributed_test_utils::{ - create_test_drt_async, create_test_shared_drt_async, - }; - use dynamo_runtime::pipeline::AsyncEngine; + use dynamo_runtime::distributed_test_utils::create_test_drt_async; use dynamo_runtime::traits::events::EventSubscriber; use futures::StreamExt; @@ -2095,25 +2113,29 @@ mod test_integration_publisher { "โœ… KvStatsPrometheusGauges constructor and publish() work correctly with real Component" ); } +} + +#[cfg(all(test, feature = "integration"))] +mod test_integration_publisher_with_kvindexer { + use super::*; + + use crate::kv_router::scheduler::DefaultWorkerSelector; + use crate::kv_router::{KvPushRouter, KvRouter, KvRouterConfig}; + use crate::local_model::LocalModelBuilder; + use crate::local_model::runtime_config::ModelRuntimeConfig; + use crate::mocker::engine::{MOCKER_COMPONENT, MockVllmEngine}; + use crate::mocker::protocols::MockEngineArgs; + use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}; + use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; + use dynamo_runtime::distributed_test_utils::create_test_shared_drt_async; + use dynamo_runtime::engine::AsyncEngine; + use dynamo_runtime::pipeline::{Context, PushRouter, RouterMode, network::Ingress}; + use dynamo_runtime::protocols::annotated::Annotated; /// Integration test: KvPushRouter end-to-end routing with mock engines. #[tokio::test] #[ignore] // Requires NATS/etcd. Run with: cargo test --package dynamo-llm --lib --features integration test_distributed_kvindexer_e2e -- --ignored --nocapture async fn test_distributed_kvindexer_e2e() -> anyhow::Result<()> { - use crate::kv_router::scheduler::DefaultWorkerSelector; - use crate::kv_router::{ - KvPushRouter, KvRouter, KvRouterConfig, worker_query::WorkerQueryClient, - }; - use crate::mocker::engine::MOCKER_COMPONENT; - use crate::mocker::engine::MockVllmEngine; - use crate::mocker::protocols::{MockEngineArgs, WorkerType}; - use crate::protocols::common::{ - OutputOptions, SamplingOptions, StopConditions, - llm_backend::{LLMEngineOutput, PreprocessedRequest}, - }; - use dynamo_runtime::pipeline::{Context, PushRouter, RouterMode, network::Ingress}; - use dynamo_runtime::protocols::annotated::Annotated; - const BLOCK_SIZE: u32 = 4; const NUM_REQUESTS: usize = 4; @@ -2138,9 +2160,8 @@ mod test_integration_publisher { let mocker_args = MockEngineArgs::builder() .block_size(BLOCK_SIZE as usize) .dp_size(1) // single worker per runtime - .worker_type(WorkerType::Aggregated) - .speedup_ratio(50.0) .enable_prefix_caching(true) + .enable_local_indexer(true) // affects scheduler/publisher args .build()?; let worker_components = vec![component1.clone(), component2.clone()]; @@ -2152,6 +2173,31 @@ mod test_integration_publisher { engine.start(comp.clone()).await?; tracing::info!("MockVllmEngine started for {:?}", comp); + // Register MDC with runtime_config so router can discover enable_local_indexer. + // (Without this step, the MDC-based assert in query_worker() in worker_query.rs will fail.) + // This inlines code which in the Python path would be performed by: + // - local_model.rs: LocalModelBuilder::build() sets runtime_config from MockEngineArgs + // - entrypoint/input/endpoint.rs: LocalModel::attach() registers MDC via discovery + let endpoint = comp.endpoint("generate"); + let runtime_config = ModelRuntimeConfig { + enable_local_indexer: true, + ..Default::default() + }; + let mut builder = LocalModelBuilder::default(); + builder + .model_name(Some("mock".to_string())) + .kv_cache_block_size(Some(BLOCK_SIZE)) + .runtime_config(runtime_config); + let mut local_model = builder.build().await?; + local_model + .attach( + &endpoint, + crate::model_type::ModelType::Chat, + crate::model_type::ModelInput::Tokens, + None, + ) + .await?; + let ingress = Ingress::for_engine(engine.clone())?; let endpoint_component = comp.clone(); let handle = tokio::spawn(async move { @@ -2195,13 +2241,13 @@ mod test_integration_publisher { ); let push_router = - PushRouter::>::from_client_with_threshold( - client, - RouterMode::KV, - None, - None, - ) - .await?; + PushRouter::>::from_client_with_threshold( + client, + RouterMode::KV, + None, + None, + ) + .await?; let kv_push_router = KvPushRouter::new(push_router, kv_router.clone()); @@ -2216,7 +2262,6 @@ mod test_integration_publisher { }) .sampling_options(SamplingOptions::default()) .output_options(OutputOptions::default()) - .eos_token_ids(vec![]) .build() .unwrap() }; // from mocker/engine.rs @@ -2240,23 +2285,25 @@ mod test_integration_publisher { // ===== TEST PART 2: QUERY WORKER-LOCAL KVINDEXERS DIRECTLY ===== // TODO: This could be refactored as router function (e.g. router.refresh_from_worker(worker_id)) // (which should also update the global kvIndexer with the buffer from the local kvIndexer) - let query_client = WorkerQueryClient::new(router_namespace.clone()); let mut best_worker_info: Option<(u64, usize)> = None; // Exactly one worker should have been routed requests. Find that worker for &worker_id in &worker_ids { - let response = query_client.query_worker(worker_id).await?; - - if !response.events.is_empty() { - let event_count = response.events.len(); - tracing::info!( - worker_id, - events = event_count, - "Worker query on worker {worker_id} returned buffered KV events" - ); - best_worker_info = Some((worker_id, event_count)); - break; + let response = kv_router + .query_worker_local_kv(worker_id, None, None) + .await?; + if response.events.is_empty() { + continue; } + + let event_count = response.events.len(); + tracing::info!( + worker_id, + events = event_count, + "Worker query on worker {worker_id} returned buffered KV events" + ); + best_worker_info = Some((worker_id, event_count)); + break; } // Verify that only one worker has KV events in buffer @@ -2272,7 +2319,9 @@ mod test_integration_publisher { continue; } - let response = query_client.query_worker(worker_id).await?; + let response = kv_router + .query_worker_local_kv(worker_id, None, None) + .await?; assert!( response.events.is_empty(), "Worker {worker_id} should not report buffered KV events; best worker {best_worker_id} reported {best_worker_event_count}" @@ -2289,4 +2338,178 @@ mod test_integration_publisher { Ok(()) } + + #[tokio::test] + #[ignore] + async fn test_distributed_kvindexer_e2e_startup() -> anyhow::Result<()> { + const BLOCK_SIZE: u32 = 4; + + dynamo_runtime::logging::init(); + + // === SETUP: Distributed runtimes and namespaces === + let shared_store_dir = tempfile::tempdir()?; + let shared_store_path = shared_store_dir.path().to_path_buf(); + + // Use a unique namespace per test run for full isolation + let test_namespace = format!("test_e2e_{}", uuid::Uuid::new_v4().simple()); + + // Make both runtimes point at the same file-backed storage backend so worker + // registrations and heartbeats remain visible to every DRT instance. + let distributed1 = create_test_shared_drt_async(&shared_store_path).await; + let distributed2 = create_test_shared_drt_async(&shared_store_path).await; + let component1 = distributed1 + .namespace(&test_namespace)? + .component(MOCKER_COMPONENT)?; + let component2 = distributed2 + .namespace(&test_namespace)? + .component(MOCKER_COMPONENT)?; + + // === SETUP: Start mocker workers === + let mocker_args = MockEngineArgs::builder() + .block_size(BLOCK_SIZE as usize) + .dp_size(1) // single worker per runtime + .enable_prefix_caching(true) + .enable_local_indexer(true) // affects scheduler/publisher args + .build()?; + + let worker_components = vec![component1.clone(), component2.clone()]; + let mut server_handles = Vec::new(); + let mut worker_ids = Vec::new(); + + for comp in worker_components { + let engine: Arc = Arc::new(MockVllmEngine::new(mocker_args.clone())); + engine.start(comp.clone()).await?; + tracing::info!("MockVllmEngine started for {:?}", comp); + + // Register MDC with runtime_config so router can discover enable_local_indexer. + // (Without this step, the MDC-based assert in query_worker() in worker_query.rs will fail.) + // This inlines code which in the Python path would be performed by: + // - local_model.rs: LocalModelBuilder::build() sets runtime_config from MockEngineArgs + // - entrypoint/input/endpoint.rs: LocalModel::attach() registers MDC via discovery + let endpoint = comp.endpoint("generate"); + let runtime_config = ModelRuntimeConfig { + enable_local_indexer: true, + ..Default::default() + }; + let mut builder = LocalModelBuilder::default(); + builder + .model_name(Some("mock".to_string())) + .kv_cache_block_size(Some(BLOCK_SIZE)) + .runtime_config(runtime_config); + let mut local_model = builder.build().await?; + local_model + .attach( + &endpoint, + crate::model_type::ModelType::Chat, + crate::model_type::ModelInput::Tokens, + None, + ) + .await?; + + let ingress = Ingress::for_engine(engine.clone())?; + let endpoint_component = comp.clone(); + let handle = tokio::spawn(async move { + if let Err(e) = endpoint_component + .endpoint("generate") + .endpoint_builder() + .handler(ingress) + .start() + .await + { + tracing::error!("Generate endpoint failed: {e}"); + } + }); + server_handles.push(handle); + worker_ids.push(comp.drt().connection_id()); + } + tracing::info!("Generate endpoint servers launched"); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // === STEP 1: Send request to worker_ids[0] to populate its local indexer === + // This simulates a situation where KvPushRouter is initialized + // to route to workers which already have KV events + let pre_router_distributed = create_test_shared_drt_async(&shared_store_path).await; + let pre_backend_endpoint = pre_router_distributed + .namespace(&test_namespace)? + .component(MOCKER_COMPONENT)? + .endpoint("generate"); + let pre_client = pre_backend_endpoint.client().await?; + + // Create a PushRouter to send requests directly to a specific worker + let pre_push_router = + PushRouter::>::from_client_with_threshold( + pre_client, + RouterMode::Random, // We'll use direct() so mode doesn't matter + None, + None, + ) + .await?; + + // Force sending one requests each to the two workers + for &worker_id in &worker_ids { + let tokens: Vec = vec![0, 1, 2, 3]; + let request = PreprocessedRequest::builder() + .model("mock".to_string()) + .token_ids(tokens.clone()) + .sampling_options(SamplingOptions::default()) + .output_options(OutputOptions::default()) + .stop_conditions(StopConditions { + max_tokens: Some(5), + ..Default::default() + }) + .build()?; + let response_stream = pre_push_router + .direct(Context::new(request), worker_id) + .await?; + // Consume the stream to complete the request + let _responses: Vec<_> = response_stream.collect().await; + tracing::debug!( + "Sent request {:?} directly to worker {} to populate its local indexer", + tokens, + worker_id + ); + } + tokio::time::sleep(Duration::from_millis(1000)).await; + + // === SETUP: Build KvPushRouter === + let router_distributed = create_test_shared_drt_async(&shared_store_path).await; + let router_namespace = router_distributed.namespace(&test_namespace)?; + let backend_component = router_namespace.component(MOCKER_COMPONENT)?; + let backend_endpoint = backend_component.endpoint("generate"); + let client = backend_endpoint.client().await?; + let kv_router_config = KvRouterConfig::default(); + let selector = Box::new(DefaultWorkerSelector::new(Some(kv_router_config))); + let consumer_id = format!("test-router-{}", router_distributed.connection_id()); + + let kv_router: Arc = Arc::new( + KvRouter::new( + backend_endpoint.clone(), + client.clone(), + BLOCK_SIZE, + Some(selector), + Some(kv_router_config), + consumer_id, + ) + .await?, + ); + + // At this point kvrouter's indexer should already have the + // events stored in the workers, due to the catch-up built into KvRouter::new. + // Each request generates 2 events: input block (parent_hash: None) + output block (parent_hash: Some) + // With 2 workers, that's 4 events total. + let global_kv_events = kv_router.indexer.dump_events().await?; + tracing::debug!("Global KV events: {:?}", global_kv_events); + assert_eq!(global_kv_events.len(), 4); // 2 workers ร— 2 events per request (input + output) + + // === Cleanup === + for handle in server_handles { + handle.abort(); + } + distributed1.shutdown(); + distributed2.shutdown(); + router_distributed.shutdown(); + + Ok(()) + } } diff --git a/lib/llm/src/kv_router/subscriber.rs b/lib/llm/src/kv_router/subscriber.rs index b233220cb2..ba2ad43676 100644 --- a/lib/llm/src/kv_router/subscriber.rs +++ b/lib/llm/src/kv_router/subscriber.rs @@ -1,9 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Background processes for the KV Router including event consumption and snapshot uploads. - -use std::{collections::HashSet, time::Duration}; +use std::{collections::HashMap, collections::HashSet, time::Duration}; use anyhow::Result; use dynamo_runtime::{ @@ -24,6 +22,7 @@ use crate::kv_router::{ indexer::{DumpRequest, GetWorkersRequest, RouterEvent}, protocols::WorkerId, router_discovery_query, + worker_query::WorkerQueryClient, }; /// Delay between snapshot reads to verify stability @@ -33,6 +32,163 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10; const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1); const CHECK_INTERVAL_JITTER_MS: i64 = 100; +// ============================================================================ +// Local KvIndexer-based Recovery +// ============================================================================ + +/// Recover missed events from all workers with local indexers. +/// +/// This function should be called on router startup to catch up on any events +/// that were missed while the router was offline. +/// +/// # Arguments +/// +/// * `worker_query_client` - Client for querying worker local indexers +/// * `last_received_event_ids` - Map of worker ID to last received event ID +/// * `worker_ids` - List of worker IDs to recover from +/// * `event_tx` - Channel to send recovered events to the indexer +/// +/// # Returns +/// +/// Total number of events recovered across all workers +pub async fn recover_from_all_workers( + worker_query_client: &WorkerQueryClient, + last_received_event_ids: &HashMap, + worker_ids: &Vec, + event_tx: &mpsc::Sender, +) -> usize { + let mut total_recovered = 0; + let mut successful_workers = 0; + let mut failed_workers = 0; + + for &worker_id in worker_ids { + // Skip workers without local indexer + if !worker_query_client.has_local_indexer(worker_id) { + tracing::debug!( + worker_id, + "Skipping recovery - worker does not have local indexer enabled" + ); + continue; + } + + // If we haven't seen any events from this worker, start from beginning (None) + // If we've seen events, start from last_known_id + 1 + let start_event_id = last_received_event_ids + .get(&worker_id) + .map(|&last_id| last_id + 1); + + match recover_from_worker( + worker_query_client, + worker_id, + start_event_id, + None, // Get all events after start_event_id + event_tx, + ) + .await + { + Ok(count) => { + total_recovered += count; + if count > 0 { + successful_workers += 1; + } + } + Err(_) => { + failed_workers += 1; + } + } + } + + // Log summary + if total_recovered > 0 || failed_workers > 0 { + tracing::info!( + total_recovered, + successful_workers, + failed_workers, + "Startup recovery completed" + ); + } + + total_recovered +} + +/// Recover missed KV events from a specific worker. +/// +/// # Arguments +/// +/// * `worker_query_client` - Client for querying worker local indexers +/// * `worker_id` - The worker to recover from +/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning +/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all +/// * `event_tx` - Channel to send recovered events to the indexer +/// +/// # Returns +/// +/// Number of events recovered, or error if recovery failed +pub async fn recover_from_worker( + worker_query_client: &WorkerQueryClient, + worker_id: WorkerId, + start_event_id: Option, + end_event_id: Option, + event_tx: &mpsc::Sender, +) -> Result { + if worker_query_client.has_local_indexer(worker_id) { + tracing::debug!( + worker_id, + start_event_id = ?start_event_id, + end_event_id = ?end_event_id, + "Attempting recovery from worker" + ); + } else { + tracing::warn!( + "Worker {} does not have local indexer enabled, skipping recovery", + worker_id + ); + return Ok(0); + } + + // Query worker for events in range + let response = worker_query_client + .query_worker(worker_id, start_event_id, end_event_id) + .await?; + + let events_count = response.events.len(); + + if events_count == 0 { + tracing::debug!( + worker_id, + start_event_id = ?start_event_id, + "No missed events to recover from worker" + ); + return Ok(0); + } + + tracing::info!( + worker_id, + start_event_id = ?start_event_id, + events_count, + "Recovered {} missed events from worker", + events_count + ); + + // Apply recovered events to the indexer + for event in response.events { + if let Err(e) = event_tx.send(event).await { + tracing::error!( + worker_id, + error = %e, + "Failed to send recovered event to indexer" + ); + anyhow::bail!("Failed to send recovered event: {}", e); + } + } + + Ok(events_count) +} + +// ============================================================================ +// Snapshot Management +// ============================================================================ + /// Download a stable snapshot from object store and send events to the indexer. /// Retries until two consecutive reads match or max attempts is reached. async fn download_stable_snapshot( @@ -187,24 +343,29 @@ impl SnapshotResources { .await .map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?; - // Upload the snapshot to NATS object store - let url = url::Url::parse(&format!( - "nats://{}/{}/{RADIX_STATE_FILE}", - self.nats_client.addr(), - self.bucket_name - ))?; - - self.nats_client - .object_store_upload_data(&events, &url) - .await - .map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?; + // Upload the snapshot to NATS object store in background (non-blocking) + let nats_client = self.nats_client.clone(); + let bucket_name = self.bucket_name.clone(); + let event_count = events.len(); + tokio::spawn(async move { + let Ok(url) = url::Url::parse(&format!( + "nats://{}/{bucket_name}/{RADIX_STATE_FILE}", + nats_client.addr(), + )) else { + tracing::warn!("Failed to parse snapshot URL"); + return; + }; + + if let Err(e) = nats_client.object_store_upload_data(&events, &url).await { + tracing::warn!("Failed to upload snapshot: {e:?}"); + return; + } - tracing::info!( - "Successfully performed snapshot of radix tree with {} events to bucket {} in {}ms", - events.len(), - self.bucket_name, - start_time.elapsed().as_millis() - ); + tracing::info!( + "Successfully uploaded snapshot with {event_count} events to bucket {bucket_name} in {}ms", + start_time.elapsed().as_millis() + ); + }); Ok(()) } diff --git a/lib/llm/src/kv_router/worker_query.rs b/lib/llm/src/kv_router/worker_query.rs index 0e5d18639b..abea534285 100644 --- a/lib/llm/src/kv_router/worker_query.rs +++ b/lib/llm/src/kv_router/worker_query.rs @@ -1,10 +1,13 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::collections::HashMap; + use anyhow::{Context, Result}; -use dynamo_runtime::component::Namespace; +use dynamo_runtime::component::Component; use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::events::EventPublisher; +use tokio::sync::watch; use crate::kv_router::WORKER_KV_INDEXER_QUERY_SUBJECT; use crate::kv_router::indexer::{WorkerKvQueryRequest, WorkerKvQueryResponse}; @@ -12,42 +15,75 @@ use crate::kv_router::protocols::WorkerId; /// Router-side client for querying worker local KV indexers /// -/// Uses the namespace abstraction for clean request/reply communication -/// with workers via NATS. +/// Performs request/reply communication with workers via NATS. +/// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data) +/// The client is spawned by KvRouter; it watches same discovery stream as the router. pub struct WorkerQueryClient { - namespace: Namespace, + component: Component, + /// Watch receiver for enable_local_indexer state per worker + local_indexer_rx: watch::Receiver>, } impl WorkerQueryClient { - pub fn new(namespace: Namespace) -> Self { - Self { namespace } + /// Create a new WorkerQueryClient with a watch receiver for local indexer states + pub fn new( + component: Component, + local_indexer_rx: watch::Receiver>, + ) -> Self { + Self { + component, + local_indexer_rx, + } } - /// Query a specific worker's local KV indexer and return its buffered events - pub async fn query_worker(&self, worker_id: WorkerId) -> Result { - // Match worker's subscribe format: namespace.{namespace_name}.{SUBJECT}.{worker_id} - let subject = format!( - "{}.{}.{}", - self.namespace.subject(), - WORKER_KV_INDEXER_QUERY_SUBJECT, - worker_id - ); + /// Check if a worker has local indexer enabled + pub fn has_local_indexer(&self, worker_id: WorkerId) -> bool { + self.local_indexer_rx + .borrow() + .get(&worker_id) + .copied() + .unwrap_or(false) + } - tracing::info!( - "Router sending request to worker {} on NATS subject: {}", + /// Query a specific worker's local KV indexer and return its buffered events. + /// Returns an error if the worker does not have enable_local_indexer=true. + pub async fn query_worker( + &self, + worker_id: WorkerId, + start_event_id: Option, + end_event_id: Option, + ) -> Result { + // Check if worker has local indexer enabled + if !self.has_local_indexer(worker_id) { + anyhow::bail!( + "Worker {} does not have local indexer enabled (enable_local_indexer=false or not set in MDC user_data)", + worker_id + ); + } + + // Match worker's subscribe format + let subject_str = format!("{}.{}", WORKER_KV_INDEXER_QUERY_SUBJECT, worker_id); // see publisher.rs/start_worker_kv_query_service() + let subject = format!("{}.{}", self.component.subject(), subject_str); + + tracing::debug!( + "Router sending query request to worker {} on NATS subject: {}", worker_id, subject ); // Create and serialize request - let request = WorkerKvQueryRequest { worker_id }; + let request = WorkerKvQueryRequest { + worker_id, + start_event_id, + end_event_id, + }; let request_bytes = serde_json::to_vec(&request).context("Failed to serialize WorkerKvQueryRequest")?; // Send NATS request with timeout using DRT helper let timeout = tokio::time::Duration::from_secs(1); let response_msg = self - .namespace + .component .drt() .kv_router_nats_request(subject.clone(), request_bytes.into(), timeout) .await diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 3c0adf0b26..9ab5e7f6be 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -234,6 +234,7 @@ impl LocalModelBuilder { self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64); self.runtime_config.max_num_batched_tokens = mocker_engine_args.max_num_batched_tokens.map(|v| v as u64); + self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer; self.runtime_config.data_parallel_size = mocker_engine_args.dp_size; self.media_decoder = Some(MediaDecoder::default()); self.media_fetcher = Some(MediaFetcher::default()); diff --git a/lib/llm/src/local_model/runtime_config.rs b/lib/llm/src/local_model/runtime_config.rs index 833465a672..482b77578f 100644 --- a/lib/llm/src/local_model/runtime_config.rs +++ b/lib/llm/src/local_model/runtime_config.rs @@ -23,6 +23,10 @@ pub struct ModelRuntimeConfig { #[serde(default = "default_data_parallel_size")] pub data_parallel_size: u32, + /// Enable worker-local KV indexer for tracking this worker's own KV cache state + #[serde(default)] + pub enable_local_indexer: bool, + /// Mapping of engine-specific runtime configs #[serde(default, skip_serializing_if = "HashMap::is_empty")] pub runtime_data: HashMap, @@ -51,6 +55,7 @@ impl Default for ModelRuntimeConfig { tool_call_parser: None, reasoning_parser: None, data_parallel_size: default_data_parallel_size(), + enable_local_indexer: false, runtime_data: HashMap::new(), tensor_model_config: None, } diff --git a/lib/llm/src/migration.rs b/lib/llm/src/migration.rs index 39ffa371b0..333f89a289 100644 --- a/lib/llm/src/migration.rs +++ b/lib/llm/src/migration.rs @@ -11,8 +11,8 @@ use async_nats::client::{ }; use crate::{ - model_card::ModelDeploymentCard, - protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, + model_card::ModelDeploymentCard, preprocessor::BackendOutput, + protocols::common::llm_backend::PreprocessedRequest, }; use dynamo_runtime::{ @@ -44,16 +44,16 @@ impl Migration { impl Operator< SingleIn, - ManyOut>, + ManyOut>, SingleIn, - ManyOut>, + ManyOut>, > for Migration { async fn generate( &self, request: SingleIn, - next: ServerStreamingEngine>, - ) -> Result>> { + next: ServerStreamingEngine>, + ) -> Result>> { let (preprocessed_request, context) = request.transfer(()); let engine_ctx = context.context(); let engine_ctx_ = engine_ctx.clone(); @@ -73,8 +73,8 @@ impl struct RetryManager { context: Arc, request: PreprocessedRequest, - next_generate: ServerStreamingEngine>, - next_stream: Option>>, + next_generate: ServerStreamingEngine>, + next_stream: Option>>, retries_left: u32, } @@ -82,7 +82,7 @@ impl RetryManager { pub async fn build( context: Arc, preprocessed_request: PreprocessedRequest, - next: ServerStreamingEngine>, + next: ServerStreamingEngine>, retries_left: u32, ) -> Result { let mut slf = Self { @@ -96,7 +96,7 @@ impl RetryManager { Ok(slf) } - pub async fn next(&mut self) -> Option> { + pub async fn next(&mut self) -> Option> { loop { let response_stream = match self.next_stream.as_mut() { Some(stream) => stream, @@ -128,7 +128,7 @@ impl RetryManager { } async fn new_stream(&mut self) -> Result<()> { - let mut response_stream: Option>>> = None; + let mut response_stream: Option>>> = None; while self.retries_left > 0 { self.retries_left -= 1; let request = Context::with_id(self.request.clone(), self.context.id().to_string()); @@ -162,7 +162,7 @@ impl RetryManager { } } - fn track_response(&mut self, response: &Annotated) { + fn track_response(&mut self, response: &Annotated) { if self.retries_left == 0 { return; } @@ -207,18 +207,17 @@ mod tests { } // Helper to create mock LLM engine output - fn create_mock_output(token_id: u32) -> Annotated { - Annotated::from_data(LLMEngineOutput { + fn create_mock_output(token_id: u32) -> Annotated { + Annotated::from_data(BackendOutput { token_ids: vec![token_id], - tokens: None, - text: Some(format!("token_{}", token_id)), + tokens: vec![], + text: Some(format!("token_{token_id}")), cum_log_probs: None, log_probs: None, top_logprobs: None, finish_reason: None, index: None, disaggregated_params: None, - extra_args: None, completion_usage: None, }) } @@ -267,16 +266,13 @@ mod tests { #[async_trait] impl - AsyncEngine< - SingleIn, - ManyOut>, - anyhow::Error, - > for MockEngine + AsyncEngine, ManyOut>, anyhow::Error> + for MockEngine { async fn generate( &self, request: SingleIn, - ) -> Result>> { + ) -> Result>> { let call_num = self.call_count.fetch_add(1, Ordering::SeqCst); let (preprocessed_request, context) = request.transfer(()); @@ -457,7 +453,7 @@ mod tests { &self, start: usize, end: usize, - ) -> Result>> { + ) -> Result>> { let (tx, rx) = mpsc::channel(1); let token_offset = self.token_offset; @@ -494,7 +490,7 @@ mod tests { 100, context_id.clone(), )); - let next_generate: ServerStreamingEngine> = + let next_generate: ServerStreamingEngine> = mock_engine; let ctx = Arc::new(Controller::new(context_id.clone())); @@ -533,7 +529,7 @@ mod tests { 100, context_id.clone(), )); - let next_generate: ServerStreamingEngine> = + let next_generate: ServerStreamingEngine> = mock_engine; let ctx = Arc::new(Controller::new(context_id.clone())); @@ -573,7 +569,7 @@ mod tests { 100, context_id.clone(), )); - let next_generate: ServerStreamingEngine> = + let next_generate: ServerStreamingEngine> = mock_engine; let ctx = Arc::new(Controller::new(context_id.clone())); @@ -613,7 +609,7 @@ mod tests { 100, context_id.clone(), )); - let next_generate: ServerStreamingEngine> = + let next_generate: ServerStreamingEngine> = mock_engine; // Should fail to build due to initial stream creation failure after exhausting all 3 retries @@ -641,7 +637,7 @@ mod tests { 100, context_id.clone(), )); - let next_generate: ServerStreamingEngine> = + let next_generate: ServerStreamingEngine> = mock_engine; let ctx = Arc::new(Controller::new(context_id.clone())); @@ -690,7 +686,7 @@ mod tests { 100, context_id.clone(), )); - let next_generate: ServerStreamingEngine> = + let next_generate: ServerStreamingEngine> = mock_engine; let ctx = Arc::new(Controller::new(context_id.clone())); @@ -739,7 +735,7 @@ mod tests { 100, context_id.clone(), )); - let next_generate: ServerStreamingEngine> = + let next_generate: ServerStreamingEngine> = mock_engine; let ctx = Arc::new(Controller::new(context_id.clone())); diff --git a/lib/llm/src/mocker/engine.rs b/lib/llm/src/mocker/engine.rs index 94b1d47b74..afe09260c8 100644 --- a/lib/llm/src/mocker/engine.rs +++ b/lib/llm/src/mocker/engine.rs @@ -60,7 +60,11 @@ impl MockVllmEngine { } pub async fn start(&self, component: Component) -> Result<()> { - let cancel_token = component.drt().runtime().child_token(); + // Use primary_token() instead of child_token() so the mocker continues running + // during graceful shutdown (Phase 1/2) and only stops in Phase 3. + // child_token() is a child of endpoint_shutdown_token which is cancelled in Phase 1. + // primary_token() is only cancelled in Phase 3, after waiting for inflight requests. + let cancel_token = component.drt().primary_token(); // Simulate engine startup time if configured if let Some(startup_time_secs) = self.engine_args.startup_time { @@ -143,6 +147,11 @@ impl MockVllmEngine { } } _ = cancel_token_cloned.cancelled() => { + tracing::info!("Scheduler output task cancelled, clearing active requests"); + // Clear all active requests to unblock waiting request handlers + // This will cause their request_rx.recv() to return None + let mut active = active_requests_clone.lock().await; + active.clear(); break; } } diff --git a/lib/llm/src/mocker/kv_manager.rs b/lib/llm/src/mocker/kv_manager.rs index 17d7491162..a949139475 100644 --- a/lib/llm/src/mocker/kv_manager.rs +++ b/lib/llm/src/mocker/kv_manager.rs @@ -72,7 +72,7 @@ pub struct KvManager { impl KvManager { pub fn new(max_capacity: usize, block_size: usize) -> Self { - Self::new_with_publisher(max_capacity, block_size, None, 0) + Self::new_with_publisher(max_capacity, block_size, None, 0, false) } pub fn new_with_publisher( @@ -80,6 +80,7 @@ impl KvManager { block_size: usize, component: Option, dp_rank: u32, + enable_local_indexer: bool, ) -> Self { let active_blocks = HashMap::new(); let inactive_blocks = LRUEvictor::default(); @@ -87,10 +88,10 @@ impl KvManager { let kv_event_publisher = component.map(|comp| { tracing::info!( - "Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}" + "Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}, enable_local_indexer={enable_local_indexer}" ); Arc::new( - KvEventPublisher::new(comp, block_size as u32, None) + KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer) .expect("Failed to create KV event publisher"), ) }); diff --git a/lib/llm/src/mocker/protocols.rs b/lib/llm/src/mocker/protocols.rs index d67e707ee6..4e62d836e4 100644 --- a/lib/llm/src/mocker/protocols.rs +++ b/lib/llm/src/mocker/protocols.rs @@ -120,6 +120,10 @@ pub struct MockEngineArgs { #[serde(skip)] #[builder(default = "Arc::new(PerfModel::default())")] pub perf_model: Arc, + + /// Enable worker-local KV indexer for tracking this worker's own KV cache state + #[builder(default = "false")] + pub enable_local_indexer: bool, } impl Default for MockEngineArgs { @@ -158,6 +162,7 @@ impl MockEngineArgs { "is_prefill", "is_decode", "planner_profile_data", + "enable_local_indexer", ] .iter() .cloned() @@ -239,6 +244,12 @@ impl MockEngineArgs { builder = builder.startup_time(Some(num)); } + if let Some(value) = extra_args.get("enable_local_indexer") + && let Some(enabled) = value.as_bool() + { + builder = builder.enable_local_indexer(enabled); + } + // Parse worker type from is_prefill and is_decode flags let is_prefill = extra_args .get("is_prefill") diff --git a/lib/llm/src/mocker/scheduler.rs b/lib/llm/src/mocker/scheduler.rs index aa33e6d13b..5ea205a4cb 100644 --- a/lib/llm/src/mocker/scheduler.rs +++ b/lib/llm/src/mocker/scheduler.rs @@ -275,32 +275,17 @@ impl Scheduler { args.block_size, component, dp_rank, + args.enable_local_indexer, ); let mut hit_rates = RunningMean::new(1000); loop { // 1. Receive requests - if state.is_empty() { - // Fully idle - block until new request arrives - tokio::select! { - biased; - Some(request) = request_rx.recv() => { - state.receive(request); - } - _ = cancel_token_clone.cancelled() => { - break; - } - } - } else { - // Has active/waiting work - collect any pending requests without blocking - while let Ok(request) = request_rx.try_recv() { - state.receive(request); - } - - // Check for cancellation - if cancel_token_clone.is_cancelled() { - break; - } + if receive_requests(&mut state, &mut request_rx, &cancel_token_clone) + .await + .is_none() + { + break; } // Start timing for this forward pass (schedule + simulate) @@ -310,106 +295,30 @@ impl Scheduler { try_schedule(&mut state, &kv_manager, &mut hit_rates, &args); // 3. Simulate prefill + decode - let mut total_time = Duration::ZERO; - - // Process prefilling - while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = - state.try_prefill(&args.perf_model) - { - // NOTE: Prefill cost/time is always incremented for new blocks, even if they - // could be cached by other requests in the same batch. This matches vLLM behavior. - // For decode workers, skip adding prefill compute time - if args.worker_type != WorkerType::Decode { - total_time += Duration::from_secs_f64(prefill_compute / 1000.0); - } - - if let Some(creation_signal) = maybe_creation_signal - && !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal)) - { - panic!("Block allocation for prefilling cannot fail."); - } - - // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill - if !is_full_prefill { - break; - } - } - - // Compute decode timing - let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size; - // Compute average context length across all active decode requests - let (total_length, count) = state - .decode - .keys() - .filter_map(|uuid| state.requests.get(uuid)) - .fold((0usize, 0usize), |(sum, cnt), req| { - if let Request::Active(seq) = req { - (sum + seq.len(), cnt + 1) - } else { - (sum, cnt) - } - }); - let context_length = if count > 0 { total_length / count } else { 0 }; - let decoding_time = args - .perf_model - .predict_decode_time(active_kv_tokens, context_length); - total_time += Duration::from_secs_f64(decoding_time / 1000.0); - - state.reset_active_tokens(); - - // Process decoding - let uuids: Vec = state.decode.keys().cloned().collect(); - for uuid in uuids { - let Some(sequence) = state.run(uuid) else { - continue; - }; - let signals = sequence.generate(); - - // Process all signals with the KvManager - // Handling of preemption on failure - if !process_signals(&mut kv_manager, &signals) { - sequence.pop(); // revert the failed generation op - for signal in state.preempt() { - kv_manager.process(&signal); - } - continue; - } - - // Check completion and send notification - let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens(); - let should_output = - sequence.generated_tokens() > sequence.already_generated_tokens(); - - let mut send_failed = false; - if should_output { - send_failed = output_tx.as_ref().is_some_and(|tx| { - tx.send(OutputSignal { - uuid, - completed: is_complete, - }) - .is_err() - }); - } - - if send_failed { - for signal in &sequence.free_signal() { - kv_manager.process(signal); - } - } - - if send_failed || is_complete { - state.complete(&uuid); - continue; - } - } - - // Send metrics once per forward pass (after all prefill and decode processing) - { - let metrics = get_fwd_pass_metrics(&state, &kv_manager, &hit_rates, dp_rank); - let _ = metrics_tx.send(metrics); - } - - // 4. Sleep to maintain target iteration timing + let prefill_time = simulate_prefill( + &mut state, + &mut kv_manager, + &args.perf_model, + args.worker_type, + ); + let decode_time = simulate_decode( + &mut state, + &mut kv_manager, + &output_tx, + &args.perf_model, + args.block_size, + ); + let total_time = prefill_time + decode_time; + + // 4. Send metrics once per forward pass (after all prefill and decode processing) + let _ = metrics_tx.send(get_fwd_pass_metrics( + &state, + &kv_manager, + &hit_rates, + dp_rank, + )); + + // 5. Sleep to maintain target iteration timing let target_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio); let elapsed = iteration_start.elapsed(); @@ -441,6 +350,148 @@ impl Scheduler { } } +/// Receive requests from the channel. +/// Returns `Some(())` to continue the loop, `None` to break (on cancellation). +async fn receive_requests( + state: &mut SchedulerState, + request_rx: &mut mpsc::UnboundedReceiver, + cancel_token: &CancellationToken, +) -> Option<()> { + if cancel_token.is_cancelled() { + return None; + } + + if state.is_empty() { + // Fully idle - block until new request arrives + tokio::select! { + biased; + _ = cancel_token.cancelled() => { + return None; + } + Some(request) = request_rx.recv() => { + state.receive(request); + return Some(()); + } + } + } + + // Has active/waiting work - collect any pending requests without blocking + while let Ok(request) = request_rx.try_recv() { + state.receive(request); + } + + Some(()) +} + +/// Simulate prefill phase for all pending prefill requests. +/// Returns the total prefill compute time. +fn simulate_prefill( + state: &mut SchedulerState, + kv_manager: &mut KvManager, + perf_model: &PerfModel, + worker_type: WorkerType, +) -> Duration { + let mut total_time = Duration::ZERO; + + while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = + state.try_prefill(perf_model) + { + // NOTE: Prefill cost/time is always incremented for new blocks, even if they + // could be cached by other requests in the same batch. This matches vLLM behavior. + // For decode workers, skip adding prefill compute time + if worker_type != WorkerType::Decode { + total_time += Duration::from_secs_f64(prefill_compute / 1000.0); + } + + if let Some(creation_signal) = maybe_creation_signal + && !process_signals(kv_manager, std::slice::from_ref(&creation_signal)) + { + panic!("Block allocation for prefilling cannot fail."); + } + + // Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill + if !is_full_prefill { + break; + } + } + + total_time +} + +/// Simulate decode phase for all active decode requests. +/// Returns the total decode compute time. +fn simulate_decode( + state: &mut SchedulerState, + kv_manager: &mut KvManager, + output_tx: &Option>, + perf_model: &PerfModel, + block_size: usize, +) -> Duration { + // Compute decode timing + let active_kv_tokens = kv_manager.num_active_blocks() * block_size; + // Compute average context length across all active decode requests + let (total_length, count) = state + .decode + .keys() + .filter_map(|uuid| state.requests.get(uuid)) + .fold((0usize, 0usize), |(sum, cnt), req| { + if let Request::Active(seq) = req { + (sum + seq.len(), cnt + 1) + } else { + (sum, cnt) + } + }); + let context_length = if count > 0 { total_length / count } else { 0 }; + let decoding_time = perf_model.predict_decode_time(active_kv_tokens, context_length); + let total_time = Duration::from_secs_f64(decoding_time / 1000.0); + + state.reset_active_tokens(); + + // Process decoding + let uuids: Vec = state.decode.keys().cloned().collect(); + for uuid in uuids { + let Some(sequence) = state.run(uuid) else { + continue; + }; + let signals = sequence.generate(); + + // Process all signals with the KvManager + // Handling of preemption on failure + if !process_signals(kv_manager, &signals) { + sequence.pop(); // revert the failed generation op + for signal in state.preempt() { + kv_manager.process(&signal); + } + continue; + } + + // Check completion and send notification + let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens(); + let should_output = sequence.generated_tokens() > sequence.already_generated_tokens(); + + let send_failed = should_output + && output_tx.as_ref().is_some_and(|tx| { + tx.send(OutputSignal { + uuid, + completed: is_complete, + }) + .is_err() + }); + + if send_failed { + for signal in &sequence.free_signal() { + kv_manager.process(signal); + } + } + + if send_failed || is_complete { + state.complete(&uuid); + } + } + + total_time +} + /// Calculate forward pass metrics from current state fn get_fwd_pass_metrics( state: &SchedulerState, diff --git a/lib/llm/src/model_card.rs b/lib/llm/src/model_card.rs index a29267e43a..77fc780f6c 100644 --- a/lib/llm/src/model_card.rs +++ b/lib/llm/src/model_card.rs @@ -385,6 +385,15 @@ impl ModelDeploymentCard { return Ok(()); } + // For TensorBased models, config files are not used - they handle everything in the backend + if self.model_type.supports_tensor() { + tracing::debug!( + display_name = %self.display_name, + "Skipping config download for TensorBased model" + ); + return Ok(()); + } + let ignore_weights = true; let local_path = crate::hub::from_hf(&self.display_name, ignore_weights).await?; diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 8dd3dcb947..b7bfd8ac4c 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -28,7 +28,7 @@ use tracing; use crate::model_card::{ModelDeploymentCard, ModelInfo}; #[cfg(feature = "media-nixl")] -use crate::preprocessor::media::{MediaDecoder, MediaFetcher, MediaLoader}; +use crate::preprocessor::media::MediaLoader; use crate::preprocessor::prompt::OAIChatLikeRequest; use crate::protocols::common::preprocessor::{ MultimodalData, MultimodalDataMap, PreprocessedRequestBuilder, @@ -71,6 +71,7 @@ pub struct LLMMetricAnnotation { pub input_tokens: usize, pub output_tokens: usize, pub chunk_tokens: usize, + pub cached_tokens: Option, } impl LLMMetricAnnotation { @@ -646,6 +647,7 @@ impl OpenAIPreprocessor { input_tokens: isl, output_tokens: current_osl, chunk_tokens, + cached_tokens: None, }; if let Ok(metrics_annotated) = llm_metrics.to_annotation::<()>() { @@ -673,20 +675,39 @@ impl OpenAIPreprocessor { // again. The stream is exhausted and will panic if polled after None. inner.finished = true; - // Check if we need to send a usage chunk - if inner.response_generator.is_usage_enabled() - && inner.finish_reason_sent - && !inner.usage_chunk_sent - { + if inner.finish_reason_sent && !inner.usage_chunk_sent { inner.usage_chunk_sent = true; - // Create the final usage chunk let usage_chunk = inner.response_generator.create_usage_chunk(); + let usage = inner.response_generator.get_usage(); + let llm_metrics = LLMMetricAnnotation { + input_tokens: usage.prompt_tokens as usize, + output_tokens: usage.completion_tokens as usize, + chunk_tokens: 0, + cached_tokens: usage + .prompt_tokens_details + .as_ref() + .and_then(|d| d.cached_tokens.map(|c| c as usize)), + }; + + // Create annotation string + let annotation = llm_metrics.to_annotation::<()>().unwrap_or_else(|e| { + tracing::warn!("Failed to serialize metrics: {}", e); + Annotated::<()>::from_data(()) + }); + + // Send the usage chunk if needed + let data = if inner.response_generator.is_usage_enabled() { + Some(usage_chunk) + } else { + None + }; + let annotated_usage = Annotated:: { id: None, - data: Some(usage_chunk), - event: None, - comment: None, + data, + event: Some(ANNOTATION_LLM_METRICS.to_string()), + comment: annotation.comment, }; tracing::trace!( @@ -752,25 +773,17 @@ impl OpenAIPreprocessor { has_tools: bool, ) -> std::result::Result { match (tool_call_parser, tool_choice, has_tools) { - // No parser but tools requested - error cases - (None, Some(ChatCompletionToolChoiceOption::Required), true) => { - tracing::warn!( - "Tool choice 'required' specified but no tool parser configured; proceeding without jailing" - ); - Ok(false) - } + // tool_choice=required/named work without parser (use Immediate jail mode) + (None, Some(ChatCompletionToolChoiceOption::Required), true) => Ok(true), + (None, Some(ChatCompletionToolChoiceOption::Named(_)), true) => Ok(true), + + // tool_choice=auto requires a parser (None, Some(ChatCompletionToolChoiceOption::Auto), true) => { tracing::warn!( "Tool choice 'auto' specified but no tool parser configured; proceeding without jailing" ); Ok(false) } - (None, Some(ChatCompletionToolChoiceOption::Named(_)), _) => { - tracing::warn!( - "Named tool choice specified but no tool parser configured; proceeding without jailing" - ); - Ok(false) - } // Parser exists and tools might be called (Some(_), Some(ChatCompletionToolChoiceOption::None), _) => { @@ -786,15 +799,38 @@ impl OpenAIPreprocessor { /// Apply tool calling jail to the stream if needed pub fn apply_tool_calling_jail( - tool_call_parser: String, + tool_call_parser: Option, + tool_choice: Option, stream: S, ) -> impl Stream> + Send where S: Stream> + Send + 'static, { - let jail = JailedStream::builder() - .tool_call_parser(tool_call_parser) - .build(); + use dynamo_async_openai::types::ChatCompletionToolChoiceOption; + + let mut builder = JailedStream::builder(); + + // Configure jail based on tool_choice + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(named)) => { + // Immediate jail mode for named tool choice + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + // Immediate jail mode for required tool choice + builder = builder.tool_choice_required(); + } + Some(ChatCompletionToolChoiceOption::Auto) + | Some(ChatCompletionToolChoiceOption::None) + | None => { + // Traditional marker-based jail for auto/none/unspecified + if let Some(parser) = tool_call_parser { + builder = builder.tool_call_parser(parser); + } + } + } + + let jail = builder.build(); jail.apply_with_finish_reason(stream) } @@ -957,11 +993,11 @@ impl // Apply jail conditionally let transformed_stream: Pin + Send>> = if should_jail { - if let Some(parser) = self.tool_call_parser.clone() { - Box::pin(Self::apply_tool_calling_jail(parser, stream)) - } else { - Box::pin(stream) // Should not happen due to should_jail check - } + Box::pin(Self::apply_tool_calling_jail( + self.tool_call_parser.clone(), + request.inner.tool_choice.clone(), + stream, + )) } else { Box::pin(stream) }; diff --git a/lib/llm/src/preprocessor/media/README.md b/lib/llm/src/preprocessor/media/README.md index fede33bc9f..3db2bc5faf 100644 --- a/lib/llm/src/preprocessor/media/README.md +++ b/lib/llm/src/preprocessor/media/README.md @@ -38,6 +38,14 @@ register_llm( ``` +## Known Limitations + +> [!WARNING] +> **Incompatible with `Dockerfile.frontend`**: Frontend media decoding (enabled with `--features media-nixl`) is not supported when using `Dockerfile.frontend`. The frontend image built from `Dockerfile.frontend` does not enable the feature + include the required NIXL/UCX dependencies. + +> [!WARNING] +> **Requires GPU node**: The frontend must run on a node with GPU access. During media processing, decoded tensors are written to GPU memory via NIXL, which requires `libcuda.so.1` to be available. Running the frontend on a CPU-only node will fail with something like: `Failed to initialize required backends: [UCX: No UCX plugin found]`. + ## TODOs ### Modalities diff --git a/lib/llm/src/preprocessor/media/loader.rs b/lib/llm/src/preprocessor/media/loader.rs index 0d229d7437..df96f43d50 100644 --- a/lib/llm/src/preprocessor/media/loader.rs +++ b/lib/llm/src/preprocessor/media/loader.rs @@ -166,7 +166,7 @@ mod tests { ..Default::default() }; - let loader: MediaLoader = MediaLoader::new(media_decoder, fetcher).unwrap(); + let loader: MediaLoader = MediaLoader::new(media_decoder, Some(fetcher)).unwrap(); let image_url = ImageUrl::from(format!("{}/llm-optimize-deploy-graphic.png", server.url())); let content_part = ChatCompletionRequestUserMessageContentPart::ImageUrl( diff --git a/lib/llm/src/protocols/common.rs b/lib/llm/src/protocols/common.rs index 76d35f9fa8..cef7b8d857 100644 --- a/lib/llm/src/protocols/common.rs +++ b/lib/llm/src/protocols/common.rs @@ -22,6 +22,7 @@ use super::TokenIdType; pub mod llm_backend; pub mod postprocessor; pub mod preprocessor; +pub mod timing; /// SamplingOptionsProvider is a trait that allows the caller to extract the sampling options from /// the object that implements it. This will mutate the object. @@ -254,7 +255,6 @@ pub struct StopConditions { impl StopConditions { pub fn apply_ignore_eos(&mut self) { if self.ignore_eos.unwrap_or(false) { - self.min_tokens = self.max_tokens; self.stop = None; self.stop_token_ids_hidden = None; } diff --git a/lib/llm/src/protocols/common/timing.rs b/lib/llm/src/protocols/common/timing.rs new file mode 100644 index 0000000000..7e36c60428 --- /dev/null +++ b/lib/llm/src/protocols/common/timing.rs @@ -0,0 +1,109 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Per-request timing tracker for capturing request lifecycle metrics. +//! +//! This module provides [`RequestTimingTracker`] for tracking timing information +//! that can be returned to clients via the `nvext` response field. + +use serde::{Deserialize, Serialize}; +use std::sync::OnceLock; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; + +/// Per-request timing tracker. +/// +/// Captures timing information throughout the request lifecycle: +/// - `request_received`: When the request was received +/// - `first_token_time`: When the first token was generated (set once via OnceLock) +/// - `request_finish_time`: When the request finished (set once via OnceLock) +/// +/// The `OnceLock` fields ensure that timing values are set exactly once, +/// which is important for disaggregated serving where the "first token" +/// might appear multiple times. +pub struct RequestTimingTracker { + /// When the request was received (monotonic clock for duration calculations) + request_received: Instant, + + /// When the request was received (wall clock time as epoch milliseconds) + request_received_epoch_ms: u64, + + /// When the first token was generated - set once via OnceLock + first_token_time: OnceLock, + + /// When the request finished - set once via OnceLock + request_finish_time: OnceLock, +} + +impl RequestTimingTracker { + /// Create a new timing tracker, capturing the current time as request received. + pub fn new() -> Self { + let now = Instant::now(); + let epoch_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + + RequestTimingTracker { + request_received: now, + request_received_epoch_ms: epoch_ms, + first_token_time: OnceLock::new(), + request_finish_time: OnceLock::new(), + } + } + + pub fn record_first_token(&self) -> bool { + self.first_token_time.set(Instant::now()).is_ok() + } + + pub fn record_finish(&self) -> bool { + self.request_finish_time.set(Instant::now()).is_ok() + } + + pub fn ttft_ms(&self) -> Option { + self.first_token_time + .get() + .map(|t| t.duration_since(self.request_received).as_secs_f64() * 1000.0) + } + + pub fn total_time_ms(&self) -> Option { + self.request_finish_time + .get() + .map(|t| t.duration_since(self.request_received).as_secs_f64() * 1000.0) + } + + pub fn request_received_epoch_ms(&self) -> u64 { + self.request_received_epoch_ms + } + + pub fn get_timing_info(&self) -> TimingInfo { + TimingInfo { + request_received_ms: self.request_received_epoch_ms, + ttft_ms: self.ttft_ms(), + total_time_ms: self.total_time_ms(), + } + } +} + +impl Default for RequestTimingTracker { + fn default() -> Self { + Self::new() + } +} + +/// Timing information for response injection. +/// +/// This struct is serialized and included in the response's `nvext` field +/// when the client requests timing information via `extra_fields: ["timing"]`. +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct TimingInfo { + /// When the request was received (epoch milliseconds) + pub request_received_ms: u64, + + /// Time to first token in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub ttft_ms: Option, + + /// Total request time in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub total_time_ms: Option, +} diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 574d7db31c..a67ff0a719 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -17,6 +17,7 @@ pub mod embeddings; pub mod models; pub mod nvext; pub mod responses; +pub mod tools; pub mod validate; use validate::{ @@ -131,7 +132,7 @@ impl SamplingOptionsProvid let guided_whitespace_pattern = self.get_guided_whitespace_pattern(); let guided_decoding = match common::GuidedDecodingOptions::from_optional( - guided_json.cloned(), + guided_json, guided_regex, guided_choice, guided_grammar, @@ -224,6 +225,9 @@ pub trait DeltaGeneratorExt: /// Check if usage tracking is enabled. fn is_usage_enabled(&self) -> bool; + + /// Get the current usage statistics with properly calculated total_tokens. + fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage; } #[derive(Clone, Debug, Serialize, Deserialize, Default)] diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index bb712086ec..cab55032cb 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -12,7 +12,7 @@ use super::{ common_ext::{CommonExt, CommonExtProvider}, nvext::NvExt, nvext::NvExtProvider, - validate, + tools, validate, }; pub mod aggregator; @@ -159,8 +159,24 @@ impl CommonExtProvider for NvCreateChatCompletionRequest { } /// Guided Decoding Options - fn get_guided_json(&self) -> Option<&serde_json::Value> { - self.common.guided_json.as_ref() + fn get_guided_json(&self) -> Option { + if let Some(value) = self.common.guided_json.clone() { + return Some(value); + } + + let tool_choice = self.inner.tool_choice.as_ref()?; + let tools = self.inner.tools.as_deref()?; + + match tools::get_json_schema_from_tools(Some(tool_choice), Some(tools)) { + Ok(schema) => schema, + Err(err) => { + tracing::warn!( + error = %err, + "failed to derive guided_json from tool_choice" + ); + None + } + } } fn get_guided_regex(&self) -> Option { diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 186bb7f095..51fbdf966b 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -4,7 +4,10 @@ use super::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}; use crate::{ local_model::runtime_config::ModelRuntimeConfig, - protocols::common::{self}, + protocols::{ + common::{self, timing::RequestTimingTracker}, + openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo}, + }, types::TokenIdType, }; @@ -41,6 +44,12 @@ impl NvCreateChatCompletionRequest { /// # Returns /// * [`DeltaGenerator`] configured with model name and response options. pub fn response_generator(&self, request_id: String) -> DeltaGenerator { + // Check if client requested timing in extra_fields + let enable_timing = self + .nvext() + .and_then(|nv| nv.extra_fields.as_ref()) + .is_some_and(|fields| fields.iter().any(|f| f == "timing")); + let options = DeltaGeneratorOptions { enable_usage: self .inner @@ -50,6 +59,7 @@ impl NvCreateChatCompletionRequest { .unwrap_or(false), enable_logprobs: self.inner.logprobs.unwrap_or(false) || self.inner.top_logprobs.unwrap_or(0) > 0, + enable_timing, runtime_config: ModelRuntimeConfig::default(), }; @@ -64,12 +74,13 @@ pub struct DeltaGeneratorOptions { pub enable_usage: bool, /// Determines whether log probabilities should be included in the response. pub enable_logprobs: bool, + /// Determines whether timing information should be included in the response's nvext. + pub enable_timing: bool, pub runtime_config: ModelRuntimeConfig, } /// Generates incremental chat completion responses in a streaming fashion. -#[derive(Debug)] pub struct DeltaGenerator { /// Unique identifier for the chat completion session. id: String, @@ -88,6 +99,8 @@ pub struct DeltaGenerator { msg_counter: u64, /// Configuration options for response generation. options: DeltaGeneratorOptions, + /// Optional timing tracker for per-request timing metrics. + timing_tracker: Option, } impl DeltaGenerator { @@ -120,6 +133,13 @@ impl DeltaGenerator { let chatcmpl_id = format!("chatcmpl-{request_id}"); + // Create timing tracker if timing is enabled + let timing_tracker = if options.enable_timing { + Some(RequestTimingTracker::new()) + } else { + None + }; + Self { id: chatcmpl_id, object: "chat.completion.chunk".to_string(), @@ -130,6 +150,7 @@ impl DeltaGenerator { usage, msg_counter: 0, options, + timing_tracker, } } @@ -268,8 +289,7 @@ impl DeltaGenerator { /// # Returns /// * A [`CreateChatCompletionStreamResponse`] with empty choices and usage stats. pub fn create_usage_chunk(&self) -> NvCreateChatCompletionStreamResponse { - let mut usage = self.usage.clone(); - usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens); + let usage = self.get_usage(); dynamo_async_openai::types::CreateChatCompletionStreamResponse { id: self.id.clone(), @@ -288,6 +308,12 @@ impl DeltaGenerator { pub fn is_usage_enabled(&self) -> bool { self.options.enable_usage } + + pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage { + let mut usage = self.usage.clone(); + usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens); + usage + } } /// Implements the [`crate::protocols::openai::DeltaGeneratorExt`] trait for [`DeltaGenerator`], allowing @@ -307,27 +333,25 @@ impl crate::protocols::openai::DeltaGeneratorExt anyhow::Result { - // Aggregate token usage if enabled. - if self.options.enable_usage { - // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`, - // but this will not be an issue until context lengths exceed 4_294_967_295. - let token_length: u32 = delta - .token_ids - .len() - .try_into() - .expect("token_ids length exceeds u32::MAX"); - - self.usage.completion_tokens += token_length; - - // If backend provides completion_usage with prompt token details, - // propagate the entire details struct to usage tracking - if let Some(prompt_details) = delta - .completion_usage - .as_ref() - .and_then(|usage| usage.prompt_tokens_details.as_ref()) - { - self.usage.prompt_tokens_details = Some(prompt_details.clone()); - } + // Aggregate token usage even if usage tracking is disabled for metrics tracking + // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`, + // but this will not be an issue until context lengths exceed 4_294_967_295. + let token_length: u32 = delta + .token_ids + .len() + .try_into() + .expect("token_ids length exceeds u32::MAX"); + + self.usage.completion_tokens += token_length; + + // If backend provides completion_usage with prompt token details, + // propagate the entire details struct to usage tracking + if let Some(prompt_details) = delta + .completion_usage + .as_ref() + .and_then(|usage| usage.prompt_tokens_details.as_ref()) + { + self.usage.prompt_tokens_details = Some(prompt_details.clone()); } let logprobs = self.create_logprobs( @@ -362,37 +386,44 @@ impl crate::protocols::openai::DeltaGeneratorExt(v.clone()).ok()); + + // Get timing info if this is the final response (has finish_reason) + let timing_info: Option = if finish_reason.is_some() { + self.timing_tracker.as_ref().map(|tracker| { + tracker.record_finish(); + tracker.get_timing_info() + }) + } else { + None + }; + // Inject nvext if we have worker_id or timing + if worker_id_info.is_some() || timing_info.is_some() { let nvext_response = NvExtResponse { - worker_id: Some(worker_id_info), + worker_id: worker_id_info.clone(), + timing: timing_info, }; if let Ok(nvext_json) = serde_json::to_value(&nvext_response) { stream_response.nvext = Some(nvext_json); - tracing::debug!( - "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}", - prefill_worker_id, - decode_worker_id - ); + if let Some(ref info) = worker_id_info { + tracing::debug!( + "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}", + info.prefill_worker_id, + info.decode_worker_id + ); + } } } @@ -410,6 +441,10 @@ impl crate::protocols::openai::DeltaGeneratorExt bool { DeltaGenerator::is_usage_enabled(self) } + + fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage { + DeltaGenerator::get_usage(self) + } } #[cfg(test)] diff --git a/lib/llm/src/protocols/openai/chat_completions/jail.rs b/lib/llm/src/protocols/openai/chat_completions/jail.rs index 0057f4993b..2a716cc743 100644 --- a/lib/llm/src/protocols/openai/chat_completions/jail.rs +++ b/lib/llm/src/protocols/openai/chat_completions/jail.rs @@ -14,6 +14,7 @@ use dynamo_parsers::tool_calling::{ use dynamo_runtime::protocols::annotated::Annotated; use futures::{Stream, StreamExt}; use std::collections::HashMap; +use uuid::Uuid; use crate::utils::{MarkerMatcher, MatchResult}; @@ -62,6 +63,24 @@ pub struct JailConfig<'a> { pub tool_call_parser: Option<&'a str>, } +/// Jail activation mode +#[derive(Debug, Clone, PartialEq)] +pub enum JailMode { + /// Traditional: wait for start marker, then jail + MarkerBased, + /// Immediate: start jailed from first token (for tool_choice) + Immediate { format: ToolChoiceFormat }, +} + +/// Format for tool_choice immediate jail mode +#[derive(Debug, Clone, PartialEq)] +pub enum ToolChoiceFormat { + /// tool_choice=named: expect single object {"location": "Paris", ...} + SingleObject { tool_name: String }, + /// tool_choice=required: expect array [{name:"search", parameters:{...}}, ...] + ArrayOfTools, +} + /// State tracking for an individual choice during jail processing #[derive(Debug, Clone)] struct ChoiceJailState { @@ -75,6 +94,8 @@ struct ChoiceJailState { partial_match_buffer: String, /// Stream finish reason stream_finish_reason: Option, + /// Number of tool calls already emitted for this choice + emitted_tool_calls_count: usize, } fn create_choice_stream( @@ -103,13 +124,14 @@ fn create_choice_stream( impl ChoiceJailState { /// Create a new jail state for a choice - fn new(index: u32) -> Self { + fn new(index: u32, starts_jailed: bool) -> Self { Self { index, - is_jailed: false, + is_jailed: starts_jailed, accumulated_content: String::new(), partial_match_buffer: String::new(), stream_finish_reason: None, + emitted_tool_calls_count: 0, } } @@ -178,10 +200,18 @@ impl ChoiceJailState { // Create the tool call choice let tool_choice = jail_stream - .create_tool_call_choice(choice.index, jailed_part, choice) + .create_tool_call_choice( + choice.index, + jailed_part, + choice, + self.emitted_tool_calls_count, + ) .await; if tool_choice.delta.tool_calls.is_some() { + if let Some(ref tool_calls) = tool_choice.delta.tool_calls { + self.emitted_tool_calls_count += tool_calls.len(); + } emissions.push(ChoiceEmission::ToolCall(tool_choice)); } else { emissions.push(ChoiceEmission::Content(tool_choice)); @@ -297,11 +327,19 @@ impl ChoiceJailState { // Create the unjailed choice let unjailed_choice = jail_stream - .create_tool_call_choice(choice.index, jailed_part, choice) + .create_tool_call_choice( + choice.index, + jailed_part, + choice, + self.emitted_tool_calls_count, + ) .await; // Determine emission type based on whether tool calls were parsed if unjailed_choice.delta.tool_calls.is_some() { + if let Some(ref tool_calls) = unjailed_choice.delta.tool_calls { + self.emitted_tool_calls_count += tool_calls.len(); + } emissions.push(ChoiceEmission::ToolCall(unjailed_choice)); } else { emissions.push(ChoiceEmission::Content(unjailed_choice)); @@ -349,9 +387,18 @@ impl ChoiceJailState { ); let final_choice = jail_stream - .create_tool_call_choice(self.index, &self.accumulated_content, &dummy_choice) + .create_tool_call_choice( + self.index, + &self.accumulated_content, + &dummy_choice, + self.emitted_tool_calls_count, + ) .await; + if let Some(ref tool_calls) = final_choice.delta.tool_calls { + self.emitted_tool_calls_count += tool_calls.len(); + } + // End jailing self.end_jail(); @@ -381,7 +428,7 @@ impl ChoiceJailStateCollection { } /// Get or create state for a choice index - fn get_or_create_state(&mut self, index: u32) -> &mut ChoiceJailState { + fn get_or_create_state(&mut self, index: u32, starts_jailed: bool) -> &mut ChoiceJailState { // Find the position where this index should be match self.states.binary_search_by_key(&index, |s| s.index) { Ok(pos) => { @@ -390,7 +437,7 @@ impl ChoiceJailStateCollection { } Err(insert_pos) => { // Need to create new state - let new_state = ChoiceJailState::new(index); + let new_state = ChoiceJailState::new(index, starts_jailed); self.states.insert(insert_pos, new_state); &mut self.states[insert_pos] } @@ -399,20 +446,15 @@ impl ChoiceJailStateCollection { } /// Emission mode for handling multiple choices -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum EmissionMode { /// Pack multiple choices in the same chunk (default, matches original behavior) + #[default] Packed, /// Emit one choice per chunk for OpenAI compatibility SingleChoicePerChunk, } -impl Default for EmissionMode { - fn default() -> Self { - Self::Packed - } -} - /// A stream transformer that can "jail" tokens based on configurable start/end sequences /// When jailed, tokens are accumulated rather than yielded immediately /// When the jail ends (via end sequence or stream completion), accumulated content is processed and released @@ -422,6 +464,7 @@ pub struct JailedStream { tool_call_parser: Option, emission_mode: EmissionMode, marker_matcher: MarkerMatcher, + jail_mode: JailMode, } impl JailedStream { @@ -439,8 +482,9 @@ impl JailedStream { where S: Stream> + Send + 'static, { + let jail_mode = self.jail_mode.clone(); let jailed_stream = self.apply(stream); - JailedStream::fix_finish_reason(jailed_stream) + JailedStream::fix_finish_reason(jailed_stream, jail_mode) } /// Apply the jail transformation to a stream of chat completion responses @@ -480,7 +524,8 @@ impl JailedStream { // Process each choice independently using the new architecture for choice in &chat_response.choices { if let Some(ref content) = choice.delta.content { - let choice_state = choice_states.get_or_create_state(choice.index); + let starts_jailed = matches!(self.jail_mode, JailMode::Immediate { .. }); + let choice_state = choice_states.get_or_create_state(choice.index, starts_jailed); // Store metadata when any choice becomes jailed (first time only) if !choice_state.is_jailed && self.should_start_jail(content) @@ -498,14 +543,24 @@ impl JailedStream { all_emissions.extend(emissions); } else { // Handle choices without content (e.g., final chunks with finish_reason) - // These should always pass through - let pass_through_choice = ChatChoiceStream { - index: choice.index, - delta: choice.delta.clone(), - finish_reason: choice.finish_reason, - logprobs: choice.logprobs.clone(), - }; - all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); + // Only filter out if this choice was ever jailed and lacks role + // (to avoid aggregator issues with deltas missing role after unjail) + let choice_state = choice_states.get_or_create_state(choice.index, false); + let was_ever_jailed = !choice_state.accumulated_content.is_empty() || choice_state.is_jailed; + + let should_emit = choice.delta.role.is_some() + || choice.delta.tool_calls.is_some() + || !was_ever_jailed; // Always pass through if never jailed + + if should_emit { + let pass_through_choice = ChatChoiceStream { + index: choice.index, + delta: choice.delta.clone(), + finish_reason: choice.finish_reason, + logprobs: choice.logprobs.clone(), + }; + all_emissions.push(ChoiceEmission::PassThrough(pass_through_choice)); + } } } @@ -673,38 +728,69 @@ impl JailedStream { /// Check if accumulated content should end jail async fn should_end_jail(&self, accumulated_content: &str) -> (bool, usize) { - // Path 1: End sequence detected - let end_marker_info = if !self.jail_end_sequences.is_empty() { - self.jail_end_sequences.iter().find_map(|seq| { - accumulated_content - .find(seq) - .map(|pos| (pos + seq.len(), seq.clone())) - }) - } else { - None - }; + match &self.jail_mode { + JailMode::MarkerBased => { + // Path 1: End sequence detected + let end_marker_info = if !self.jail_end_sequences.is_empty() { + self.jail_end_sequences.iter().find_map(|seq| { + accumulated_content + .find(seq) + .map(|pos| (pos + seq.len(), seq.clone())) + }) + } else { + None + }; - // Path 2: Complete tool call(s) can be parsed (early exit) - let early_exit = self.should_exit_jail_early(accumulated_content).await; + // Path 2: Complete tool call(s) can be parsed (early exit) + let early_exit = self.should_exit_jail_early(accumulated_content).await; - if let Some((end_pos, _)) = end_marker_info { - (true, end_pos) - } else if early_exit { - // For early exit, find where the complete tool call ends - if let Some(parser) = &self.tool_call_parser { - if let Ok((_, _)) = - try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await - { - let split_pos = find_tool_call_end_position(accumulated_content, Some(parser)); - (true, split_pos) + if let Some((end_pos, _)) = end_marker_info { + (true, end_pos) + } else if early_exit { + // For early exit, find where the complete tool call ends + if let Some(parser) = &self.tool_call_parser { + if let Ok((_, _)) = + try_tool_call_parse_aggregate(accumulated_content, Some(parser)).await + { + let split_pos = + find_tool_call_end_position(accumulated_content, Some(parser)); + (true, split_pos) + } else { + (false, accumulated_content.len()) + } + } else { + (false, accumulated_content.len()) + } } else { (false, accumulated_content.len()) } - } else { - (false, accumulated_content.len()) } - } else { - (false, accumulated_content.len()) + JailMode::Immediate { format } => { + // For tool_choice, check if we have valid complete JSON + match format { + ToolChoiceFormat::SingleObject { .. } => { + // Expect single object: {"location": "Paris", "unit": "celsius"} + if let Ok(value) = + serde_json::from_str::(accumulated_content) + && value.is_object() + { + return (true, accumulated_content.len()); + } + (false, accumulated_content.len()) + } + ToolChoiceFormat::ArrayOfTools => { + // Expect array: [{"name":"search","parameters":{...}}, ...] + if let Ok(value) = + serde_json::from_str::(accumulated_content) + && let Some(arr) = value.as_array() + && !arr.is_empty() + { + return (true, accumulated_content.len()); + } + (false, accumulated_content.len()) + } + } + } } } @@ -714,47 +800,138 @@ impl JailedStream { choice_index: u32, accumulated_content: &str, base_choice: &ChatChoiceStream, + tool_call_offset: usize, ) -> ChatChoiceStream { - if let Ok((tool_calls, normal_text)) = - try_tool_call_parse_aggregate(accumulated_content, self.tool_call_parser.as_deref()) + match &self.jail_mode { + JailMode::MarkerBased => { + // Traditional marker-based tool call parsing + if let Ok((tool_calls, normal_text)) = try_tool_call_parse_aggregate( + accumulated_content, + self.tool_call_parser.as_deref(), + ) .await - && !tool_calls.is_empty() - { - // Convert to streaming format - let tool_call_chunks: Vec = tool_calls - .into_iter() - .enumerate() - .map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk { - index: idx as u32, - id: Some(tool_call.id), - r#type: Some(tool_call.r#type), - function: Some(FunctionCallStream { - name: Some(tool_call.function.name), - arguments: Some(tool_call.function.arguments), - }), - }) - .collect(); - // Create choice with tool calls - let choice = create_choice_stream( - choice_index, - Some(Role::Assistant), - normal_text.as_deref().unwrap_or(""), - Some(tool_call_chunks), - None, - None, - ); - return choice; + && !tool_calls.is_empty() + { + // Convert to streaming format + let tool_call_chunks: Vec = tool_calls + .into_iter() + .enumerate() + .map(|(idx, tool_call)| ChatCompletionMessageToolCallChunk { + index: (tool_call_offset + idx) as u32, + id: Some(tool_call.id), + r#type: Some(tool_call.r#type), + function: Some(FunctionCallStream { + name: Some(tool_call.function.name), + arguments: Some(tool_call.function.arguments), + }), + }) + .collect(); + // Create choice with tool calls + let choice = create_choice_stream( + choice_index, + Some(Role::Assistant), + normal_text.as_deref().unwrap_or(""), + Some(tool_call_chunks), + None, + None, + ); + return choice; + } + + // No tool calls found or parsing failed, return content choice + create_choice_stream( + choice_index, + Some(Role::Assistant), + accumulated_content, + None, + base_choice.finish_reason, + base_choice.logprobs.clone(), + ) + } + JailMode::Immediate { format } => { + // tool_choice mode: parse JSON and convert to tool calls + match self.parse_tool_choice_json(accumulated_content, format) { + Ok(tool_call_chunks) if !tool_call_chunks.is_empty() => create_choice_stream( + choice_index, + Some(Role::Assistant), + "", + Some(tool_call_chunks), + base_choice.finish_reason, + base_choice.logprobs.clone(), + ), + Ok(_) | Err(_) => { + // Parsing failed, return as content + create_choice_stream( + choice_index, + Some(Role::Assistant), + accumulated_content, + None, + base_choice.finish_reason, + base_choice.logprobs.clone(), + ) + } + } + } } + } - // No tool calls found or parsing failed, return content choice - create_choice_stream( - choice_index, - Some(Role::Assistant), - accumulated_content, - None, - base_choice.finish_reason, - base_choice.logprobs.clone(), - ) + /// Helper to create a ChatCompletionMessageToolCallChunk + fn create_tool_call_chunk( + index: u32, + name: String, + arguments: String, + ) -> ChatCompletionMessageToolCallChunk { + ChatCompletionMessageToolCallChunk { + index, + id: Some(format!("call-{}", Uuid::new_v4())), + r#type: Some(dynamo_async_openai::types::ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(name), + arguments: Some(arguments), + }), + } + } + + /// Parse tool_choice JSON output into tool call chunks + fn parse_tool_choice_json( + &self, + json_content: &str, + format: &ToolChoiceFormat, + ) -> anyhow::Result> { + let parsed = serde_json::from_str::(json_content)?; + + match format { + ToolChoiceFormat::SingleObject { tool_name } => { + // For named tool choice: JSON is the parameters object + if parsed.is_object() { + Ok(vec![Self::create_tool_call_chunk( + 0, + tool_name.clone(), + json_content.to_string(), + )]) + } else { + Ok(vec![]) + } + } + ToolChoiceFormat::ArrayOfTools => { + // For required tool choice: JSON is array of {name, parameters} + if let Some(array) = parsed.as_array() { + let chunks: Vec = array + .iter() + .enumerate() + .filter_map(|(idx, entry)| { + let name = entry.get("name")?.as_str()?.to_string(); + let parameters = entry.get("parameters")?; + let args = serde_json::to_string(parameters).ok()?; + Some(Self::create_tool_call_chunk(idx as u32, name, args)) + }) + .collect(); + Ok(chunks) + } else { + Ok(vec![]) + } + } + } } /// Check if accumulated content contains complete tool calls that can be parsed @@ -775,8 +952,9 @@ impl JailedStream { /// Post-processor that sets finish_reason to ToolCalls when tool calls were emitted /// This should be called after apply() to fix the finish_reason for tool call chunks - pub fn fix_finish_reason( + fn fix_finish_reason( input_stream: S, + jail_mode: JailMode, ) -> impl Stream> + Send where S: Stream> + Send + 'static, @@ -795,13 +973,39 @@ impl JailedStream { } } - // If this chunk has finish_reason and the choice had tool calls, override to ToolCalls + // Fix finish_reason based on jail mode and whether tool calls were emitted if let Some(ref mut data) = response.data { for choice in &mut data.choices { - if choice.finish_reason.is_some() && choice.finish_reason == Some(FinishReason::Stop) - && has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false) - { - choice.finish_reason = Some(FinishReason::ToolCalls); + if let Some(finish) = choice.finish_reason { + // Only modify Stop finish reason, preserve Length/ContentFilter + if finish == FinishReason::Stop { + let has_tool_calls = has_tool_calls_per_choice.get(&choice.index).copied().unwrap_or(false); + + match &jail_mode { + JailMode::MarkerBased => { + // Traditional: if tool calls emitted, change to ToolCalls + if has_tool_calls { + choice.finish_reason = Some(FinishReason::ToolCalls); + } + } + JailMode::Immediate { format } => { + // tool_choice mode: apply specific finish_reason logic + match format { + ToolChoiceFormat::SingleObject { .. } => { + // Named tool choice: keep Stop + // (already Stop, no change needed) + } + ToolChoiceFormat::ArrayOfTools => { + // Required tool choice: change to ToolCalls + if has_tool_calls { + choice.finish_reason = Some(FinishReason::ToolCalls); + } + } + } + } + } + } + // Length and ContentFilter are preserved as-is } } } @@ -818,6 +1022,7 @@ pub struct JailedStreamBuilder { jail_end_sequences: Vec, tool_call_parser: Option, emission_mode: EmissionMode, + jail_mode: JailMode, } impl JailedStreamBuilder { @@ -828,6 +1033,7 @@ impl JailedStreamBuilder { jail_end_sequences: Vec::new(), tool_call_parser: None, emission_mode: EmissionMode::default(), + jail_mode: JailMode::MarkerBased, } } @@ -887,6 +1093,22 @@ impl JailedStreamBuilder { self } + /// Enable immediate jail mode for tool_choice=named + pub fn tool_choice_named(mut self, tool_name: String) -> Self { + self.jail_mode = JailMode::Immediate { + format: ToolChoiceFormat::SingleObject { tool_name }, + }; + self + } + + /// Enable immediate jail mode for tool_choice=required + pub fn tool_choice_required(mut self) -> Self { + self.jail_mode = JailMode::Immediate { + format: ToolChoiceFormat::ArrayOfTools, + }; + self + } + /// Build the configured JailedStream pub fn build(mut self) -> JailedStream { // Auto-populate jail sequences from parser config if not manually configured @@ -965,6 +1187,7 @@ impl JailedStreamBuilder { tool_call_parser: self.tool_call_parser, emission_mode: self.emission_mode, marker_matcher, + jail_mode: self.jail_mode, } } } diff --git a/lib/llm/src/protocols/openai/common_ext.rs b/lib/llm/src/protocols/openai/common_ext.rs index a77f765ae6..51d9ea0a99 100644 --- a/lib/llm/src/protocols/openai/common_ext.rs +++ b/lib/llm/src/protocols/openai/common_ext.rs @@ -94,7 +94,7 @@ pub trait CommonExtProvider { fn common_ext(&self) -> Option<&CommonExt>; /// Guided Decoding Options - fn get_guided_json(&self) -> Option<&serde_json::Value>; + fn get_guided_json(&self) -> Option; fn get_guided_regex(&self) -> Option; fn get_guided_grammar(&self) -> Option; fn get_guided_choice(&self) -> Option>; diff --git a/lib/llm/src/protocols/openai/completions.rs b/lib/llm/src/protocols/openai/completions.rs index b62f801c40..056c2a3a2d 100644 --- a/lib/llm/src/protocols/openai/completions.rs +++ b/lib/llm/src/protocols/openai/completions.rs @@ -183,8 +183,8 @@ impl CommonExtProvider for NvCreateCompletionRequest { } /// Guided Decoding Options - fn get_guided_json(&self) -> Option<&serde_json::Value> { - self.common.guided_json.as_ref() + fn get_guided_json(&self) -> Option { + self.common.guided_json.clone() } fn get_guided_regex(&self) -> Option { @@ -238,7 +238,11 @@ impl OpenAIStopConditionsProvider for NvCreateCompletionRequest { } fn get_stop(&self) -> Option> { - None + use dynamo_async_openai::types::Stop; + self.inner.stop.as_ref().map(|s| match s { + Stop::String(s) => vec![s.clone()], + Stop::StringArray(arr) => arr.clone(), + }) } fn nvext(&self) -> Option<&NvExt> { @@ -494,4 +498,36 @@ mod tests { assert_eq!(output_options.skip_special_tokens, Some(skip_value)); } } + + #[test] + fn test_stop() { + let null_stop = json!({ + "model": "test-model", + "prompt": "Hello, world!" + }); + let request: NvCreateCompletionRequest = + serde_json::from_value(null_stop).expect("Failed to deserialize request"); + assert_eq!(request.get_stop(), None); + + let one_stop = json!({ + "model": "test-model", + "prompt": "Hello, world!", + "stop": "foo" + }); + let request: NvCreateCompletionRequest = + serde_json::from_value(one_stop).expect("Failed to deserialize request"); + assert_eq!(request.get_stop(), Some(vec!["foo".to_string()])); + + let many_stops = json!({ + "model": "test-model", + "prompt": "Hello, world!", + "stop": ["foo", "bar"] + }); + let request: NvCreateCompletionRequest = + serde_json::from_value(many_stops).expect("Failed to deserialize request"); + assert_eq!( + request.get_stop(), + Some(vec!["foo".to_string(), "bar".to_string()]) + ); + } } diff --git a/lib/llm/src/protocols/openai/completions/delta.rs b/lib/llm/src/protocols/openai/completions/delta.rs index 3b27ffebdb..608a30565d 100644 --- a/lib/llm/src/protocols/openai/completions/delta.rs +++ b/lib/llm/src/protocols/openai/completions/delta.rs @@ -2,7 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 use super::{NvCreateCompletionRequest, NvCreateCompletionResponse}; -use crate::{protocols::common, types::TokenIdType}; +use crate::{ + protocols::{ + common::{self, timing::RequestTimingTracker}, + openai::nvext::{NvExtProvider, NvExtResponse, TimingInfo, WorkerIdInfo}, + }, + types::TokenIdType, +}; impl NvCreateCompletionRequest { /// Enables usage tracking for non-streaming requests to comply with OpenAI API specification. @@ -33,6 +39,12 @@ impl NvCreateCompletionRequest { // put this method on the request // inspect the request to extract options pub fn response_generator(&self, request_id: String) -> DeltaGenerator { + // Check if client requested timing in extra_fields + let enable_timing = self + .nvext() + .and_then(|nv| nv.extra_fields.as_ref()) + .is_some_and(|fields| fields.iter().any(|f| f == "timing")); + let options = DeltaGeneratorOptions { enable_usage: self .inner @@ -41,6 +53,7 @@ impl NvCreateCompletionRequest { .map(|opts| opts.include_usage) .unwrap_or(false), enable_logprobs: self.inner.logprobs.unwrap_or(0) > 0, + enable_timing, }; DeltaGenerator::new(self.inner.model.clone(), options, request_id) @@ -51,9 +64,9 @@ impl NvCreateCompletionRequest { pub struct DeltaGeneratorOptions { pub enable_usage: bool, pub enable_logprobs: bool, + pub enable_timing: bool, } -#[derive(Debug, Clone)] pub struct DeltaGenerator { id: String, object: String, @@ -62,6 +75,7 @@ pub struct DeltaGenerator { system_fingerprint: Option, usage: dynamo_async_openai::types::CompletionUsage, options: DeltaGeneratorOptions, + timing_tracker: Option, } impl DeltaGenerator { @@ -87,6 +101,13 @@ impl DeltaGenerator { let completion_id = format!("cmpl-{request_id}"); + // Create timing tracker if timing is enabled + let timing_tracker = if options.enable_timing { + Some(RequestTimingTracker::new()) + } else { + None + }; + Self { id: completion_id, object: "text_completion".to_string(), @@ -95,6 +116,7 @@ impl DeltaGenerator { system_fingerprint: None, usage, options, + timing_tracker, } } @@ -201,8 +223,7 @@ impl DeltaGenerator { /// # Returns /// * A [`NvCreateCompletionResponse`] with empty choices and usage stats. pub fn create_usage_chunk(&self) -> NvCreateCompletionResponse { - let mut usage = self.usage.clone(); - usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens); + let usage = self.get_usage(); let inner = dynamo_async_openai::types::CreateCompletionResponse { id: self.id.clone(), @@ -222,6 +243,12 @@ impl DeltaGenerator { pub fn is_usage_enabled(&self) -> bool { self.options.enable_usage } + + pub fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage { + let mut usage = self.usage.clone(); + usage.total_tokens = usage.prompt_tokens.saturating_add(usage.completion_tokens); + usage + } } impl crate::protocols::openai::DeltaGeneratorExt for DeltaGenerator { @@ -229,27 +256,25 @@ impl crate::protocols::openai::DeltaGeneratorExt for &mut self, delta: common::llm_backend::BackendOutput, ) -> anyhow::Result { - // aggregate usage - if self.options.enable_usage { - // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`, - // but this will not be an issue until context lengths exceed 4_294_967_295. - let token_length: u32 = delta - .token_ids - .len() - .try_into() - .expect("token_ids length exceeds u32::MAX"); - - self.usage.completion_tokens += token_length; - - // If backend provides completion_usage with prompt token details, - // propagate the entire details struct to usage tracking - if let Some(prompt_details) = delta - .completion_usage - .as_ref() - .and_then(|usage| usage.prompt_tokens_details.as_ref()) - { - self.usage.prompt_tokens_details = Some(prompt_details.clone()); - } + // Aggregate token usage even if usage tracking is disabled for metrics tracking + // SAFETY: Casting from `usize` to `u32` could lead to precision loss after `u32::MAX`, + // but this will not be an issue until context lengths exceed 4_294_967_295. + let token_length: u32 = delta + .token_ids + .len() + .try_into() + .expect("token_ids length exceeds u32::MAX"); + + self.usage.completion_tokens += token_length; + + // If backend provides completion_usage with prompt token details, + // propagate the entire details struct to usage tracking + if let Some(prompt_details) = delta + .completion_usage + .as_ref() + .and_then(|usage| usage.prompt_tokens_details.as_ref()) + { + self.usage.prompt_tokens_details = Some(prompt_details.clone()); } let logprobs = self.create_logprobs( @@ -265,37 +290,44 @@ impl crate::protocols::openai::DeltaGeneratorExt for let index = delta.index.unwrap_or(0); let mut response = self.create_choice(index, delta.text.clone(), finish_reason, logprobs); - // Extract worker_id from disaggregated_params and inject into nvext if present - if let Some(worker_id_json) = delta + // Record first token time (only succeeds on first call due to OnceLock) + if let Some(ref tracker) = self.timing_tracker { + tracker.record_first_token(); + } + + // Extract worker_id from disaggregated_params + let worker_id_info = delta .disaggregated_params .as_ref() .and_then(|params| params.get("worker_id")) - { - use crate::protocols::openai::nvext::{NvExtResponse, WorkerIdInfo}; - - let prefill_worker_id = worker_id_json - .get("prefill_worker_id") - .and_then(|v| v.as_u64()); - let decode_worker_id = worker_id_json - .get("decode_worker_id") - .and_then(|v| v.as_u64()); - - let worker_id_info = WorkerIdInfo { - prefill_worker_id, - decode_worker_id, - }; + .and_then(|v| serde_json::from_value::(v.clone()).ok()); + + // Get timing info if this is the final response (has finish_reason) + let timing_info: Option = if finish_reason.is_some() { + self.timing_tracker.as_ref().map(|tracker| { + tracker.record_finish(); + tracker.get_timing_info() + }) + } else { + None + }; + // Inject nvext if we have worker_id or timing + if worker_id_info.is_some() || timing_info.is_some() { let nvext_response = NvExtResponse { - worker_id: Some(worker_id_info), + worker_id: worker_id_info.clone(), + timing: timing_info, }; if let Ok(nvext_json) = serde_json::to_value(&nvext_response) { response.inner.nvext = Some(nvext_json); - tracing::debug!( - "Injected worker_id into completions nvext: prefill={:?}, decode={:?}", - prefill_worker_id, - decode_worker_id - ); + if let Some(ref info) = worker_id_info { + tracing::debug!( + "Injected worker_id into completions nvext: prefill={:?}, decode={:?}", + info.prefill_worker_id, + info.decode_worker_id + ); + } } } @@ -313,4 +345,8 @@ impl crate::protocols::openai::DeltaGeneratorExt for fn is_usage_enabled(&self) -> bool { DeltaGenerator::is_usage_enabled(self) } + + fn get_usage(&self) -> dynamo_async_openai::types::CompletionUsage { + DeltaGenerator::get_usage(self) + } } diff --git a/lib/llm/src/protocols/openai/nvext.rs b/lib/llm/src/protocols/openai/nvext.rs index d4edc19069..26980b4215 100644 --- a/lib/llm/src/protocols/openai/nvext.rs +++ b/lib/llm/src/protocols/openai/nvext.rs @@ -5,6 +5,8 @@ use derive_builder::Builder; use serde::{Deserialize, Serialize}; use validator::{Validate, ValidationError}; +pub use crate::protocols::common::timing::TimingInfo; + pub trait NvExtProvider { fn nvext(&self) -> Option<&NvExt>; fn raw_prompt(&self) -> Option; @@ -28,6 +30,11 @@ pub struct NvExtResponse { /// Worker ID information (prefill and decode worker IDs) #[serde(skip_serializing_if = "Option::is_none")] pub worker_id: Option, + + /// Per-request timing information + /// Populated when client requests `extra_fields: ["timing"]` + #[serde(skip_serializing_if = "Option::is_none")] + pub timing: Option, } /// NVIDIA LLM extensions to the OpenAI API @@ -76,7 +83,7 @@ pub struct NvExt { /// Extra fields to be included in the response's nvext /// This is a list of field names that should be populated in the response - /// Supported fields: "worker_id" + /// Supported fields: "worker_id", "timing", which has a 1:1 mapping with the NvExtResponse names #[serde(default, skip_serializing_if = "Option::is_none")] #[builder(default, setter(strip_option))] pub extra_fields: Option>, diff --git a/lib/llm/src/protocols/openai/tools.rs b/lib/llm/src/protocols/openai/tools.rs new file mode 100644 index 0000000000..457f5b37a9 --- /dev/null +++ b/lib/llm/src/protocols/openai/tools.rs @@ -0,0 +1,404 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::BTreeMap; + +use dynamo_async_openai::types::{ + ChatCompletionTool, ChatCompletionToolChoiceOption, FunctionObject, +}; +use serde_json::{Value, json}; +use thiserror::Error; + +/// Errors that can occur when deriving JSON schemas for tool_choice requests. +#[derive(Debug, Error, PartialEq, Eq)] +pub enum ToolChoiceError { + #[error("tool_choice requires a matching `tools` array")] + MissingTools, + #[error("tool `{0}` was not provided in `tools`")] + ToolNotFound(String), + #[error("$defs for tool `{0}` must be an object")] + InvalidDefinitionMap(String), + #[error("duplicate $defs entry `{0}` has conflicting schemas")] + ConflictingDefinition(String), + #[error("tool_choice `required` needs at least one tool definition")] + EmptyTools, +} + +/// Builds the JSON schema enforced by Guided Decoding for the given tool_choice/tools pair. +pub fn get_json_schema_from_tools( + tool_choice: Option<&ChatCompletionToolChoiceOption>, + tools: Option<&[ChatCompletionTool]>, +) -> Result, ToolChoiceError> { + let Some(choice) = tool_choice else { + return Ok(None); + }; + + match choice { + ChatCompletionToolChoiceOption::None | ChatCompletionToolChoiceOption::Auto => Ok(None), + ChatCompletionToolChoiceOption::Named(named) => { + let tools = tools.ok_or(ToolChoiceError::MissingTools)?; + let tool = find_tool(tools, &named.function.name) + .ok_or_else(|| ToolChoiceError::ToolNotFound(named.function.name.clone()))?; + Ok(Some(clone_parameters(&tool.function))) + } + ChatCompletionToolChoiceOption::Required => { + let tools = tools.ok_or(ToolChoiceError::MissingTools)?; + if tools.is_empty() { + return Err(ToolChoiceError::EmptyTools); + } + build_required_schema(tools).map(Some) + } + } +} + +fn find_tool<'a>(tools: &'a [ChatCompletionTool], name: &str) -> Option<&'a ChatCompletionTool> { + tools.iter().find(|tool| tool.function.name == name) +} + +fn clone_parameters(function: &FunctionObject) -> Value { + function + .parameters + .clone() + .unwrap_or_else(|| json!({"type": "object", "properties": {}})) +} + +/// Builds a JSON Schema for `tool_choice=required` that enforces an array of tool calls. +/// +/// # Schema Structure +/// +/// The generated schema looks like: +/// ```json +/// { +/// "type": "array", +/// "minItems": 1, +/// "items": { +/// "type": "object", +/// "anyOf": [ +/// { +/// "properties": { +/// "name": {"type": "string", "enum": ["tool1"]}, +/// "parameters": { /* tool1's parameter schema */ } +/// }, +/// "required": ["name", "parameters"] +/// }, +/// { +/// "properties": { +/// "name": {"type": "string", "enum": ["tool2"]}, +/// "parameters": { /* tool2's parameter schema */ } +/// }, +/// "required": ["name", "parameters"] +/// } +/// ] +/// }, +/// "$defs": { /* shared type definitions from all tools */ } +/// } +/// ``` +/// +/// # $defs Handling +/// +/// `$defs` contains shared JSON Schema definitions that can be referenced via `$ref`. +/// For example, if two tools reference a common type: +/// ```json +/// { +/// "$defs": { +/// "Location": { +/// "type": "object", +/// "properties": { +/// "city": {"type": "string"}, +/// "country": {"type": "string"} +/// } +/// } +/// } +/// } +/// ``` +/// +/// We extract `$defs` from each tool's schema and merge them into a global `$defs` map +/// at the root level. If multiple tools define the same type, we verify they match to +/// avoid conflicts. +fn build_required_schema(tools: &[ChatCompletionTool]) -> Result { + // Accumulator for all shared type definitions ($defs) across tools + let mut defs: BTreeMap = BTreeMap::new(); + let mut any_of = Vec::with_capacity(tools.len()); + + for tool in tools { + // Extract parameter schema and its $defs (if any) + let ParamsAndDefs { + schema, + defs: new_defs, + } = split_defs(&tool.function)?; + merge_defs(&mut defs, new_defs)?; + any_of.push(json!({ + "properties": { + "name": { + "type": "string", + "enum": [tool.function.name], + }, + "parameters": schema, + }, + "required": ["name", "parameters"], + })); + } + + // Build the top-level array schema with anyOf constraints + let mut result = json!({ + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": any_of, + }, + }); + + // Attach the merged $defs at the root level if any were collected + if !defs.is_empty() + && let Value::Object(map) = &mut result + { + map.insert( + "$defs".to_string(), + Value::Object(defs.into_iter().collect()), + ); + } + + Ok(result) +} + +/// Holds a tool's parameter schema and its extracted $defs (if any). +/// +/// When a tool's parameters reference shared types via `$ref`, those types +/// are defined in a `$defs` section within the schema. We extract them separately +/// to merge into a global definitions map. +struct ParamsAndDefs { + /// The parameter schema with `$defs` removed (if it had one) + schema: Value, + /// Extracted `$defs` map, or None if the schema had no definitions + defs: Option>, +} + +/// Extracts `$defs` from a function's parameter schema, returning both the +/// cleaned schema and the definitions separately. +/// +/// # Example +/// +/// Input schema: +/// ```json +/// { +/// "type": "object", +/// "properties": { +/// "location": {"$ref": "#/$defs/Location"} +/// }, +/// "$defs": { +/// "Location": { +/// "type": "object", +/// "properties": {"city": {"type": "string"}} +/// } +/// } +/// } +/// ``` +/// +/// Returns: +/// - schema: same as input but with `$defs` removed +/// - defs: `Some({"Location": {...}})` +fn split_defs(function: &FunctionObject) -> Result { + let mut schema = clone_parameters(function); + let defs = match &mut schema { + Value::Object(obj) => { + if let Some(value) = obj.remove("$defs") { + Some(convert_defs(function, value)?) + } else { + None + } + } + _ => None, + }; + + Ok(ParamsAndDefs { schema, defs }) +} + +fn convert_defs( + function: &FunctionObject, + defs_value: Value, +) -> Result, ToolChoiceError> { + match defs_value { + Value::Object(map) => Ok(map.into_iter().collect()), + _ => Err(ToolChoiceError::InvalidDefinitionMap(function.name.clone())), + } +} + +/// Merges definitions from one tool into the global `$defs` accumulator. +/// +/// # Conflict Detection +/// +/// If two tools define the same type name but with different schemas, we return +/// an error. This ensures consistency across tool definitions. +/// +/// # Example +/// +/// If `target` contains: +/// ```json +/// {"Location": {"type": "object", "properties": {"city": {"type": "string"}}}} +/// ``` +/// +/// And we try to merge: +/// ```json +/// {"Location": {"type": "object", "properties": {"city": {"type": "number"}}}} +/// ``` +/// +/// This will return `ToolChoiceError::ConflictingDefinition("Location")`. +fn merge_defs( + target: &mut BTreeMap, + defs: Option>, +) -> Result<(), ToolChoiceError> { + let Some(defs) = defs else { + return Ok(()); + }; + + for (name, schema) in defs { + if let Some(existing) = target.get(&name) { + if existing != &schema { + return Err(ToolChoiceError::ConflictingDefinition(name)); + } + } else { + target.insert(name, schema); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use dynamo_async_openai::types::{ChatCompletionToolChoiceOption, ChatCompletionToolType}; + + fn sample_tools() -> Vec { + vec![ + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "add_numbers".to_string(), + description: Some("Add two integers".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + "required": ["a", "b"], + })), + strict: None, + }, + }, + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "get_weather".to_string(), + description: Some("Get weather".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + })), + strict: None, + }, + }, + ] + } + + #[test] + fn named_choice_returns_parameters() { + let tools = sample_tools(); + let tool_choice = ChatCompletionToolChoiceOption::Named( + dynamo_async_openai::types::ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: dynamo_async_openai::types::FunctionName { + name: "get_weather".to_string(), + }, + }, + ); + let schema = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).expect("schema"); + + assert_eq!( + schema.unwrap(), + json!({ + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location", "unit"], + }) + ); + } + + #[test] + fn required_choice_builds_any_of_schema() { + let tools = sample_tools(); + let schema = get_json_schema_from_tools( + Some(&ChatCompletionToolChoiceOption::Required), + Some(&tools), + ) + .expect("schema"); + + let schema = schema.expect("required schema"); + assert_eq!(schema["type"], "array"); + assert_eq!(schema["minItems"], 1); + assert!(schema["items"]["anyOf"].is_array()); + + let any_of = schema["items"]["anyOf"].as_array().unwrap(); + assert_eq!(any_of.len(), 2); + assert_eq!( + any_of[0]["properties"]["name"], + json!({"type": "string", "enum": ["add_numbers"]}) + ); + } + + #[test] + fn missing_tool_errors() { + let tools = sample_tools(); + let tool_choice = ChatCompletionToolChoiceOption::Named( + dynamo_async_openai::types::ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: dynamo_async_openai::types::FunctionName { + name: "unknown".to_string(), + }, + }, + ); + let err = get_json_schema_from_tools(Some(&tool_choice), Some(&tools)).unwrap_err(); + assert_eq!(err, ToolChoiceError::ToolNotFound("unknown".to_string())); + } + + #[test] + fn conflicting_defs_errors() { + let tool = ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "foo".to_string(), + description: None, + parameters: Some(json!({ + "type": "object", + "$defs": { + "shared": {"type": "string"} + } + })), + strict: None, + }, + }; + + let mut tool_with_conflict = tool.clone(); + tool_with_conflict.function.parameters = Some(json!({ + "type": "object", + "$defs": { + "shared": {"type": "number"} + } + })); + + let tools = vec![tool, tool_with_conflict]; + let err = build_required_schema(&tools).unwrap_err(); + assert_eq!( + err, + ToolChoiceError::ConflictingDefinition("shared".to_string()) + ); + } +} diff --git a/lib/llm/tests/kserve_service.rs b/lib/llm/tests/kserve_service.rs index 371bbd5eda..49d1070b8f 100644 --- a/lib/llm/tests/kserve_service.rs +++ b/lib/llm/tests/kserve_service.rs @@ -5,6 +5,22 @@ mod ports; pub mod kserve_test { + // [gluo NOTE] Tests may run in parallel, use this enum to keep track of + // port used for different test cases + enum TestPort { + InferFailure = 8988, + InferSuccess = 8989, + StreamInferFailure = 8990, + StreamInferSuccess = 8991, + InferCancellation = 8992, + StreamInferCancellation = 8993, + ModelInfo = 8994, + TensorModel = 8995, + TensorModelTypes = 8996, + TritonModelConfig = 8997, + LiveReady = 8998, + } + // For using gRPC client for test pub mod inference { tonic::include_proto!("inference"); @@ -16,6 +32,7 @@ pub mod kserve_test { use inference::grpc_inference_service_client::GrpcInferenceServiceClient; use inference::{ DataType, ModelConfigRequest, ModelInferRequest, ModelInferResponse, ModelMetadataRequest, + ModelReadyRequest, ServerLiveRequest, ServerReadyRequest, }; use anyhow::Error; @@ -354,21 +371,6 @@ pub mod kserve_test { } } - // Tests may run in parallel, use this enum to keep track of port used for different - // test cases - enum TestPort { - InferFailure = 8988, - InferSuccess = 8989, - StreamInferFailure = 8990, - StreamInferSuccess = 8991, - InferCancellation = 8992, - StreamInferCancellation = 8993, - ModelInfo = 8994, - TensorModel = 8995, - TensorModelTypes = 8996, - TritonModelConfig = 8997, - } - #[rstest] #[tokio::test] async fn test_infer_failure( @@ -1971,4 +1973,86 @@ pub mod kserve_test { cancel_token.cancel(); let _ = tokio::join!(grpc_task, http_task); } + + #[rstest] + #[tokio::test] + async fn test_live_ready() { + let grpc_port = TestPort::LiveReady as u16; + let service = KserveService::builder().port(grpc_port).build().unwrap(); + + // start server + let _running = RunningService::spawn(service.clone()); + + let mut client = get_ready_client(grpc_port, 5).await; + + // Check server liveness + let server_live_request = tonic::Request::new(ServerLiveRequest {}); + let server_live_response = client.server_live(server_live_request).await.unwrap(); + let server_live = server_live_response.get_ref().live; + assert!(server_live, "Server should be live"); + + // Check server readiness + let server_ready_request = tonic::Request::new(ServerReadyRequest {}); + let server_ready_response = client.server_ready(server_ready_request).await.unwrap(); + let server_ready = server_ready_response.get_ref().ready; + assert!( + !server_ready, + "Server should not be ready without model registered" + ); + + // Check model readiness for unregistered model + let model_ready_request = tonic::Request::new(ModelReadyRequest { + name: "tensor".into(), + version: "".into(), + }); + let model_ready_response = client.model_ready(model_ready_request).await.unwrap(); + let model_ready = model_ready_response.get_ref().ready; + assert!(!model_ready, "Unregistered model should not be ready"); + + // Register a tensor model + let mut card = ModelDeploymentCard::with_name_only("tensor"); + card.model_type = ModelType::TensorBased; + card.model_input = ModelInput::Tensor; + card.runtime_config = ModelRuntimeConfig { + tensor_model_config: Some(tensor::TensorModelConfig { + name: "tensor".to_string(), + inputs: vec![tensor::TensorMetadata { + name: "input".to_string(), + data_type: tensor::DataType::Int32, + shape: vec![1], + parameters: Default::default(), + }], + outputs: vec![tensor::TensorMetadata { + name: "output".to_string(), + data_type: tensor::DataType::Bool, + shape: vec![-1], + parameters: Default::default(), + }], + triton_model_config: None, + }), + ..Default::default() + }; + let tensor = Arc::new(TensorEngine {}); + service + .model_manager() + .add_tensor_model("tensor", card.mdcsum(), tensor.clone()) + .unwrap(); + let _ = service.model_manager().save_model_card("key", card); + + // Re-check readiness + // Check server readiness + let server_ready_request = tonic::Request::new(ServerReadyRequest {}); + let server_ready_response = client.server_ready(server_ready_request).await.unwrap(); + let server_ready = server_ready_response.get_ref().ready; + assert!(server_ready, "Server should be ready with model registered"); + + // Check model readiness for unregistered model + let model_ready_request = tonic::Request::new(ModelReadyRequest { + name: "tensor".into(), + version: "".into(), + }); + let model_ready_response = client.model_ready(model_ready_request).await.unwrap(); + let model_ready = model_ready_response.get_ref().ready; + assert!(model_ready, "Registered model should be ready"); + } } diff --git a/lib/llm/tests/test_common_ext.rs b/lib/llm/tests/test_common_ext.rs index 933e486a86..149a205491 100644 --- a/lib/llm/tests/test_common_ext.rs +++ b/lib/llm/tests/test_common_ext.rs @@ -92,7 +92,7 @@ fn test_chat_completions_guided_decoding_from_common() { ); assert_eq!( request.get_guided_json(), - Some(&serde_json::json!({"key": "value"})) + Some(serde_json::json!({"key": "value"})) ); // Test guided_regex can be specified at root level diff --git a/lib/llm/tests/test_jail.rs b/lib/llm/tests/test_jail.rs index c575ad1a79..82e298d945 100644 --- a/lib/llm/tests/test_jail.rs +++ b/lib/llm/tests/test_jail.rs @@ -2392,6 +2392,13 @@ mod parallel_jail_tests { for (i, (expected_name, expected_args)) in expected_tool_calls.iter().enumerate() { let tool_call = &all_tool_calls[i]; assert!(tool_call.id.is_some(), "Tool call {} should have an ID", i); + + assert_eq!( + tool_call.index, i as u32, + "Tool call {} should have index {}, got {}", + i, i, tool_call.index + ); + assert_eq!( tool_call.r#type, Some(dynamo_async_openai::types::ChatCompletionToolType::Function), diff --git a/lib/llm/tests/test_reasoning_parser.rs b/lib/llm/tests/test_reasoning_parser.rs index 190fd9badb..19a0ec328a 100644 --- a/lib/llm/tests/test_reasoning_parser.rs +++ b/lib/llm/tests/test_reasoning_parser.rs @@ -484,7 +484,8 @@ mod tests { // Step 2: Apply tool calling jail transformation let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( - "nemotron_deci".to_string(), + Some("nemotron_deci".to_string()), + None, // No tool_choice in this test reasoning_parsed_stream, ); @@ -596,7 +597,8 @@ mod tests { let reasoning_parsed_stream = stream::iter(debug_chunks); let tool_parsed_stream = OpenAIPreprocessor::apply_tool_calling_jail( - "harmony".to_string(), + Some("harmony".to_string()), + None, // No tool_choice in this test reasoning_parsed_stream, ); diff --git a/lib/llm/tests/test_streaming_tool_parsers.rs b/lib/llm/tests/test_streaming_tool_parsers.rs index 03392d2fcb..c2214e707b 100644 --- a/lib/llm/tests/test_streaming_tool_parsers.rs +++ b/lib/llm/tests/test_streaming_tool_parsers.rs @@ -158,7 +158,8 @@ async fn parse_response_stream( > = if tool_parse_enable { if let Some(tool_parser) = tool_parser_str { Box::pin(OpenAIPreprocessor::apply_tool_calling_jail( - tool_parser, + Some(tool_parser), + None, // No tool_choice in this test stream, )) } else { diff --git a/lib/llm/tests/test_streaming_usage.rs b/lib/llm/tests/test_streaming_usage.rs index 068b85bab7..b1fe1b873f 100644 --- a/lib/llm/tests/test_streaming_usage.rs +++ b/lib/llm/tests/test_streaming_usage.rs @@ -208,11 +208,29 @@ async fn test_streaming_without_usage() { // Collect all chunks let chunks: Vec<_> = transformed_stream.collect().await; - // Verify we got exactly 3 chunks (no extra usage chunk) - assert_eq!(chunks.len(), 3, "Should have exactly 3 content chunks"); + // Filter out metrics annotation events (events without SSE data payload) + let content_chunks: Vec<_> = chunks + .into_iter() + .filter(|chunk| { + // Metrics annotation events have event=Some(ANNOTATION_LLM_METRICS) and data=None + !(chunk + .event + .as_ref() + .map(|e| e == "llm_metrics") + .unwrap_or(false) + && chunk.data.is_none()) + }) + .collect(); + + // Verify we got exactly 3 content chunks (no extra usage chunk) + assert_eq!( + content_chunks.len(), + 3, + "Should have exactly 3 content chunks" + ); // Verify all chunks have usage: None - for (i, chunk) in chunks.iter().enumerate() { + for (i, chunk) in content_chunks.iter().enumerate() { if let Some(response) = &chunk.data { assert!( response.usage.is_none(), @@ -322,15 +340,29 @@ async fn test_streaming_with_usage_false() { // Collect all chunks let chunks: Vec<_> = transformed_stream.collect().await; + // Filter out metrics annotation events (events without SSE data payload) + let content_chunks: Vec<_> = chunks + .into_iter() + .filter(|chunk| { + // Metrics annotation events have event=Some(ANNOTATION_LLM_METRICS) and data=None + !(chunk + .event + .as_ref() + .map(|e| e == "llm_metrics") + .unwrap_or(false) + && chunk.data.is_none()) + }) + .collect(); + // Verify we got exactly 3 chunks (no extra usage chunk when explicitly false) assert_eq!( - chunks.len(), + content_chunks.len(), 3, "Should have exactly 3 content chunks when include_usage is false" ); // Verify all chunks have usage: None - for (i, chunk) in chunks.iter().enumerate() { + for (i, chunk) in content_chunks.iter().enumerate() { if let Some(response) = &chunk.data { assert!( response.usage.is_none(), diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs new file mode 100644 index 0000000000..c970108d9b --- /dev/null +++ b/lib/llm/tests/tool_choice.rs @@ -0,0 +1,436 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use dynamo_async_openai::types::{ + ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption, + ChatCompletionToolType, CreateChatCompletionRequest, FunctionName, +}; +use dynamo_llm::protocols::common; +use dynamo_llm::protocols::common::llm_backend::BackendOutput; +use dynamo_llm::protocols::openai::DeltaGeneratorExt; +use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; + +fn create_test_request() -> NvCreateChatCompletionRequest { + let messages = vec![ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text("test".to_string()), + name: None, + }, + )]; + + NvCreateChatCompletionRequest { + inner: CreateChatCompletionRequest { + model: "test-model".to_string(), + messages, + stream: Some(false), + stream_options: None, + ..Default::default() + }, + common: Default::default(), + nvext: None, + chat_template_args: None, + unsupported_fields: Default::default(), + } +} + +async fn apply_jail_transformation( + raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + tool_choice: Option, +) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse { + use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; + use dynamo_runtime::protocols::annotated::Annotated; + use futures::StreamExt; + use futures::stream; + + let input_stream = stream::iter(vec![Annotated { + data: Some(raw_response), + id: None, + event: None, + comment: None, + }]); + + let mut builder = JailedStream::builder(); + + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(ref named)) => { + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + builder = builder.tool_choice_required(); + } + _ => {} + } + + let jail = builder.build(); + let output_stream = jail.apply_with_finish_reason(input_stream); + + tokio::pin!(output_stream); + output_stream.next().await.unwrap().data.unwrap() +} + +async fn apply_jail_transformation_streaming( + raw_responses: Vec< + dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + >, + tool_choice: Option, +) -> Vec { + use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; + use dynamo_runtime::protocols::annotated::Annotated; + use futures::StreamExt; + use futures::stream; + + let input_stream = stream::iter(raw_responses.into_iter().map(|r| Annotated { + data: Some(r), + id: None, + event: None, + comment: None, + })); + + let mut builder = JailedStream::builder(); + + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(ref named)) => { + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + builder = builder.tool_choice_required(); + } + _ => {} + } + + let jail = builder.build(); + let output_stream = jail.apply_with_finish_reason(input_stream); + + tokio::pin!(output_stream); + output_stream + .filter_map(|ann| async move { ann.data }) + .collect() + .await +} + +fn build_backend_output(text: &str) -> BackendOutput { + BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(text.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: Some(common::FinishReason::Stop), + index: Some(0), + completion_usage: None, + disaggregated_params: None, + } +} + +#[tokio::test] +async fn test_named_tool_choice_parses_json() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-1".to_string()); + let backend_output = build_backend_output(r#"{"location":"Paris"}"#); + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + let choice = &response.choices[0]; + + assert_eq!( + choice.finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop) + ); + let delta = &choice.delta; + assert!(delta.content.is_none() || delta.content.as_deref() == Some("")); + let tool_calls = delta.tool_calls.as_ref().unwrap(); + + assert_eq!(tool_calls.len(), 1); + + let tool_call = &tool_calls[0]; + assert_eq!(tool_call.index, 0); + assert!(tool_call.id.as_ref().unwrap().starts_with("call-")); + assert_eq!(tool_call.r#type, Some(ChatCompletionToolType::Function)); + assert_eq!( + tool_call.function.as_ref().unwrap().name.as_deref(), + Some("get_weather") + ); + assert_eq!( + tool_call.function.as_ref().unwrap().arguments.as_deref(), + Some(r#"{"location":"Paris"}"#) + ); +} + +#[tokio::test] +async fn test_required_tool_choice_parses_json_array() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-2".to_string()); + let backend_output = build_backend_output( + r#"[{"name":"search","parameters":{"query":"rust"}}, + {"name":"summarize","parameters":{"topic":"memory"}}]"#, + ); + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + let choice = &response.choices[0]; + + assert_eq!( + choice.finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls) + ); + let delta = &choice.delta; + assert!(delta.content.is_none() || delta.content.as_deref() == Some("")); + let tool_calls = delta.tool_calls.as_ref().unwrap(); + + assert_eq!(tool_calls.len(), 2); + + assert_eq!(tool_calls[0].index, 0); + assert!(tool_calls[0].id.as_ref().unwrap().starts_with("call-")); + assert_eq!(tool_calls[0].r#type, Some(ChatCompletionToolType::Function)); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("search") + ); + assert_eq!( + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"query":"rust"}"#) + ); + + assert_eq!(tool_calls[1].index, 1); + assert!(tool_calls[1].id.as_ref().unwrap().starts_with("call-")); + assert_eq!(tool_calls[1].r#type, Some(ChatCompletionToolType::Function)); + assert_eq!( + tool_calls[1].function.as_ref().unwrap().name.as_deref(), + Some("summarize") + ); + assert_eq!( + tool_calls[1] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"topic":"memory"}"#) + ); +} + +#[tokio::test] +async fn test_tool_choice_parse_failure_returns_as_content() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-3".to_string()); + let backend_output = build_backend_output("not-json"); + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + let delta = &response.choices[0].delta; + + // Jail stream behavior: if parsing fails, return accumulated content as-is + // This matches marker-based FC behavior + assert_eq!(delta.content.as_deref(), Some("not-json")); + assert!(delta.tool_calls.is_none()); +} + +#[tokio::test] +async fn test_streaming_named_tool_buffers_until_finish() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-stream-1".to_string()); + + let chunks = [r#"{"location":""#, r#"Paris","unit":""#, r#"celsius"}"#]; + + let mut raw_responses = Vec::new(); + for (i, chunk) in chunks.iter().enumerate() { + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(chunk.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: if i == chunks.len() - 1 { + Some(common::FinishReason::Stop) + } else { + None + }, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("streaming chunk"); + raw_responses.push(response); + } + + let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await; + + // Jail stream buffers content until valid JSON, then emits once + assert_eq!(all_responses.len(), 1); + + let response = &all_responses[0]; + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop) + ); + + let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("get_weather") + ); + assert_eq!( + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"location":"Paris","unit":"celsius"}"#) + ); +} + +#[tokio::test] +async fn test_streaming_required_tool_parallel() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-stream-2".to_string()); + + let chunks = [ + r#"[{"name":"search","parameters":{"query":"rust"}},"#, + r#"{"name":"summarize","parameters":{"topic":"memory"}}]"#, + ]; + + let mut raw_responses = Vec::new(); + for (i, chunk) in chunks.iter().enumerate() { + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(chunk.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: if i == chunks.len() - 1 { + Some(common::FinishReason::Stop) + } else { + None + }, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("streaming chunk"); + raw_responses.push(response); + } + + let all_responses = apply_jail_transformation_streaming(raw_responses, tool_choice).await; + + // Jail stream buffers until complete JSON array + assert_eq!(all_responses.len(), 1); + + let response = &all_responses[0]; + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls) + ); + + let tool_calls = response.choices[0].delta.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 2); + + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("search") + ); + assert_eq!( + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"query":"rust"}"#) + ); + + assert_eq!( + tool_calls[1].function.as_ref().unwrap().name.as_deref(), + Some("summarize") + ); + assert_eq!( + tool_calls[1] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some(r#"{"topic":"memory"}"#) + ); +} + +#[test] +fn test_no_tool_choice_outputs_normal_text() { + let request = create_test_request(); + + let mut generator = request.response_generator("req-stream-4".to_string()); + + let backend_output = BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some("Hello world".to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: None, + index: Some(0), + completion_usage: None, + disaggregated_params: None, + }; + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("normal text"); + + assert_eq!( + response.choices[0].delta.content.as_deref(), + Some("Hello world") + ); + assert!(response.choices[0].delta.tool_calls.is_none()); +} diff --git a/lib/llm/tests/tool_choice_finish_reasons.rs b/lib/llm/tests/tool_choice_finish_reasons.rs new file mode 100644 index 0000000000..07f28d5962 --- /dev/null +++ b/lib/llm/tests/tool_choice_finish_reasons.rs @@ -0,0 +1,250 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Tests for tool_choice finish_reason handling. + +use dynamo_async_openai::types::{ + ChatCompletionNamedToolChoice, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, ChatCompletionToolChoiceOption, + ChatCompletionToolType, CreateChatCompletionRequest, FunctionName, +}; +use dynamo_llm::protocols::common; +use dynamo_llm::protocols::common::llm_backend::BackendOutput; +use dynamo_llm::protocols::openai::DeltaGeneratorExt; +use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest; + +fn create_test_request() -> NvCreateChatCompletionRequest { + let messages = vec![ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text("test".to_string()), + name: None, + }, + )]; + + NvCreateChatCompletionRequest { + inner: CreateChatCompletionRequest { + model: "test-model".to_string(), + messages, + stream: Some(false), + stream_options: None, + ..Default::default() + }, + common: Default::default(), + nvext: None, + chat_template_args: None, + unsupported_fields: Default::default(), + } +} + +fn build_backend_output_with_finish(text: &str, finish: common::FinishReason) -> BackendOutput { + BackendOutput { + token_ids: vec![], + tokens: vec![], + text: Some(text.to_string()), + cum_log_probs: None, + log_probs: None, + top_logprobs: None, + finish_reason: Some(finish), + index: Some(0), + completion_usage: None, + disaggregated_params: None, + } +} + +async fn apply_jail_transformation( + raw_response: dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse, + tool_choice: Option, +) -> dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionStreamResponse { + use dynamo_llm::protocols::openai::chat_completions::jail::JailedStream; + use dynamo_runtime::protocols::annotated::Annotated; + use futures::StreamExt; + use futures::stream; + + let input_stream = stream::iter(vec![Annotated { + data: Some(raw_response), + id: None, + event: None, + comment: None, + }]); + + let mut builder = JailedStream::builder(); + + match tool_choice { + Some(ChatCompletionToolChoiceOption::Named(ref named)) => { + builder = builder.tool_choice_named(named.function.name.clone()); + } + Some(ChatCompletionToolChoiceOption::Required) => { + builder = builder.tool_choice_required(); + } + _ => {} + } + + let jail = builder.build(); + let output_stream = jail.apply_with_finish_reason(input_stream); + + tokio::pin!(output_stream); + output_stream.next().await.unwrap().data.unwrap() +} + +#[tokio::test] +async fn test_named_tool_choice_preserves_length_finish_reason() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-length-1".to_string()); + let backend_output = build_backend_output_with_finish( + r#"{"location":"Par"#, // Incomplete due to length limit + common::FinishReason::Length, + ); + + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + + // Critical: Length finish reason should be preserved, NOT replaced with Stop + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Length), + "Length finish reason must be preserved for tool_choice=named" + ); +} + +#[test] +fn test_required_tool_choice_preserves_length_finish_reason() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-length-2".to_string()); + let backend_output = build_backend_output_with_finish( + r#"[{"name":"search","parameters":{"query":"incomplete"#, // Incomplete due to length + common::FinishReason::Length, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: Length finish reason should be preserved, NOT replaced with ToolCalls + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Length), + "Length finish reason must be preserved for tool_choice=required" + ); +} + +#[test] +fn test_named_tool_choice_preserves_content_filter() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "search".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-filter-1".to_string()); + let backend_output = build_backend_output_with_finish( + r#"{"query":"filtered content"#, + common::FinishReason::ContentFilter, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: ContentFilter finish reason should be preserved + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ContentFilter), + "ContentFilter finish reason must be preserved for tool_choice=named" + ); +} + +#[test] +fn test_required_tool_choice_preserves_content_filter() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Required); + + let mut generator = request.response_generator("req-filter-2".to_string()); + let backend_output = build_backend_output_with_finish( + r#"[{"name":"harmful_action"#, + common::FinishReason::ContentFilter, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Critical: ContentFilter finish reason should be preserved + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ContentFilter), + "ContentFilter finish reason must be preserved for tool_choice=required" + ); +} + +#[test] +fn test_named_tool_choice_normal_stop_becomes_stop() { + let mut request = create_test_request(); + request.inner.tool_choice = Some(ChatCompletionToolChoiceOption::Named( + ChatCompletionNamedToolChoice { + r#type: ChatCompletionToolType::Function, + function: FunctionName { + name: "get_weather".to_string(), + }, + }, + )); + + let mut generator = request.response_generator("req-stop-1".to_string()); + let backend_output = build_backend_output_with_finish( + r#"{"location":"Paris","unit":"celsius"}"#, + common::FinishReason::Stop, + ); + + let response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + // Normal completion: Stop should remain Stop for named tool choice + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::Stop), + ); +} + +#[tokio::test] +async fn test_required_tool_choice_normal_stop_becomes_tool_calls() { + let mut request = create_test_request(); + let tool_choice = Some(ChatCompletionToolChoiceOption::Required); + request.inner.tool_choice = tool_choice.clone(); + + let mut generator = request.response_generator("req-stop-2".to_string()); + let backend_output = build_backend_output_with_finish( + r#"[{"name":"search","parameters":{"query":"rust"}}]"#, + common::FinishReason::Stop, + ); + + let raw_response = generator + .choice_from_postprocessor(backend_output) + .expect("choice generation"); + + let response = apply_jail_transformation(raw_response, tool_choice).await; + + // Normal completion: Stop should become ToolCalls for required tool choice + assert_eq!( + response.choices[0].finish_reason, + Some(dynamo_async_openai::types::FinishReason::ToolCalls), + ); +} diff --git a/lib/parsers/src/tool_calling/config.rs b/lib/parsers/src/tool_calling/config.rs index a7534e0920..e03295c754 100644 --- a/lib/parsers/src/tool_calling/config.rs +++ b/lib/parsers/src/tool_calling/config.rs @@ -69,6 +69,36 @@ impl Default for XmlParserConfig { } } +/// Configuration for DSML-style tool call parser (DeepSeek V3.2+) +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct DsmlParserConfig { + /// Start token for function_calls block (e.g., "<๏ฝœDSML๏ฝœfunction_calls>") + pub function_calls_start: String, + /// End token for function_calls block (e.g., "") + pub function_calls_end: String, + /// Start prefix for invoke (e.g., "<๏ฝœDSML๏ฝœinvoke name=") + pub invoke_start_prefix: String, + /// End token for invoke (e.g., "") + pub invoke_end: String, + /// Start prefix for parameter (e.g., "<๏ฝœDSML๏ฝœparameter name=") + pub parameter_prefix: String, + /// End token for parameter (e.g., "") + pub parameter_end: String, +} + +impl Default for DsmlParserConfig { + fn default() -> Self { + Self { + function_calls_start: "<๏ฝœDSML๏ฝœfunction_calls>".to_string(), + function_calls_end: "".to_string(), + invoke_start_prefix: "<๏ฝœDSML๏ฝœinvoke name=".to_string(), + invoke_end: "".to_string(), + parameter_prefix: "<๏ฝœDSML๏ฝœparameter name=".to_string(), + parameter_end: "".to_string(), + } + } +} + /// Parser-specific configuration #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -78,6 +108,7 @@ pub enum ParserConfig { Pythonic, Harmony(JsonParserConfig), Typescript, + Dsml(DsmlParserConfig), } impl ParserConfig { @@ -90,6 +121,7 @@ impl ParserConfig { ParserConfig::Xml(config) => vec![config.tool_call_start_token.clone()], ParserConfig::Pythonic => vec![], ParserConfig::Typescript => vec![], + ParserConfig::Dsml(config) => vec![config.function_calls_start.clone()], } } @@ -102,6 +134,7 @@ impl ParserConfig { ParserConfig::Xml(config) => vec![config.tool_call_end_token.clone()], ParserConfig::Pythonic => vec![], ParserConfig::Typescript => vec![], + ParserConfig::Dsml(config) => vec![config.function_calls_end.clone()], } } } @@ -239,4 +272,26 @@ impl ToolCallConfig { parser_config: ParserConfig::Xml(XmlParserConfig::default()), } } + + pub fn jamba() -> Self { + Self { + parser_config: ParserConfig::Json(JsonParserConfig { + tool_call_start_tokens: vec!["".to_string()], + tool_call_end_tokens: vec!["".to_string()], + ..Default::default() + }), + } + } + + pub fn deepseek_v3_2() -> Self { + // DeepSeek V3.2 format (DSML): + // <๏ฝœDSML๏ฝœfunction_calls> + // <๏ฝœDSML๏ฝœinvoke name="function_name"> + // <๏ฝœDSML๏ฝœparameter name="param_name" string="true|false">value + // + // + Self { + parser_config: ParserConfig::Dsml(DsmlParserConfig::default()), + } + } } diff --git a/lib/parsers/src/tool_calling/dsml/mod.rs b/lib/parsers/src/tool_calling/dsml/mod.rs new file mode 100644 index 0000000000..43d412421b --- /dev/null +++ b/lib/parsers/src/tool_calling/dsml/mod.rs @@ -0,0 +1,9 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod parser; + +pub use super::response; +pub use parser::{ + detect_tool_call_start_dsml, find_tool_call_end_position_dsml, try_tool_call_parse_dsml, +}; diff --git a/lib/parsers/src/tool_calling/dsml/parser.rs b/lib/parsers/src/tool_calling/dsml/parser.rs new file mode 100644 index 0000000000..11bb8503e5 --- /dev/null +++ b/lib/parsers/src/tool_calling/dsml/parser.rs @@ -0,0 +1,493 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Reference implementation: +// https://huggingface.co/deepseek-ai/DeepSeek-V3.2/tree/main/encoding/encoding_dsv32.py + +use regex::Regex; +use uuid::Uuid; + +use super::super::config::DsmlParserConfig; +use super::super::response::{CalledFunction, ToolCallResponse, ToolCallType}; + +/// DeepSeek V3.2 uses DSML (DeepSeek Markup Language) format for tool calls: +/// +/// <๏ฝœDSML๏ฝœfunction_calls> +/// <๏ฝœDSML๏ฝœinvoke name="function_name"> +/// <๏ฝœDSML๏ฝœparameter name="param_name" string="true|false">value +/// ... +/// +/// +/// Check if a chunk contains the start of a DSML tool call +pub fn detect_tool_call_start_dsml(chunk: &str, config: &DsmlParserConfig) -> bool { + let start_token = &config.function_calls_start; + + // Check for complete start token + if chunk.contains(start_token.as_str()) { + return true; + } + + // Check for partial match at the end (streaming scenario) + let start_chars: Vec = start_token.chars().collect(); + for i in 1..start_chars.len() { + let partial: String = start_chars[..i].iter().collect(); + if chunk.ends_with(&partial) { + return true; + } + } + + false +} + +/// Find the end position of a DSML tool call block +pub fn find_tool_call_end_position_dsml(chunk: &str, config: &DsmlParserConfig) -> usize { + let end_token = &config.function_calls_end; + + if let Some(pos) = chunk.find(end_token.as_str()) { + pos + end_token.len() + } else { + chunk.len() + } +} + +/// Parse DSML formatted tool calls from a message +/// Returns (parsed_tool_calls, normal_text_content) +pub fn try_tool_call_parse_dsml( + message: &str, + config: &DsmlParserConfig, +) -> anyhow::Result<(Vec, Option)> { + let trimmed = message.trim(); + + // Early exit if no content + if trimmed.is_empty() { + return Ok((vec![], Some(String::new()))); + } + + // Check if tool call block exists + if !trimmed.contains(&config.function_calls_start) { + return Ok((vec![], Some(trimmed.to_string()))); + } + + // Extract normal text before tool calls + let normal_text = if let Some(start_idx) = trimmed.find(&config.function_calls_start) { + let text = trimmed[..start_idx].trim(); + if text.is_empty() { + String::new() + } else { + text.to_string() + } + } else { + String::new() + }; + + // Extract tool calls blocks + let tool_calls = extract_tool_calls(trimmed, config)?; + + if tool_calls.is_empty() { + // No valid tool calls found + return Ok((vec![], Some(trimmed.to_string()))); + } + + Ok((tool_calls, Some(normal_text))) +} + +/// Extract all tool calls from the DSML formatted text +fn extract_tool_calls( + text: &str, + config: &DsmlParserConfig, +) -> anyhow::Result> { + let mut tool_calls = Vec::new(); + + // Find all function_calls blocks + // Matches: <๏ฝœDSML๏ฝœfunction_calls> ... + // Pattern: (?s) = dot matches newlines + // \s*(.*?)\s* = capture content between start/end tags (non-greedy) + let block_pattern = format!( + r"(?s){}\s*(.*?)\s*{}", + regex::escape(&config.function_calls_start), + regex::escape(&config.function_calls_end) + ); + let block_regex = Regex::new(&block_pattern)?; + + for block_match in block_regex.captures_iter(text) { + if let Some(block_content) = block_match.get(1) { + let block = block_content.as_str(); + + // Extract individual invokes from this block + let invokes = extract_invokes(block, config)?; + tool_calls.extend(invokes); + } + } + + Ok(tool_calls) +} + +/// Extract individual invoke blocks from function_calls content +fn extract_invokes( + block: &str, + config: &DsmlParserConfig, +) -> anyhow::Result> { + let mut invokes = Vec::new(); + + // Regex to match: <๏ฝœDSML๏ฝœinvoke name="function_name">..content.. + // Note: invoke_start_prefix is "<๏ฝœDSML๏ฝœinvoke name=" (no quotes, we add them in pattern) + let invoke_pattern = format!( + r#"(?s){}\"([^"]+)\"\s*>(.*?){}"#, + regex::escape(&config.invoke_start_prefix), + regex::escape(&config.invoke_end) + ); + let invoke_regex = Regex::new(&invoke_pattern)?; + + for invoke_match in invoke_regex.captures_iter(block) { + if let (Some(name_match), Some(content_match)) = (invoke_match.get(1), invoke_match.get(2)) + { + let function_name = name_match.as_str().trim().to_string(); + let invoke_content = content_match.as_str(); + + // Parse parameters from invoke content + let parameters = parse_parameters(invoke_content, config)?; + + // Create tool call response + let arguments_json = serde_json::to_string(¶meters)?; + + invokes.push(ToolCallResponse { + id: format!("call-{}", Uuid::new_v4()), + tp: ToolCallType::Function, + function: CalledFunction { + name: function_name, + arguments: arguments_json, + }, + }); + } + } + + Ok(invokes) +} + +/// Parse parameters from invoke content +fn parse_parameters( + content: &str, + config: &DsmlParserConfig, +) -> anyhow::Result> { + let mut parameters = serde_json::Map::new(); + + // Build pattern with proper escaping + // Match: <๏ฝœDSML๏ฝœparameter name="param_name" string="true|false">value + // Note: parameter_prefix is "<๏ฝœDSML๏ฝœparameter name=" (no quotes, we add them in pattern) + let prefix_escaped = regex::escape(&config.parameter_prefix); + let end_escaped = regex::escape(&config.parameter_end); + + let param_pattern = format!( + r#"(?s){}\"([^"]+)\"\s+string=\"(true|false)\"\s*>(.*?){}"#, + prefix_escaped, end_escaped + ); + + let param_regex = Regex::new(¶m_pattern)?; + + for param_match in param_regex.captures_iter(content) { + if let (Some(name_match), Some(string_match), Some(value_match)) = + (param_match.get(1), param_match.get(2), param_match.get(3)) + { + let param_name = name_match.as_str().trim(); + let is_string = string_match.as_str() == "true"; + let param_value = value_match.as_str().trim(); + + // Parse value based on string attribute + let value = if is_string { + // String type - use as-is + serde_json::Value::String(param_value.to_string()) + } else { + // Non-string type - parse as JSON + serde_json::from_str(param_value).unwrap_or_else(|_| { + // Fallback to string if JSON parsing fails + serde_json::Value::String(param_value.to_string()) + }) + }; + + parameters.insert(param_name.to_string(), value); + } + } + + Ok(parameters) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn extract_name_and_args(call: ToolCallResponse) -> (String, serde_json::Value) { + let args: serde_json::Value = serde_json::from_str(&call.function.arguments).unwrap(); + (call.function.name, args) + } + + fn get_test_config() -> DsmlParserConfig { + DsmlParserConfig::default() + } + + #[test] + fn test_detect_tool_call_start() { + let config = get_test_config(); + assert!(detect_tool_call_start_dsml( + "<๏ฝœDSML๏ฝœfunction_calls>", + &config + )); + assert!(detect_tool_call_start_dsml( + "text <๏ฝœDSML๏ฝœfunction_calls>", + &config + )); + assert!(detect_tool_call_start_dsml("<๏ฝœDSML๏ฝœfunction_c", &config)); // Partial + assert!(!detect_tool_call_start_dsml("no tool call here", &config)); + } + + #[test] + fn test_find_tool_call_end_position() { + let config = get_test_config(); + let text = "<๏ฝœDSML๏ฝœfunction_calls><๏ฝœDSML๏ฝœinvoke name=\"test\">more"; + let pos = find_tool_call_end_position_dsml(text, &config); + assert_eq!(&text[pos..], "more"); + } + + #[test] + fn test_parse_single_tool_call_string_param() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="get_weather"> +<๏ฝœDSML๏ฝœparameter name="location" string="true">San Francisco + +"#; + + let config = get_test_config(); + let result = try_tool_call_parse_dsml(input, &config); + if let Err(e) = &result { + eprintln!("Parse error: {:?}", e); + } + let (calls, normal) = result.unwrap(); + + if calls.is_empty() { + eprintln!("Input: {}", input); + eprintln!("No calls parsed!"); + } + + assert_eq!(calls.len(), 1, "Expected 1 tool call, got {}", calls.len()); + assert_eq!(normal, Some("".to_string())); + + let (name, args) = extract_name_and_args(calls[0].clone()); + assert_eq!(name, "get_weather"); + assert_eq!(args["location"], "San Francisco"); + } + + #[test] + fn test_parse_single_tool_call_mixed_params() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="search"> +<๏ฝœDSML๏ฝœparameter name="query" string="true">test query +<๏ฝœDSML๏ฝœparameter name="topn" string="false">10 + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + + let (name, args) = extract_name_and_args(calls[0].clone()); + assert_eq!(name, "search"); + assert_eq!(args["query"], "test query"); + assert_eq!(args["topn"], 10); + } + + #[test] + fn test_parse_multiple_tool_calls() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="get_weather"> +<๏ฝœDSML๏ฝœparameter name="location" string="true">Beijing +<๏ฝœDSML๏ฝœparameter name="date" string="true">2024-01-16 + +<๏ฝœDSML๏ฝœinvoke name="get_weather"> +<๏ฝœDSML๏ฝœparameter name="location" string="true">Hangzhou +<๏ฝœDSML๏ฝœparameter name="date" string="true">2024-01-16 + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 2); + + let (name1, args1) = extract_name_and_args(calls[0].clone()); + assert_eq!(name1, "get_weather"); + assert_eq!(args1["location"], "Beijing"); + + let (name2, args2) = extract_name_and_args(calls[1].clone()); + assert_eq!(name2, "get_weather"); + assert_eq!(args2["location"], "Hangzhou"); + } + + #[test] + fn test_parse_with_normal_text() { + let input = r#"Here's the result: <๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="test"> +<๏ฝœDSML๏ฝœparameter name="value" string="true">test + +"#; + + let config = get_test_config(); + let (calls, normal) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(normal, Some("Here's the result:".to_string())); + } + + #[test] + fn test_parse_no_tool_calls() { + let input = "This is just normal text without any tool calls."; + let config = get_test_config(); + let (calls, normal) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 0); + assert_eq!(normal, Some(input.to_string())); + } + + #[test] + fn test_parse_json_parameter_value() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="process"> +<๏ฝœDSML๏ฝœparameter name="config" string="false">{"key": "value", "count": 42} + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + + let (_, args) = extract_name_and_args(calls[0].clone()); + assert!(args["config"].is_object()); + assert_eq!(args["config"]["key"], "value"); + assert_eq!(args["config"]["count"], 42); + } + + #[test] + fn test_parse_array_parameter_value() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="process"> +<๏ฝœDSML๏ฝœparameter name="items" string="false">[1, 2, 3] + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + + let (_, args) = extract_name_and_args(calls[0].clone()); + assert!(args["items"].is_array()); + assert_eq!(args["items"][0], 1); + assert_eq!(args["items"][2], 3); + } + + #[test] + fn test_parse_boolean_parameters() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="config"> +<๏ฝœDSML๏ฝœparameter name="enabled" string="false">true +<๏ฝœDSML๏ฝœparameter name="disabled" string="false">false + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + + let (_, args) = extract_name_and_args(calls[0].clone()); + assert_eq!(args["enabled"], true); + assert_eq!(args["disabled"], false); + } + + #[test] + fn test_parse_number_parameters() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="calculate"> +<๏ฝœDSML๏ฝœparameter name="integer" string="false">42 +<๏ฝœDSML๏ฝœparameter name="float" string="false">2.7 +<๏ฝœDSML๏ฝœparameter name="negative" string="false">-100 + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + + let (_, args) = extract_name_and_args(calls[0].clone()); + assert_eq!(args["integer"], 42); + assert_eq!(args["float"], 2.7); + assert_eq!(args["negative"], -100); + } + + #[test] + fn test_parse_mixed_types_realistic() { + // Realistic example based on test data + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="search"> +<๏ฝœDSML๏ฝœparameter name="query" string="true">search agent benchmark 2024 +<๏ฝœDSML๏ฝœparameter name="topn" string="false">10 +<๏ฝœDSML๏ฝœparameter name="source" string="true">web + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + + let (name, args) = extract_name_and_args(calls[0].clone()); + assert_eq!(name, "search"); + assert_eq!(args["query"], "search agent benchmark 2024"); + assert_eq!(args["topn"], 10); // Should be number, not string + assert_eq!(args["source"], "web"); + } + + #[test] + fn test_parse_nested_object_parameter() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="configure"> +<๏ฝœDSML๏ฝœparameter name="settings" string="false">{"timeout": 30, "retry": true, "endpoints": ["a", "b"]} + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + + let (_, args) = extract_name_and_args(calls[0].clone()); + assert!(args["settings"].is_object()); + assert_eq!(args["settings"]["timeout"], 30); + assert_eq!(args["settings"]["retry"], true); + assert!(args["settings"]["endpoints"].is_array()); + assert_eq!(args["settings"]["endpoints"][0], "a"); + } + + #[test] + fn test_parse_empty_string_parameter() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="test"> +<๏ฝœDSML๏ฝœparameter name="empty" string="true"> + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + + let (_, args) = extract_name_and_args(calls[0].clone()); + assert_eq!(args["empty"], ""); + } + + #[test] + fn test_parse_null_parameter() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="test"> +<๏ฝœDSML๏ฝœparameter name="value" string="false">null + +"#; + + let config = get_test_config(); + let (calls, _) = try_tool_call_parse_dsml(input, &config).unwrap(); + assert_eq!(calls.len(), 1); + + let (_, args) = extract_name_and_args(calls[0].clone()); + assert!(args["value"].is_null()); + } +} diff --git a/lib/parsers/src/tool_calling/mod.rs b/lib/parsers/src/tool_calling/mod.rs index e2412722fb..ff86235f15 100644 --- a/lib/parsers/src/tool_calling/mod.rs +++ b/lib/parsers/src/tool_calling/mod.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 pub mod config; +pub mod dsml; pub mod harmony; pub mod json; pub mod parsers; @@ -14,6 +15,7 @@ pub mod xml; // Re-export main types and functions for convenience pub use config::{JsonParserConfig, ParserConfig, ToolCallConfig, XmlParserConfig}; +pub use dsml::try_tool_call_parse_dsml; pub use harmony::parse_tool_calls_harmony_complete; pub use json::try_tool_call_parse_json; pub use parsers::{ diff --git a/lib/parsers/src/tool_calling/parsers.rs b/lib/parsers/src/tool_calling/parsers.rs index 93c9ae7e24..d454599d16 100644 --- a/lib/parsers/src/tool_calling/parsers.rs +++ b/lib/parsers/src/tool_calling/parsers.rs @@ -2,6 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 use super::config::{ParserConfig, ToolCallConfig}; +use super::dsml::{ + detect_tool_call_start_dsml, find_tool_call_end_position_dsml, try_tool_call_parse_dsml, +}; use super::harmony::{ detect_tool_call_start_harmony, find_tool_call_end_position_harmony, parse_tool_calls_harmony_complete, @@ -35,7 +38,9 @@ pub fn get_tool_parser_map() -> &'static HashMap<&'static str, ToolCallConfig> { map.insert("harmony", ToolCallConfig::harmony()); map.insert("deepseek_v3", ToolCallConfig::deepseek_v3()); map.insert("deepseek_v3_1", ToolCallConfig::deepseek_v3_1()); + map.insert("deepseek_v3_2", ToolCallConfig::deepseek_v3_2()); map.insert("qwen3_coder", ToolCallConfig::qwen3_coder()); + map.insert("jamba", ToolCallConfig::jamba()); map.insert("default", ToolCallConfig::default()); map }) @@ -71,6 +76,10 @@ pub async fn try_tool_call_parse( let (results, normal_content) = try_tool_call_parse_xml(message, xml_config)?; Ok((results, normal_content)) } + ParserConfig::Dsml(dsml_config) => { + let (results, normal_content) = try_tool_call_parse_dsml(message, dsml_config)?; + Ok((results, normal_content)) + } } } @@ -119,6 +128,7 @@ pub fn detect_tool_call_start(chunk: &str, parser_str: Option<&str>) -> anyhow:: anyhow::bail!("Typescript parser not implemented"); } ParserConfig::Xml(xml_config) => Ok(detect_tool_call_start_xml(chunk, xml_config)), + ParserConfig::Dsml(dsml_config) => Ok(detect_tool_call_start_dsml(chunk, dsml_config)), }, None => anyhow::bail!( "Parser '{}' is not implemented. Available parsers: {:?}", @@ -155,6 +165,7 @@ pub fn find_tool_call_end_position(chunk: &str, parser_str: Option<&str>) -> usi chunk.len() } ParserConfig::Xml(xml_config) => find_tool_call_end_position_xml(chunk, xml_config), + ParserConfig::Dsml(dsml_config) => find_tool_call_end_position_dsml(chunk, dsml_config), }, None => { // Unknown parser, return full content length @@ -190,7 +201,9 @@ mod tests { "pythonic", "deepseek_v3", "deepseek_v3_1", + "deepseek_v3_2", "qwen3_coder", + "jamba", ]; for parser in available_parsers { assert!(parsers.contains(&parser)); @@ -940,19 +953,11 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it } #[tokio::test] - #[ignore] async fn test_ai21labs_ai21_jamba_15_mini_simple() { - let input = r#" [ - {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}} -]"#; - let config = ToolCallConfig { - parser_config: ParserConfig::Json(JsonParserConfig { - tool_call_start_tokens: vec![], - tool_call_end_tokens: vec![], - arguments_keys: vec!["arguments".to_string()], - ..Default::default() - }), - }; + let input = r#"[ +{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}} +]"#; + let config = ToolCallConfig::jamba(); let (result, content) = try_tool_call_parse(input, &config).await.unwrap(); assert_eq!(content, Some("".to_string())); assert!(!result.is_empty()); @@ -963,6 +968,29 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it assert_eq!(args["unit"], "fahrenheit"); } + #[tokio::test] + async fn test_ai21labs_ai21_jamba_15_mini_multiple() { + let input = r#"[ +{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}}, +{"name": "get_weather", "arguments": {"location": "New York, NY", "unit": "celsius"}} +]"#; + let config = ToolCallConfig::jamba(); + let (result, content) = try_tool_call_parse(input, &config).await.unwrap(); + assert_eq!(content, Some("".to_string())); + assert!(!result.is_empty()); + assert_eq!(result.len(), 2); + + let (name, args) = extract_name_and_args(result[0].clone()); + assert_eq!(name, "get_weather"); + assert_eq!(args["location"], "San Francisco, CA"); + assert_eq!(args["unit"], "fahrenheit"); + + let (name, args) = extract_name_and_args(result[1].clone()); + assert_eq!(name, "get_weather"); + assert_eq!(args["location"], "New York, NY"); + assert_eq!(args["unit"], "celsius"); + } + #[tokio::test] #[ignore] async fn test_salesforce_llama_xlam_2_8b_fc_r_simple() { @@ -1552,6 +1580,82 @@ Remember, San Francisco weather can be quite unpredictable, particularly with it assert_eq!(args["location"], "Paris"); } + #[tokio::test] + async fn test_deepseek_v3_2_single_tool_call() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="get_datetime"> +<๏ฝœDSML๏ฝœparameter name="timezone" string="true">Asia/Shanghai + +"#; + + let (tool_calls, normal_text) = detect_and_parse_tool_call(input, Some("deepseek_v3_2")) + .await + .expect("Failed to parse"); + + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "get_datetime"); + assert_eq!(normal_text, Some("".to_string())); + + let args: serde_json::Value = + serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); + assert_eq!(args["timezone"], "Asia/Shanghai"); + } + + #[tokio::test] + async fn test_deepseek_v3_2_multiple_tool_calls() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="get_weather"> +<๏ฝœDSML๏ฝœparameter name="location" string="true">Hangzhou +<๏ฝœDSML๏ฝœparameter name="date" string="true">2024-01-16 + +<๏ฝœDSML๏ฝœinvoke name="get_weather"> +<๏ฝœDSML๏ฝœparameter name="location" string="true">Beijing +<๏ฝœDSML๏ฝœparameter name="date" string="true">2024-01-16 + +"#; + + let (tool_calls, _) = detect_and_parse_tool_call(input, Some("deepseek_v3_2")) + .await + .expect("Failed to parse"); + + assert_eq!(tool_calls.len(), 2); + assert_eq!(tool_calls[0].function.name, "get_weather"); + assert_eq!(tool_calls[1].function.name, "get_weather"); + + let args0: serde_json::Value = + serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); + assert_eq!(args0["location"], "Hangzhou"); + assert_eq!(args0["date"], "2024-01-16"); + + let args1: serde_json::Value = + serde_json::from_str(&tool_calls[1].function.arguments).unwrap(); + assert_eq!(args1["location"], "Beijing"); + } + + #[tokio::test] + async fn test_deepseek_v3_2_mixed_parameter_types() { + let input = r#"<๏ฝœDSML๏ฝœfunction_calls> +<๏ฝœDSML๏ฝœinvoke name="search"> +<๏ฝœDSML๏ฝœparameter name="query" string="true">search agent benchmark 2024 +<๏ฝœDSML๏ฝœparameter name="topn" string="false">10 +<๏ฝœDSML๏ฝœparameter name="source" string="true">web + +"#; + + let (tool_calls, _) = detect_and_parse_tool_call(input, Some("deepseek_v3_2")) + .await + .expect("Failed to parse"); + + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "search"); + + let args: serde_json::Value = + serde_json::from_str(&tool_calls[0].function.arguments).unwrap(); + assert_eq!(args["query"], "search agent benchmark 2024"); + assert_eq!(args["topn"], 10); // Should be number, not string + assert_eq!(args["source"], "web"); + } + #[tokio::test] async fn test_hermes_parser_without_new_line() { let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "unit": "celsius"}}" diff --git a/lib/runtime/examples/service_metrics/README.md b/lib/runtime/examples/service_metrics/README.md index 27b66398a9..0aacc8b28e 100644 --- a/lib/runtime/examples/service_metrics/README.md +++ b/lib/runtime/examples/service_metrics/README.md @@ -27,13 +27,7 @@ Annotated { data: Some("o"), id: None, event: None, comment: None } Annotated { data: Some("r"), id: None, event: None, comment: None } Annotated { data: Some("l"), id: None, event: None, comment: None } Annotated { data: Some("d"), id: None, event: None, comment: None } -ServiceSet { services: [ServiceInfo { name: "dynamo_init_backend_720278f8", id: "eOHMc4ndRw8s5flv4WOZx7", version: "0.0.1", started: "2025-02-26T18:54:04.917294605Z", endpoints: [EndpointInfo { name: "dynamo_init_backend_720278f8-generate-694d951a80e06abf", subject: "dynamo_init_backend_720278f8.generate-694d951a80e06abf", data: Some(Metrics(Object {"average_processing_time": Number(53662), "data": Object {"val": Number(10)}, "last_error": String(""), "num_errors": Number(0), "num_requests": Number(2), "processing_time": Number(107325), "queue_group": String("q")})) }] }] } -``` - -Note the following stats in the output demonstrate the custom -`stats_handler` attached to the service in `server.rs` is being invoked: -``` -data: Some(Metrics(Object {..., "data": Object {"val": Number(10)}, ...) +ServiceSet { services: [ServiceInfo { name: "dynamo_init_backend_720278f8", id: "eOHMc4ndRw8s5flv4WOZx7", version: "0.0.1", started: "2025-02-26T18:54:04.917294605Z", endpoints: [EndpointInfo { name: "dynamo_init_backend_720278f8-generate-694d951a80e06abf", subject: "dynamo_init_backend_720278f8.generate-694d951a80e06abf", data: Some(Metrics(Object {"average_processing_time": Number(53662), "last_error": String(""), "num_errors": Number(0), "num_requests": Number(2), "processing_time": Number(107325), "queue_group": String("q")})) }] }] } ``` If you start two copies of the server, you will see two entries being emitted. diff --git a/lib/runtime/examples/service_metrics/src/bin/service_server.rs b/lib/runtime/examples/service_metrics/src/bin/service_server.rs index 3f6dc209e1..3c1e4b6d67 100644 --- a/lib/runtime/examples/service_metrics/src/bin/service_server.rs +++ b/lib/runtime/examples/service_metrics/src/bin/service_server.rs @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use service_metrics::{DEFAULT_NAMESPACE, MyStats}; +use service_metrics::DEFAULT_NAMESPACE; use dynamo_runtime::{ DistributedRuntime, Runtime, Worker, logging, @@ -63,11 +63,6 @@ async fn backend(runtime: DistributedRuntime) -> anyhow::Result<()> { component .endpoint("generate") .endpoint_builder() - .stats_handler(|stats| { - println!("stats: {:?}", stats); - let stats = MyStats { val: 10 }; - serde_json::to_value(stats).unwrap() - }) .handler(ingress) .start() .await diff --git a/lib/runtime/examples/service_metrics/src/lib.rs b/lib/runtime/examples/service_metrics/src/lib.rs index fcffdcc88f..4c58ba787d 100644 --- a/lib/runtime/examples/service_metrics/src/lib.rs +++ b/lib/runtime/examples/service_metrics/src/lib.rs @@ -1,12 +1,4 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use serde::{Deserialize, Serialize}; - pub const DEFAULT_NAMESPACE: &str = "dynamo"; - -#[derive(Serialize, Deserialize)] -// Dummy Stats object to demonstrate how to attach a custom stats handler -pub struct MyStats { - pub val: u32, -} diff --git a/lib/runtime/examples/system_metrics/src/lib.rs b/lib/runtime/examples/system_metrics/src/lib.rs index 885543dee6..3abc0bdc1a 100644 --- a/lib/runtime/examples/system_metrics/src/lib.rs +++ b/lib/runtime/examples/system_metrics/src/lib.rs @@ -18,13 +18,6 @@ pub const DEFAULT_NAMESPACE: &str = "dyn_example_namespace"; pub const DEFAULT_COMPONENT: &str = "dyn_example_component"; pub const DEFAULT_ENDPOINT: &str = "dyn_example_endpoint"; -/// Stats structure returned by the endpoint's stats handler -#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] -pub struct MyStats { - // Example value for demonstration purposes - pub val: i32, -} - /// Custom metrics for system stats with data bytes tracking #[derive(Clone, Debug)] pub struct MySystemStatsMetrics { @@ -103,17 +96,7 @@ pub async fn backend(drt: DistributedRuntime, endpoint_name: Option<&str>) -> an // Use the factory pattern - single line factory call with metrics let ingress = Ingress::for_engine(RequestHandler::with_metrics(system_metrics))?; - endpoint - .endpoint_builder() - .stats_handler(|_stats| { - println!("Stats handler called with stats: {:?}", _stats); - // TODO(keivenc): return a real stats object - let stats = MyStats { val: 10 }; - serde_json::to_value(stats).unwrap() - }) - .handler(ingress) - .start() - .await?; + endpoint.endpoint_builder().handler(ingress).start().await?; Ok(()) } diff --git a/lib/runtime/src/component.rs b/lib/runtime/src/component.rs index 87b02cae4f..4d01faac22 100644 --- a/lib/runtime/src/component.rs +++ b/lib/runtime/src/component.rs @@ -43,7 +43,6 @@ use super::{DistributedRuntime, Runtime, traits::*, transports::nats::Slug, util use crate::pipeline::network::{PushWorkHandler, ingress::push_endpoint::PushEndpoint}; use crate::protocols::EndpointId; -use crate::service::ComponentNatsServerPrometheusMetrics; use async_nats::{ rustls::quic, service::{Service, ServiceExt}, @@ -52,7 +51,6 @@ use derive_builder::Builder; use derive_getters::Getters; use educe::Educe; use serde::{Deserialize, Serialize}; -use service::EndpointStatsHandler; use std::{collections::HashMap, hash::Hash, sync::Arc}; use validator::{Validate, ValidationError}; @@ -79,8 +77,6 @@ pub enum TransportType { #[derive(Default)] pub struct RegistryInner { pub(crate) services: HashMap, - pub(crate) stats_handlers: - HashMap>>>, } #[derive(Clone)] @@ -279,10 +275,38 @@ impl ComponentBuilder { pub fn build(self) -> Result { let component = self.build_internal()?; - // If this component is using NATS, gather it's metrics + // If this component is using NATS, register the NATS service and wait for completion. + // This prevents a race condition where serve_endpoint() tries to look up the service + // before it's registered in the component registry. let drt = component.drt(); if drt.request_plane().is_nats() { - drt.start_stats_service(component.clone()); + let mut rx = drt.register_nats_service(component.clone()); + // Wait synchronously for the NATS service registration to complete. + // Uses block_in_place() to safely call blocking_recv() from async contexts. + // This temporarily moves the current task off the runtime thread to allow + // blocking without deadlocking the runtime. + let result = tokio::task::block_in_place(|| rx.blocking_recv()); + match result { + Some(Ok(())) => { + tracing::debug!( + component = component.service_name(), + "NATS service registration completed" + ); + } + Some(Err(e)) => { + return Err(anyhow::anyhow!( + "NATS service registration failed for component '{}': {}", + component.service_name(), + e + )); + } + None => { + return Err(anyhow::anyhow!( + "NATS service registration channel closed unexpectedly for component '{}'", + component.service_name() + )); + } + } } Ok(component) } diff --git a/lib/runtime/src/component/endpoint.rs b/lib/runtime/src/component/endpoint.rs index 16618dece6..9e0826d3a1 100644 --- a/lib/runtime/src/component/endpoint.rs +++ b/lib/runtime/src/component/endpoint.rs @@ -4,14 +4,13 @@ use std::sync::Arc; use anyhow::Result; -pub use async_nats::service::endpoint::Stats as EndpointStats; use derive_builder::Builder; use derive_getters::Dissolve; use educe::Educe; use tokio_util::sync::CancellationToken; use crate::{ - component::{Endpoint, Instance, TransportType, service::EndpointStatsHandler}, + component::{Endpoint, Instance, TransportType}, distributed::RequestPlaneMode, pipeline::network::{PushWorkHandler, ingress::push_endpoint::PushEndpoint}, protocols::EndpointId, @@ -30,11 +29,6 @@ pub struct EndpointConfig { #[educe(Debug(ignore))] handler: Arc, - /// Stats handler - #[educe(Debug(ignore))] - #[builder(default, private)] - _stats_handler: Option, - /// Additional labels for metrics #[builder(default, setter(into))] metrics_labels: Option>, @@ -56,13 +50,6 @@ impl EndpointConfigBuilder { Self::default().endpoint(endpoint) } - pub fn stats_handler(self, handler: F) -> Self - where - F: FnMut(EndpointStats) -> serde_json::Value + Send + Sync + 'static, - { - self._stats_handler(Some(Box::new(handler))) - } - /// Register an async engine in the local endpoint registry for direct in-process calls pub fn register_local_engine( self, @@ -80,46 +67,19 @@ impl EndpointConfigBuilder { } pub async fn start(self) -> Result<()> { - let ( - endpoint, - handler, - stats_handler, - metrics_labels, - graceful_shutdown, - health_check_payload, - ) = self.build_internal()?.dissolve(); + let (endpoint, handler, metrics_labels, graceful_shutdown, health_check_payload) = + self.build_internal()?.dissolve(); let connection_id = endpoint.drt().connection_id(); let endpoint_id = endpoint.id(); tracing::debug!("Starting endpoint: {endpoint_id}"); - let service_name = endpoint.component.service_name(); - let metrics_labels: Option> = metrics_labels .as_ref() .map(|v| v.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect()); // Add metrics to the handler. The endpoint provides additional information to the handler. handler.add_metrics(&endpoint, metrics_labels.as_deref())?; - // Insert the stats handler. depends on NATS. - if let Some(stats_handler) = stats_handler { - let registry = endpoint.drt().component_registry().inner.lock().await; - let handler_map = registry - .stats_handlers - .get(&service_name) - .cloned() - .expect("no stats handler registry; this is unexpected"); - // There is something wrong with the stats handler map I think. - // Here the connection_id is included, but in component/service.rs add_stats_service it uses service_name, - // no connection id so it's per-endpoint not per-instance. Doesn't match. - // To not block current refactor I am keeping previous behavior, but I think needs - // investigation. - handler_map.lock().insert( - nats::instance_subject(&endpoint_id, connection_id), - stats_handler, - ); - } - // This creates a child token of the runtime's endpoint_shutdown_token. That token is // cancelled first as part of graceful shutdown. See Runtime::shutdown. let endpoint_shutdown_token = endpoint.drt().child_token(); diff --git a/lib/runtime/src/component/service.rs b/lib/runtime/src/component/service.rs index e1128acc4d..c42df30592 100644 --- a/lib/runtime/src/component/service.rs +++ b/lib/runtime/src/component/service.rs @@ -1,60 +1,34 @@ // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use super::*; use crate::component::Component; -use async_nats::service::Service as NatsService; -use async_nats::service::ServiceExt as _; -use derive_builder::Builder; -use derive_getters::Dissolve; -use parking_lot::Mutex; -use std::collections::HashMap; -use std::sync::Arc; - -pub use super::endpoint::EndpointStats; - -type StatsHandlerRegistry = Arc>>; -pub type StatsHandler = - Box serde_json::Value + Send + Sync + 'static>; -pub type EndpointStatsHandler = - Box serde_json::Value + Send + Sync + 'static>; +use async_nats::service::{Service as NatsService, ServiceExt}; pub const PROJECT_NAME: &str = "Dynamo"; const SERVICE_VERSION: &str = env!("CARGO_PKG_VERSION"); +/// Minimal NATS service builder to support legacy NATS request plane. +/// This will be removed once all components migrate to TCP request plane. pub async fn build_nats_service( nats_client: &crate::transports::nats::Client, component: &Component, description: Option, -) -> anyhow::Result<(NatsService, StatsHandlerRegistry)> { +) -> anyhow::Result { let service_name = component.service_name(); - tracing::trace!("component: {component}; creating, service_name: {service_name}"); + tracing::trace!("component: {component}; creating NATS service, service_name: {service_name}"); let description = description.unwrap_or(format!( "{PROJECT_NAME} component {} in namespace {}", component.name, component.namespace )); - let stats_handler_registry: StatsHandlerRegistry = Arc::new(Mutex::new(HashMap::new())); - let stats_handler_registry_clone = stats_handler_registry.clone(); - - let nats_service_builder = nats_client.client().service_builder(); - - let nats_service_builder = - nats_service_builder - .description(description) - .stats_handler(move |name, stats| { - tracing::trace!("stats_handler: {name}, {stats:?}"); - let mut guard = stats_handler_registry.lock(); - match guard.get_mut(&name) { - Some(handler) => handler(stats), - None => serde_json::Value::Null, - } - }); - let nats_service = nats_service_builder + let nats_service = nats_client + .client() + .service_builder() + .description(description) .start(service_name, SERVICE_VERSION.to_string()) .await .map_err(|e| anyhow::anyhow!("Failed to start NATS service: {e}"))?; - Ok((nats_service, stats_handler_registry_clone)) + Ok(nats_service) } diff --git a/lib/runtime/src/distributed.rs b/lib/runtime/src/distributed.rs index c46a861d19..d517664467 100644 --- a/lib/runtime/src/distributed.rs +++ b/lib/runtime/src/distributed.rs @@ -4,9 +4,8 @@ use crate::component::{Component, Instance}; use crate::pipeline::PipelineError; use crate::pipeline::network::manager::NetworkManager; -use crate::service::{ComponentNatsServerPrometheusMetrics, ServiceClient, ServiceSet}; +use crate::service::{ServiceClient, ServiceSet}; use crate::storage::kv::{self, Store as _}; -use crate::transports::nats::DRTNatsClientPrometheusMetrics; use crate::{ component::{self, ComponentBuilder, Endpoint, Namespace}, discovery::Discovery, @@ -75,6 +74,9 @@ pub struct DistributedRuntime { // This hierarchy's own metrics registry metrics_registry: MetricsRegistry, + + // Registry for /engine/* route callbacks + engine_routes: crate::engine_routes::EngineRouteRegistry, } impl MetricsHierarchy for DistributedRuntime { @@ -174,7 +176,6 @@ impl DistributedRuntime { }; let component_registry = component::Registry::new(); - let nats_client_for_metrics = nats_client.clone(); // NetworkManager for request plane let network_manager = NetworkManager::new( @@ -199,26 +200,9 @@ impl DistributedRuntime { system_health, request_plane, local_endpoint_registry: crate::local_endpoint_registry::LocalEndpointRegistry::new(), + engine_routes: crate::engine_routes::EngineRouteRegistry::new(), }; - if let Some(nats_client_for_metrics) = nats_client_for_metrics { - let nats_client_metrics = DRTNatsClientPrometheusMetrics::new( - &distributed_runtime, - nats_client_for_metrics.client().clone(), - )?; - // Register a callback to update NATS client metrics on the DRT's metrics registry - let nats_client_callback = Arc::new({ - let nats_client_clone = nats_client_metrics.clone(); - move || { - nats_client_clone.set_from_client_stats(); - Ok(()) - } - }); - distributed_runtime - .metrics_registry - .add_update_callback(nats_client_callback); - } - // Initialize the uptime gauge in SystemHealth distributed_runtime .system_health @@ -327,6 +311,11 @@ impl DistributedRuntime { &self.local_endpoint_registry } + /// Get the engine route registry for registering custom /engine/* routes + pub fn engine_routes(&self) -> &crate::engine_routes::EngineRouteRegistry { + &self.engine_routes + } + pub fn connection_id(&self) -> u64 { self.discovery_client.instance_id() } @@ -450,124 +439,84 @@ impl DistributedRuntime { Ok(response) } - /// Start NATS metrics service in the background to isolate the async, - /// and because we don't need it yet. - /// TODO: This and the things it calls should be in a nats module somewhere. - pub fn start_stats_service(&self, component: Component) { + /// DEPRECATED: This method exists only for NATS request plane support. + /// Once everything uses the TCP request plane, this can be removed along with + /// the NATS service registration infrastructure. + /// + /// Returns a receiver that signals when the NATS service registration is complete. + /// The caller should use `blocking_recv()` to wait for completion. + pub fn register_nats_service( + &self, + component: Component, + ) -> tokio::sync::mpsc::Receiver> { + // Create a oneshot-style channel (capacity 1) to signal completion + let (tx, rx) = tokio::sync::mpsc::channel::>(1); + let drt = self.clone(); self.runtime().secondary().spawn(async move { let service_name = component.service_name(); - if let Err(err) = drt.add_stats_service(component).await { - tracing::error!(error = %err, component = service_name, "Failed starting stats service"); - } - }); - } - /// Gather NATS metrics - async fn add_stats_service(&self, component: Component) -> anyhow::Result<()> { - let service_name = component.service_name(); + // Pre-check to save cost of creating the service, but don't hold the lock + if drt + .component_registry() + .inner + .lock() + .await + .services + .contains_key(&service_name) + { + // The NATS service is per component, but it is called from `serve_endpoint`, and there + // are often multiple endpoints for a component (e.g. `clear_kv_blocks` and `generate`). + tracing::trace!("Service {service_name} already exists"); + // Signal success - service already exists + let _ = tx.send(Ok(())).await; + return; + } - // Pre-check to save cost of creating the service, but don't hold the lock - if self - .component_registry() - .inner - .lock() + let Some(nats_client) = drt.nats_client.as_ref() else { + tracing::error!("Cannot create NATS service without NATS."); + let _ = tx + .send(Err("Cannot create NATS service without NATS".to_string())) + .await; + return; + }; + let description = None; + let nats_service = match crate::component::service::build_nats_service( + nats_client, + &component, + description, + ) .await - .services - .contains_key(&service_name) - { - // The NATS service is per component, but it is called from `serve_endpoint`, and there - // are often multiple endpoints for a component (e.g. `clear_kv_blocks` and `generate`). - tracing::trace!("Service {service_name} already exists"); - return Ok(()); - } - - let Some(nats_client) = self.nats_client.as_ref() else { - anyhow::bail!("Cannot create NATS service without NATS."); - }; - let description = None; - let (nats_service, stats_reg) = - crate::component::service::build_nats_service(nats_client, &component, description) - .await?; - - let mut guard = self.component_registry().inner.lock().await; - if !guard.services.contains_key(&service_name) { - // Normal case - guard.services.insert(service_name.clone(), nats_service); - guard.stats_handlers.insert(service_name.clone(), stats_reg); + { + Ok(service) => service, + Err(err) => { + tracing::error!(error = %err, component = service_name, "Failed to build NATS service"); + let _ = tx.send(Err(format!("Failed to build NATS service: {err}"))).await; + return; + } + }; - tracing::info!("Added NATS / stats service {service_name}"); + let mut guard = drt.component_registry().inner.lock().await; + if !guard.services.contains_key(&service_name) { + // Normal case + guard.services.insert(service_name.clone(), nats_service); - drop(guard); - } else { - drop(guard); - let _ = nats_service.stop().await; - // The NATS service is per component, but it is called from `serve_endpoint`, and there - // are often multiple endpoints for a component (e.g. `clear_kv_blocks` and `generate`). - // TODO: Is this still true? - return Ok(()); - } + tracing::info!("Added NATS service {service_name}"); - let cancel_token = self.primary_token(); - let service_client = self - .nats_client - .as_ref() - .map(|nc| ServiceClient::new(nc.clone())) - .ok_or_else(|| { - anyhow::anyhow!("Stats service requires NATS client to collect service metrics.") - })?; - // If there is another component with the same service name, this will fail. - let component_metrics = ComponentNatsServerPrometheusMetrics::new(&component)?; - - self.runtime().secondary().spawn(nats_metrics_worker( - cancel_token, - service_client, - component_metrics, - component, - )); - Ok(()) - } -} - -/// Add Prometheus metrics for this component's NATS service stats. -/// -/// Starts a background task that periodically requests service statistics from NATS -/// and updates the corresponding Prometheus metrics. The first scrape happens immediately, -/// then subsequent scrapes occur at a fixed interval of 9.8 seconds (MAX_WAIT_MS), -/// which should be near or smaller than typical Prometheus scraping intervals to ensure -/// metrics are fresh when Prometheus collects them. -async fn nats_metrics_worker( - cancel_token: CancellationToken, - service_client: ServiceClient, - component_metrics: ComponentNatsServerPrometheusMetrics, - component: Component, -) { - const MAX_WAIT_MS: Duration = Duration::from_millis(9800); // Should be <= Prometheus scrape interval - let timeout = Duration::from_millis(500); - let mut interval = tokio::time::interval(MAX_WAIT_MS); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - let service_name = component.service_name(); - loop { - tokio::select! { - result = service_client.collect_services(&service_name, timeout) => { - match result { - Ok(service_set) => { - component_metrics.update_from_service_set(&service_set); - } - Err(err) => { - tracing::error!("Background scrape failed for {service_name}: {err}",); - component_metrics.reset_to_zeros(); - } - } - } - _ = cancel_token.cancelled() => { - tracing::trace!("nats_metrics_worker stopped"); - break; + drop(guard); + } else { + drop(guard); + let _ = nats_service.stop().await; + // The NATS service is per component, but it is called from `serve_endpoint`, and there + // are often multiple endpoints for a component (e.g. `clear_kv_blocks` and `generate`). + // TODO: Is this still true? } - } - interval.tick().await; + // Signal completion - service registered successfully + let _ = tx.send(Ok(())).await; + }); + + rx } } diff --git a/lib/runtime/src/engine_routes.rs b/lib/runtime/src/engine_routes.rs new file mode 100644 index 0000000000..447dabe39a --- /dev/null +++ b/lib/runtime/src/engine_routes.rs @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, RwLock}; + +/// Callback type for engine routes (async) +/// Takes JSON body, returns JSON response (or error) wrapped in a Future +pub type EngineRouteCallback = Arc< + dyn Fn( + serde_json::Value, + ) -> Pin> + Send>> + + Send + + Sync, +>; + +/// Registry for engine route callbacks +/// +/// This registry stores callbacks that handle requests to `/engine/*` routes. +/// Routes are registered from Python via `runtime.register_engine_route()`. +#[derive(Clone, Default)] +pub struct EngineRouteRegistry { + routes: Arc>>, +} + +impl EngineRouteRegistry { + /// Create a new empty registry + pub fn new() -> Self { + Self { + routes: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Register a callback for a route (e.g., "start_profile" for /engine/start_profile) + pub fn register(&self, route: &str, callback: EngineRouteCallback) { + let mut routes = self.routes.write().unwrap(); + routes.insert(route.to_string(), callback); + tracing::debug!("Registered engine route: /engine/{}", route); + } + + /// Get callback for a route + pub fn get(&self, route: &str) -> Option { + let routes = self.routes.read().unwrap(); + routes.get(route).cloned() + } + + /// List all registered routes + pub fn routes(&self) -> Vec { + let routes = self.routes.read().unwrap(); + routes.keys().cloned().collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_registry_basic() { + let registry = EngineRouteRegistry::new(); + + // Register a simple callback + let callback: EngineRouteCallback = + Arc::new(|body| Box::pin(async move { Ok(serde_json::json!({"echo": body})) })); + + registry.register("test", callback); + + // Verify it's registered + assert!(registry.get("test").is_some()); + assert!(registry.get("nonexistent").is_none()); + + // Verify routes list + let routes = registry.routes(); + assert_eq!(routes.len(), 1); + assert!(routes.contains(&"test".to_string())); + } + + #[tokio::test] + async fn test_callback_execution() { + let registry = EngineRouteRegistry::new(); + + let callback: EngineRouteCallback = Arc::new(|body| { + Box::pin(async move { + let input = body.get("input").and_then(|v| v.as_str()).unwrap_or(""); + Ok(serde_json::json!({ + "output": format!("processed: {}", input) + })) + }) + }); + + registry.register("process", callback); + + // Get and execute callback + let cb = registry.get("process").unwrap(); + let result = cb(serde_json::json!({"input": "test"})).await.unwrap(); + + assert_eq!(result["output"], "processed: test"); + } + + #[tokio::test] + async fn test_clone_shares_routes() { + let registry = EngineRouteRegistry::new(); + + let callback: EngineRouteCallback = + Arc::new(|_| Box::pin(async { Ok(serde_json::json!({"ok": true})) })); + registry.register("test", callback); + + // Clone the registry + let cloned = registry.clone(); + + // Both should see the same route + assert!(registry.get("test").is_some()); + assert!(cloned.get("test").is_some()); + + // Register on clone + let callback2: EngineRouteCallback = + Arc::new(|_| Box::pin(async { Ok(serde_json::json!({"ok": false})) })); + cloned.register("test2", callback2); + + // Original should also see it (they share the Arc) + assert!(registry.get("test2").is_some()); + } +} diff --git a/lib/runtime/src/health_check.rs b/lib/runtime/src/health_check.rs index 18a1b671d7..a15cdda1bb 100644 --- a/lib/runtime/src/health_check.rs +++ b/lib/runtime/src/health_check.rs @@ -145,7 +145,7 @@ impl HealthCheckManager { tokio::select! { _ = tokio::time::sleep(canary_wait) => { // Timeout - send health check for this specific endpoint - info!("Canary timer expired for {}, sending health check", endpoint_subject); + debug!("Canary timer expired for {}, sending health check", endpoint_subject); // Get the health check payload for this endpoint let target = manager.drt.system_health().lock().get_health_check_target(&endpoint_subject); @@ -292,7 +292,7 @@ impl HealthCheckManager { ); false } else { - info!("Health check successful for {}", endpoint_subject_owned); + debug!("Health check successful for {}", endpoint_subject_owned); true } } else { @@ -303,6 +303,11 @@ impl HealthCheckManager { false }; + tokio::spawn(async move { + // We need to consume the rest of the stream to avoid warnings on the frontend. + response_stream.for_each(|_| async {}).await; + }); + // Update health status based on response system_health.lock().set_endpoint_health_status( &endpoint_subject_owned, diff --git a/lib/runtime/src/lib.rs b/lib/runtime/src/lib.rs index 57ad8f8d05..a482601cf9 100644 --- a/lib/runtime/src/lib.rs +++ b/lib/runtime/src/lib.rs @@ -24,6 +24,7 @@ pub mod component; pub mod compute; pub mod discovery; pub mod engine; +pub mod engine_routes; pub mod health_check; pub mod local_endpoint_registry; pub mod system_status_server; diff --git a/lib/runtime/src/metrics.rs b/lib/runtime/src/metrics.rs index 674eddd951..e9fa84456d 100644 --- a/lib/runtime/src/metrics.rs +++ b/lib/runtime/src/metrics.rs @@ -21,8 +21,8 @@ use std::collections::HashMap; // Import commonly used items to avoid verbose prefixes use prometheus_names::{ - COMPONENT_NATS_METRICS, DRT_NATS_METRICS, build_component_metric_name, labels, name_prefix, - nats_client, nats_service, sanitize_prometheus_label, sanitize_prometheus_name, work_handler, + build_component_metric_name, labels, name_prefix, sanitize_prometheus_label, + sanitize_prometheus_name, work_handler, }; // Pipeline imports for endpoint creation @@ -646,10 +646,15 @@ pub type PrometheusUpdateCallback = Arc anyhow::Result<()> + Send + pub type PrometheusExpositionFormatCallback = Arc anyhow::Result + Send + Sync + 'static>; -/// Structure to hold Prometheus registries and associated callbacks for a given hierarchy +/// Structure to hold Prometheus registries and associated callbacks for a given hierarchy. +/// +/// All fields are Arc-wrapped, so cloning shares state. This ensures metrics registered +/// on cloned instances (e.g., cloned Client/Endpoint) are visible to the original. +#[derive(Clone)] pub struct MetricsRegistry { - /// The Prometheus registry for this hierarchy (with interior mutability for thread-safe access) - pub prometheus_registry: std::sync::RwLock, + /// The Prometheus registry for this hierarchy. + /// Arc-wrapped so clones share the same registry (metrics registered on clones are visible everywhere). + pub prometheus_registry: Arc>, /// Update callbacks invoked before metrics are scraped. /// Wrapped in Arc to preserve callbacks across clones (prevents callback loss when MetricsRegistry is cloned). @@ -683,25 +688,11 @@ impl std::fmt::Debug for MetricsRegistry { } } -impl Clone for MetricsRegistry { - fn clone(&self) -> Self { - Self { - prometheus_registry: std::sync::RwLock::new( - self.prometheus_registry.read().unwrap().clone(), - ), - // Clone the Arc to share callbacks across all clones (prevents callback loss). - // Previously used Vec::new() here, which caused vllm: metrics to disappear. - prometheus_update_callbacks: Arc::clone(&self.prometheus_update_callbacks), - prometheus_expfmt_callbacks: Arc::clone(&self.prometheus_expfmt_callbacks), - } - } -} - impl MetricsRegistry { /// Create a new metrics registry with an empty Prometheus registry and callback lists pub fn new() -> Self { Self { - prometheus_registry: std::sync::RwLock::new(prometheus::Registry::new()), + prometheus_registry: Arc::new(std::sync::RwLock::new(prometheus::Registry::new())), prometheus_update_callbacks: Arc::new(std::sync::RwLock::new(Vec::new())), prometheus_expfmt_callbacks: Arc::new(std::sync::RwLock::new(Vec::new())), } @@ -792,7 +783,6 @@ impl Default for MetricsRegistry { #[cfg(test)] mod test_helpers { use super::prometheus_names::name_prefix; - use super::prometheus_names::{nats_client, nats_service}; use super::*; /// Base function to filter Prometheus output lines based on a predicate. @@ -808,36 +798,6 @@ mod test_helpers { .collect::>() } - /// Filters out all NATS metrics from Prometheus output for test comparisons. - pub fn remove_nats_lines(input: &str) -> Vec { - filter_prometheus_lines(input, |line| { - !line.contains(&format!( - "{}_{}", - name_prefix::COMPONENT, - nats_client::PREFIX - )) && !line.contains(&format!( - "{}_{}", - name_prefix::COMPONENT, - nats_service::PREFIX - )) && !line.trim().is_empty() - }) - } - - /// Filters to only include NATS metrics from Prometheus output for test comparisons. - pub fn extract_nats_lines(input: &str) -> Vec { - filter_prometheus_lines(input, |line| { - line.contains(&format!( - "{}_{}", - name_prefix::COMPONENT, - nats_client::PREFIX - )) || line.contains(&format!( - "{}_{}", - name_prefix::COMPONENT, - nats_service::PREFIX - )) - }) - } - /// Extracts all component metrics (excluding help text and type definitions). /// Returns only the actual metric lines with values. pub fn extract_metrics(input: &str) -> Vec { @@ -1208,8 +1168,6 @@ mod test_metricsregistry_prefixes { #[cfg(test)] mod test_metricsregistry_prometheus_fmt_outputs { use super::prometheus_names::name_prefix; - use super::prometheus_names::{COMPONENT_NATS_METRICS, DRT_NATS_METRICS}; - use super::prometheus_names::{nats_client, nats_service}; use super::*; use crate::distributed::distributed_test_utils::create_test_drt_async; use prometheus::Counter; @@ -1240,21 +1198,17 @@ mod test_metricsregistry_prometheus_fmt_outputs { println!("Endpoint output:"); println!("{}", endpoint_output_raw); - // Filter out NATS service metrics for test comparison - let endpoint_output = - super::test_helpers::remove_nats_lines(&endpoint_output_raw).join("\n"); - let expected_endpoint_output = r#"# HELP dynamo_component_testcounter A test counter # TYPE dynamo_component_testcounter counter dynamo_component_testcounter{dynamo_component="comp345",dynamo_endpoint="ep345",dynamo_namespace="ns345"} 123.456789"#.to_string(); assert_eq!( - endpoint_output, expected_endpoint_output, + endpoint_output_raw, expected_endpoint_output, "\n=== ENDPOINT COMPARISON FAILED ===\n\ - Expected:\n{}\n\ Actual:\n{}\n\ + Expected:\n{}\n\ ==============================", - expected_endpoint_output, endpoint_output + endpoint_output_raw, expected_endpoint_output ); // Test Gauge creation @@ -1270,10 +1224,6 @@ dynamo_component_testcounter{dynamo_component="comp345",dynamo_endpoint="ep345", println!("Component output:"); println!("{}", component_output_raw); - // Filter out NATS service metrics for test comparison - let component_output = - super::test_helpers::remove_nats_lines(&component_output_raw).join("\n"); - let expected_component_output = r#"# HELP dynamo_component_testcounter A test counter # TYPE dynamo_component_testcounter counter dynamo_component_testcounter{dynamo_component="comp345",dynamo_endpoint="ep345",dynamo_namespace="ns345"} 123.456789 @@ -1282,12 +1232,12 @@ dynamo_component_testcounter{dynamo_component="comp345",dynamo_endpoint="ep345", dynamo_component_testgauge{dynamo_component="comp345",dynamo_namespace="ns345"} 50000"#.to_string(); assert_eq!( - component_output, expected_component_output, + component_output_raw, expected_component_output, "\n=== COMPONENT COMPARISON FAILED ===\n\ - Expected:\n{}\n\ Actual:\n{}\n\ + Expected:\n{}\n\ ==============================", - expected_component_output, component_output + component_output_raw, expected_component_output ); let intcounter = namespace @@ -1302,10 +1252,6 @@ dynamo_component_testgauge{dynamo_component="comp345",dynamo_namespace="ns345"} println!("Namespace output:"); println!("{}", namespace_output_raw); - // Filter out NATS service metrics for test comparison - let namespace_output = - super::test_helpers::remove_nats_lines(&namespace_output_raw).join("\n"); - let expected_namespace_output = r#"# HELP dynamo_component_testcounter A test counter # TYPE dynamo_component_testcounter counter dynamo_component_testcounter{dynamo_component="comp345",dynamo_endpoint="ep345",dynamo_namespace="ns345"} 123.456789 @@ -1317,12 +1263,12 @@ dynamo_component_testgauge{dynamo_component="comp345",dynamo_namespace="ns345"} dynamo_component_testintcounter{dynamo_namespace="ns345"} 12345"#.to_string(); assert_eq!( - namespace_output, expected_namespace_output, + namespace_output_raw, expected_namespace_output, "\n=== NAMESPACE COMPARISON FAILED ===\n\ - Expected:\n{}\n\ Actual:\n{}\n\ + Expected:\n{}\n\ ==============================", - expected_namespace_output, namespace_output + namespace_output_raw, expected_namespace_output ); // Test IntGauge creation @@ -1377,10 +1323,6 @@ dynamo_component_testintcounter{dynamo_namespace="ns345"} 12345"#.to_string(); println!("DRT output:"); println!("{}", drt_output_raw); - // Filter out all NATS metrics for comparison - let filtered_drt_output = - super::test_helpers::remove_nats_lines(&drt_output_raw).join("\n"); - let expected_drt_output = r#"# HELP dynamo_component_testcounter A test counter # TYPE dynamo_component_testcounter counter dynamo_component_testcounter{dynamo_component="comp345",dynamo_endpoint="ep345",dynamo_namespace="ns345"} 123.456789 @@ -1422,12 +1364,12 @@ dynamo_component_testintgaugevec{dynamo_namespace="ns345",instance="server2",ser dynamo_component_uptime_seconds 0"#.to_string(); assert_eq!( - filtered_drt_output, expected_drt_output, + drt_output_raw, expected_drt_output, "\n=== DRT COMPARISON FAILED ===\n\ Expected:\n{}\n\ Actual (filtered):\n{}\n\ ==============================", - expected_drt_output, filtered_drt_output + expected_drt_output, drt_output_raw ); println!("โœ“ All Prometheus format outputs verified successfully!"); @@ -1435,33 +1377,19 @@ dynamo_component_uptime_seconds 0"#.to_string(); #[test] fn test_refactored_filter_functions() { - // Test data with mixed content + // Test data with component metrics let test_input = r#"# HELP dynamo_component_requests Total requests # TYPE dynamo_component_requests counter dynamo_component_requests 42 -# HELP dynamo_component_nats_client_connection_state Connection state -# TYPE dynamo_component_nats_client_connection_state gauge -dynamo_component_nats_client_connection_state 1 # HELP dynamo_component_latency Response latency # TYPE dynamo_component_latency histogram dynamo_component_latency_bucket{le="0.1"} 10 dynamo_component_latency_bucket{le="0.5"} 25 -dynamo_component_nats_service_requests_total 100 -dynamo_component_nats_service_errors_total 5"#; - - // Test remove_nats_lines (excludes NATS lines but keeps help/type) - let filtered_out = super::test_helpers::remove_nats_lines(test_input); - assert_eq!(filtered_out.len(), 7); // 7 non-NATS lines - assert!(!filtered_out.iter().any(|line| line.contains("nats"))); - - // Test extract_nats_lines (includes all NATS lines including help/type) - let filtered_only = super::test_helpers::extract_nats_lines(test_input); - assert_eq!(filtered_only.len(), 5); // 5 NATS lines - assert!(filtered_only.iter().all(|line| line.contains("nats"))); +dynamo_component_errors_total 5"#; // Test extract_metrics (only actual metric lines, excluding help/type) let metrics_only = super::test_helpers::extract_metrics(test_input); - assert_eq!(metrics_only.len(), 6); // 6 actual metric lines (excluding help/type) + assert_eq!(metrics_only.len(), 4); // 4 actual metric lines (excluding help/type) assert!( metrics_only .iter() @@ -1471,490 +1399,3 @@ dynamo_component_nats_service_errors_total 5"#; println!("โœ“ All refactored filter functions work correctly!"); } } - -#[cfg(feature = "integration")] -#[cfg(test)] -mod test_metricsregistry_nats { - use super::prometheus_names::name_prefix; - use super::prometheus_names::{COMPONENT_NATS_METRICS, DRT_NATS_METRICS}; - use super::prometheus_names::{nats_client, nats_service}; - use super::*; - use crate::distributed::distributed_test_utils::create_test_drt_async; - use crate::pipeline::PushRouter; - use crate::{DistributedRuntime, Runtime}; - use tokio::time::{Duration, sleep}; - #[ignore = "Deprecated - NATS related code to be deleted soon"] - #[tokio::test] - async fn test_drt_nats_metrics() { - // Setup real DRT and registry using the test-friendly constructor - let drt = create_test_drt_async().await; - - // Get DRT output which should include NATS client metrics - let drt_output = drt.metrics().prometheus_expfmt().unwrap(); - println!("DRT output with NATS metrics:"); - println!("{}", drt_output); - - // Additional checks for NATS client metrics (without checking specific values) - let drt_nats_metrics = super::test_helpers::extract_nats_lines(&drt_output); - - // Check that NATS client metrics are present - assert!( - !drt_nats_metrics.is_empty(), - "NATS client metrics should be present" - ); - - // Check for specific NATS client metric names (without values) - // Extract only the metric lines from the already-filtered NATS metrics - let drt_nats_metric_lines = - super::test_helpers::extract_metrics(&drt_nats_metrics.join("\n")); - let actual_drt_nats_metrics_sorted: Vec<&str> = drt_nats_metric_lines - .iter() - .map(|line| { - let without_labels = line.split('{').next().unwrap_or(line); - // Remove the value part (everything after the last space) - without_labels.split(' ').next().unwrap_or(without_labels) - }) - .collect(); - - let expect_drt_nats_metrics_sorted = { - let mut temp = DRT_NATS_METRICS - .iter() - .map(|metric| build_component_metric_name(metric)) - .collect::>(); - temp.sort(); - temp - }; - - // Print both lists for comparison - println!( - "actual_drt_nats_metrics_sorted: {:?}", - actual_drt_nats_metrics_sorted - ); - println!( - "expect_drt_nats_metrics_sorted: {:?}", - expect_drt_nats_metrics_sorted - ); - - // Compare the sorted lists - assert_eq!( - actual_drt_nats_metrics_sorted, expect_drt_nats_metrics_sorted, - "DRT_NATS_METRICS with prefix and expected_nats_metrics should be identical when sorted" - ); - - println!("โœ“ DistributedRuntime NATS metrics integration test passed!"); - } - - #[ignore = "Deprecated - NATS related code to be deleted soon"] - #[tokio::test] - async fn test_nats_metric_names() { - // This test only tests the existence of the NATS metrics. It does not check - // the values of the metrics. - - // Setup real DRT and registry using the test-friendly constructor - let drt = create_test_drt_async().await; - - // Create a namespace and component from the DRT - let namespace = drt.namespace("ns789").unwrap(); - let component = namespace.component("comp789").unwrap(); - - // Get component output which should include NATS client metrics - // Additional checks for NATS client metrics (without checking specific values) - let component_nats_metrics = super::test_helpers::extract_nats_lines( - &component.metrics().prometheus_expfmt().unwrap(), - ); - println!( - "Component NATS metrics count: {}", - component_nats_metrics.len() - ); - - // Check that NATS client metrics are present - assert!( - !component_nats_metrics.is_empty(), - "NATS client metrics should be present" - ); - - // Check for specific NATS client metric names (without values) - let component_metrics = - super::test_helpers::extract_metrics(&component.metrics().prometheus_expfmt().unwrap()); - let actual_component_nats_metrics_sorted: Vec<&str> = component_metrics - .iter() - .map(|line| { - let without_labels = line.split('{').next().unwrap_or(line); - // Remove the value part (everything after the last space) - without_labels.split(' ').next().unwrap_or(without_labels) - }) - .collect(); - - let expect_component_nats_metrics_sorted = { - let mut temp = COMPONENT_NATS_METRICS - .iter() - .map(|metric| build_component_metric_name(metric)) - .collect::>(); - temp.sort(); - temp - }; - - // Print both lists for comparison - println!( - "actual_component_nats_metrics_sorted: {:?}", - actual_component_nats_metrics_sorted - ); - println!( - "expect_component_nats_metrics_sorted: {:?}", - expect_component_nats_metrics_sorted - ); - - // Compare the sorted lists - assert_eq!( - actual_component_nats_metrics_sorted, expect_component_nats_metrics_sorted, - "COMPONENT_NATS_METRICS with prefix and expected_nats_metrics should be identical when sorted" - ); - - // Get both DRT and component output and filter for NATS metrics only - let drt_output = drt.metrics().prometheus_expfmt().unwrap(); - let drt_nats_lines = super::test_helpers::extract_nats_lines(&drt_output); - let drt_and_component_nats_metrics = - super::test_helpers::extract_metrics(&drt_nats_lines.join("\n")); - println!( - "DRT and component NATS metrics count: {}", - drt_and_component_nats_metrics.len() - ); - - // Check that the NATS metrics are present in the component output - assert_eq!( - drt_and_component_nats_metrics.len(), - DRT_NATS_METRICS.len() + COMPONENT_NATS_METRICS.len(), - "DRT at this point should have both the DRT and component NATS metrics" - ); - - // Check that the NATS metrics are present in the component output - println!("โœ“ Component NATS metrics integration test passed!"); - } - - /// Tests NATS metrics values before and after endpoint activity with large message processing. - /// Creates endpoint, sends test messages + 10k byte message, validates metrics (NATS + work handler) - /// at initial state and post-activity state. Ensures byte thresholds, message counts, and processing - /// times are within expected ranges. Tests end-to-end client-server communication and metrics collection. - #[ignore = "Deprecated - NATS related code to be deleted soon"] - #[tokio::test] - async fn test_nats_metrics_values() -> anyhow::Result<()> { - struct MessageHandler {} - impl MessageHandler { - fn new() -> std::sync::Arc { - std::sync::Arc::new(Self {}) - } - } - - #[async_trait] - impl AsyncEngine, ManyOut>, Error> for MessageHandler { - async fn generate( - &self, - input: SingleIn, - ) -> Result>, Error> { - let (data, ctx) = input.into_parts(); - let response = data.to_string(); - let stream = stream::iter(vec![Annotated::from_data(response)]); - Ok(ResponseStream::new(Box::pin(stream), ctx.context())) - } - } - - println!("\n=== Initializing DistributedRuntime ==="); - let runtime = Runtime::from_current()?; - let drt = DistributedRuntime::from_settings(runtime.clone()).await?; - let namespace = drt.namespace("ns123").unwrap(); - let component = namespace.component("comp123").unwrap(); - let ingress = Ingress::for_engine(MessageHandler::new()).unwrap(); - - let _backend_handle = tokio::spawn(async move { - let endpoint = component - .endpoint("echo") - .endpoint_builder() - .handler(ingress); - endpoint.start().await.unwrap(); - }); - - sleep(Duration::from_millis(500)).await; - println!("โœ“ Launched endpoint service in background successfully"); - - let drt_output = drt.metrics().prometheus_expfmt().unwrap(); - let parsed_metrics: Vec<_> = drt_output - .lines() - .filter_map(super::test_helpers::parse_prometheus_metric) - .collect(); - - println!("=== Initial DRT metrics output ==="); - println!("{}", drt_output); - - println!("\n=== Checking Initial Metric Values ==="); - - let initial_expected_metric_values = [ - // DRT NATS metrics (ordered to match DRT_NATS_METRICS) - ( - build_component_metric_name(nats_client::CONNECTION_STATE), - 1.0, - 1.0, - ), // Should be connected - ( - build_component_metric_name(nats_client::CURRENT_CONNECTIONS), - 1.0, - 1.0, - ), // Should have 1 connection - ( - build_component_metric_name(nats_client::IN_TOTAL_BYTES), - 800.0, - 4000.0, - ), // Wide range around observed value of 1888 - ( - build_component_metric_name(nats_client::IN_MESSAGES), - 0.0, - 5.0, - ), // Wide range around 2 - ( - build_component_metric_name(nats_client::OUT_OVERHEAD_BYTES), - 1500.0, - 5000.0, - ), // Wide range around observed value of 2752 - ( - build_component_metric_name(nats_client::OUT_MESSAGES), - 0.0, - 5.0, - ), // Wide range around 2 - // Component NATS metrics (ordered to match COMPONENT_NATS_METRICS) - ( - build_component_metric_name(nats_service::PROCESSING_MS_AVG), - 0.0, - 0.0, - ), // No processing yet - ( - build_component_metric_name(nats_service::ERRORS_TOTAL), - 0.0, - 0.0, - ), // No errors yet - ( - build_component_metric_name(nats_service::REQUESTS_TOTAL), - 0.0, - 0.0, - ), // No requests yet - ( - build_component_metric_name(nats_service::PROCESSING_MS_TOTAL), - 0.0, - 0.0, - ), // No processing yet - ( - build_component_metric_name(nats_service::ACTIVE_SERVICES), - 0.0, - 2.0, - ), // Service may not be fully active yet - ( - build_component_metric_name(nats_service::ACTIVE_ENDPOINTS), - 0.0, - 2.0, - ), // Endpoint may not be fully active yet - ]; - - for (metric_name, min_value, max_value) in &initial_expected_metric_values { - let actual_value = parsed_metrics - .iter() - .find(|(name, _, _)| name == metric_name) - .map(|(_, _, value)| *value) - .unwrap_or_else(|| panic!("Could not find expected metric: {}", metric_name)); - - assert!( - actual_value >= *min_value && actual_value <= *max_value, - "Initial metric {} should be between {} and {}, but got {}", - metric_name, - min_value, - max_value, - actual_value - ); - } - - println!("\n=== Client Runtime to hit the endpoint ==="); - let client_runtime = Runtime::from_current()?; - let client_distributed = DistributedRuntime::from_settings(client_runtime.clone()).await?; - let namespace = client_distributed.namespace("ns123")?; - let component = namespace.component("comp123")?; - let client = component.endpoint("echo").client().await?; - - client.wait_for_instances().await?; - println!("โœ“ Connected to endpoint, waiting for instances..."); - - let router = - PushRouter::>::from_client(client, Default::default()) - .await?; - - for i in 0..10 { - let msg = i.to_string().repeat(2000); // 2k bytes message - let mut stream = router.random(msg.clone().into()).await?; - while let Some(resp) = stream.next().await { - // Check if response matches the original message - if let Some(data) = &resp.data { - let is_same = data == &msg; - println!( - "Response {}: {} bytes, matches original: {}", - i, - data.len(), - is_same - ); - } - } - } - println!("โœ“ Sent messages and received responses successfully"); - - println!("\n=== Waiting 500ms for metrics to update ==="); - sleep(Duration::from_millis(500)).await; - println!("โœ“ Wait complete, getting final metrics..."); - - let final_drt_output = drt.metrics().prometheus_expfmt().unwrap(); - println!("\n=== Final Prometheus DRT output ==="); - println!("{}", final_drt_output); - - let final_drt_nats_output = super::test_helpers::extract_nats_lines(&final_drt_output); - println!("\n=== Filtered NATS metrics from final DRT output ==="); - for line in &final_drt_nats_output { - println!("{}", line); - } - - let final_parsed_metrics: Vec<_> = super::test_helpers::extract_metrics(&final_drt_output) - .iter() - .filter_map(|line| super::test_helpers::parse_prometheus_metric(line.as_str())) - .collect(); - - let post_expected_metric_values = [ - // DRT NATS metrics - ( - build_component_metric_name(nats_client::CONNECTION_STATE), - 1.0, - 1.0, - ), // Connected - ( - build_component_metric_name(nats_client::CURRENT_CONNECTIONS), - 1.0, - 1.0, - ), // 1 connection - ( - build_component_metric_name(nats_client::IN_TOTAL_BYTES), - 20000.0, - 32000.0, - ), // Wide range around 26117 - ( - build_component_metric_name(nats_client::IN_MESSAGES), - 8.0, - 20.0, - ), // Wide range around 16 - ( - build_component_metric_name(nats_client::OUT_OVERHEAD_BYTES), - 2500.0, - 8000.0, - ), // Wide range around 5524 - ( - build_component_metric_name(nats_client::OUT_MESSAGES), - 8.0, - 20.0, - ), // Wide range around 16 - // Component NATS metrics - ( - build_component_metric_name(nats_service::PROCESSING_MS_AVG), - 0.0, - 1.0, - ), // Low processing time - ( - build_component_metric_name(nats_service::ERRORS_TOTAL), - 0.0, - 0.0, - ), // No errors - ( - build_component_metric_name(nats_service::REQUESTS_TOTAL), - 0.0, - 10.0, - ), // NATS service stats requests (may differ from work handler count) - ( - build_component_metric_name(nats_service::PROCESSING_MS_TOTAL), - 0.0, - 5.0, - ), // Low total processing time - ( - build_component_metric_name(nats_service::ACTIVE_SERVICES), - 0.0, - 2.0, - ), // Service may not be fully active - ( - build_component_metric_name(nats_service::ACTIVE_ENDPOINTS), - 0.0, - 2.0, - ), // Endpoint may not be fully active - // Work handler metrics - ( - build_component_metric_name(work_handler::REQUESTS_TOTAL), - 10.0, - 10.0, - ), // 10 messages - ( - build_component_metric_name(work_handler::REQUEST_BYTES_TOTAL), - 21000.0, - 26000.0, - ), // ~75-125% of 23520 - ( - build_component_metric_name(work_handler::RESPONSE_BYTES_TOTAL), - 18000.0, - 23000.0, - ), // ~75-125% of 20660 - ( - build_component_metric_name(work_handler::INFLIGHT_REQUESTS), - 0.0, - 1.0, - ), // 0 or very low - // Histograms have _{count,sum} suffixes - ( - format!( - "{}_count", - build_component_metric_name(work_handler::REQUEST_DURATION_SECONDS) - ), - 10.0, - 10.0, - ), // 10 messages - ( - format!( - "{}_sum", - build_component_metric_name(work_handler::REQUEST_DURATION_SECONDS) - ), - 0.0001, - 1.0, - ), // Processing time sum (wide range) - ]; - - println!("\n=== Checking Post-Activity All Metrics (NATS + Work Handler) ==="); - for (metric_name, min_value, max_value) in &post_expected_metric_values { - let actual_value = final_parsed_metrics - .iter() - .find(|(name, _, _)| name == metric_name) - .map(|(_, _, value)| *value) - .unwrap_or_else(|| { - panic!( - "Could not find expected post-activity metric: {}", - metric_name - ) - }); - - assert!( - actual_value >= *min_value && actual_value <= *max_value, - "Post-activity metric {} should be between {} and {}, but got {}", - metric_name, - min_value, - max_value, - actual_value - ); - println!( - "โœ“ {}: {} (range: {} to {})", - metric_name, actual_value, min_value, max_value - ); - } - - println!("โœ“ All NATS and component metrics parsed successfully!"); - println!("โœ“ Byte metrics verified to be >= 100 bytes!"); - println!("โœ“ Post-activity metrics verified with higher thresholds!"); - println!("โœ“ Work handler metrics reflect increased activity!"); - - Ok(()) - } -} diff --git a/lib/runtime/src/metrics/prometheus_names.rs b/lib/runtime/src/metrics/prometheus_names.rs index 91153a134d..9a070fc1f5 100644 --- a/lib/runtime/src/metrics/prometheus_names.rs +++ b/lib/runtime/src/metrics/prometheus_names.rs @@ -113,6 +113,9 @@ pub mod frontend_service { /// Output sequence length in tokens pub const OUTPUT_SEQUENCE_TOKENS: &str = "output_sequence_tokens"; + /// Number of cached tokens (prefix cache hits) per request + pub const CACHED_TOKENS: &str = "cached_tokens"; + /// Total number of output tokens generated (counter that updates in real-time) pub const OUTPUT_TOKENS_TOTAL: &str = "output_tokens_total"; @@ -209,90 +212,6 @@ pub mod work_handler { } } -/// NATS client metrics. DistributedRuntime contains a NATS client shared by all children) -pub mod nats_client { - /// Macro to generate NATS client metric names with the prefix - macro_rules! nats_client_name { - ($name:expr) => { - concat!("nats_client_", $name) - }; - } - - /// Prefix for all NATS client metrics - pub const PREFIX: &str = nats_client_name!(""); - - /// Total number of bytes received by NATS client - pub const IN_TOTAL_BYTES: &str = nats_client_name!("in_total_bytes"); - - /// Total number of bytes sent by NATS client - pub const OUT_OVERHEAD_BYTES: &str = nats_client_name!("out_overhead_bytes"); - - /// Total number of messages received by NATS client - pub const IN_MESSAGES: &str = nats_client_name!("in_messages"); - - /// Total number of messages sent by NATS client - pub const OUT_MESSAGES: &str = nats_client_name!("out_messages"); - - /// Current number of active connections for NATS client - /// Note: Gauge metric measuring current connections, not cumulative total - pub const CURRENT_CONNECTIONS: &str = nats_client_name!("current_connections"); - - /// Current connection state of NATS client (0=disconnected, 1=connected, 2=reconnecting) - pub const CONNECTION_STATE: &str = nats_client_name!("connection_state"); -} - -/// NATS service metrics, from the $SRV.STATS. requests on NATS server -pub mod nats_service { - /// Macro to generate NATS service metric names with the prefix - macro_rules! nats_service_name { - ($name:expr) => { - concat!("nats_service_", $name) - }; - } - - /// Prefix for all NATS service metrics - pub const PREFIX: &str = nats_service_name!(""); - - /// Average processing time in milliseconds (maps to: average_processing_time in ms) - pub const PROCESSING_MS_AVG: &str = nats_service_name!("processing_ms_avg"); - - /// Total errors across all endpoints (maps to: num_errors) - pub const ERRORS_TOTAL: &str = nats_service_name!("errors_total"); - - /// Total requests across all endpoints (maps to: num_requests) - pub const REQUESTS_TOTAL: &str = nats_service_name!("requests_total"); - - /// Total processing time in milliseconds (maps to: processing_time in ms) - pub const PROCESSING_MS_TOTAL: &str = nats_service_name!("processing_ms_total"); - - /// Number of active services (derived from ServiceSet.services) - pub const ACTIVE_SERVICES: &str = nats_service_name!("active_services"); - - /// Number of active endpoints (derived from ServiceInfo.endpoints) - pub const ACTIVE_ENDPOINTS: &str = nats_service_name!("active_endpoints"); -} - -/// All NATS client Prometheus metric names as an array for iteration/validation -pub const DRT_NATS_METRICS: &[&str] = &[ - nats_client::CONNECTION_STATE, - nats_client::CURRENT_CONNECTIONS, - nats_client::IN_TOTAL_BYTES, - nats_client::IN_MESSAGES, - nats_client::OUT_OVERHEAD_BYTES, - nats_client::OUT_MESSAGES, -]; - -/// All component service Prometheus metric names as an array for iteration/validation -/// (ordered to match NatsStatsMetrics fields) -pub const COMPONENT_NATS_METRICS: &[&str] = &[ - nats_service::PROCESSING_MS_AVG, // maps to: average_processing_time (nanoseconds) - nats_service::ERRORS_TOTAL, // maps to: num_errors - nats_service::REQUESTS_TOTAL, // maps to: num_requests - nats_service::PROCESSING_MS_TOTAL, // maps to: processing_time (nanoseconds) - nats_service::ACTIVE_SERVICES, // derived from ServiceSet.services - nats_service::ACTIVE_ENDPOINTS, // derived from ServiceInfo.endpoints -]; - /// Task tracker Prometheus metric name suffixes pub mod task_tracker { /// Total number of tasks issued/submitted diff --git a/lib/runtime/src/pipeline/network/ingress/http_endpoint.rs b/lib/runtime/src/pipeline/network/ingress/http_endpoint.rs index d3e1aca14b..743e794cf0 100644 --- a/lib/runtime/src/pipeline/network/ingress/http_endpoint.rs +++ b/lib/runtime/src/pipeline/network/ingress/http_endpoint.rs @@ -105,7 +105,27 @@ impl SharedHttpServer { .system_health .lock() .set_endpoint_health_status(endpoint_name, HealthStatus::NotReady); - tracing::debug!("Unregistered endpoint handler for subject: {}", subject); + tracing::debug!( + endpoint_name = %endpoint_name, + subject = %subject, + "Unregistered HTTP endpoint handler" + ); + + let inflight_count = handler.inflight.load(Ordering::SeqCst); + if inflight_count > 0 { + tracing::info!( + endpoint_name = %endpoint_name, + inflight_count = inflight_count, + "Waiting for inflight HTTP requests to complete" + ); + while handler.inflight.load(Ordering::SeqCst) > 0 { + handler.notify.notified().await; + } + tracing::info!( + endpoint_name = %endpoint_name, + "All inflight HTTP requests completed" + ); + } } } diff --git a/lib/runtime/src/pipeline/network/ingress/nats_server.rs b/lib/runtime/src/pipeline/network/ingress/nats_server.rs index 2fb9f1d32b..e6c93c8417 100644 --- a/lib/runtime/src/pipeline/network/ingress/nats_server.rs +++ b/lib/runtime/src/pipeline/network/ingress/nats_server.rs @@ -32,6 +32,7 @@ pub struct NatsMultiplexedServer { struct EndpointTask { cancel_token: CancellationToken, + join_handle: tokio::task::JoinHandle<()>, _endpoint_name: String, } @@ -145,7 +146,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer { // Spawn task to handle this endpoint using PushEndpoint // Note: PushEndpoint::start() is a blocking loop that runs until cancelled let endpoint_name_clone = endpoint_name.clone(); - tokio::spawn(async move { + let join_handle = tokio::spawn(async move { if let Err(e) = push_endpoint .start( service_endpoint, @@ -180,6 +181,7 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer { endpoint_name.clone(), EndpointTask { cancel_token: endpoint_cancel, + join_handle, _endpoint_name: endpoint_name, }, ); @@ -193,7 +195,25 @@ impl super::unified_server::RequestPlaneServer for NatsMultiplexedServer { endpoint_name = %endpoint_name, "Unregistering NATS endpoint" ); + // Cancel the token to trigger graceful shutdown task.cancel_token.cancel(); + + // Wait for the endpoint task to complete (which includes waiting for inflight requests) + tracing::debug!( + endpoint_name = %endpoint_name, + "Waiting for NATS endpoint task to complete" + ); + if let Err(e) = task.join_handle.await { + tracing::warn!( + endpoint_name = %endpoint_name, + error = %e, + "NATS endpoint task panicked during shutdown" + ); + } + tracing::info!( + endpoint_name = %endpoint_name, + "NATS endpoint unregistration complete" + ); } Ok(()) } diff --git a/lib/runtime/src/pipeline/network/ingress/push_endpoint.rs b/lib/runtime/src/pipeline/network/ingress/push_endpoint.rs index adaa3e4756..4915f88925 100644 --- a/lib/runtime/src/pipeline/network/ingress/push_endpoint.rs +++ b/lib/runtime/src/pipeline/network/ingress/push_endpoint.rs @@ -135,16 +135,26 @@ impl PushEndpoint { // await for all inflight requests to complete if graceful shutdown if self.graceful_shutdown { - tracing::info!( - "Waiting for {} inflight requests to complete", - inflight.load(Ordering::SeqCst) - ); - while inflight.load(Ordering::SeqCst) > 0 { - notify.notified().await; + let inflight_count = inflight.load(Ordering::SeqCst); + if inflight_count > 0 { + tracing::info!( + endpoint_name = endpoint_name_local.as_str(), + inflight_count = inflight_count, + "Waiting for inflight NATS requests to complete" + ); + while inflight.load(Ordering::SeqCst) > 0 { + notify.notified().await; + } + tracing::info!( + endpoint_name = endpoint_name_local.as_str(), + "All inflight NATS requests completed" + ); } - tracing::info!("All inflight requests completed"); } else { - tracing::info!("Skipping graceful shutdown, not waiting for inflight requests"); + tracing::info!( + endpoint_name = endpoint_name_local.as_str(), + "Skipping graceful shutdown, not waiting for inflight requests" + ); } Ok(()) diff --git a/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs b/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs index 0640cc6dcd..2b9d880fc2 100644 --- a/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs +++ b/lib/runtime/src/pipeline/network/ingress/shared_tcp_endpoint.rs @@ -100,11 +100,33 @@ impl SharedTcpServer { } pub async fn unregister_endpoint(&self, endpoint_path: &str, endpoint_name: &str) { - self.handlers.remove(endpoint_path); - tracing::info!( - "Unregistered endpoint '{}' from shared TCP server", - endpoint_name - ); + if let Some((_, handler)) = self.handlers.remove(endpoint_path) { + handler + .system_health + .lock() + .set_endpoint_health_status(endpoint_name, crate::HealthStatus::NotReady); + tracing::info!( + endpoint_name = %endpoint_name, + endpoint_path = %endpoint_path, + "Unregistered TCP endpoint handler" + ); + + let inflight_count = handler.inflight.load(Ordering::SeqCst); + if inflight_count > 0 { + tracing::info!( + endpoint_name = %endpoint_name, + inflight_count = inflight_count, + "Waiting for inflight TCP requests to complete" + ); + while handler.inflight.load(Ordering::SeqCst) > 0 { + handler.notify.notified().await; + } + tracing::info!( + endpoint_name = %endpoint_name, + "All inflight TCP requests completed" + ); + } + } } pub async fn start(self: Arc) -> Result<()> { @@ -369,3 +391,218 @@ impl super::unified_server::RequestPlaneServer for SharedTcpServer { true } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::pipeline::error::PipelineError; + use async_trait::async_trait; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::time::Duration; + use tokio::time::Instant; + + /// Mock handler that simulates slow request processing for testing + struct SlowMockHandler { + /// Tracks if a request is currently being processed + request_in_flight: Arc, + /// Notifies when request processing starts + request_started: Arc, + /// Notifies when request processing completes + request_completed: Arc, + /// Duration to simulate request processing + processing_duration: Duration, + } + + impl SlowMockHandler { + fn new(processing_duration: Duration) -> Self { + Self { + request_in_flight: Arc::new(AtomicBool::new(false)), + request_started: Arc::new(Notify::new()), + request_completed: Arc::new(Notify::new()), + processing_duration, + } + } + } + + #[async_trait] + impl PushWorkHandler for SlowMockHandler { + async fn handle_payload(&self, _payload: Bytes) -> Result<(), PipelineError> { + self.request_in_flight.store(true, Ordering::SeqCst); + self.request_started.notify_one(); + + tracing::debug!( + "SlowMockHandler: Request started, sleeping for {:?}", + self.processing_duration + ); + + // Simulate slow request processing + tokio::time::sleep(self.processing_duration).await; + + tracing::debug!("SlowMockHandler: Request completed"); + + self.request_in_flight.store(false, Ordering::SeqCst); + self.request_completed.notify_one(); + Ok(()) + } + + fn add_metrics( + &self, + _endpoint: &crate::component::Endpoint, + _metrics_labels: Option<&[(&str, &str)]>, + ) -> Result<()> { + Ok(()) + } + } + + #[tokio::test] + async fn test_graceful_shutdown_waits_for_inflight_tcp_requests() { + // Initialize tracing for test debugging + let _ = tracing_subscriber::fmt() + .with_test_writer() + .with_max_level(tracing::Level::DEBUG) + .try_init(); + + let cancellation_token = CancellationToken::new(); + let bind_addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); + + // Create SharedTcpServer + let server = SharedTcpServer::new(bind_addr, cancellation_token.clone()); + + // Create a handler that takes 1s to process requests + let handler = Arc::new(SlowMockHandler::new(Duration::from_secs(1))); + let request_started = handler.request_started.clone(); + let request_completed = handler.request_completed.clone(); + let request_in_flight = handler.request_in_flight.clone(); + + // Register endpoint + let endpoint_path = "test_endpoint".to_string(); + let system_health = Arc::new(Mutex::new(SystemHealth::new( + crate::HealthStatus::Ready, + vec![], + "/health".to_string(), + "/live".to_string(), + ))); + + server + .register_endpoint( + endpoint_path.clone(), + handler.clone() as Arc, + 1, + "test_namespace".to_string(), + "test_component".to_string(), + "test_endpoint".to_string(), + system_health, + ) + .await + .expect("Failed to register endpoint"); + + tracing::debug!("Endpoint registered"); + + // Get the endpoint handler to simulate request processing + let endpoint_handler = server + .handlers + .get(&endpoint_path) + .expect("Handler should be registered") + .clone(); + + // Spawn a task that simulates an inflight request + let request_task = tokio::spawn({ + let handler = handler.clone(); + async move { + let payload = Bytes::from("test payload"); + handler.handle_payload(payload).await + } + }); + + // Increment inflight counter manually to simulate the request being tracked + endpoint_handler.inflight.fetch_add(1, Ordering::SeqCst); + + // Wait for request to start processing + tokio::select! { + _ = request_started.notified() => { + tracing::debug!("Request processing started"); + } + _ = tokio::time::sleep(Duration::from_secs(2)) => { + panic!("Timeout waiting for request to start"); + } + } + + // Verify request is in flight + assert!( + request_in_flight.load(Ordering::SeqCst), + "Request should be in flight" + ); + + // Now unregister the endpoint while request is inflight + let unregister_start = Instant::now(); + tracing::debug!("Starting unregister_endpoint with inflight request"); + + // Spawn unregister in a separate task so we can monitor its behavior + let unregister_task = tokio::spawn({ + let server = server.clone(); + let endpoint_path = endpoint_path.clone(); + async move { + server + .unregister_endpoint(&endpoint_path, "test_endpoint") + .await; + Instant::now() + } + }); + + // Give unregister a moment to remove handler and start waiting + tokio::time::sleep(Duration::from_millis(50)).await; + + // Verify that unregister_endpoint hasn't returned yet (it should be waiting) + assert!( + !unregister_task.is_finished(), + "unregister_endpoint should still be waiting for inflight request" + ); + + tracing::debug!("Verified unregister is waiting, now waiting for request to complete"); + + // Wait for the request to complete + tokio::select! { + _ = request_completed.notified() => { + tracing::debug!("Request completed"); + } + _ = tokio::time::sleep(Duration::from_secs(2)) => { + panic!("Timeout waiting for request to complete"); + } + } + + // Decrement inflight counter and notify (simulating what the real code does) + endpoint_handler.inflight.fetch_sub(1, Ordering::SeqCst); + endpoint_handler.notify.notify_one(); + + // Now wait for unregister to complete + let unregister_end = tokio::time::timeout(Duration::from_secs(2), unregister_task) + .await + .expect("unregister_endpoint should complete after inflight request finishes") + .expect("unregister task should not panic"); + + let unregister_duration = unregister_end - unregister_start; + + tracing::debug!("unregister_endpoint completed in {:?}", unregister_duration); + + // Verify unregister_endpoint waited for the inflight request + assert!( + unregister_duration >= Duration::from_secs(1), + "unregister_endpoint should have waited ~1s for inflight request, but only took {:?}", + unregister_duration + ); + + // Verify request completed successfully + assert!( + !request_in_flight.load(Ordering::SeqCst), + "Request should have completed" + ); + + // Wait for request task to finish + request_task + .await + .expect("Request task should complete") + .expect("Request should succeed"); + + tracing::info!("Test passed: unregister_endpoint properly waited for inflight TCP request"); + } +} diff --git a/lib/runtime/src/service.rs b/lib/runtime/src/service.rs index 8da224d508..5a54f12bc9 100644 --- a/lib/runtime/src/service.rs +++ b/lib/runtime/src/service.rs @@ -10,7 +10,7 @@ use crate::{ DistributedRuntime, component::Component, - metrics::{MetricsHierarchy, prometheus_names, prometheus_names::nats_service}, + metrics::{MetricsHierarchy, prometheus_names}, traits::*, transports::nats, utils::stream, @@ -294,150 +294,3 @@ mod tests { assert_eq!(endpoints.len(), 2); } } - -/// Prometheus metrics for component service statistics (ordered to match NatsStatsMetrics) -/// -/// โš ๏ธ IMPORTANT: These Prometheus Gauges are COPIES of NATS data, not live references! -/// -/// How it works: -/// 1. NATS provides source data via NatsStatsMetrics -/// 2. Metrics callbacks read current NATS values and update these Prometheus Gauges -/// 3. Prometheus scrapes these Gauge values (snapshots, not live data) -/// -/// Flow: NATS Service โ†’ NatsStatsMetrics (Counters) โ†’ Metrics Callback โ†’ Prometheus Gauge -/// Note: These are snapshots updated when execute_prometheus_update_callbacks() is called. -#[derive(Debug, Clone)] -/// Prometheus metrics for NATS server components. -/// Note: Metrics with `_total` names use IntGauge because we copy counter values -/// from underlying services rather than incrementing directly. -pub struct ComponentNatsServerPrometheusMetrics { - /// Average processing time in milliseconds (maps to: average_processing_time) - pub service_processing_ms_avg: prometheus::Gauge, - /// Total errors across all endpoints (maps to: num_errors) - pub service_errors_total: prometheus::IntGauge, - /// Total requests across all endpoints (maps to: num_requests) - pub service_requests_total: prometheus::IntGauge, - /// Total processing time in milliseconds (maps to: processing_time) - pub service_processing_ms_total: prometheus::IntGauge, - /// Number of active services (derived from ServiceSet.services) - pub service_active_services: prometheus::IntGauge, - /// Number of active endpoints (derived from ServiceInfo.endpoints) - pub service_active_endpoints: prometheus::IntGauge, -} - -impl ComponentNatsServerPrometheusMetrics { - /// Create new ComponentServiceMetrics using Component's DistributedRuntime's Prometheus constructors - pub fn new(component: &Component) -> Result { - let service_name = component.service_name(); - - // Build labels: service_name first, then component's labels - let mut labels_vec = vec![("service_name", service_name.as_str())]; - - // Add component's labels (convert from (String, String) to (&str, &str)) - for (key, value) in component.labels() { - labels_vec.push((key.as_str(), value.as_str())); - } - - let labels: &[(&str, &str)] = &labels_vec; - - let service_processing_ms_avg = component.metrics().create_gauge( - nats_service::PROCESSING_MS_AVG, - "Average processing time across all component endpoints in milliseconds", - labels, - )?; - - let service_errors_total = component.metrics().create_intgauge( - nats_service::ERRORS_TOTAL, - "Total number of errors across all component endpoints", - labels, - )?; - - let service_requests_total = component.metrics().create_intgauge( - nats_service::REQUESTS_TOTAL, - "Total number of requests across all component endpoints", - labels, - )?; - - let service_processing_ms_total = component.metrics().create_intgauge( - nats_service::PROCESSING_MS_TOTAL, - "Total processing time across all component endpoints in milliseconds", - labels, - )?; - - let service_active_services = component.metrics().create_intgauge( - nats_service::ACTIVE_SERVICES, - "Number of active services in this component", - labels, - )?; - - let service_active_endpoints = component.metrics().create_intgauge( - nats_service::ACTIVE_ENDPOINTS, - "Number of active endpoints across all services", - labels, - )?; - - Ok(Self { - service_processing_ms_avg, - service_errors_total, - service_requests_total, - service_processing_ms_total, - service_active_services, - service_active_endpoints, - }) - } - - /// Update metrics from scraped ServiceSet data - pub fn update_from_service_set(&self, service_set: &ServiceSet) { - // Variables ordered to match NatsStatsMetrics fields - let mut processing_time_samples = 0u64; // for average_processing_time calculation - let mut total_errors = 0u64; // maps to: num_errors - let mut total_requests = 0u64; // maps to: num_requests - let mut total_processing_time_nanos = 0u64; // maps to: processing_time (nanoseconds from NATS) - let mut endpoint_count = 0u64; // for derived metrics - - let service_count = service_set.services().len() as i64; - - for service in service_set.services() { - for endpoint in &service.endpoints { - endpoint_count += 1; - - if let Some(ref stats) = endpoint.data { - total_errors += stats.num_errors; - total_requests += stats.num_requests; - total_processing_time_nanos += stats.processing_time; - - if stats.num_requests > 0 { - processing_time_samples += 1; - } - } - } - } - - // Update metrics (ordered to match NatsStatsMetrics fields) - // Calculate average processing time in milliseconds (maps to: average_processing_time) - if processing_time_samples > 0 && total_requests > 0 { - let avg_time_nanos = total_processing_time_nanos as f64 / total_requests as f64; - let avg_time_ms = avg_time_nanos / 1_000_000.0; // Convert nanoseconds to milliseconds - self.service_processing_ms_avg.set(avg_time_ms); - } else { - self.service_processing_ms_avg.set(0.0); - } - - self.service_errors_total.set(total_errors as i64); // maps to: num_errors - self.service_requests_total.set(total_requests as i64); // maps to: num_requests - self.service_processing_ms_total - .set((total_processing_time_nanos / 1_000_000) as i64); // maps to: processing_time (converted to milliseconds) - self.service_active_services.set(service_count); // derived from ServiceSet.services - self.service_active_endpoints.set(endpoint_count as i64); // derived from ServiceInfo.endpoints - } - - /// Reset all metrics to zero. Useful when no data is available or to clear stale values. - pub fn reset_to_zeros(&self) { - self.service_processing_ms_avg.set(0.0); - self.service_errors_total.set(0); - self.service_requests_total.set(0); - self.service_processing_ms_total.set(0); - self.service_active_services.set(0); - self.service_active_endpoints.set(0); - } -} diff --git a/lib/runtime/src/system_status_server.rs b/lib/runtime/src/system_status_server.rs index 69fcdc16ec..7f8360aafe 100644 --- a/lib/runtime/src/system_status_server.rs +++ b/lib/runtime/src/system_status_server.rs @@ -10,14 +10,14 @@ use crate::config::environment_names::runtime::canary as env_canary; use crate::config::environment_names::runtime::system as env_system; use crate::logging::make_request_span; use crate::metrics::MetricsHierarchy; -use crate::metrics::prometheus_names::{nats_client, nats_service}; use crate::traits::DistributedRuntimeProvider; use axum::{ Router, + body::Bytes, extract::{Json, Path, State}, http::StatusCode, response::IntoResponse, - routing::{delete, get, post}, + routing::{any, delete, get, post}, }; use futures::StreamExt; use serde::{Deserialize, Serialize}; @@ -184,6 +184,13 @@ pub async fn spawn_system_status_server( let state = Arc::clone(&server_state); move || metadata_handler(state) }), + ) + .route( + "/engine/{*path}", + any({ + let state = Arc::clone(&server_state); + move |path, body| engine_route_handler(state, path, body) + }), ); // Add LoRA routes only if DYN_LORA_ENABLED is set to true @@ -525,6 +532,77 @@ fn parse_lora_response(response_data: &serde_json::Value) -> LoraResponse { } } +/// Engine route handler for /engine/* routes +/// +/// This handler looks up registered callbacks in the engine routes registry +/// and invokes them with the request body, returning the response as JSON. +#[tracing::instrument(skip_all, level = "trace", fields(path = %path))] +async fn engine_route_handler( + state: Arc, + Path(path): Path, + body: Bytes, +) -> impl IntoResponse { + tracing::trace!("Engine route request to /engine/{}", path); + + // Parse body as JSON (empty object for GET/empty body) + let body_json: serde_json::Value = if body.is_empty() { + serde_json::json!({}) + } else { + match serde_json::from_slice(&body) { + Ok(json) => json, + Err(e) => { + tracing::warn!("Invalid JSON in request body: {}", e); + return ( + StatusCode::BAD_REQUEST, + json!({ + "error": "Invalid JSON", + "message": format!("{}", e) + }) + .to_string(), + ) + .into_response(); + } + } + }; + + // Look up callback + let callback = match state.drt().engine_routes().get(&path) { + Some(cb) => cb, + None => { + tracing::debug!("Route /engine/{} not found", path); + return ( + StatusCode::NOT_FOUND, + json!({ + "error": "Route not found", + "message": format!("Route /engine/{} not found", path) + }) + .to_string(), + ) + .into_response(); + } + }; + + // Call callback (it's async, so await it) + match callback(body_json).await { + Ok(response) => { + tracing::trace!("Engine route handler succeeded for /engine/{}", path); + (StatusCode::OK, response.to_string()).into_response() + } + Err(e) => { + tracing::error!("Engine route handler error for /engine/{}: {}", path, e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + json!({ + "error": "Handler error", + "message": format!("{}", e) + }) + .to_string(), + ) + .into_response() + } + } +} + // Regular tests: cargo test system_status_server --lib #[cfg(test)] mod tests { @@ -606,26 +684,17 @@ mod integration_tests { let response = drt.metrics().prometheus_expfmt().unwrap(); println!("Full metrics response:\n{}", response); - // Filter out NATS client metrics for comparison - let filtered_response: String = response - .lines() - .filter(|line| { - !line.contains(nats_client::PREFIX) && !line.contains(nats_service::PREFIX) - }) - .collect::>() - .join("\n"); - // Check that uptime_seconds metric is present with correct namespace assert!( - filtered_response.contains("# HELP dynamo_component_uptime_seconds"), + response.contains("# HELP dynamo_component_uptime_seconds"), "Should contain uptime_seconds help text" ); assert!( - filtered_response.contains("# TYPE dynamo_component_uptime_seconds gauge"), + response.contains("# TYPE dynamo_component_uptime_seconds gauge"), "Should contain uptime_seconds type" ); assert!( - filtered_response.contains("dynamo_component_uptime_seconds"), + response.contains("dynamo_component_uptime_seconds"), "Should contain uptime_seconds metric with correct namespace" ); }) @@ -918,7 +987,6 @@ mod integration_tests { // Start the service and endpoint with a health check payload // This will automatically register the endpoint for health monitoring tokio::spawn(async move { - component.add_stats_service().await.unwrap(); let _ = component.endpoint(ENDPOINT_NAME) .endpoint_builder() .handler(ingress) diff --git a/lib/runtime/src/transports/nats.rs b/lib/runtime/src/transports/nats.rs index c0588c8100..51c9cd73dd 100644 --- a/lib/runtime/src/transports/nats.rs +++ b/lib/runtime/src/transports/nats.rs @@ -39,7 +39,6 @@ use url::Url; use validator::{Validate, ValidationError}; use crate::config::environment_names::nats as env_nats; -use crate::metrics::prometheus_names::nats_client as nats_metrics; pub use crate::slug::Slug; use tracing as log; @@ -887,110 +886,6 @@ impl EventPublisher for NatsQueue { } } -/// Prometheus metrics that mirror the NATS client statistics (in primitive types) -/// to be used for the System Status Server. -/// -/// โš ๏ธ IMPORTANT: These Prometheus Gauges are COPIES of NATS client data, not live references! -/// -/// How it works: -/// 1. NATS client provides source data via client.statistics() and connection_state() -/// 2. set_from_client_stats() reads current NATS values and updates these Prometheus Gauges -/// 3. Prometheus scrapes these Gauge values (snapshots, not live data) -/// -/// Flow: NATS Client โ†’ Client Statistics โ†’ set_from_client_stats() โ†’ Prometheus Gauge -/// Note: These are snapshots updated when set_from_client_stats() is called. -#[derive(Debug, Clone)] -pub struct DRTNatsClientPrometheusMetrics { - nats_client: client::Client, - /// Number of bytes received (excluding protocol overhead) - pub in_bytes: IntGauge, - /// Number of bytes sent (excluding protocol overhead) - pub out_bytes: IntGauge, - /// Number of messages received - pub in_messages: IntGauge, - /// Number of messages sent - pub out_messages: IntGauge, - /// Number of times connection was established - pub connects: IntGauge, - /// Current connection state (0 = disconnected, 1 = connected, 2 = reconnecting) - pub connection_state: IntGauge, -} - -impl DRTNatsClientPrometheusMetrics { - /// Create a new instance of NATS client metrics using a DistributedRuntime's Prometheus constructors - pub fn new(drt: &crate::DistributedRuntime, nats_client: client::Client) -> Result { - let metrics = drt.metrics(); - let in_bytes = metrics.create_intgauge( - nats_metrics::IN_TOTAL_BYTES, - "Total number of bytes received by NATS client", - &[], - )?; - let out_bytes = metrics.create_intgauge( - nats_metrics::OUT_OVERHEAD_BYTES, - "Total number of bytes sent by NATS client", - &[], - )?; - let in_messages = metrics.create_intgauge( - nats_metrics::IN_MESSAGES, - "Total number of messages received by NATS client", - &[], - )?; - let out_messages = metrics.create_intgauge( - nats_metrics::OUT_MESSAGES, - "Total number of messages sent by NATS client", - &[], - )?; - let connects = metrics.create_intgauge( - nats_metrics::CURRENT_CONNECTIONS, - "Current number of active connections for NATS client", - &[], - )?; - let connection_state = metrics.create_intgauge( - nats_metrics::CONNECTION_STATE, - "Current connection state of NATS client (0=disconnected, 1=connected, 2=reconnecting)", - &[], - )?; - - Ok(Self { - nats_client, - in_bytes, - out_bytes, - in_messages, - out_messages, - connects, - connection_state, - }) - } - - /// Copy statistics from the stored NATS client to these Prometheus metrics - pub fn set_from_client_stats(&self) { - let stats = self.nats_client.statistics(); - - // Get current values from the client statistics - let in_bytes = stats.in_bytes.load(Ordering::Relaxed); - let out_bytes = stats.out_bytes.load(Ordering::Relaxed); - let in_messages = stats.in_messages.load(Ordering::Relaxed); - let out_messages = stats.out_messages.load(Ordering::Relaxed); - let connects = stats.connects.load(Ordering::Relaxed); - - // Get connection state - let connection_state = match self.nats_client.connection_state() { - State::Connected => 1, - // treat Disconnected and Pending as "down" - State::Disconnected | State::Pending => 0, - }; - - // Update Prometheus metrics - // Using gauges allows us to set absolute values directly - self.in_bytes.set(in_bytes as i64); - self.out_bytes.set(out_bytes as i64); - self.in_messages.set(in_messages as i64); - self.out_messages.set(out_messages as i64); - self.connects.set(connects as i64); - self.connection_state.set(connection_state); - } -} - /// The NATS subject / inbox to talk to an instance on. /// TODO: Do we need to sanitize the names? pub fn instance_subject(endpoint_id: &EndpointId, instance_id: u64) -> String { diff --git a/pyproject.toml b/pyproject.toml index 9e71366b55..2100d47951 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,19 +50,19 @@ Repository = "https://github.com/ai-dynamo/dynamo.git" [project.optional-dependencies] trtllm =[ "uvloop", - "tensorrt-llm==1.2.0rc3", + "tensorrt-llm==1.2.0rc5", ] vllm = [ "uvloop", "nixl[cu12]<=0.7.1", - "vllm[flashinfer]==0.11.0", + "vllm[flashinfer]==0.12.0", ] sglang = [ "uvloop", "nixl[cu12]<=0.7.1", - "sglang==0.5.4.post3", + "sglang==0.5.6", ] [project.entry-points.pytest11] @@ -181,6 +181,8 @@ filterwarnings = [ "ignore:Support for class-based `config`.*:pydantic.warnings.PydanticDeprecatedSince20", "ignore:Using extra keyword arguments on `Field`.*:pydantic.warnings.PydanticDeprecatedSince20", "ignore:The `schema` method is deprecated.*:pydantic.warnings.PydanticDeprecatedSince20", + # Pydantic warning about field shadowing in tensorrt_llm.serve.openai_protocol.ResponseFormat + 'ignore:Field name "schema" in "ResponseFormat" shadows an attribute in parent:UserWarning', # pytest-benchmark automatically disables when xdist is active, ignore the warning "ignore:.*Benchmarks are automatically disabled.*:pytest_benchmark.logger.PytestBenchmarkWarning", ] diff --git a/recipes/llama-3-70b/vllm/agg/deploy.yaml b/recipes/llama-3-70b/vllm/agg/deploy.yaml index 2cca281b96..54078054d2 100644 --- a/recipes/llama-3-70b/vllm/agg/deploy.yaml +++ b/recipes/llama-3-70b/vllm/agg/deploy.yaml @@ -43,7 +43,7 @@ spec: - name: HF_HOME value: /opt/models args: - - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 4 --data-parallel-size 1 --disable-log-requests --gpu-memory-utilization 0.90 --no-enable-prefix-caching --block-size 128" + - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 4 --data-parallel-size 1 --gpu-memory-utilization 0.90 --no-enable-prefix-caching --block-size 128" command: - /bin/sh - -c diff --git a/recipes/llama-3-70b/vllm/disagg-multi-node/deploy.yaml b/recipes/llama-3-70b/vllm/disagg-multi-node/deploy.yaml index 94acf7c846..b66870435a 100644 --- a/recipes/llama-3-70b/vllm/disagg-multi-node/deploy.yaml +++ b/recipes/llama-3-70b/vllm/disagg-multi-node/deploy.yaml @@ -43,7 +43,7 @@ spec: - name: HF_HOME value: /opt/models args: - - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 8 --data-parallel-size 1 --disable-log-requests --is-prefill-worker --gpu-memory-utilization 0.95 --no-enable-prefix-caching --block-size 128" + - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 8 --data-parallel-size 1 --is-prefill-worker --gpu-memory-utilization 0.95 --no-enable-prefix-caching --block-size 128" command: - /bin/sh - -c @@ -74,7 +74,7 @@ spec: - name: HF_HOME value: /opt/models args: - - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 8 --data-parallel-size 1 --disable-log-requests --gpu-memory-utilization 0.90 --no-enable-prefix-caching --block-size 128" + - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 8 --data-parallel-size 1 --gpu-memory-utilization 0.90 --no-enable-prefix-caching --block-size 128" command: - /bin/sh - -c diff --git a/recipes/llama-3-70b/vllm/disagg-single-node/deploy.yaml b/recipes/llama-3-70b/vllm/disagg-single-node/deploy.yaml index e67996f06a..7c91aaeda5 100644 --- a/recipes/llama-3-70b/vllm/disagg-single-node/deploy.yaml +++ b/recipes/llama-3-70b/vllm/disagg-single-node/deploy.yaml @@ -55,7 +55,7 @@ spec: - name: HF_HOME value: /opt/models args: - - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 2 --data-parallel-size 1 --disable-log-requests --is-prefill-worker --gpu-memory-utilization 0.95 --no-enable-prefix-caching --block-size 128" + - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 2 --data-parallel-size 1 --is-prefill-worker --gpu-memory-utilization 0.95 --no-enable-prefix-caching --block-size 128" command: - /bin/sh - -c @@ -98,7 +98,7 @@ spec: - name: HF_HOME value: /opt/models args: - - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 4 --data-parallel-size 1 --disable-log-requests --gpu-memory-utilization 0.90 --no-enable-prefix-caching --block-size 128" + - "python3 -m dynamo.vllm --model $MODEL_PATH --served-model-name $SERVED_MODEL_NAME --tensor-parallel-size 4 --data-parallel-size 1 --gpu-memory-utilization 0.90 --no-enable-prefix-caching --block-size 128" command: - /bin/sh - -c diff --git a/tests/conftest.py b/tests/conftest.py index e4a4b562a6..c477fb7056 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,22 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import logging import os import shutil import tempfile +import time from pathlib import Path from typing import Optional @@ -25,6 +14,14 @@ from tests.utils.constants import TEST_MODELS from tests.utils.managed_process import ManagedProcess +from tests.utils.port_utils import ( + allocate_port, + allocate_ports, + deallocate_port, + deallocate_ports, +) + +_logger = logging.getLogger(__name__) def pytest_configure(config): @@ -226,45 +223,140 @@ def pytest_collection_modifyitems(config, items): config.models_to_download = models_to_download +def pytest_runtestloop(session): + """Download models after collection but before any tests run. + + This hook runs after pytest_collection_modifyitems (so models are collected) + but before any test execution, ensuring model downloads don't count against test timeouts. + """ + models = getattr(session.config, "models_to_download", None) + + if models: + logging.info( + f"Downloading {len(models)} models before test execution\nModels: {models}" + ) + start_time = time.time() + + download_models(model_list=list(models)) + + download_duration = time.time() - start_time + logging.info(f"Model download completed in {download_duration:.1f}s") + + class EtcdServer(ManagedProcess): def __init__(self, request, port=2379, timeout=300): + # Allocate free ports if port is 0 + use_random_port = port == 0 + if use_random_port: + # Need two ports: client port and peer port for parallel execution + # Start from 2380 (etcd default 2379 + 1) + port, peer_port = allocate_ports(2, 2380) + else: + peer_port = None + + self.port = port + self.peer_port = peer_port # Store for cleanup + self.use_random_port = use_random_port # Track if we allocated the port port_string = str(port) etcd_env = os.environ.copy() etcd_env["ALLOW_NONE_AUTHENTICATION"] = "yes" data_dir = tempfile.mkdtemp(prefix="etcd_") + command = [ "etcd", "--listen-client-urls", f"http://0.0.0.0:{port_string}", "--advertise-client-urls", f"http://0.0.0.0:{port_string}", - "--data-dir", - data_dir, ] + + # Add peer port configuration only for random ports (parallel execution) + if peer_port is not None: + peer_port_string = str(peer_port) + command.extend( + [ + "--listen-peer-urls", + f"http://0.0.0.0:{peer_port_string}", + "--initial-advertise-peer-urls", + f"http://localhost:{peer_port_string}", + "--initial-cluster", + f"default=http://localhost:{peer_port_string}", + ] + ) + + command.extend( + [ + "--data-dir", + data_dir, + ] + ) super().__init__( env=etcd_env, command=command, timeout=timeout, display_output=False, + terminate_existing=not use_random_port, # Disabled for parallel test execution with random ports health_check_ports=[port], data_dir=data_dir, log_dir=request.node.name, ) + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated ports when server exits.""" + try: + # Only deallocate ports that were dynamically allocated (not default ports) + if self.use_random_port: + ports_to_release = [self.port] + if self.peer_port is not None: + ports_to_release.append(self.peer_port) + deallocate_ports(ports_to_release) + except Exception as e: + logging.warning(f"Failed to release EtcdServer port: {e}") + + return super().__exit__(exc_type, exc_val, exc_tb) + class NatsServer(ManagedProcess): def __init__(self, request, port=4222, timeout=300): + # Allocate a free port if port is 0 + use_random_port = port == 0 + if use_random_port: + # Start from 4223 (nats-server default 4222 + 1) + port = allocate_port(4223) + + self.port = port + self.use_random_port = use_random_port # Track if we allocated the port data_dir = tempfile.mkdtemp(prefix="nats_") - command = ["nats-server", "-js", "--trace", "--store_dir", data_dir] + command = [ + "nats-server", + "-js", + "--trace", + "--store_dir", + data_dir, + "-p", + str(port), + ] super().__init__( command=command, timeout=timeout, display_output=False, + terminate_existing=not use_random_port, # Disabled for parallel test execution with random ports data_dir=data_dir, health_check_ports=[port], log_dir=request.node.name, ) + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated port when server exits.""" + try: + # Only deallocate ports that were dynamically allocated (not default ports) + if self.use_random_port: + deallocate_port(self.port) + except Exception as e: + logging.warning(f"Failed to release NatsServer port: {e}") + + return super().__exit__(exc_type, exc_val, exc_tb) + class SharedManagedProcess: """Base class for ManagedProcess with file-based reference counting for multi-process sharing.""" @@ -391,11 +483,98 @@ def _create_server(self) -> ManagedProcess: return server +@pytest.fixture +def store_kv(request): + """ + KV store for runtime. Defaults to "etcd". + + To iterate over multiple stores in a test: + @pytest.mark.parametrize("store_kv", ["file", "etcd"], indirect=True) + def test_example(runtime_services): + ... + """ + return getattr(request, "param", "etcd") + + +@pytest.fixture +def request_plane(request): + """ + Request plane for runtime. Defaults to "nats". + + To iterate over multiple transports in a test: + @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) + def test_example(runtime_services): + ... + """ + return getattr(request, "param", "nats") + + @pytest.fixture() -def runtime_services(request): - with NatsServer(request) as nats_process: +def runtime_services(request, store_kv, request_plane): + """ + Start runtime services (NATS and/or etcd) based on store_kv and request_plane. + + - If store_kv != "etcd", etcd is not started (returns None) + - If request_plane != "nats", NATS is not started (returns None) + + Returns a tuple of (nats_process, etcd_process) where each has a .port attribute. + """ + # Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods + if request_plane == "nats" and store_kv == "etcd": + with NatsServer(request) as nats_process: + with EtcdServer(request) as etcd_process: + yield nats_process, etcd_process + elif request_plane == "nats": + with NatsServer(request) as nats_process: + yield nats_process, None + elif store_kv == "etcd": with EtcdServer(request) as etcd_process: - yield nats_process, etcd_process + yield None, etcd_process + else: + yield None, None + + +@pytest.fixture() +def runtime_services_dynamic_ports(request, store_kv, request_plane): + """Provide NATS and Etcd servers with truly dynamic ports per test. + + This fixture actually allocates dynamic ports by passing port=0 to the servers. + It also sets the NATS_SERVER and ETCD_ENDPOINTS environment variables so that + Dynamo processes can find the services on the dynamic ports. + + - If store_kv != "etcd", etcd is not started (returns None) + - If request_plane != "nats", NATS is not started (returns None) + + Returns a tuple of (nats_process, etcd_process) where each has a .port attribute. + """ + import os + + # Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods + if request_plane == "nats" and store_kv == "etcd": + with NatsServer(request, port=0) as nats_process: + with EtcdServer(request, port=0) as etcd_process: + # Set environment variables for Rust/Python runtime to use. Note that xdist (parallel execution) + # will launch isolated tests in a new process, so no need to worry about environment pollution. + os.environ["NATS_SERVER"] = f"nats://localhost:{nats_process.port}" + os.environ["ETCD_ENDPOINTS"] = f"http://localhost:{etcd_process.port}" + + yield nats_process, etcd_process + + # No test should rely on these variables after the test, but clean up just in case. + os.environ.pop("NATS_SERVER", None) + os.environ.pop("ETCD_ENDPOINTS", None) + elif request_plane == "nats": + with NatsServer(request, port=0) as nats_process: + os.environ["NATS_SERVER"] = f"nats://localhost:{nats_process.port}" + yield nats_process, None + os.environ.pop("NATS_SERVER", None) + elif store_kv == "etcd": + with EtcdServer(request, port=0) as etcd_process: + os.environ["ETCD_ENDPOINTS"] = f"http://localhost:{etcd_process.port}" + yield None, etcd_process + os.environ.pop("ETCD_ENDPOINTS", None) + else: + yield None, None @pytest.fixture(scope="session") diff --git a/tests/dependencies/test_kvbm_imports.py b/tests/dependencies/test_kvbm_imports.py new file mode 100644 index 0000000000..611f0a0eb3 --- /dev/null +++ b/tests/dependencies/test_kvbm_imports.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests to verify KVBM package and wheels are properly installed.""" + +import subprocess + +import pytest + + +# Helper functions for KVBM verification +def _check_kvbm_wheel_exists(): + """Helper to verify KVBM wheel file exists in expected location.""" + result = subprocess.run( + ["bash", "-c", "ls /opt/dynamo/wheelhouse/kvbm*.whl"], + capture_output=True, + text=True, + ) + + assert result.returncode == 0, ( + f"KVBM wheel not found in /opt/dynamo/wheelhouse/\n" + f"stdout: {result.stdout}\n" + f"stderr: {result.stderr}" + ) + assert ( + "kvbm" in result.stdout + ), f"Expected kvbm wheel in output, got: {result.stdout}" + + +def _check_kvbm_imports(): + """Helper to verify KVBM package and core classes can be imported.""" + try: + import kvbm + from kvbm import BlockManager, KvbmLeader, KvbmWorker + + assert kvbm is not None, "kvbm module is None" + assert BlockManager is not None, "BlockManager class not available" + assert KvbmLeader is not None, "KvbmLeader class not available" + assert KvbmWorker is not None, "KvbmWorker class not available" + except ImportError as e: + pytest.fail(f"Failed to import KVBM package or core classes: {e}") + + +# Base tests (no framework markers) - run in main job with --framework none --enable-kvbm +@pytest.mark.pre_merge +def test_kvbm_wheel_exists(): + """Verify KVBM wheel file exists in expected location.""" + _check_kvbm_wheel_exists() + + +@pytest.mark.pre_merge +def test_kvbm_imports(): + """Verify KVBM package and core classes can be imported.""" + _check_kvbm_imports() + + +# vLLM-specific tests - run in vLLM job (vLLM auto-enables KVBM) +@pytest.mark.pre_merge +@pytest.mark.vllm +def test_kvbm_wheel_exists_vllm(): + """Verify KVBM wheel exists in vLLM image.""" + _check_kvbm_wheel_exists() + + +@pytest.mark.pre_merge +@pytest.mark.vllm +def test_kvbm_imports_vllm(): + """Verify KVBM package and core classes can be imported in vLLM image.""" + _check_kvbm_imports() + + +# TRT-LLM-specific tests - run in TRT-LLM job (TRT-LLM auto-enables KVBM) +@pytest.mark.pre_merge +@pytest.mark.trtllm +def test_kvbm_wheel_exists_trtllm(): + """Verify KVBM wheel exists in TRT-LLM image.""" + _check_kvbm_wheel_exists() + + +@pytest.mark.pre_merge +@pytest.mark.trtllm +def test_kvbm_imports_trtllm(): + """Verify KVBM package and core classes can be imported in TRT-LLM image.""" + _check_kvbm_imports() diff --git a/tests/dependencies/test_vllm_imports.py b/tests/dependencies/test_vllm_imports.py new file mode 100644 index 0000000000..d26aedfec8 --- /dev/null +++ b/tests/dependencies/test_vllm_imports.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests to sanity check that required dependencies can be imported.""" + +import pytest + + +@pytest.mark.vllm +@pytest.mark.unit +@pytest.mark.gpu_1 +def test_import_deep_ep(): + """Test that deep_ep module can be imported.""" + try: + import deep_ep + + assert deep_ep is not None + except ImportError as e: + pytest.fail(f"Failed to import deep_ep: {e}") + + +@pytest.mark.vllm +@pytest.mark.unit +@pytest.mark.gpu_1 +def test_import_pplx_kernels(): + """Test that pplx_kernels module can be imported.""" + try: + import pplx_kernels + + assert pplx_kernels is not None + except ImportError as e: + pytest.fail(f"Failed to import pplx_kernels: {e}") diff --git a/tests/fault_tolerance/cancellation/test_sglang.py b/tests/fault_tolerance/cancellation/test_sglang.py index b073420b14..01f16b43a8 100644 --- a/tests/fault_tolerance/cancellation/test_sglang.py +++ b/tests/fault_tolerance/cancellation/test_sglang.py @@ -1,6 +1,13 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +""" +Test Execution Times (Last Run: 2025-12-09): +- test_request_cancellation_sglang_aggregated: ~46s (gpu_1) +- test_request_cancellation_sglang_decode_cancel: ~60s (gpu_2, estimate) +- Total: 46.06s (0:00:46) for aggregated test only +""" + import logging import os import shutil @@ -15,22 +22,37 @@ send_cancellable_request, ) from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME -from tests.utils.engine_process import FRONTEND_PORT from tests.utils.managed_process import ManagedProcess from tests.utils.payloads import check_health_generate, check_models_api +from tests.utils.port_utils import allocate_port, deallocate_port logger = logging.getLogger(__name__) +pytestmark = [ + pytest.mark.sglang, + pytest.mark.e2e, + pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), + pytest.mark.post_merge, # post_merge to pinpoint failure commit +] + class DynamoWorkerProcess(ManagedProcess): """Process manager for Dynamo worker with SGLang backend""" - def __init__(self, request, mode: str = "agg"): + def __init__( + self, + request, + system_port: int, + frontend_port: int, + mode: str = "agg", + ): """ Initialize SGLang worker process. Args: request: pytest request object + system_port: Port for system metrics server + frontend_port: Port where frontend is running mode: One of "agg", "prefill", "decode" """ command = [ @@ -59,7 +81,7 @@ def __init__(self, request, mode: str = "agg"): "--disaggregation-mode", mode, "--disaggregation-bootstrap-port", - "12345", + "12345", # TODO: use dynamic port allocation "--host", "0.0.0.0", "--disaggregation-transfer-backend", @@ -67,26 +89,33 @@ def __init__(self, request, mode: str = "agg"): ] ) - health_check_urls = [ - (f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api), - (f"http://localhost:{FRONTEND_PORT}/health", check_health_generate), - ] + # Configure health check based on worker type + if mode in ["prefill", "decode"]: + # Prefill and decode workers check their own status endpoint + health_check_urls = [ + (f"http://localhost:{system_port}/health", self.is_ready) + ] + else: + # Aggregated workers check both system status and frontend + health_check_urls = [ + (f"http://localhost:{system_port}/health", self.is_ready), + (f"http://localhost:{frontend_port}/v1/models", check_models_api), + (f"http://localhost:{frontend_port}/health", check_health_generate), + ] - # Set port based on worker type - if mode == "prefill": - port = "8082" - health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)] - elif mode == "decode": - port = "8081" - health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)] - else: # agg (aggregated mode) - port = "8081" - - # Set debug logging environment + # Set environment variables env = os.environ.copy() + env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") + env["DYN_LOG"] = "debug" + # Disable canary health check - these tests expect full control over requests + # sent to the workers where canary health check intermittently sends dummy + # requests to workers interfering with the test process which may cause + # intermittent failures + env["DYN_HEALTH_CHECK_ENABLED"] = "false" env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' - env["DYN_SYSTEM_PORT"] = port + env["DYN_SYSTEM_PORT"] = str(system_port) + env["DYN_HTTP_PORT"] = str(frontend_port) # Set GPU assignment for disaggregated mode (like disagg.sh) if mode == "decode": @@ -124,10 +153,17 @@ def __init__(self, request, mode: str = "agg"): ) self.mode = mode + self.system_port = system_port + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated port when worker exits.""" + try: + # system_port is a required parameter, always set in __init__ + deallocate_port(self.system_port) + except Exception as e: + logging.warning(f"Failed to release SGLang worker port: {e}") - def get_pid(self): - """Get the PID of the worker process""" - return self.proc.pid if self.proc else None + return super().__exit__(exc_type, exc_val, exc_tb) def is_ready(self, response) -> bool: """Check the health of the worker process""" @@ -146,14 +182,12 @@ def is_ready(self, response) -> bool: return False -@pytest.mark.e2e -@pytest.mark.sglang +@pytest.mark.timeout(160) # 3x average @pytest.mark.gpu_1 -@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) -@pytest.mark.nightly @pytest.mark.xfail(strict=False) +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) def test_request_cancellation_sglang_aggregated( - request, runtime_services, predownload_models + request, runtime_services_dynamic_ports ): """ End-to-end test for request cancellation functionality in aggregated mode. @@ -162,16 +196,35 @@ def test_request_cancellation_sglang_aggregated( the system properly handles the cancellation and cleans up resources on the worker side in aggregated (agg) mode. + Tests 3 cancellation scenarios: + 1. Completion request + 2. Chat completion request + 3. Chat completion request (streaming) + + Timing (Last Run: 2025-12-09): ~46s total + - Engine initialization: ~14s + - Testing 3 scenarios: ~30s (~10s each) + - Teardown: ~2s + TODO: Test is currently flaky/failing due to SGLang limitations with prefill cancellation. See: https://github.com/sgl-project/sglang/issues/11139 """ logger.info("Sanity check if latest test is getting executed") - # Step 1: Start the frontend + + # Allocate ports to avoid conflicts with parallel tests + system_port = allocate_port(9100) + + # Step 1: Start the frontend (allocates its own port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") # Step 2: Start an aggregated worker - with DynamoWorkerProcess(request, mode="agg") as worker: + with DynamoWorkerProcess( + request, + system_port=system_port, + frontend_port=frontend.frontend_port, + mode="agg", + ) as worker: logger.info(f"Aggregated Worker PID: {worker.get_pid()}") # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness? time.sleep(2) @@ -192,7 +245,9 @@ def test_request_cancellation_sglang_aggregated( logger.info(f"Testing {description.lower()}...") # Send the request (non-blocking) - cancellable_req = send_cancellable_request(request_type) + cancellable_req = send_cancellable_request( + frontend.frontend_port, request_type + ) # Poll for "New Request ID" pattern (Dynamo context ID) request_id, worker_log_offset = poll_for_pattern( @@ -236,13 +291,21 @@ def test_request_cancellation_sglang_aggregated( logger.info(f"{description} detected successfully") -@pytest.mark.e2e -@pytest.mark.sglang +@pytest.mark.timeout(185) # 3x average @pytest.mark.gpu_2 -@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) -@pytest.mark.nightly +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) def test_request_cancellation_sglang_decode_cancel( - request, runtime_services, predownload_models + request, runtime_services_dynamic_ports ): """ End-to-end test for request cancellation during decode phase. @@ -252,18 +315,37 @@ def test_request_cancellation_sglang_decode_cancel( on both the prefill and decode workers in a disaggregated setup. Note: This test requires 2 GPUs to run decode and prefill workers on separate GPUs. + + Timing (Last Run: 2025-12-09): ~60s total (estimated) + - Engine initialization: ~20s (decode + prefill workers) + - Testing stream cancellation during decode: ~38s + - Teardown: ~2s """ - # Step 1: Start the frontend + # Allocate ports to avoid conflicts with parallel tests + decode_system_port = allocate_port(9100) + prefill_system_port = allocate_port(9200) + + # Step 1: Start the frontend (allocates its own port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") # Step 2: Start the decode worker - with DynamoWorkerProcess(request, mode="decode") as decode_worker: + with DynamoWorkerProcess( + request, + system_port=decode_system_port, + frontend_port=frontend.frontend_port, + mode="decode", + ) as decode_worker: logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") # Step 3: Start the prefill worker - with DynamoWorkerProcess(request, mode="prefill") as prefill_worker: + with DynamoWorkerProcess( + request, + system_port=prefill_system_port, + frontend_port=frontend.frontend_port, + mode="prefill", + ) as prefill_worker: logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness? @@ -275,7 +357,9 @@ def test_request_cancellation_sglang_decode_cancel( ) # Send streaming request (non-blocking) - cancellable_req = send_cancellable_request("chat_completion_stream") + cancellable_req = send_cancellable_request( + frontend.frontend_port, "chat_completion_stream" + ) # Poll for "New Request ID" pattern in decode worker (Dynamo context ID) request_id, decode_log_offset = poll_for_pattern( diff --git a/tests/fault_tolerance/cancellation/test_trtllm.py b/tests/fault_tolerance/cancellation/test_trtllm.py index a9281417a1..1ace9f8bb5 100644 --- a/tests/fault_tolerance/cancellation/test_trtllm.py +++ b/tests/fault_tolerance/cancellation/test_trtllm.py @@ -1,6 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +""" +Test Execution Times (Last Run: 2025-12-09): +- test_request_cancellation_trtllm_aggregated: ~45s (gpu_1) +- test_request_cancellation_trtllm_decode_cancel: ~115s (gpu_1) +- test_request_cancellation_trtllm_prefill_cancel: ~115s (gpu_1) +- test_request_cancellation_trtllm_kv_transfer_cancel: ~115s (gpu_1, xfail) +- Total: ~390s (0:06:30) +""" + import logging import os import shutil @@ -15,9 +24,9 @@ send_cancellable_request, ) from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME -from tests.utils.engine_process import FRONTEND_PORT from tests.utils.managed_process import ManagedProcess from tests.utils.payloads import check_health_generate, check_models_api +from tests.utils.port_utils import allocate_port, deallocate_port logger = logging.getLogger(__name__) @@ -26,20 +35,31 @@ pytest.mark.gpu_1, pytest.mark.e2e, pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), + pytest.mark.post_merge, # post_merge to pinpoint failure commit ] class DynamoWorkerProcess(ManagedProcess): """Process manager for Dynamo worker with TensorRT-LLM backend""" - def __init__(self, request, mode: str = "prefill_and_decode"): + def __init__( + self, + request, + frontend_port: int, + mode: str = "prefill_and_decode", + ): """ Initialize TensorRT-LLM worker process. Args: request: pytest request object + frontend_port: Port for the frontend server mode: One of "prefill_and_decode", "prefill", "decode" """ + # Allocate system port for this worker + system_port = allocate_port(9100) + self.system_port = system_port + self.frontend_port = frontend_port # Prefill workers require migration_limit=0 (no KV cache migration support) migration_limit = "0" if mode == "prefill" else "3" @@ -70,25 +90,28 @@ def __init__(self, request, mode: str = "prefill_and_decode"): ] health_check_urls = [ - (f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api), - (f"http://localhost:{FRONTEND_PORT}/health", check_health_generate), + (f"http://localhost:{frontend_port}/v1/models", check_models_api), + (f"http://localhost:{frontend_port}/health", check_health_generate), ] - # Set port based on worker type - if mode == "prefill": - port = "8082" - health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)] - elif mode == "decode": - port = "8081" - health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)] - else: # prefill_and_decode - port = "8081" - - # Set debug logging environment + # Set health check based on worker type + if mode in ["prefill", "decode"]: + health_check_urls = [ + (f"http://localhost:{system_port}/health", self.is_ready) + ] + + # Set environment variables env = os.environ.copy() + env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") + env["DYN_LOG"] = "debug" + # Disable canary health check - these tests expect full control over requests + # sent to the workers where canary health check intermittently sends dummy + # requests to workers interfering with the test process which may cause + # intermittent failures + env["DYN_HEALTH_CHECK_ENABLED"] = "false" env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' - env["DYN_SYSTEM_PORT"] = port + env["DYN_SYSTEM_PORT"] = str(system_port) # Set log directory based on worker type log_dir = f"{request.node.name}_{mode}_worker" @@ -113,10 +136,6 @@ def __init__(self, request, mode: str = "prefill_and_decode"): self.mode = mode - def get_pid(self): - """Get the PID of the worker process""" - return self.proc.pid if self.proc else None - def is_ready(self, response) -> bool: """Check the health of the worker process""" try: @@ -133,25 +152,47 @@ def is_ready(self, response) -> bool: ) return False + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated port when worker exits.""" + try: + # system_port is always allocated in __init__ + deallocate_port(self.system_port) + except Exception as e: + logging.warning(f"Failed to release TRT-LLM worker port: {e}") + + return super().__exit__(exc_type, exc_val, exc_tb) + -@pytest.mark.nightly +@pytest.mark.timeout(140) # 3x average +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) def test_request_cancellation_trtllm_aggregated( - request, runtime_services, predownload_models + request, runtime_services_dynamic_ports ): """ End-to-end test for request cancellation functionality in aggregated mode. This test verifies that when a request is cancelled by the client, the system properly handles the cancellation and cleans up resources - on the worker side in aggregated (prefill_and_decode) mode. + on the worker side in aggregated (prefill_and_decode) mode. Tests three scenarios: + 1. Completion request + 2. Chat completion request (non-streaming) + 3. Chat completion request (streaming) + + Timing (Last Run: 2025-12-09): ~45s total + - Engine initialization: ~27s (frontend + worker) + - Testing 3 scenarios: ~15s (~5s each) + - Teardown: ~3s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") # Step 2: Start an aggregated worker - with DynamoWorkerProcess(request, mode="prefill_and_decode") as worker: + # Step 2: Start a single worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, mode="prefill_and_decode" + ) as worker: logger.info(f"Aggregated Worker PID: {worker.get_pid()}") # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness? @@ -173,7 +214,9 @@ def test_request_cancellation_trtllm_aggregated( logger.info(f"Testing {description.lower()}...") # Send the request (non-blocking) - cancellable_req = send_cancellable_request(request_type) + cancellable_req = send_cancellable_request( + frontend.frontend_port, request_type + ) # Poll for "New Request ID" pattern request_id, worker_log_offset = poll_for_pattern( @@ -208,9 +251,20 @@ def test_request_cancellation_trtllm_aggregated( logger.info(f"{description} detected successfully") -@pytest.mark.nightly +@pytest.mark.timeout(350) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) def test_request_cancellation_trtllm_decode_cancel( - request, runtime_services, predownload_models + request, runtime_services_dynamic_ports ): """ End-to-end test for request cancellation during decode phase with unified frontend. @@ -218,18 +272,27 @@ def test_request_cancellation_trtllm_decode_cancel( This test verifies that when a request is cancelled by the client during the decode phase, the system properly handles the cancellation and cleans up resources on the decode worker side in a disaggregated setup. + + Timing (Last Run: 2025-12-09): ~115s total (2 workers at 45% GPU each) + - Engine initialization: ~92s (frontend: 2s, prefill worker: 45s, decode worker: 45s sequential) + - Testing stream cancellation during decode: ~20s + - Teardown: ~3s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start the prefill worker - with DynamoWorkerProcess(request, mode="prefill") as prefill_worker: + # Step 2: Start the prefill worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, mode="prefill" + ) as prefill_worker: logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") - # Step 3: Start the decode worker - with DynamoWorkerProcess(request, mode="decode") as decode_worker: + # Step 3: Start the decode worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, mode="decode" + ) as decode_worker: logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness? @@ -241,7 +304,9 @@ def test_request_cancellation_trtllm_decode_cancel( ) # Send streaming request (non-blocking) - cancellable_req = send_cancellable_request("chat_completion_stream") + cancellable_req = send_cancellable_request( + frontend.frontend_port, "chat_completion_stream" + ) # Poll for "Prefill Request ID" pattern in prefill worker (frontend routes here first) request_id, prefill_log_offset = poll_for_pattern( @@ -281,9 +346,20 @@ def test_request_cancellation_trtllm_decode_cancel( ) -@pytest.mark.nightly +@pytest.mark.timeout(350) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) def test_request_cancellation_trtllm_prefill_cancel( - request, runtime_services, predownload_models + request, runtime_services_dynamic_ports ): """ End-to-end test for request cancellation during prefill phase with unified frontend. @@ -291,18 +367,27 @@ def test_request_cancellation_trtllm_prefill_cancel( This test verifies that when a request is cancelled by the client during the prefill phase, the system properly handles the cancellation and cleans up resources on the prefill worker. Since the request is cancelled before prefill completes, the decode worker never receives it. + + Timing (Last Run: 2025-12-09): ~115s total (2 workers at 45% GPU each) + - Engine initialization: ~92s (frontend: 2s, prefill worker: 45s, decode worker: 45s sequential) + - Testing cancellation during prefill: ~20s + - Teardown: ~3s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start the prefill worker - with DynamoWorkerProcess(request, mode="prefill") as prefill_worker: + # Step 2: Start the prefill worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, mode="prefill" + ) as prefill_worker: logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") - # Step 3: Start the decode worker - with DynamoWorkerProcess(request, mode="decode") as decode_worker: + # Step 3: Start the decode worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, mode="decode" + ) as decode_worker: logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness? @@ -315,7 +400,7 @@ def test_request_cancellation_trtllm_prefill_cancel( # Send request with long prompt (non-blocking) cancellable_req = send_cancellable_request( - "completion", use_long_prompt=True + frontend.frontend_port, "completion", use_long_prompt=True ) # Poll for "Prefill Request ID" pattern in prefill worker (frontend routes here first) @@ -364,30 +449,41 @@ def test_request_cancellation_trtllm_prefill_cancel( ) +@pytest.mark.timeout(350) # 3x average +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.xfail( reason="May fail due to unknown reason with TRT-LLM or backend implementation", strict=False, ) def test_request_cancellation_trtllm_kv_transfer_cancel( - request, runtime_services, predownload_models + request, runtime_services_dynamic_ports ): """ End-to-end test for request cancellation during prefill to decode KV transfer phase. This test verifies that when a request is cancelled by the client during the KV transfer phase, the system properly handles the cancellation and cleans up resources on the workers. + + Timing (Last Run: 2025-12-09): ~115s total (2 workers at 45% GPU each) + - Engine initialization: ~92s (frontend: 2s, prefill worker: 45s, decode worker: 45s sequential) + - Testing KV transfer cancellation: ~20s + - Teardown: ~3s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start the prefill worker - with DynamoWorkerProcess(request, mode="prefill") as prefill_worker: + # Step 2: Start the prefill worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, mode="prefill" + ) as prefill_worker: logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") - # Step 3: Start the decode worker - with DynamoWorkerProcess(request, mode="decode") as decode_worker: + # Step 3: Start the decode worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, mode="decode" + ) as decode_worker: logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") # TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness? @@ -400,7 +496,7 @@ def test_request_cancellation_trtllm_kv_transfer_cancel( # Send request with long prompt cancellable_req = send_cancellable_request( - "completion", use_long_prompt=True + frontend.frontend_port, "completion", use_long_prompt=True ) # Poll for "Prefill Request ID" pattern in prefill worker @@ -441,7 +537,9 @@ def test_request_cancellation_trtllm_kv_transfer_cancel( ) # Verify the workers are still functional - cancellable_req = send_cancellable_request("chat_completion_stream") + cancellable_req = send_cancellable_request( + frontend.frontend_port, "chat_completion_stream" + ) _, decode_log_offset = poll_for_pattern( process=decode_worker, pattern="Decode Request ID: ", diff --git a/tests/fault_tolerance/cancellation/test_vllm.py b/tests/fault_tolerance/cancellation/test_vllm.py index 86c365e843..3c5731751b 100644 --- a/tests/fault_tolerance/cancellation/test_vllm.py +++ b/tests/fault_tolerance/cancellation/test_vllm.py @@ -1,6 +1,14 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +""" +Test Execution Times (Last Run: 2025-12-09): +- test_request_cancellation_vllm_aggregated: ~55s (gpu_1) +- test_request_cancellation_vllm_decode_cancel: ~53s (gpu_2) +- test_request_cancellation_vllm_prefill_cancel: ~53s (gpu_2) +- Total: 161.65s (0:02:41) +""" + import logging import os import shutil @@ -14,17 +22,35 @@ send_cancellable_request, ) from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME -from tests.utils.engine_process import FRONTEND_PORT from tests.utils.managed_process import ManagedProcess from tests.utils.payloads import check_health_generate, check_models_api +from tests.utils.port_utils import allocate_port, deallocate_port logger = logging.getLogger(__name__) +pytestmark = [ + pytest.mark.vllm, + pytest.mark.gpu_1, + pytest.mark.e2e, + pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), + pytest.mark.post_merge, # post_merge to pinpoint failure commit +] + class DynamoWorkerProcess(ManagedProcess): """Process manager for Dynamo worker with vLLM backend""" - def __init__(self, request, is_prefill: bool = False): + def __init__( + self, + request, + frontend_port: int, + is_prefill: bool = False, + ): + # Allocate system port for this worker + system_port = allocate_port(9100) + self.system_port = system_port + self.frontend_port = frontend_port + command = [ "python3", "-m", @@ -40,34 +66,43 @@ def __init__(self, request, is_prefill: bool = False): "3", ] - # Set port based on worker type - port = "8082" if is_prefill else "8081" - # Configure health check based on worker type if is_prefill: # Prefill workers check their own status endpoint command.append("--is-prefill-worker") - health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)] + health_check_urls = [ + (f"http://localhost:{system_port}/health", self.is_ready) + ] else: # Decode workers should also check their own status endpoint first, # then verify the frontend sees the model health_check_urls = [ - (f"http://localhost:{port}/health", self.is_ready), - (f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api), - (f"http://localhost:{FRONTEND_PORT}/health", check_health_generate), + (f"http://localhost:{system_port}/health", self.is_ready), + (f"http://localhost:{frontend_port}/v1/models", check_models_api), + (f"http://localhost:{frontend_port}/health", check_health_generate), ] - # Set debug logging environment + # Set environment variables env = os.environ.copy() + env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") + env["DYN_LOG"] = "debug" + # Disable canary health check - these tests expect full control over requests + # sent to the workers where canary health check intermittently sends dummy + # requests to workers interfering with the test process which may cause + # intermittent failures + env["DYN_HEALTH_CHECK_ENABLED"] = "false" env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' - env["DYN_SYSTEM_PORT"] = port + env["DYN_SYSTEM_PORT"] = str(system_port) + env["DYN_HTTP_PORT"] = str(frontend_port) # Set KV event port and NIXL side channel port only for prefill worker # to avoid conflicts with decode worker if is_prefill: - env["DYN_VLLM_KV_EVENT_PORT"] = "20082" - env["VLLM_NIXL_SIDE_CHANNEL_PORT"] = "5601" + env["DYN_VLLM_KV_EVENT_PORT"] = "20082" # TODO: use dynamic port allocation + env[ + "VLLM_NIXL_SIDE_CHANNEL_PORT" + ] = "5601" # TODO: use dynamic port allocation # Set log directory based on worker type worker_type = "prefill_worker" if is_prefill else "worker" @@ -104,6 +139,16 @@ def get_pid(self): """Get the PID of the worker process""" return self.proc.pid if self.proc else None + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated port when worker exits.""" + try: + # system_port is always allocated in __init__ + deallocate_port(self.system_port) + except Exception as e: + logging.warning(f"Failed to release vLLM worker port: {e}") + + return super().__exit__(exc_type, exc_val, exc_tb) + def is_ready(self, response) -> bool: """Check the health of the worker process""" try: @@ -120,14 +165,9 @@ def is_ready(self, response) -> bool: return False -@pytest.mark.vllm -@pytest.mark.gpu_1 -@pytest.mark.e2e -@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) -@pytest.mark.nightly -def test_request_cancellation_vllm_aggregated( - request, runtime_services, predownload_models -): +@pytest.mark.timeout(110) # 3x average +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) +def test_request_cancellation_vllm_aggregated(request, runtime_services_dynamic_ports): """ End-to-end test for request cancellation functionality in aggregated mode. @@ -137,14 +177,19 @@ def test_request_cancellation_vllm_aggregated( 1. Completion request 2. Chat completion request (non-streaming) 3. Chat completion request (streaming) + + Timing (Last Run: 2025-12-09): ~55s total + - Engine initialization: ~15s + - Testing 3 scenarios: ~38s (~12s each) + - Teardown: ~2s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start a single worker - with DynamoWorkerProcess(request) as worker: + # Step 2: Start a single worker (allocates its own system_port) + with DynamoWorkerProcess(request, frontend.frontend_port) as worker: logger.info(f"Worker PID: {worker.get_pid()}") # Step 3: Test request cancellation with polling approach @@ -163,7 +208,9 @@ def test_request_cancellation_vllm_aggregated( logger.info(f"Testing {description.lower()}...") # Send the request (non-blocking) - cancellable_req = send_cancellable_request(request_type) + cancellable_req = send_cancellable_request( + frontend.frontend_port, request_type + ) # Poll for "Decode Request ID" pattern (vLLM v2 pattern) request_id, worker_log_offset = poll_for_pattern( @@ -198,13 +245,20 @@ def test_request_cancellation_vllm_aggregated( logger.info(f"{description} detected successfully") -@pytest.mark.vllm -@pytest.mark.gpu_1 -@pytest.mark.e2e -@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) -@pytest.mark.nightly +@pytest.mark.timeout(150) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) def test_request_cancellation_vllm_decode_cancel( - request, runtime_services, predownload_models, set_ucx_tls_no_mm + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm ): """ End-to-end test for request cancellation during decode phase. @@ -212,18 +266,27 @@ def test_request_cancellation_vllm_decode_cancel( This test verifies that when a request is cancelled by the client during the decode phase, the system properly handles the cancellation and cleans up resources on the decode worker side in a disaggregated setup. + + Timing (Last Run: 2025-12-09): ~53s total (requires 2 GPUs) + - Engine initialization: ~23s (decode + prefill workers) + - Testing stream cancellation during decode: ~28s + - Teardown: ~2s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start the prefill worker - with DynamoWorkerProcess(request, is_prefill=True) as prefill_worker: + # Step 2: Start the prefill worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, is_prefill=True + ) as prefill_worker: logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") - # Step 3: Start the decode worker - with DynamoWorkerProcess(request, is_prefill=False) as decode_worker: + # Step 3: Start the decode worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, is_prefill=False + ) as decode_worker: logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") # Step 4: Test request cancellation for streaming scenario @@ -232,7 +295,9 @@ def test_request_cancellation_vllm_decode_cancel( ) # Send streaming request (non-blocking) - cancellable_req = send_cancellable_request("chat_completion_stream") + cancellable_req = send_cancellable_request( + frontend.frontend_port, "chat_completion_stream" + ) # Poll for "Decode Request ID" pattern in decode worker (vLLM v2 pattern) request_id, decode_log_offset = poll_for_pattern( @@ -272,13 +337,20 @@ def test_request_cancellation_vllm_decode_cancel( ) -@pytest.mark.vllm -@pytest.mark.gpu_1 -@pytest.mark.e2e -@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) -@pytest.mark.nightly +@pytest.mark.timeout(150) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) def test_request_cancellation_vllm_prefill_cancel( - request, runtime_services, predownload_models, set_ucx_tls_no_mm + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm ): """ End-to-end test for request cancellation during prefill phase. @@ -286,18 +358,27 @@ def test_request_cancellation_vllm_prefill_cancel( This test verifies that when a request is cancelled by the client during the prefill phase, the system properly handles the cancellation and cleans up resources on both the decode and prefill workers in a disaggregated setup. + + Timing (Last Run: 2025-12-09): ~53s total (requires 2 GPUs) + - Engine initialization: ~23s (decode + prefill workers) + - Testing cancellation during prefill: ~28s + - Teardown: ~2s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start the prefill worker - with DynamoWorkerProcess(request, is_prefill=True) as prefill_worker: + # Step 2: Start the prefill worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, is_prefill=True + ) as prefill_worker: logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") - # Step 3: Start the decode worker - with DynamoWorkerProcess(request, is_prefill=False) as decode_worker: + # Step 3: Start the decode worker (allocates its own system_port) + with DynamoWorkerProcess( + request, frontend.frontend_port, is_prefill=False + ) as decode_worker: logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") # Step 4: Test request cancellation during prefill phase @@ -309,7 +390,7 @@ def test_request_cancellation_vllm_prefill_cancel( # Send request with long prompt (non-blocking) cancellable_req = send_cancellable_request( - "completion", use_long_prompt=True + frontend.frontend_port, "completion", use_long_prompt=True ) # Poll for "Prefill Request ID" pattern in prefill worker (vLLM v2 pattern) diff --git a/tests/fault_tolerance/cancellation/utils.py b/tests/fault_tolerance/cancellation/utils.py index 822e9a143c..1d366e85ee 100644 --- a/tests/fault_tolerance/cancellation/utils.py +++ b/tests/fault_tolerance/cancellation/utils.py @@ -8,14 +8,14 @@ import socket import threading import time -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, cast import pytest import requests from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME -from tests.utils.engine_process import FRONTEND_PORT from tests.utils.managed_process import ManagedProcess +from tests.utils.port_utils import allocate_port, deallocate_port logger = logging.getLogger(__name__) @@ -24,11 +24,20 @@ class DynamoFrontendProcess(ManagedProcess): """Process manager for Dynamo frontend""" def __init__(self, request): - command = ["python", "-m", "dynamo.frontend"] + # Allocate frontend port + frontend_port = allocate_port(8100) + self.frontend_port = frontend_port + + command = ["python", "-m", "dynamo.frontend", "--http-port", str(frontend_port)] - # Set debug logging environment env = os.environ.copy() + env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") env["DYN_LOG"] = "debug" + # Disable canary health check - these tests expect full control over requests + # sent to the workers where canary health check intermittently sends dummy + # requests to workers interfering with the test process which may cause + # intermittent failures + env["DYN_HEALTH_CHECK_ENABLED"] = "false" # Unset DYN_SYSTEM_PORT - frontend doesn't use system metrics server env.pop("DYN_SYSTEM_PORT", None) @@ -46,10 +55,20 @@ def __init__(self, request): command=command, env=env, display_output=True, - terminate_existing=True, + terminate_existing=False, # Don't terminate other processes of the same name, we'll only terminate our own PID log_dir=log_dir, ) + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated port when frontend exits.""" + try: + # frontend_port is always allocated in __init__ + deallocate_port(self.frontend_port) + except Exception as e: + logger.warning(f"Failed to release frontend port: {e}") + + return super().__exit__(exc_type, exc_val, exc_tb) + class CancellableRequest: """A wrapper for a single request that can be explicitly cancelled. @@ -169,12 +188,15 @@ def get_response(self): return self.response -def send_completion_request(prompt: str, max_tokens: int) -> CancellableRequest: +def send_completion_request( + prompt: str, max_tokens: int, frontend_port: int +) -> CancellableRequest: """Send a completion request to the frontend Args: prompt: The prompt for completion max_tokens: Maximum tokens to generate + frontend_port: Port where the frontend is running Returns: A CancellableRequest object that can be explicitly cancelled @@ -194,7 +216,7 @@ def send_completion_request(prompt: str, max_tokens: int) -> CancellableRequest: # Return a cancellable request object cancellable_req = CancellableRequest() cancellable_req.post( - f"http://localhost:{FRONTEND_PORT}/v1/completions", + f"http://localhost:{frontend_port}/v1/completions", headers=headers, json=payload, ) @@ -202,13 +224,14 @@ def send_completion_request(prompt: str, max_tokens: int) -> CancellableRequest: def send_chat_completion_request( - prompt: str, max_tokens: int, stream: bool = False + prompt: str, max_tokens: int, frontend_port: int, stream: bool = False ) -> CancellableRequest: """Send a chat completion request to the frontend Args: prompt: The prompt for chat completion max_tokens: Maximum tokens to generate + frontend_port: Port where the frontend is running stream: Whether to stream the response Returns: @@ -230,7 +253,7 @@ def send_chat_completion_request( # Return a cancellable request object cancellable_req = CancellableRequest() cancellable_req.post( - f"http://localhost:{FRONTEND_PORT}/v1/chat/completions", + f"http://localhost:{frontend_port}/v1/chat/completions", headers=headers, json=payload, stream=stream, @@ -239,12 +262,14 @@ def send_chat_completion_request( def send_cancellable_request( + frontend_port: int, request_type: str = "completion", use_long_prompt: bool = False, ) -> CancellableRequest: """Send a request that can be manually cancelled. Args: + frontend_port: Port where the frontend is running request_type: Type of request - "completion", "chat_completion", or "chat_completion_stream" use_long_prompt: Whether to use an extremely long prompt @@ -256,11 +281,11 @@ def send_cancellable_request( prompt += " Make sure it is" + " long" * 16000 + "!" if request_type == "completion": - return send_completion_request(prompt, 16384) + return send_completion_request(prompt, 16384, frontend_port) elif request_type == "chat_completion": - return send_chat_completion_request(prompt, 16384, stream=False) + return send_chat_completion_request(prompt, 16384, frontend_port, stream=False) elif request_type == "chat_completion_stream": - return send_chat_completion_request(prompt, 16384, stream=True) + return send_chat_completion_request(prompt, 16384, frontend_port, stream=True) else: raise ValueError(f"Unknown request type: {request_type}") @@ -278,12 +303,15 @@ def read_streaming_responses( Raises: pytest.fail if stream ends before expected_count responses """ - response = cancellable_req.get_response() - if not response or response.status_code != 200: + response_raw = cancellable_req.get_response() + if response_raw is None: + pytest.fail("Failed to get streaming response: response is None") + if response_raw.status_code != 200: pytest.fail( - f"Failed to get streaming response: status_code={response.status_code if response else 'None'}" + f"Failed to get streaming response: status_code={response_raw.status_code}" ) + response = cast(requests.Response, response_raw) # Type narrowing after checks response_count = 0 for line in response.iter_lines(): response_count += 1 diff --git a/tests/fault_tolerance/deploy/client.py b/tests/fault_tolerance/deploy/client.py index 03432c0d38..89dd7daec7 100644 --- a/tests/fault_tolerance/deploy/client.py +++ b/tests/fault_tolerance/deploy/client.py @@ -18,12 +18,14 @@ import json import logging import os +import signal import subprocess import time from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import requests +from kr8s.objects import Pod from tests.utils.managed_deployment import ManagedDeployment @@ -44,7 +46,7 @@ def get_frontend_port( deployment_spec: Any, pod_ports: Dict[str, Any], logger: logging.Logger, -) -> Tuple[Optional[str], Optional[int], Optional[str]]: +) -> Tuple[Optional[str], Optional[int], Optional[Pod]]: """ Select a frontend pod using round-robin and setup port forwarding. @@ -60,7 +62,7 @@ def get_frontend_port( Returns: Tuple of (pod_name, local_port, pod_instance) or (None, None, None) if failed """ - pods = managed_deployment.get_pods(managed_deployment.frontend_service_name) + pods = managed_deployment.get_pods([managed_deployment.frontend_service_name]) port = 0 pod_name = None @@ -270,6 +272,7 @@ def run_aiperf( logger: logging.Logger, max_retries: int = 1, retry_delay: float = 1, + continuous_load: bool = False, ) -> bool: """ Execute AI-Perf with specified parameters. @@ -280,13 +283,14 @@ def run_aiperf( model: Model name pod_name: Selected pod name for logging port: Local port number - requests_per_client: Number of requests to send + requests_per_client: Number of requests to send (used if continuous load not enabled) input_token_length: Input token count output_token_length: Output token count output_dir: Directory for AI-Perf artifacts logger: Logger instance max_retries: Maximum number of retry attempts (default: 1) retry_delay: Delay in seconds between retries (default: 1) + continuous_load: If True, use continuous load instead of fixed request count Returns: True if successful, False otherwise @@ -315,8 +319,6 @@ def run_aiperf( # Enable streaming for TTFT and ITL metrics "--streaming", # Request parameters - "--request-count", - str(requests_per_client), # Required: how many requests "--concurrency", "1", # Optional: we set to 1 for sequential # Token configuration @@ -338,8 +340,13 @@ def run_aiperf( "100", # For reproducible results ] - # Calculate timeout (same as legacy would for all requests) - timeout = max(requests_per_client * 2 + 60, 300) # At least 5 minutes + if continuous_load: + cmd.extend(["--benchmark-duration", "1800"]) # 30 minutes for continuous load + logger.info("Using continuous load with duration: 30 minutes") + timeout = 1860 # 31 minutes default for duration-based tests (30 minutes + 1 minute buffer) + else: + cmd.extend(["--request-count", str(requests_per_client)]) + timeout = max(requests_per_client * 2 + 60, 300) # At least 5 minutes # Log execution logger.info(f"Starting AI-Perf for Pod {pod_name} Local Port {port}") @@ -354,15 +361,19 @@ def run_aiperf( logger.info(f"Command: {' '.join(cmd)}") # Retry logic for fault tolerance - retry FULL request count until success - - max_attempts = max_retries if max_retries > 0 else 1 + # Note: For continuous load, we only run once and expect SIGINT to stop it + max_attempts = 1 if continuous_load else (max_retries if max_retries > 0 else 1) success = False - all_results = [] for attempt in range(max_attempts): - logger.info( - f"AI-Perf attempt {attempt + 1}/{max_attempts} with {requests_per_client} requests" - ) + if continuous_load: + logger.info( + "AI-Perf continuous load (will run until interrupted by SIGINT)" + ) + else: + logger.info( + f"AI-Perf attempt {attempt + 1}/{max_attempts} with {requests_per_client} requests" + ) # Update output directory for this attempt attempt_dir = output_dir / f"attempt_{attempt}" @@ -374,13 +385,7 @@ def run_aiperf( cmd_attempt[artifact_dir_idx] = str(attempt_dir) try: - result = subprocess.run( - cmd_attempt, - capture_output=True, - text=True, - timeout=timeout, - stdin=subprocess.DEVNULL, # Prevent stdin reading which can cause process suspension - ) + result = run_aiperf_with_signal_handling(cmd_attempt, logger, timeout) # Save logs for this attempt with open(attempt_dir / "genai_perf.log", "w") as f: @@ -389,15 +394,6 @@ def run_aiperf( f.write("\n\n=== STDERR ===\n") f.write(result.stderr) - all_results.append( - { - "attempt": attempt + 1, - "returncode": result.returncode, - "stdout": result.stdout, - "stderr": result.stderr, - } - ) - if result.returncode == 0: # AI-Perf returns 0 even if all requests failed, so we need to check the output json_path = attempt_dir / "profile_export_aiperf.json" @@ -412,6 +408,19 @@ def run_aiperf( ) if success: break # Success - exit the retry loop + ## TODO: bug with aiperf git+https://github.com/ai-dynamo/aiperf.git@4d3fa29403c8f75da22a14f1f7b3aeb27db9288f + ## where sending a SIGINT on Mac can sometimes have an error code of -9 (SIGABRT) which results in profile_export_aiperf.json not being created + elif result.returncode == -9 and continuous_load: + logger.warning( + f""" + Attempt {attempt + 1} failed with return code {result.returncode} + This is a known bug with aiperf on Mac where sending a SIGINT can sometimes have an error code of -9 (SIGABRT) + which results in profile_export_aiperf.json not being created + """ + ) + logger.debug( + f"Stderr: {result.stderr[:500] if result.stderr else 'No stderr'}" + ) else: logger.warning( f"Attempt {attempt + 1} failed with return code {result.returncode}" @@ -421,22 +430,84 @@ def run_aiperf( ) except Exception as e: logger.error(f"Error in attempt {attempt + 1}: {str(e)}") - all_results.append({"attempt": attempt + 1, "error": str(e)}) - # Sleep before next attempt (if not the last attempt) - if not success and attempt < max_attempts - 1: + # Sleep before next attempt (if not the last attempt and not continuous load) + if not success and attempt < max_attempts - 1 and not continuous_load: time.sleep(retry_delay) - if success: + if success and not continuous_load: logger.info( f"AI-Perf successfully completed all {requests_per_client} requests for {pod_name}" ) + elif success and continuous_load: + logger.info( + f"AI-Perf sustained continuous load for {pod_name} and existed succesfully" + ) else: logger.error(f"AI-Perf failed all {max_attempts} attempts for {pod_name}") return success +# TODO: use file redirection and wait() instead of pipes and communicate +def run_aiperf_with_signal_handling( + cmd_attempt: List[str], + logger: logging.Logger, + timeout: int, +) -> subprocess.CompletedProcess: + """ + Run aiperf with signal handling for graceful shutdown. + + Handles SIGINT and SIGTERM forwarding and timeout when running with subprocess.Popen. + This ensures that Ctrl-C (SIGINT) and graceful termination signals (SIGTERM) + are properly forwarded to the subprocess so it can clean up gracefully and write results files. + """ + proc = subprocess.Popen( + cmd_attempt, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + stdin=subprocess.DEVNULL, + ) + + def signal_handler(signum, frame): + signal_names = { + signal.SIGINT: "SIGINT", + signal.SIGTERM: "SIGTERM", + } + signal_name = signal_names.get(signum, f"signal {signum}") + logger.info(f"Received {signal_name}, forwarding to aiperf subprocess") + try: + proc.send_signal(signum) + except ProcessLookupError: + pass # Process already terminated + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + stdout, stderr = proc.communicate(timeout=timeout) + returncode = proc.returncode + except subprocess.TimeoutExpired: + logger.warning(f"AI-Perf subprocess timed out after {timeout}s") + proc.kill() + stdout, stderr = proc.communicate() + returncode = proc.returncode + except KeyboardInterrupt: + logger.info("Received KeyboardInterrupt, sending SIGINT to aiperf subprocess") + proc.send_signal(signal.SIGINT) + try: + stdout, stderr = proc.communicate(timeout=30) # Give it time to clean up + returncode = proc.returncode + except subprocess.TimeoutExpired: + logger.warning("Subprocess didn't terminate gracefully, killing it") + proc.kill() + stdout, stderr = proc.communicate() + returncode = proc.returncode + + return subprocess.CompletedProcess(cmd_attempt, returncode, stdout, stderr) + + def log_summary_metrics( output_dir: Path, logger: logging.Logger, pod_name: str, port: int ) -> None: @@ -513,6 +584,7 @@ def client( output_token_length: int, max_retries: int, retry_delay: float = 1, + continuous_load: bool = False, ): """ Generate load using AI-Perf for fault tolerance testing. @@ -527,11 +599,12 @@ def client( model: Model name log_dir: Directory for output logs and AI-Perf artifacts index: Client index used for round-robin pod selection - requests_per_client: Number of requests to generate + requests_per_client: Number of requests to generate (used if continuous load not enabled) input_token_length: Number of input tokens per request output_token_length: Number of output tokens per request max_retries: Maximum retry attempts for AI-Perf execution retry_delay: Delay in seconds between retry attempts + continuous_load: If True, use continuous load instead of fixed request count """ logger = logging.getLogger(f"CLIENT: {index}") logging.getLogger("httpx").setLevel(logging.WARNING) @@ -578,6 +651,7 @@ def client( logger=logger, max_retries=max_retries, retry_delay=retry_delay, + continuous_load=continuous_load, ) if not success: diff --git a/tests/fault_tolerance/deploy/client_factory.py b/tests/fault_tolerance/deploy/client_factory.py index 936122f082..d8f8e3f99f 100644 --- a/tests/fault_tolerance/deploy/client_factory.py +++ b/tests/fault_tolerance/deploy/client_factory.py @@ -42,6 +42,7 @@ def get_client_function(client_type: str) -> Callable: output_token_length, max_retries, retry_delay_or_rate, # Differs between implementations + continuous_load, ) Raises: diff --git a/tests/fault_tolerance/deploy/conftest.py b/tests/fault_tolerance/deploy/conftest.py index 70545b9526..2fb85fb5ad 100644 --- a/tests/fault_tolerance/deploy/conftest.py +++ b/tests/fault_tolerance/deploy/conftest.py @@ -35,6 +35,13 @@ def pytest_addoption(parser): help="Include tests that require custom builds (e.g., MoE models). " "By default, these tests are excluded.", ) + parser.addoption( + "--skip-service-restart", + action="store_true", + default=False, + help="Skip restarting NATS and etcd services before deployment. " + "By default, these services are restarted.", + ) def pytest_generate_tests(metafunc): @@ -109,3 +116,9 @@ def namespace(request): def client_type(request): """Get client type from command line or use scenario default.""" return request.config.getoption("--client-type") + + +@pytest.fixture +def skip_service_restart(request): + """Get skip restart services flag from command line.""" + return request.config.getoption("--skip-service-restart") diff --git a/tests/fault_tolerance/deploy/container/Dockerfile.local_vllm b/tests/fault_tolerance/deploy/container/Dockerfile.local_vllm index 2e2a1c6f96..ea4e877321 100644 --- a/tests/fault_tolerance/deploy/container/Dockerfile.local_vllm +++ b/tests/fault_tolerance/deploy/container/Dockerfile.local_vllm @@ -9,7 +9,7 @@ ARG LOCAL_VLLM_IMAGE="vllm-elastic-ep:latest_all2all_buffer_input" ARG DYNAMO_BASE_IMAGE="dynamo:latest-none" ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda" -ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04" +ARG RUNTIME_IMAGE_TAG="12.9.0-runtime-ubuntu24.04" # Other build arguments ARG PYTHON_VERSION=3.12 @@ -57,7 +57,7 @@ RUN apt-get update && \ # prometheus dependencies ca-certificates \ # DeepGemm uses 'cuobjdump' which does not come with CUDA image - cuda-command-line-tools-12-8 && \ + cuda-command-line-tools-12-9 && \ rm -rf /var/lib/apt/lists/* # Copy CUDA development tools from vLLM image (for JIT compilation) diff --git a/tests/fault_tolerance/deploy/legacy_client.py b/tests/fault_tolerance/deploy/legacy_client.py index 5cb4df4557..668145838c 100644 --- a/tests/fault_tolerance/deploy/legacy_client.py +++ b/tests/fault_tolerance/deploy/legacy_client.py @@ -192,6 +192,7 @@ def client( max_retries, max_request_rate, retry_delay=1, + continuous_load=False, ): """Legacy custom client for fault tolerance testing. @@ -211,7 +212,11 @@ def client( max_retries: Maximum retry attempts per request max_request_rate: Maximum requests per second (for rate limiting) retry_delay: Delay in seconds between retries + continuous_load: If True, use continuous load instead of fixed request count """ + if continuous_load: + raise ValueError("Continuous load is not supported for legacy client") + logger = logging.getLogger(f"CLIENT: {index}") logging.getLogger("httpx").setLevel(logging.WARNING) @@ -228,7 +233,7 @@ def client( for i in range(requests_per_client): # Get available pods pods = managed_deployment.get_pods( - managed_deployment.frontend_service_name + [managed_deployment.frontend_service_name] ) port = 0 pod_name = None diff --git a/tests/fault_tolerance/deploy/parse_results.py b/tests/fault_tolerance/deploy/parse_results.py index 66bc967e9b..00c1839468 100644 --- a/tests/fault_tolerance/deploy/parse_results.py +++ b/tests/fault_tolerance/deploy/parse_results.py @@ -341,6 +341,7 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]: Returns: Dictionary with aggregated metrics and client count """ + logger = logging.getLogger(__name__) all_metrics: Dict[str, Any] = { "total_requests": 0, "successful_requests": 0, @@ -382,22 +383,28 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]: with open(profile_json) as f: client_metrics = json.load(f) - # AI-Perf format has "records" dictionary at the top level + # AI-Perf format can have "records" dictionary or metrics at top level + # Try records first (older format), then fall back to top level (newer format) records = client_metrics.get("records", {}) - # Extract successful request count - request_count_record = records.get("request_count", {}) + # Extract successful request count - check both locations + request_count_record = records.get( + "request_count" + ) or client_metrics.get("request_count", {}) successful_count = ( int(request_count_record.get("avg", 0)) - if request_count_record + if request_count_record and isinstance(request_count_record, dict) else 0 ) - # Extract error request count - error_request_count_record = records.get("error_request_count", {}) + # Extract error request count - check both locations + error_request_count_record = records.get( + "error_request_count" + ) or client_metrics.get("error_request_count", {}) error_request_count = ( int(error_request_count_record.get("avg", 0)) if error_request_count_record + and isinstance(error_request_count_record, dict) else 0 ) @@ -418,9 +425,17 @@ def parse_aiperf_client_results(log_dir: str) -> Dict[str, Any]: # Sum up actual error counts from each error type error_count = sum(error.get("count", 0) for error in error_summary) - # Check if test was cancelled + # Log if test was cancelled (expected for continuous load mode) if client_metrics.get("was_cancelled", False): - error_count = request_count # Mark all as failed if cancelled + logger.info( + f"AI-Perf client {item} was cancelled - anticipated if running with continuous load mode. " + f"Completed {request_count} requests before cancellation." + ) + + # Note: If test was cancelled (was_cancelled=True), we still count the requests + # that were successfully completed before cancellation. The request_count + # represents successful requests, and error_count represents actual errors. + # We don't mark cancelled requests as failed - they were just interrupted. # Validate data consistency if request_count < error_count: diff --git a/tests/fault_tolerance/deploy/scenarios.py b/tests/fault_tolerance/deploy/scenarios.py index 0dc93e384c..817f28394d 100644 --- a/tests/fault_tolerance/deploy/scenarios.py +++ b/tests/fault_tolerance/deploy/scenarios.py @@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import logging import re +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum, auto from typing import TYPE_CHECKING, Dict, List, Optional, Pattern -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict -from tests.utils.managed_deployment import DeploymentSpec +from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment if TYPE_CHECKING: from tests.fault_tolerance.deploy.base_checker import BaseChecker @@ -54,8 +57,8 @@ class DeploymentInfo(TypedDict, total=False): is_moe: Optional flag indicating if this is a Mixture-of-Experts model """ - spec: DeploymentSpec - backend: str + spec: Required[DeploymentSpec] + backend: Required[str] model: str is_moe: bool @@ -155,14 +158,144 @@ class Load: overflow_request_count: int = 15 # Number of overflow requests normal_request_count: int = 15 # Number of normal requests after overflow + continuous_load: bool = ( + False # If True, use continuous load instead of fixed request count + ) + @dataclass -class Failure: +class Failure(ABC): + """Base class for all failure types.""" + + # time to wait in seconds before the failure is injected time: int - pod_name: str - command: str - signal: str = "SIGINT" - replicas: int = 1 + + # names of DGD services to inject the failure into the corresponding pods for + service_names: list[str] + + @abstractmethod + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Execute the failure injection. + + Args: + deployment: The managed deployment to inject the failure into + logger: Logger instance for logging failure injection + + Returns: List of affected pod names + """ + pass + + @abstractmethod + def get_failure_key(self) -> str: + """Get the failure key for the failure.""" + pass + + +@dataclass +class RollingUpgradeFailure(Failure): + """Failure type for triggering rolling upgrades.""" + + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Execute rolling upgrade failure injection.""" + await deployment.trigger_rolling_upgrade(self.service_names) + + # Need to wait for the deployment to be unready so we know the rolling upgrade has started + await deployment.wait_for_unready(timeout=60, log_interval=10) + + await deployment._wait_for_ready(timeout=1800) # 30 minute timeout + + await asyncio.sleep( + self.time + ) # have some requests processed after the rolling upgrade has completed + + return await deployment.get_pod_names(self.service_names) + + def get_failure_key(self) -> str: + """Get the failure key for the rolling upgrade failure.""" + return f"rolling_upgrade:{','.join(self.service_names)}" + + +@dataclass +class DeletePodFailure(Failure): + """Failure type for deleting pods.""" + + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Execute pod deletion failure injection.""" + service_pod_dict = deployment.get_pods(self.service_names) + pod_names: list[str] = [] + for service_name, pods in service_pod_dict.items(): + for pod in pods: + deployment.get_pod_manifest_logs_metrics( + service_name, pod, ".before_delete" + ) + pod.delete(force=True) # force means no graceful termination + pod_names.append(pod.name) + + return pod_names + + def get_failure_key(self) -> str: + """Get the failure key for the delete pod failure.""" + return f"delete_pod:{','.join(self.service_names)}" + + +class TerminateProcessFailure(Failure): + """Failure type for terminating specific processes by name.""" + + def __init__( + self, + time: int, + service_names: list[str], + signal: str = "SIGINT", + process_name: str = "", + ): + """Initialize TerminateProcessFailure. + + Args: + time: Time to wait in seconds before the failure is injected + service_names: Names of DGD services to inject the failure into + signal: Signal to send (default: "SIGINT") + process_name: Name of the process to terminate (required) + end_condition: End condition for failure (e.g., "dgd_ready") + """ + super().__init__( + time=time, + service_names=service_names, + ) + if not process_name or not signal: + raise ValueError( + "process_name and signal are required for TerminateProcessFailure" + ) + self.process_name = process_name + self.signal = signal + + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Execute process termination failure injection.""" + service_pod_dict = deployment.get_pods(self.service_names) + pod_names: list[str] = [] + for service_name, pods in service_pod_dict.items(): + for pod in pods: + processes = deployment.get_processes(pod) + for process in processes: + if self.process_name in process.command: + logger.info( + f"Terminating {service_name} pod {pod} Pid {process.pid} Command {process.command}" + ) + process.kill(self.signal) + pod_names.append(pod.name) + + return pod_names + + def get_failure_key(self) -> str: + """Get the failure key for the terminate process failure.""" + return f"terminate_process:{','.join(self.service_names)}:{self.process_name}:{self.signal}" @dataclass @@ -182,13 +315,25 @@ def __init__( ): super().__init__( time=time, - pod_name="Client", - command="token_overflow", + service_names=["Client"], ) self.max_seq_len = max_seq_len self.overflow_multiplier = overflow_multiplier self.overflow_token_count = int(max_seq_len * overflow_multiplier) + async def execute( + self, deployment: ManagedDeployment, logger: logging.Logger + ) -> list[str]: + """Token overflow is handled client-side, so this is a no-op.""" + # The actual overflow is handled by the client configuration + # which uses the input_token_length from the Load config + # This is just a placeholder for the abstract method + return [] + + def get_failure_key(self) -> str: + """Get the failure key for the token overflow failure.""" + return f"token_overflow:{self.overflow_token_count}" + @dataclass class Scenario: @@ -206,7 +351,7 @@ class Scenario: # Helper functions to create deployment specs -def _create_deployment_spec(backend: str, yaml_path: str) -> DeploymentInfo: +def _create_deployment_info(backend: str, yaml_path: str) -> DeploymentInfo: """Create a deployment spec with backend information. Args: @@ -240,7 +385,9 @@ def _set_replicas(deployment_spec, backend, deploy_type, replicas): spec[WORKER_MAP[backend]["prefill"]].replicas = replicas -def _set_tensor_parallel(deployment_spec, backend, deploy_type, tp_size): +def _set_tensor_parallel( + deployment_spec: DeploymentInfo, backend: str, deploy_type: str, tp_size: int +): """Set tensor parallel size for worker components.""" spec = deployment_spec["spec"] @@ -308,7 +455,7 @@ def _create_deployments_for_backend(backend: str) -> Dict[str, DeploymentInfo]: scenario_name = "-".join(name_parts) # Create and configure the deployment - deployment = _create_deployment_spec(backend, yaml_files[deploy_type]) + deployment = _create_deployment_info(backend, yaml_files[deploy_type]) if tp_size > 1: _set_tensor_parallel(deployment, backend, deploy_type, tp_size) if dp_replicas > 1: @@ -397,34 +544,69 @@ def _create_backend_failures(backend, deploy_type="disagg"): process_name = f"dynamo.{backend}" failures = { - "frontend": [Failure(30, "Frontend", "dynamo.frontend")], - "frontend_pod": [Failure(30, "Frontend", "delete_pod")], - "decode_worker": [Failure(30, decode_worker, process_name, "SIGKILL")], - "decode_worker_pod": [Failure(30, decode_worker, "delete_pod")], - "prefill_worker": [Failure(30, prefill_worker, process_name, "SIGKILL")], - "prefill_worker_pod": [Failure(30, prefill_worker, "delete_pod")], + "frontend": [ + TerminateProcessFailure( + 30, ["Frontend"], "SIGINT", process_name="dynamo.frontend" + ) + ], + "frontend_pod": [DeletePodFailure(30, ["Frontend"])], + "decode_worker": [ + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name=process_name + ) + ], + "decode_worker_pod": [DeletePodFailure(30, [decode_worker])], + "prefill_worker": [ + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name=process_name + ) + ], + "prefill_worker_pod": [DeletePodFailure(30, [prefill_worker])], "none": [], } if backend == "vllm": failures["vllm_decode_engine_core"] = [ - Failure(30, decode_worker, "VLLM::EngineCore", "SIGKILL") + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name="VLLM::EngineCore" + ) ] failures["vllm_prefill_engine_core"] = [ - Failure(30, prefill_worker, "VLLM::EngineCore", "SIGKILL") + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name="VLLM::EngineCore" + ) ] elif backend == "sglang": failures["sglang_decode_scheduler"] = [ - Failure(30, decode_worker, "sglang::scheduler", "SIGKILL") + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name="sglang::scheduler" + ) ] failures["sglang_decode_detokenizer"] = [ - Failure(30, decode_worker, "sglang::detokenizer", "SIGKILL") + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name="sglang::detokenizer" + ) ] failures["sglang_prefill_scheduler"] = [ - Failure(30, prefill_worker, "sglang::scheduler", "SIGKILL") + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name="sglang::scheduler" + ) ] failures["sglang_prefill_detokenizer"] = [ - Failure(30, prefill_worker, "sglang::detokenizer", "SIGKILL") + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name="sglang::detokenizer" + ) + ] + elif backend == "trtllm": + failures["trtllm_decode_engine_core"] = [ + TerminateProcessFailure( + 30, [decode_worker], "SIGKILL", process_name="TRTLLM::EngineCore" + ) + ] + failures["trtllm_prefill_engine_core"] = [ + TerminateProcessFailure( + 30, [prefill_worker], "SIGKILL", process_name="TRTLLM::EngineCore" + ) ] return failures @@ -533,7 +715,7 @@ def create_legacy_load( # Populate Scenarios -scenarios = {} +scenarios: dict[str, Scenario] = {} # Map of backend+deploy_type to failure definitions backend_failure_map = {} @@ -729,5 +911,59 @@ def add_token_overflow_scenarios(): ) +def add_rolling_upgrade_scenarios(): + for backend in ["vllm", "sglang", "trtllm"]: + for worker_mode in ["agg", "disagg"]: + yaml_files = { + "agg": f"examples/backends/{backend}/deploy/agg.yaml", + "disagg": f"examples/backends/{backend}/deploy/disagg.yaml", + } + deployment_info = _create_deployment_info(backend, yaml_files[worker_mode]) + deployment_spec: DeploymentSpec = deployment_info["spec"] + + service_names: list[str] = [] + + # setting replicas to 2 so we have availability of 1 replica at a time + if worker_mode == "agg" and backend == "trtllm": + service_names.append(WORKER_MAP[backend]["decode_agg"]) + else: + service_names.append(WORKER_MAP[backend]["decode"]) + + if worker_mode == "disagg": + service_names.append(WORKER_MAP[backend]["prefill"]) + + for service_name in service_names: + deployment_spec.set_service_replicas(service_name, 2) + + load = Load( + clients=10, + input_token_length=100, + output_token_length=100, + max_retries=1, + client_type="aiperf", + max_request_rate=1.0, + success_threshold=100.0, + continuous_load=True, + ) + + scenario_name = f"{backend}-{worker_mode}-rolling-upgrade" + model = "Qwen/Qwen3-0.6B" + + failure = RollingUpgradeFailure( + time=30, + service_names=service_names, + ) + scenarios[scenario_name] = Scenario( + deployment=deployment_info["spec"], + load=load, + failures=[failure], + model=model, + backend=backend, + ) + + # Add the token overflow scenarios add_token_overflow_scenarios() + +# Add the rolling upgrade scenarios +add_rolling_upgrade_scenarios() diff --git a/tests/fault_tolerance/deploy/templates/vllm/moe_agg.yaml b/tests/fault_tolerance/deploy/templates/vllm/moe_agg.yaml index fa76acfc2c..9de8641034 100644 --- a/tests/fault_tolerance/deploy/templates/vllm/moe_agg.yaml +++ b/tests/fault_tolerance/deploy/templates/vllm/moe_agg.yaml @@ -60,7 +60,6 @@ spec: - --model - deepseek-ai/DeepSeek-V2-Lite - --trust-remote-code - - --disable-log-requests - --tensor-parallel-size - "1" - --data-parallel-size diff --git a/tests/fault_tolerance/deploy/templates/vllm/moe_disagg.yaml b/tests/fault_tolerance/deploy/templates/vllm/moe_disagg.yaml index b45bcf97de..6e559164c7 100644 --- a/tests/fault_tolerance/deploy/templates/vllm/moe_disagg.yaml +++ b/tests/fault_tolerance/deploy/templates/vllm/moe_disagg.yaml @@ -63,7 +63,6 @@ spec: - --model - deepseek-ai/DeepSeek-V2-Lite - --trust-remote-code - - --disable-log-requests - --tensor-parallel-size - "1" - --data-parallel-size @@ -130,7 +129,6 @@ spec: - --model - deepseek-ai/DeepSeek-V2-Lite - --trust-remote-code - - --disable-log-requests - --is-prefill-worker - --tensor-parallel-size - "1" diff --git a/tests/fault_tolerance/deploy/test_deployment.py b/tests/fault_tolerance/deploy/test_deployment.py index caeb9039a2..8fe12dba20 100644 --- a/tests/fault_tolerance/deploy/test_deployment.py +++ b/tests/fault_tolerance/deploy/test_deployment.py @@ -1,12 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import logging import multiprocessing +import os import re -import time +import signal from contextlib import contextmanager -from typing import Any +from multiprocessing.context import SpawnProcess +from typing import Any, Optional import pytest @@ -17,11 +20,12 @@ from tests.fault_tolerance.deploy.scenarios import ( OVERFLOW_SUFFIX, RECOVERY_SUFFIX, + Failure, Load, - TokenOverflowFailure, + Scenario, scenarios, ) -from tests.utils.managed_deployment import ManagedDeployment +from tests.utils.managed_deployment import DeploymentSpec, ManagedDeployment @pytest.fixture @@ -55,18 +59,18 @@ def scenario(scenario_name, client_type): @contextmanager def _clients( - logger, - request, - deployment_spec, - namespace, - model, + logger: logging.Logger, + log_dir: str, + deployment_spec: DeploymentSpec, + namespace: str, + model: str, load_config: Load, ): """Start client processes using factory pattern for client selection. Args: logger: Logger instance - request: Pytest request fixture + log_dir: Log directory for output logs and client logs/artifacts deployment_spec: Deployment specification namespace: Kubernetes namespace model: Model name to test @@ -79,7 +83,7 @@ def _clients( f"Starting {load_config.clients} clients using '{load_config.client_type}' client" ) - procs = [] + procs: list[SpawnProcess] = [] ctx = multiprocessing.get_context("spawn") # Determine retry_delay_or_rate based on client type @@ -90,6 +94,9 @@ def _clients( # AI-Perf client uses retry_delay between attempts (default 5s) retry_delay_or_rate = 5 + # Check if this is a continuous load test (rolling upgrade scenarios) + continuous_load = getattr(load_config, "continuous_load", False) + # Check if this is a mixed token test (overflow + recovery) # If mixed_token_test is True, run two phases; otherwise run normally if hasattr(load_config, "mixed_token_test") and load_config.mixed_token_test: @@ -108,13 +115,14 @@ def _clients( deployment_spec, namespace, model, - request.node.name + OVERFLOW_SUFFIX, + f"{log_dir}{OVERFLOW_SUFFIX}", i, load_config.overflow_request_count, # 15 overflow requests load_config.overflow_token_length, # 2x max_seq_len tokens load_config.output_token_length, load_config.max_retries, retry_delay_or_rate, + continuous_load, ), ) proc_overflow.start() @@ -128,7 +136,7 @@ def _clients( logger.info("Overflow requests completed. Starting recovery phase...") # Second phase: Send normal requests to test recovery - procs_recovery = [] + procs_recovery: list[SpawnProcess] = [] for i in range(load_config.clients): proc_normal = ctx.Process( target=client_func, @@ -136,7 +144,7 @@ def _clients( deployment_spec, namespace, model, - request.node.name + RECOVERY_SUFFIX, + f"{log_dir}{RECOVERY_SUFFIX}", i, load_config.normal_request_count, # 15 normal requests load_config.input_token_length, # Normal token count @@ -161,13 +169,14 @@ def _clients( deployment_spec, namespace, model, - request.node.name, + log_dir, i, load_config.requests_per_client, load_config.input_token_length, load_config.output_token_length, load_config.max_retries, retry_delay_or_rate, + continuous_load, # Pass continuous_load flag ), ) ) @@ -182,65 +191,50 @@ def _clients( logger.debug(f"{proc} joined") -def _inject_failures(failures, logger, deployment: ManagedDeployment): # noqa: F811 - """Inject failures and return info about affected pods. - - Returns: - Dict mapping failure info to list of affected pod names - Example: {"VllmDecodeWorker:delete_pod": ["pod-abc123", "pod-xyz789"]} +def _terminate_client_processes( + client_procs: list[SpawnProcess], + logger: logging.Logger, +): """ - affected_pods: dict[str, list] = {} - - for failure in failures: - time.sleep(failure.time) - - # Handle TokenOverflowFailure differently - it's a client-side injection - if isinstance(failure, TokenOverflowFailure): - # The actual overflow is handled by the client configuration - # which uses the input_token_length from the Load config - # This is just logging for visibility - continue - - pods = deployment.get_pods(failure.pod_name)[failure.pod_name] - - num_pods = len(pods) + Terminate client processes. + """ + # Send SIGINT to client processes to stop continuous load + if client_procs: + logger.info(f"Sending SIGINT to {len(client_procs)} client processes...") + for proc in client_procs: + if proc.is_alive(): + try: + if proc.pid is not None: + logger.debug(f"Sending SIGINT to client process {proc.pid}") + os.kill(proc.pid, signal.SIGINT) + else: + raise ValueError(f"Process {proc} has no PID") + except ProcessLookupError: + logger.debug(f"Process {proc.pid} already terminated") + except Exception as e: + logger.warning(f"Failed to send SIGINT to process {proc.pid}: {e}") + logger.info( + "SIGINT sent to all client processes, waiting for graceful shutdown..." + ) + else: + logger.warning("No client processes provided to terminate") - if not pods: - continue - replicas = failure.replicas +async def _inject_failures( + failures: list[Failure], + logger: logging.Logger, + deployment: ManagedDeployment, +) -> dict[str, list]: # noqa: F811 + affected_pods: dict[str, list] = {} - if not replicas: - replicas = num_pods + for failure in failures: + await asyncio.sleep(failure.time) logger.info(f"Injecting failure for: {failure}") - # Track which pods were affected by this failure - failure_key = f"{failure.pod_name}:{failure.command}" - if failure_key not in affected_pods: - affected_pods[failure_key] = [] - - for x in range(replicas): - pod = pods[x % num_pods] - - # Capture the exact pod name before we kill it - pod_name = pod.name - affected_pods[failure_key].append(pod_name) - - logger.info(f"Target pod for failure: {pod_name}") - - if failure.command == "delete_pod": - deployment.get_pod_logs(failure.pod_name, pod, ".before_delete") - logger.info(f"Deleting pod: {pod_name}") - pod.delete(force=True) - else: - processes = deployment.get_processes(pod) - for process in processes: - if failure.command in process.command: - logger.info( - f"Terminating {failure.pod_name} Pid {process.pid} Command {process.command} in pod {pod_name}" - ) - process.kill(failure.signal) + affected_pods[failure.get_failure_key()] = await failure.execute( + deployment, logger + ) return affected_pods @@ -445,11 +439,12 @@ def results_summary(): @pytest.mark.slow @pytest.mark.filterwarnings("ignore::DeprecationWarning") async def test_fault_scenario( - scenario, # noqa: F811 + scenario: Scenario, # noqa: F811 request, - image, - namespace, + image: str, + namespace: str, validation_context, # noqa: F811 # Shared context for passing data to validation + skip_service_restart: bool, ): """ Test dynamo serve deployments with injected failures @@ -468,6 +463,7 @@ async def test_fault_scenario( if image: scenario.deployment.set_image(image) + model: Optional[str] = None if scenario.model: scenario.deployment.set_model(scenario.model) model = scenario.model @@ -500,6 +496,7 @@ async def test_fault_scenario( namespace=namespace, log_dir=request.node.name, deployment_spec=scenario.deployment, + skip_service_restart=skip_service_restart, ) as deployment: # Populate shared context for validation validation_context["deployment"] = deployment @@ -507,14 +504,17 @@ async def test_fault_scenario( with _clients( logger, - request, + request.node.name, scenario.deployment, namespace, model, scenario.load, # Pass entire Load config object - ): + ) as client_procs: # Inject failures and capture which pods were affected - affected_pods = _inject_failures(scenario.failures, logger, deployment) - validation_context["affected_pods"] = affected_pods - + affected_pods = await _inject_failures( + scenario.failures, logger, deployment + ) logger.info(f"Affected pods during test: {affected_pods}") + + if scenario.load.continuous_load: + _terminate_client_processes(client_procs, logger) diff --git a/tests/fault_tolerance/etcd_ha/test_sglang.py b/tests/fault_tolerance/etcd_ha/test_sglang.py index 35e6783b51..45d660ba1a 100644 --- a/tests/fault_tolerance/etcd_ha/test_sglang.py +++ b/tests/fault_tolerance/etcd_ha/test_sglang.py @@ -1,6 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# TODO: Update to use dynamic port allocation (allocate_free_port) for parallel execution +# Currently uses hardcoded ports: FRONTEND_PORT (8000), system ports (8081, 8082) +# See tests/fault_tolerance/migration/test_sglang.py for dynamic port pattern + import logging import os import shutil diff --git a/tests/fault_tolerance/hardware/fault_injection_service/agents/gpu_fault_injector/gpu_xid_injector.py b/tests/fault_tolerance/hardware/fault_injection_service/agents/gpu_fault_injector/gpu_xid_injector.py index 9a7a95c2e9..fd9278bddc 100644 --- a/tests/fault_tolerance/hardware/fault_injection_service/agents/gpu_fault_injector/gpu_xid_injector.py +++ b/tests/fault_tolerance/hardware/fault_injection_service/agents/gpu_fault_injector/gpu_xid_injector.py @@ -153,6 +153,83 @@ def _check_privileged(self) -> bool: """Check if we have privileged access (required for nsenter)""" return os.geteuid() == 0 + def _get_pci_address_from_proc(self, gpu_id: int) -> str: + """ + Get PCI address for GPU by reading /host/proc/driver/nvidia/gpus/. + + This method works without nvidia-smi by reading NVIDIA kernel driver's procfs. + Maps GPU ID (Device Minor) to PCI address by scanning all GPU directories. + + Directory structure: + /host/proc/driver/nvidia/gpus/ + โ”œโ”€โ”€ 0001:00:00.0/information (Device Minor: 0 โ†’ GPU 0) + โ”œโ”€โ”€ 0002:00:00.0/information (Device Minor: 1 โ†’ GPU 1) + โ””โ”€โ”€ ... + + Args: + gpu_id: GPU device ID (0, 1, 2, ...) + + Returns: + PCI address (e.g., "0001:00:00.0") + + Raises: + FileNotFoundError: If /host/proc is not mounted + ValueError: If GPU ID not found + """ + proc_path = "/host/proc/driver/nvidia/gpus" + + # Check if proc path exists + if not os.path.exists(proc_path): + raise FileNotFoundError( + f"{proc_path} not found. Ensure /host/proc is mounted in pod spec." + ) + + # Iterate through all GPU directories + try: + gpu_dirs = os.listdir(proc_path) + except (PermissionError, OSError) as e: + raise FileNotFoundError(f"Cannot access {proc_path}: {e}") + logger.debug( + f"Found {len(gpu_dirs)} GPU directories in {proc_path}: {gpu_dirs}" + ) + + available_minors = [] + for pci_addr in gpu_dirs: + info_file = f"{proc_path}/{pci_addr}/information" + + try: + with open(info_file, "r") as f: + for line in f: + if line.startswith("Device Minor:"): + # Parse: "Device Minor: 0" โ†’ 0 + parts = line.split(":") + if len(parts) < 2: + logger.warning( + f"Unexpected format in {info_file}: {line.strip()}" + ) + continue + device_minor = int(parts[1].strip()) + available_minors.append(device_minor) + + if device_minor == gpu_id: + logger.info( + f"GPU {gpu_id} mapped to PCI {pci_addr} " + f"via /proc (Device Minor: {device_minor})" + ) + return pci_addr + except (IOError, OSError) as e: + logger.warning(f"Could not read {info_file}: {e}") + continue + except (ValueError, IndexError) as e: + logger.warning(f"Could not parse Device Minor from {info_file}: {e}") + continue + + # GPU ID not found + raise ValueError( + f"GPU {gpu_id} not found in {proc_path}. " + f"Available Device Minors: {sorted(available_minors)}" + ) + def _normalize_pci_address(self, pci_addr: str) -> str: """ Normalize PCI address from nvidia-smi format to kernel sysfs format. @@ -247,28 +324,8 @@ def _inject_fake_xid_to_kmsg(self, gpu_id: int, xid: int) -> Tuple[bool, str]: message template for each XID type. """ try: - # Get PCI address for the GPU - pci_result = subprocess.run( - [ - "nvidia-smi", - "--query-gpu=pci.bus_id", - "--format=csv,noheader", - "-i", - str(gpu_id), - ], - capture_output=True, - text=True, - timeout=10, - ) - - if pci_result.returncode != 0: - return ( - False, - f"Failed to get PCI address for GPU {gpu_id}: {pci_result.stderr}", - ) - - pci_addr_full = pci_result.stdout.strip() - pci_addr = self._normalize_pci_address(pci_addr_full) + # Get PCI address using /proc method (works without nvidia-smi) + pci_addr = self._get_pci_address_from_proc(gpu_id) # Get appropriate error message for this XID type # If XID is known, use specific message; otherwise use generic format diff --git a/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/README.md b/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/README.md index 810d676830..045dfb5063 100644 --- a/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/README.md +++ b/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/README.md @@ -6,13 +6,16 @@ ## What This Does -Makes CUDA calls return error codes to simulate various GPU failures. Uses LD_PRELOAD to intercept CUDA library calls. +Intercepts CUDA calls to simulate GPU failures using LD_PRELOAD. Faults persist across pod restarts via hostPath volumes, enabling realistic hardware failure testing. ``` -Pod calls cudaMalloc() โ†’ LD_PRELOAD intercepts โ†’ Returns error โ†’ Pod crashes +Pod calls cudaMalloc() โ†’ LD_PRELOAD intercepts โ†’ Checks /host-fault/cuda_fault_enabled โ†’ Returns error โ†’ Pod crashes ``` -**Result**: Realistic GPU failure testing without hardware damage. +**Key Features**: +- **Persistent faults**: hostPath volume (`/var/lib/cuda-fault-test`) survives pod restarts on same node +- **Runtime toggle**: Enable/disable faults without pod restarts via `/host-fault/cuda_fault_enabled` +- **Node-specific**: Faults only on target node, healthy nodes unaffected ## Scope @@ -35,13 +38,20 @@ This library simulates **software/orchestration-level failures** that occur when | **43** | GPU stopped responding | `CUDA_ERROR_LAUNCH_TIMEOUT` | Hung kernel | | **74** | NVLink error | `CUDA_ERROR_PEER_ACCESS_UNSUPPORTED` | Multi-GPU communication failure | +## How It Works + +1. **Deployment patching**: Adds hostPath volume + init container to compile library +2. **LD_PRELOAD injection**: Environment variable loads library before CUDA +3. **Runtime control**: Toggle file (`/host-fault/cuda_fault_enabled`) controls fault state +4. **Node persistence**: hostPath ensures faults survive pod restarts on same node + ## Files in This Directory | File | Purpose | |------|---------| -| `cuda_intercept.c` | C library source that intercepts CUDA calls | -| `inject_into_pods.py` | Helper functions for patching Kubernetes deployments | -| `Makefile` | Builds the `.so` library locally (optional, for standalone testing) | +| `cuda_intercept.c` | C library that intercepts CUDA calls and checks fault markers | +| `inject_into_pods.py` | Kubernetes deployment patcher (adds hostPath volume + library) | +| `Makefile` | Local build (optional, for testing) | ## Prerequisites diff --git a/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/cuda_intercept.c b/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/cuda_intercept.c index 1052eeda05..dfc05a8e79 100644 --- a/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/cuda_intercept.c +++ b/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/cuda_intercept.c @@ -59,19 +59,20 @@ static const xid_mapping_t xid_mappings[] = { }; // Get XID type and corresponding CUDA error +// Supports runtime toggling via /tmp/cuda_fault_enabled file static void get_fault_config(int* inject, int* xid_type, cudaError_t* error_code) { static int initialized = 0; - static int cached_inject = 0; + static int env_inject = 0; // From environment variable static int cached_xid = 79; // Default to XID 79 static cudaError_t cached_error = cudaErrorNoDevice; if (!initialized) { - // Check if injection is enabled + // Check if injection is enabled via environment char* env = getenv("CUDA_FAULT_INJECTION_ENABLED"); if (env) { - cached_inject = (strcmp(env, "1") == 0 || strcmp(env, "true") == 0); + env_inject = (strcmp(env, "1") == 0 || strcmp(env, "true") == 0); } // Get XID type @@ -85,8 +86,7 @@ get_fault_config(int* inject, int* xid_type, cudaError_t* error_code) if (xid_mappings[i].xid == cached_xid) { cached_error = xid_mappings[i].cuda_error; fprintf( - stderr, "[CUDA FAULT INJECTION] ENABLED - Simulating XID %d (%s)\n", cached_xid, - xid_mappings[i].description); + stderr, "[CUDA FAULT INJECTION] Library loaded - XID %d (%s)\n", cached_xid, xid_mappings[i].description); found = 1; break; } @@ -97,16 +97,37 @@ get_fault_config(int* inject, int* xid_type, cudaError_t* error_code) cached_xid = 79; cached_error = cudaErrorNoDevice; } - } else { - fprintf( - stderr, "[CUDA FAULT INJECTION] %s (default: XID 79 - GPU fell off bus)\n", - cached_inject ? "ENABLED" : "DISABLED"); } initialized = 1; } - *inject = cached_inject; + // Runtime toggle: Check node-persistent fault marker on EVERY call + // Use hostPath (/host-fault) so fault persists across pod restarts on same node + // Pod reschedules to different node โ†’ no file there โ†’ automatic recovery! + int runtime_inject = env_inject; // Default to env var + + // Check hostPath first (persistent across restarts on same node) + FILE* toggle_file = fopen("/host-fault/cuda_fault_enabled", "r"); + if (toggle_file) { + char toggle_value[4] = {0}; + if (fgets(toggle_value, sizeof(toggle_value), toggle_file)) { + runtime_inject = (toggle_value[0] == '1'); + } + fclose(toggle_file); + } else { + // Fallback to ephemeral /tmp for backwards compatibility + toggle_file = fopen("/tmp/cuda_fault_enabled", "r"); + if (toggle_file) { + char toggle_value[4] = {0}; + if (fgets(toggle_value, sizeof(toggle_value), toggle_file)) { + runtime_inject = (toggle_value[0] == '1'); + } + fclose(toggle_file); + } + } + + *inject = runtime_inject; *xid_type = cached_xid; *error_code = cached_error; } diff --git a/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/inject_into_pods.py b/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/inject_into_pods.py index 552ed46ee4..5083d7c2fb 100755 --- a/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/inject_into_pods.py +++ b/tests/fault_tolerance/hardware/fault_injection_service/cuda_fault_injection/inject_into_pods.py @@ -201,6 +201,18 @@ def _patch_service_for_injection( {"name": "cuda-fault-lib", "emptyDir": {}} ) + # Add hostPath volume for persistent fault marker (survives pod restarts on same node) + # This simulates persistent hardware failure! + service["extraPodSpec"]["volumes"].append( + { + "name": "node-fault-marker", + "hostPath": { + "path": "/var/lib/cuda-fault-test", + "type": "DirectoryOrCreate", + }, + } + ) + # Add init container to decode base64 if "initContainers" not in service["extraPodSpec"]: service["extraPodSpec"]["initContainers"] = [] @@ -247,7 +259,7 @@ def _patch_service_for_injection( if vm.get("name") != "cuda-fault-lib" ] - # Add mount + # Add mount for compiled library service["extraPodSpec"]["mainContainer"]["volumeMounts"].append( { "name": "cuda-fault-lib", @@ -256,8 +268,18 @@ def _patch_service_for_injection( } ) + # Add mount for persistent fault marker (hostPath) + service["extraPodSpec"]["mainContainer"]["volumeMounts"].append( + { + "name": "node-fault-marker", + "mountPath": "/host-fault", + "readOnly": False, # Need write access + } + ) + print(" โœ“ Added init container to compile library") print(" โœ“ Added ConfigMap volume mount") + print(" โœ“ Added hostPath volume for persistent fault marker") # Add node affinity to pin pods to target node (simulates real XID 79 behavior) if target_node and enable: @@ -287,14 +309,15 @@ def _patch_service_for_injection( service["extraPodSpec"]["volumes"] = [ v for v in service["extraPodSpec"]["volumes"] - if v.get("name") not in ["cuda-fault-lib", "cuda-fault-lib-source"] + if v.get("name") + not in ["cuda-fault-lib", "cuda-fault-lib-source", "node-fault-marker"] ] if "volumeMounts" in service["extraPodSpec"].get("mainContainer", {}): service["extraPodSpec"]["mainContainer"]["volumeMounts"] = [ vm for vm in service["extraPodSpec"]["mainContainer"]["volumeMounts"] - if vm.get("name") != "cuda-fault-lib" + if vm.get("name") not in ["cuda-fault-lib", "node-fault-marker"] ] # Remove init container @@ -323,6 +346,7 @@ def patch_deployment_env( use_configmap=True, target_node=None, xid_type=79, + passthrough_mode=False, ): """Patch deployment to add/remove LD_PRELOAD environment variable. @@ -334,6 +358,8 @@ def patch_deployment_env( target_node: If provided, adds node affinity to pin pods to this node (simulates real XID where pods crash on the faulty node) xid_type: XID error type to simulate (79, 48, 94, 95, 43, 74). Default: 79 + passthrough_mode: If True, set CUDA_FAULT_INJECTION_ENABLED=0 (library loaded but disabled) + Allows baseline testing before enabling faults via toggle """ custom_api = client.CustomObjectsApi() apps_api = client.AppsV1Api() @@ -385,9 +411,14 @@ def patch_deployment_env( # Prepare environment variables new_envs = [] if enable: + # Set CUDA_FAULT_INJECTION_ENABLED based on passthrough_mode + fault_enabled_value = "0" if passthrough_mode else "1" new_envs = [ {"name": "LD_PRELOAD", "value": lib_path}, - {"name": "CUDA_FAULT_INJECTION_ENABLED", "value": "1"}, + { + "name": "CUDA_FAULT_INJECTION_ENABLED", + "value": fault_enabled_value, + }, {"name": "CUDA_XID_TYPE", "value": str(xid_type)}, ] @@ -400,6 +431,28 @@ def patch_deployment_env( available_services = list(services.keys()) print(f" โ†’ Available services: {available_services}") + # Set aggressive update strategy when enabling (allow all pods to update at once) + # This ensures all pods get CUDA faults, not just the first few + if enable: + if "updateStrategy" not in spec: + spec["updateStrategy"] = {} + if "rollingUpdate" not in spec["updateStrategy"]: + spec["updateStrategy"]["rollingUpdate"] = {} + + # Allow all pods to be unavailable during update + spec["updateStrategy"]["rollingUpdate"]["maxUnavailable"] = "100%" + # Don't create surge pods + spec["updateStrategy"]["rollingUpdate"]["maxSurge"] = 0 + print(" โ†’ Set update strategy: maxUnavailable=100%, maxSurge=0") + print(" (All pods will update simultaneously)") + else: + # Restore default update strategy when disabling + if "updateStrategy" in spec: + spec["updateStrategy"] = { + "rollingUpdate": {"maxUnavailable": "25%", "maxSurge": "25%"} + } + print(" โ†’ Restored default update strategy (maxUnavailable=25%)") + for service_name in services_to_patch: if service_name in services: print(f" โ†’ Patching service: {service_name}") @@ -465,6 +518,38 @@ def patch_deployment_env( print(f" Services patched: {', '.join(patched_services)}") if use_configmap and enable: print(f" Library mounted at: {lib_path}") + + # Force restart all worker pods when enabling to apply changes immediately + if enable: + print( + " โ†’ Force-deleting all worker pods to apply changes immediately..." + ) + core_api = client.CoreV1Api() + try: + worker_pods = core_api.list_namespaced_pod( + namespace=namespace, + label_selector=f"nvidia.com/dynamo-graph-deployment-name={deployment_name},nvidia.com/dynamo-component-type=worker", + ) + deleted_count = 0 + for pod in worker_pods.items: + try: + core_api.delete_namespaced_pod( + name=pod.metadata.name, + namespace=namespace, + grace_period_seconds=0, + ) + deleted_count += 1 + except Exception as e: + print( + f" โš  Could not delete pod {pod.metadata.name}: {e}" + ) + print( + f" โœ“ Deleted {deleted_count} pod(s) - they will restart with CUDA library" + ) + except Exception as e: + print(f" โš  Could not list/delete pods: {e}") + print(" Pods will eventually restart, but may take longer") + return True except ApiException as e: @@ -505,11 +590,15 @@ def patch_deployment_env( if enable: # Add new env vars + # Set CUDA_FAULT_INJECTION_ENABLED based on passthrough_mode + fault_enabled_value = "0" if passthrough_mode else "1" container.env.append( client.V1EnvVar(name="LD_PRELOAD", value="/tmp/cuda_intercept.so") ) container.env.append( - client.V1EnvVar(name="CUDA_FAULT_INJECTION_ENABLED", value="1") + client.V1EnvVar( + name="CUDA_FAULT_INJECTION_ENABLED", value=fault_enabled_value + ) ) container.env.append( client.V1EnvVar(name="CUDA_XID_TYPE", value=str(xid_type)) diff --git a/tests/fault_tolerance/hardware/fault_injection_service/helpers/__init__.py b/tests/fault_tolerance/hardware/fault_injection_service/helpers/__init__.py index 5acad0802c..498555aa5b 100644 --- a/tests/fault_tolerance/hardware/fault_injection_service/helpers/__init__.py +++ b/tests/fault_tolerance/hardware/fault_injection_service/helpers/__init__.py @@ -9,11 +9,26 @@ """ __all__ = [ + # GPU discovery utilities + "get_available_gpu_ids", + "get_gpu_id_for_process", + "get_gpu_pci_address", + "get_gpu_info", + "get_processes_on_gpu", + # Inference testing utilities "InferenceLoadTester", "get_inference_endpoint", + # Kubernetes operations utilities "NodeOperations", "PodOperations", ] +from .gpu_discovery import ( + get_available_gpu_ids, + get_gpu_id_for_process, + get_gpu_info, + get_gpu_pci_address, + get_processes_on_gpu, +) from .inference_testing import InferenceLoadTester, get_inference_endpoint from .k8s_operations import NodeOperations, PodOperations diff --git a/tests/fault_tolerance/hardware/fault_injection_service/helpers/cuda_fault_injection.py b/tests/fault_tolerance/hardware/fault_injection_service/helpers/cuda_fault_injection.py index 994a8a7a65..fa3b58b09c 100644 --- a/tests/fault_tolerance/hardware/fault_injection_service/helpers/cuda_fault_injection.py +++ b/tests/fault_tolerance/hardware/fault_injection_service/helpers/cuda_fault_injection.py @@ -37,7 +37,7 @@ def __init__(self, lib_dir: Optional[Path] = None): lib_dir = Path(__file__).parent.parent / "cuda_fault_injection" self.lib_dir = lib_dir - self.lib_path = lib_dir / "fake_cuda_xid79.so" + self.lib_path = lib_dir / "cuda_intercept.so" self.lib_built = False def build_library(self) -> bool: @@ -101,12 +101,57 @@ def create_configmap_with_library(self, namespace: str) -> bool: traceback.print_exc() return False + def check_if_cuda_library_deployed( + self, deployment_name: str, namespace: str + ) -> bool: + """ + Check if CUDA fault injection is already deployed to the deployment. + + Args: + deployment_name: Name of the deployment + namespace: Kubernetes namespace + + Returns: + True if CUDA fault library is already deployed, False otherwise + """ + try: + k8s_custom = client.CustomObjectsApi() + + # Get the DynamoGraphDeployment + dgd = k8s_custom.get_namespaced_custom_object( + group="nvidia.com", + version="v1alpha1", + namespace=namespace, + plural="dynamographdeployments", + name=deployment_name, + ) + + # Check for LD_PRELOAD in worker container env + spec = dgd.get("spec", {}) + worker_spec = spec.get("workerSpec", {}) + pod_spec = worker_spec.get("podSpec", {}) + containers = pod_spec.get("containers", []) + + for container in containers: + if container.get("name") in ["vllm-worker", "worker"]: + env = container.get("env", []) + for env_var in env: + if env_var.get("name") == "LD_PRELOAD": + return True + + return False + + except Exception: + # If we can't read the deployment, assume it's not deployed + return False + def patch_deployment_for_cuda_fault( self, deployment_name: str, namespace: str, target_node: Optional[str] = None, xid_type: int = 79, + passthrough_mode: bool = False, ) -> bool: """ Patch deployment to enable CUDA fault injection. @@ -116,6 +161,7 @@ def patch_deployment_for_cuda_fault( - Init container to compile library - LD_PRELOAD environment variable - CUDA_XID_TYPE environment variable + - CUDA_FAULT_INJECTION_ENABLED (0 in passthrough mode, 1 otherwise) - Node affinity (if target_node specified) Args: @@ -123,6 +169,8 @@ def patch_deployment_for_cuda_fault( namespace: Kubernetes namespace target_node: Node to pin pods to (simulates real XID behavior) xid_type: XID error type to simulate (79, 48, 94, 95, 43, 74). Default: 79 + passthrough_mode: If True, set CUDA_FAULT_INJECTION_ENABLED=0 + (library loaded but faults disabled for baseline) Returns: True if patch succeeded @@ -149,6 +197,7 @@ def patch_deployment_for_cuda_fault( use_configmap=True, target_node=target_node, xid_type=xid_type, + passthrough_mode=passthrough_mode, ) except Exception as e: @@ -339,6 +388,248 @@ def cleanup_cuda_fault_injection( traceback.print_exc() return False + def enable_cuda_faults_via_toggle( + self, pods: List[client.V1Pod], namespace: str, enable: bool = True + ) -> bool: + """ + Enable or disable CUDA faults on running pods via environment variable toggle. + + This modifies the CUDA_FAULT_INJECTION_ENABLED env var in running pods + without restarting them. Requires the CUDA library to already be loaded. + + Args: + pods: List of pods to toggle faults on + namespace: Kubernetes namespace + enable: True to enable faults, False to disable + + Returns: + True if toggle succeeded + """ + if not pods: + return False + + toggle_value = "1" if enable else "0" + action = "Enabling" if enable else "Disabling" + + print(f"\n[โ†’] {action} CUDA faults via toggle on {len(pods)} pods...") + + success_count = 0 + failed_pods = [] + + for pod in pods: + pod_name = pod.metadata.name + + try: + # Get the main container name from pod spec + container_name = ( + pod.spec.containers[0].name if pod.spec.containers else None + ) + if not container_name: + failed_pods.append((pod_name, "No container found")) + continue + + # Write toggle file to hostPath (persists across pod restarts on same node) + # This simulates persistent hardware failure! + exec_command = [ + "sh", + "-c", + f'mkdir -p /host-fault && echo "{toggle_value}" > /host-fault/cuda_fault_enabled && cat /host-fault/cuda_fault_enabled', + ] + + result = subprocess.run( + [ + "kubectl", + "exec", + "-n", + namespace, + pod_name, + "-c", + container_name, + "--", + ] + + exec_command, + capture_output=True, + text=True, + timeout=10, + ) + + if result.returncode == 0: + actual_value = result.stdout.strip() + if actual_value == toggle_value: + print( + f" โœ“ Toggle={toggle_value} in {pod_name}/{container_name}" + ) + success_count += 1 + else: + failed_pods.append( + ( + pod_name, + f"Verify failed: expected '{toggle_value}', got '{actual_value}'", + ) + ) + else: + failed_pods.append( + (pod_name, f"Exec failed: {result.stderr.strip()}") + ) + + except Exception as e: + failed_pods.append((pod_name, str(e))) + continue + + if failed_pods: + print(f" โš  Failed to toggle {len(failed_pods)} pods:") + for pod_name, error in failed_pods: + print(f" - {pod_name}: {error}") + + print(f" โ†’ Result: {success_count}/{len(pods)} pods toggled successfully") + return success_count > 0 + + def disable_cuda_faults_via_toggle( + self, pods: List[client.V1Pod], namespace: str + ) -> bool: + """ + Disable CUDA faults on running pods via toggle. + + Args: + pods: List of pod objects to disable faults on + namespace: Kubernetes namespace + + Returns: + True if disable succeeded + """ + return self.enable_cuda_faults_via_toggle(pods, namespace, enable=False) + + def cleanup_node_fault_markers( + self, pods: List[client.V1Pod], namespace: str + ) -> bool: + """ + Remove persistent fault marker files from node hostPath. + This cleans up /host-fault/cuda_fault_enabled to prevent future tests from failing. + + Args: + pods: List of pods (to access nodes) + namespace: Kubernetes namespace + + Returns: + True if cleanup succeeded + """ + if not pods: + return True + + print(" [->] Cleaning persistent fault markers from nodes...") + + success_count = 0 + nodes_cleaned = set() + + for pod in pods: + pod_name = pod.metadata.name + node_name = pod.spec.node_name + + # Skip if we already cleaned this node + if node_name in nodes_cleaned: + continue + + try: + container_name = ( + pod.spec.containers[0].name if pod.spec.containers else None + ) + if not container_name: + continue + + # Remove the persistent marker file from hostPath + exec_command = [ + "sh", + "-c", + 'rm -f /host-fault/cuda_fault_enabled 2>/dev/null; echo "ok"', + ] + + result = subprocess.run( + [ + "kubectl", + "exec", + "-n", + namespace, + pod_name, + "-c", + container_name, + "--", + ] + + exec_command, + capture_output=True, + text=True, + timeout=10, + ) + + if result.returncode == 0: + print(f" โœ“ Cleaned fault marker on node {node_name}") + nodes_cleaned.add(node_name) + success_count += 1 + + except Exception: + continue + + return success_count > 0 + + def verify_env_var_set( + self, + deployment_name: str, + namespace: str, + expected_value: str, + max_wait: int = 30, + ) -> bool: + """ + Verify that CUDA_FAULT_INJECTION_ENABLED env var is set to expected value. + Polls until the value matches or timeout. + + Args: + deployment_name: Name of the DynamoGraphDeployment + namespace: Kubernetes namespace + expected_value: Expected value ("0" or "1") + max_wait: Maximum seconds to wait + + Returns: + True if verified + """ + k8s_custom = client.CustomObjectsApi() + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + dgd = k8s_custom.get_namespaced_custom_object( + group="nvidia.com", + version="v1alpha1", + namespace=namespace, + plural="dynamographdeployments", + name=deployment_name, + ) + + # Check both worker services + for service_name in ["VllmDecodeWorker", "VllmPrefillWorker"]: + if service_name in dgd["spec"]["services"]: + service = dgd["spec"]["services"][service_name] + env_vars = ( + service.get("extraPodSpec", {}) + .get("mainContainer", {}) + .get("env", []) + ) + + for env_var in env_vars: + if env_var.get("name") == "CUDA_FAULT_INJECTION_ENABLED": + if env_var.get("value") != expected_value: + time.sleep(1) + break # Try again + else: + continue # This service is good + break # Inner loop broke, try again + else: + # All services verified + return True + + except Exception: + time.sleep(1) + + return False + def trigger_pod_restart(self, pods: List[client.V1Pod], namespace: str): """ Delete pods to trigger restart with new env vars. diff --git a/tests/fault_tolerance/hardware/fault_injection_service/helpers/gpu_discovery.py b/tests/fault_tolerance/hardware/fault_injection_service/helpers/gpu_discovery.py new file mode 100644 index 0000000000..0e651cd030 --- /dev/null +++ b/tests/fault_tolerance/hardware/fault_injection_service/helpers/gpu_discovery.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +""" +GPU discovery utilities for fault tolerance testing. + +Provides functions to discover GPU information from Kubernetes pods, +including mapping processes to GPUs and handling CUDA_VISIBLE_DEVICES remapping. +""" + +import logging +from typing import List, Optional + +from kr8s.objects import Pod + +logger = logging.getLogger(__name__) + + +def get_available_gpu_ids(pod: Pod) -> List[int]: + """ + Get list of actual GPU IDs available in the pod. + + Handles non-sequential GPU IDs correctly (e.g., [0, 1, 3, 7] with gaps). + + Args: + pod: Kubernetes pod object (kr8s pod with exec() method) + + Returns: + List of GPU IDs (e.g., [0, 1, 2, 3]) or empty list if no GPUs found + + Example: + >>> gpu_ids = get_available_gpu_ids(pod) + >>> print(gpu_ids) + [0, 1, 2, 3] + """ + try: + result = pod.exec(["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"]) + + # Parse GPU indices from output + gpu_ids = [] + for line in result.stdout.decode().splitlines(): + line = line.strip() + if line.isdigit(): + gpu_ids.append(int(line)) + + if not gpu_ids: + logger.warning(f"No GPUs found in pod {pod.name}") + return [] + + logger.debug(f"Available GPU IDs in pod {pod.name}: {gpu_ids}") + return gpu_ids + + except Exception as e: + logger.error(f"Failed to get GPU IDs from pod {pod.name}: {e}") + return [] + + +def get_gpu_id_for_process(pod: Pod, process_pid: int) -> int: + """ + Find which GPU a process is using. + + Queries nvidia-smi to determine the primary GPU for a given process. + This correctly handles: + - Non-sequential GPU IDs + - CUDA_VISIBLE_DEVICES remapping + - Multi-GPU processes (returns primary GPU) + + Args: + pod: Kubernetes pod object (kr8s pod with exec() method) + process_pid: Process ID to find GPU for + + Returns: + GPU ID (0-N) where the process is running, or 0 if not found + + Example: + >>> gpu_id = get_gpu_id_for_process(pod, 603) + >>> print(gpu_id) + 1 # Process 603 is running on GPU 1 + """ + try: + # Get actual GPU IDs available in pod (handles non-sequential IDs) + gpu_ids = get_available_gpu_ids(pod) + + if not gpu_ids: + logger.error(f"No GPUs found in pod {pod.name}!") + return 0 + + logger.debug( + f"Searching for PID {process_pid} across {len(gpu_ids)} GPUs: {gpu_ids}" + ) + + # Check each GPU for our process + for gpu_id in gpu_ids: + result = pod.exec( + [ + "nvidia-smi", + "-i", + str(gpu_id), + "--query-compute-apps=pid", + "--format=csv,noheader", + ] + ) + + # Parse PIDs running on this GPU + pids_output = result.stdout.decode().strip() + + # Handle both single PID and multiple PIDs + # Output can be: + # "602" (single PID) + # "602\n603\n604" (multiple PIDs) + # " 602 " (with spaces) + pids_on_gpu = [p.strip() for p in pids_output.split("\n") if p.strip()] + + # Check if our PID is in the list + if str(process_pid) in pids_on_gpu: + logger.info( + f"PID {process_pid} found on GPU {gpu_id} in pod {pod.name}" + ) + return gpu_id + + # Process not found on any GPU + logger.warning( + f"PID {process_pid} not found on any GPU in pod {pod.name}. " + f"This may happen if the process hasn't initialized CUDA yet or " + f"if nvidia-smi doesn't track multi-process CUDA apps. " + f"Defaulting to first GPU: {gpu_ids[0]}" + ) + return gpu_ids[0] + + except Exception as e: + logger.error( + f"GPU discovery failed for PID {process_pid} in pod {pod.name}: {e}" + ) + return 0 + + +def get_gpu_pci_address(pod: Pod, gpu_id: int) -> Optional[str]: + """ + Get PCI bus address for a GPU. + + The PCI address is used in kernel XID messages and identifies + the physical hardware location of the GPU. + + Args: + pod: Kubernetes pod object + gpu_id: GPU index (0-N) as shown by nvidia-smi + + Returns: + PCI address (e.g., "00000000:8D:00.0") or None if failed + + Example: + >>> pci_addr = get_gpu_pci_address(pod, 1) + >>> print(pci_addr) + 00000000:91:00.0 + """ + try: + result = pod.exec( + [ + "nvidia-smi", + "-i", + str(gpu_id), + "--query-gpu=pci.bus_id", + "--format=csv,noheader", + ] + ) + + pci_addr = result.stdout.decode().strip() + + if not pci_addr: + logger.error(f"Empty PCI address for GPU {gpu_id}") + return None + + logger.debug(f"GPU {gpu_id} in pod {pod.name} has PCI address: {pci_addr}") + return pci_addr + + except Exception as e: + logger.error( + f"Failed to get PCI address for GPU {gpu_id} in pod {pod.name}: {e}" + ) + return None + + +def get_gpu_info(pod: Pod, gpu_id: int) -> Optional[dict]: + """ + Get comprehensive information about a GPU. + + Args: + pod: Kubernetes pod object + gpu_id: GPU index (0-N) + + Returns: + Dict with keys: index, name, pci_bus_id, memory_total, driver_version + or None if failed + + Example: + >>> info = get_gpu_info(pod, 0) + >>> print(info) + { + 'index': 0, + 'name': 'NVIDIA H200', + 'pci_bus_id': '00000000:8D:00.0', + 'memory_total': '143771 MiB', + 'driver_version': '550.163.01' + } + """ + try: + result = pod.exec( + [ + "nvidia-smi", + "-i", + str(gpu_id), + "--query-gpu=index,name,pci.bus_id,memory.total,driver_version", + "--format=csv,noheader", + ] + ) + + output = result.stdout.decode().strip() + parts = [p.strip() for p in output.split(",")] + + if len(parts) < 5: + logger.error(f"Unexpected nvidia-smi output format: {output}") + return None + + return { + "index": int(parts[0]), + "name": parts[1], + "pci_bus_id": parts[2], + "memory_total": parts[3], + "driver_version": parts[4], + } + + except Exception as e: + logger.error(f"Failed to get GPU info for GPU {gpu_id}: {e}") + return None + + +def get_processes_on_gpu(pod: Pod, gpu_id: int) -> List[int]: + """ + Get list of process IDs running on a specific GPU. + + Args: + pod: Kubernetes pod object + gpu_id: GPU index (0-N) + + Returns: + List of PIDs running on this GPU, or empty list if none/error + + Example: + >>> pids = get_processes_on_gpu(pod, 1) + >>> print(pids) + [602, 603] + """ + try: + result = pod.exec( + [ + "nvidia-smi", + "-i", + str(gpu_id), + "--query-compute-apps=pid", + "--format=csv,noheader", + ] + ) + + pids_output = result.stdout.decode().strip() + + if not pids_output: + logger.debug(f"No processes found on GPU {gpu_id} in pod {pod.name}") + return [] + + # Parse PIDs (handle multiple PIDs on same GPU) + pids = [] + for line in pids_output.split("\n"): + line = line.strip() + if line.isdigit(): + pids.append(int(line)) + + logger.debug(f"GPU {gpu_id} in pod {pod.name} has processes: {pids}") + return pids + + except Exception as e: + logger.error(f"Failed to get processes for GPU {gpu_id}: {e}") + return [] diff --git a/tests/fault_tolerance/migration/test_sglang.py b/tests/fault_tolerance/migration/test_sglang.py new file mode 100644 index 0000000000..3540dab5c0 --- /dev/null +++ b/tests/fault_tolerance/migration/test_sglang.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test Execution Times (Last Run: 2025-12-09): +- test_request_migration_sglang_worker_failure: ~58s (gpu_1) +- test_request_migration_sglang_graceful_shutdown: ~58s (gpu_1, skipped) +- test_no_request_migration_sglang_worker_failure: ~38s (gpu_1) +- test_no_request_migration_sglang_graceful_shutdown: ~38s (gpu_1, skipped) +- Total: 115.71s (0:01:55) for enabled tests +""" + +import logging +import os +import shutil + +import pytest + +from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME +from tests.utils.managed_process import ManagedProcess, terminate_process_tree +from tests.utils.payloads import check_models_api +from tests.utils.port_utils import allocate_port, deallocate_port + +# Import utilities from the refactored utils module +from .utils import ( + DynamoFrontendProcess, + determine_request_receiving_worker, + start_completion_request, + validate_completion_response, + verify_migration_occurred, +) + +logger = logging.getLogger(__name__) + +pytestmark = [ + pytest.mark.sglang, + pytest.mark.gpu_1, + pytest.mark.e2e, + pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), + pytest.mark.post_merge, # post_merge to pinpoint failure commit +] + + +class DynamoWorkerProcess(ManagedProcess): + """Process manager for Dynamo worker with SGLang backend""" + + def __init__( + self, + request, + worker_id: str, + system_port: int, + frontend_port: int, + migration_limit: int = 3, + ): + self.worker_id = worker_id + self.system_port = system_port + + command = [ + "python3", + "-m", + "dynamo.sglang", + "--model-path", + FAULT_TOLERANCE_MODEL_NAME, + "--served-model-name", + FAULT_TOLERANCE_MODEL_NAME, + "--trust-remote-code", + "--skip-tokenizer-init", + "--mem-fraction-static", + "0.45", + "--context-length", + "8192", + "--migration-limit", + str(migration_limit), + ] + + # Set environment variables + env = os.environ.copy() + env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") + env["DYN_LOG"] = "debug" + # Disable canary health check - these tests expect full control over requests + # sent to the workers where canary health check intermittently sends dummy + # requests to workers interfering with the test process which may cause + # intermittent failures + env["DYN_HEALTH_CHECK_ENABLED"] = "false" + env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' + env["DYN_SYSTEM_PORT"] = str(system_port) + env["DYN_HTTP_PORT"] = str(frontend_port) + + # TODO: Have the managed process take a command name explicitly to distinguish + # between processes started with the same command. + log_dir = f"{request.node.name}_{worker_id}" + + # Clean up any existing log directory from previous runs + try: + shutil.rmtree(log_dir) + logger.info(f"Cleaned up existing log directory: {log_dir}") + except FileNotFoundError: + # Directory doesn't exist, which is fine + pass + + super().__init__( + command=command, + env=env, + health_check_urls=[ + (f"http://localhost:{frontend_port}/v1/models", check_models_api), + (f"http://localhost:{system_port}/health", self.is_ready), + ], + timeout=300, + display_output=True, + terminate_existing=False, + stragglers=["SGLANG:EngineCore"], + straggler_commands=["-m dynamo.sglang"], + log_dir=log_dir, + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated port when worker exits.""" + try: + # system_port is a required parameter, always set in __init__ + deallocate_port(self.system_port) + except Exception as e: + logging.warning(f"Failed to release SGLang worker port: {e}") + + return super().__exit__(exc_type, exc_val, exc_tb) + + def is_ready(self, response) -> bool: + """Check the health of the worker process""" + try: + data = response.json() + if data.get("status") == "ready": + logger.info(f"{self.worker_id} status is ready") + return True + logger.warning( + f"{self.worker_id} status is not ready: {data.get('status')}" + ) + except ValueError: + logger.warning(f"{self.worker_id} health response is not valid JSON") + return False + + +@pytest.mark.timeout(235) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) +def test_request_migration_sglang_worker_failure( + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm +): + """ + End-to-end test for worker fault tolerance with migration support using SGLang. + + This test verifies that when a worker is killed during request processing, + the system can handle the failure gracefully and migrate the request to + another worker. + + Timing (Last Run: 2025-12-09): ~58s total + - Engine initialization: ~22s (Worker1: 12s, Worker2: 10s) + - Test execution (request + migration): ~21s + - Teardown: ~15s + """ + + # Allocate ports to avoid conflicts with parallel tests + worker1_system_port = allocate_port(9100) + worker2_system_port = allocate_port(9200) + + # Step 1: Start the frontend (allocates its own port) + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start 2 workers sequentially + with DynamoWorkerProcess( + request, + "worker1", + system_port=worker1_system_port, + frontend_port=frontend.frontend_port, + ) as worker1: + logger.info(f"Worker 1 PID: {worker1.get_pid()}") + + with DynamoWorkerProcess( + request, + "worker2", + system_port=worker2_system_port, + frontend_port=frontend.frontend_port, + ) as worker2: + logger.info(f"Worker 2 PID: {worker2.get_pid()}") + + # Step 3: Send the request + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) + + # Step 4: Use polling to determine which worker received the request + worker, worker_name = determine_request_receiving_worker( + worker1, worker2, receiving_pattern="New Request ID: " + ) + + # Step 5: Kill the worker that has the request + logger.info( + f"Killing {worker_name} with PID {worker.get_pid()} processing the request" + ) + terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0) + + # Step 6: Validate the completion response + validate_completion_response(request_thread, response_list) + + # Step 7: Verify migration occurred + verify_migration_occurred(frontend) + + +@pytest.mark.timeout(235) # 3x average +@pytest.mark.skip(reason="SGLang graceful shutdown not yet implemented") +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) +def test_request_migration_sglang_graceful_shutdown( + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm +): + """ + End-to-end test for worker fault tolerance with graceful shutdown and migration support using SGLang. + + This test verifies that when a worker receives a graceful shutdown signal (SIGTERM) + during request processing, the system can handle the shutdown gracefully and migrate + the request to another worker. Unlike the abrupt kill test, this simulates a more + controlled shutdown scenario where the worker has time to clean up and notify the + system about its shutdown. + + Timing (Last Run: 2025-12-09): ~58s total (estimated, similar to worker_failure) + - Engine initialization: ~22s (Worker1: 12s, Worker2: 10s) + - Test execution (request + graceful shutdown + migration): ~21s + - Teardown: ~15s + """ + + # Allocate ports to avoid conflicts with parallel tests + worker1_system_port = allocate_port(9100) + worker2_system_port = allocate_port(9200) + + # Step 1: Start the frontend (allocates its own port) + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start 2 workers sequentially + with DynamoWorkerProcess( + request, + "worker1", + system_port=worker1_system_port, + frontend_port=frontend.frontend_port, + ) as worker1: + logger.info(f"Worker 1 PID: {worker1.get_pid()}") + + with DynamoWorkerProcess( + request, + "worker2", + system_port=worker2_system_port, + frontend_port=frontend.frontend_port, + ) as worker2: + logger.info(f"Worker 2 PID: {worker2.get_pid()}") + + # Step 3: Send the request + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) + + # Step 4: Use polling to determine which worker received the request + worker, worker_name = determine_request_receiving_worker( + worker1, worker2, receiving_pattern="New Request ID: " + ) + + # Step 5: Gracefully shutdown the worker that has the request + logger.info( + f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request" + ) + terminate_process_tree( + worker.get_pid(), immediate_kill=False, timeout=10 + ) + + # Step 6: Validate the completion response + validate_completion_response(request_thread, response_list) + + # Step 7: Verify migration occurred during graceful shutdown + verify_migration_occurred(frontend) + + +@pytest.mark.timeout(135) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) +def test_no_request_migration_sglang_worker_failure( + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm +): + """ + End-to-end test for worker fault tolerance with migration disabled using SGLang. + + This test verifies that when migration is disabled (migration_limit=0) and a worker + is killed during request processing, the request fails as expected without migration. + This is the opposite behavior of test_request_migration_sglang_worker_failure. + + Timing (Last Run: 2025-12-09): ~38s total + - Engine initialization: ~23s (Worker1: 13s, Worker2: 10s) + - Test execution (failure validation): <1s + - Teardown: ~15s + """ + + # Allocate ports to avoid conflicts with parallel tests + worker1_system_port = allocate_port(9100) + worker2_system_port = allocate_port(9200) + + # Step 1: Start the frontend (allocates its own port) + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start 2 workers sequentially with migration disabled + with DynamoWorkerProcess( + request, + "worker1", + system_port=worker1_system_port, + frontend_port=frontend.frontend_port, + migration_limit=0, + ) as worker1: + logger.info(f"Worker 1 PID: {worker1.get_pid()}") + + with DynamoWorkerProcess( + request, + "worker2", + system_port=worker2_system_port, + frontend_port=frontend.frontend_port, + migration_limit=0, + ) as worker2: + logger.info(f"Worker 2 PID: {worker2.get_pid()}") + + # Step 3: Send the request + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) + + # Step 4: Use polling to determine which worker received the request + worker, worker_name = determine_request_receiving_worker( + worker1, worker2, receiving_pattern="New Request ID: " + ) + + # Step 5: Kill the worker that has the request + logger.info( + f"Killing {worker_name} with PID {worker.get_pid()} processing the request" + ) + terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0) + + # Step 6: Validate the completion response - should fail without migration + try: + validate_completion_response(request_thread, response_list) + pytest.fail( + "Request succeeded unexpectedly when migration was disabled" + ) + except AssertionError as e: + assert "Request failed with status 500: " in str( + e + ), f"Unexpected request error message: {e}" + + # Step 7: Verify migration did NOT occur - should fail + try: + verify_migration_occurred(frontend) + pytest.fail( + "Migration verification unexpectedly passed when migration was disabled" + ) + except AssertionError as e: + assert "'Cannot recreate stream: ...' error found in logs" in str( + e + ), f"Unexpected migration message: {e}" + + +@pytest.mark.timeout(135) # 3x average +@pytest.mark.skip(reason="SGLang graceful shutdown not yet implemented") +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) +def test_no_request_migration_sglang_graceful_shutdown( + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm +): + """ + End-to-end test for worker fault tolerance with graceful shutdown and migration disabled using SGLang. + + This test verifies that when migration is disabled (migration_limit=0) and a worker + receives a graceful shutdown signal (SIGTERM) during request processing, the request + fails as expected without migration. This is the opposite behavior of + test_request_migration_sglang_graceful_shutdown. + + Timing (Last Run: 2025-12-09): ~38s total (estimated, similar to no_migration_worker_failure) + - Engine initialization: ~23s (Worker1: 13s, Worker2: 10s) + - Test execution (graceful shutdown + failure validation): <1s + - Teardown: ~15s + """ + + # Allocate ports to avoid conflicts with parallel tests + worker1_system_port = allocate_port(9100) + worker2_system_port = allocate_port(9200) + + # Step 1: Start the frontend (allocates its own port) + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start 2 workers sequentially with migration disabled + with DynamoWorkerProcess( + request, + "worker1", + system_port=worker1_system_port, + frontend_port=frontend.frontend_port, + migration_limit=0, + ) as worker1: + logger.info(f"Worker 1 PID: {worker1.get_pid()}") + + with DynamoWorkerProcess( + request, + "worker2", + system_port=worker2_system_port, + frontend_port=frontend.frontend_port, + migration_limit=0, + ) as worker2: + logger.info(f"Worker 2 PID: {worker2.get_pid()}") + + # Step 3: Send the request + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) + + # Step 4: Use polling to determine which worker received the request + worker, worker_name = determine_request_receiving_worker( + worker1, worker2, receiving_pattern="New Request ID: " + ) + + # Step 5: Gracefully shutdown the worker that has the request + logger.info( + f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request" + ) + terminate_process_tree( + worker.get_pid(), immediate_kill=False, timeout=10 + ) + + # Step 6: Validate the completion response - should fail without migration + try: + validate_completion_response(request_thread, response_list) + pytest.fail( + "Request succeeded unexpectedly when migration was disabled" + ) + except AssertionError as e: + assert "Request failed with status 500: " in str( + e + ), f"Unexpected request error message: {e}" + + # Step 7: Verify migration did NOT occur - should fail + try: + verify_migration_occurred(frontend) + pytest.fail( + "Migration verification unexpectedly passed when migration was disabled" + ) + except AssertionError as e: + assert "'Cannot recreate stream: ...' error found in logs" in str( + e + ), f"Unexpected migration message: {e}" diff --git a/tests/fault_tolerance/migration/test_trtllm.py b/tests/fault_tolerance/migration/test_trtllm.py new file mode 100644 index 0000000000..0f7ef3dd03 --- /dev/null +++ b/tests/fault_tolerance/migration/test_trtllm.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test Execution Times (Last Run: 2025-12-09): +- test_request_migration_trtllm_worker_failure: ~95s (gpu_1) +- test_request_migration_trtllm_graceful_shutdown: ~95s (gpu_1, skipped) +- test_no_request_migration_trtllm_worker_failure: ~60s (gpu_1) +- test_no_request_migration_trtllm_graceful_shutdown: ~60s (gpu_1, skipped) +- Total: ~155s (0:02:35) for enabled tests +""" + +import logging +import os +import shutil + +import pytest + +from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME +from tests.utils.managed_process import ManagedProcess, terminate_process_tree +from tests.utils.payloads import check_models_api +from tests.utils.port_utils import allocate_port, deallocate_port + +# Import utilities from the refactored utils module +from .utils import ( + DynamoFrontendProcess, + determine_request_receiving_worker, + start_completion_request, + validate_completion_response, + verify_migration_occurred, +) + +logger = logging.getLogger(__name__) + +pytestmark = [ + pytest.mark.trtllm, + pytest.mark.gpu_1, + pytest.mark.e2e, + pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), + pytest.mark.post_merge, # post_merge to pinpoint failure commit +] + + +class DynamoWorkerProcess(ManagedProcess): + """Process manager for Dynamo worker with TRT-LLM backend""" + + def __init__( + self, + request, + worker_id: str, + frontend_port: int, + migration_limit: int = 3, + ): + self.worker_id = worker_id + self.frontend_port = frontend_port + + # Allocate system port for this worker + system_port = allocate_port(9100) + self.system_port = system_port + + command = [ + "python3", + "-m", + "dynamo.trtllm", + "--model", + FAULT_TOLERANCE_MODEL_NAME, + "--disaggregation-mode", + "prefill_and_decode", + "--free-gpu-memory-fraction", + "0.45", + "--max-seq-len", + "8192", + "--migration-limit", + str(migration_limit), + ] + + # Set environment variables + env = os.environ.copy() + env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") + env["DYN_LOG"] = "debug" + # Disable canary health check - these tests expect full control over requests + # sent to the workers where canary health check intermittently sends dummy + # requests to workers interfering with the test process which may cause + # intermittent failures + env["DYN_HEALTH_CHECK_ENABLED"] = "false" + env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' + env["DYN_SYSTEM_PORT"] = str(system_port) + + # TODO: Have the managed process take a command name explicitly to distinguish + # between processes started with the same command. + log_dir = f"{request.node.name}_{worker_id}" + + # Clean up any existing log directory from previous runs + try: + shutil.rmtree(log_dir) + logger.info(f"Cleaned up existing log directory: {log_dir}") + except FileNotFoundError: + # Directory doesn't exist, which is fine + pass + + super().__init__( + command=command, + env=env, + health_check_urls=[ + (f"http://localhost:{frontend_port}/v1/models", check_models_api), + (f"http://localhost:{system_port}/health", self.is_ready), + ], + timeout=300, + display_output=True, + terminate_existing=False, + log_dir=log_dir, + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated port when worker exits.""" + try: + # system_port is always allocated in __init__ + deallocate_port(self.system_port) + except Exception as e: + logging.warning(f"Failed to release TRT-LLM worker port: {e}") + + return super().__exit__(exc_type, exc_val, exc_tb) + + def is_ready(self, response) -> bool: + """Check the health of the worker process""" + try: + data = response.json() + if data.get("status") == "ready": + logger.info(f"{self.worker_id} status is ready") + return True + logger.warning( + f"{self.worker_id} status is not ready: {data.get('status')}" + ) + except ValueError: + logger.warning(f"{self.worker_id} health response is not valid JSON") + return False + + +@pytest.mark.timeout(290) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) +def test_request_migration_trtllm_worker_failure( + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm +): + """ + End-to-end test for worker fault tolerance with migration support using TRT-LLM. + + This test verifies that when a worker is killed during request processing, + the system can handle the failure gracefully and migrate the request to + another worker. + + Timing (Last Run: 2025-12-09): ~95s total (2 workers at 45% GPU each) + - Engine initialization: ~52s (frontend: 2s, worker1: 25s, worker2: 25s sequential) + - Test execution (request + migration): ~40s + - Teardown: ~3s + """ + + # Step 1: Start the frontend (allocates its own frontend_port) + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start 2 workers sequentially + with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1: + logger.info(f"Worker 1 PID: {worker1.get_pid()}") + + with DynamoWorkerProcess( + request, "worker2", frontend.frontend_port + ) as worker2: + logger.info(f"Worker 2 PID: {worker2.get_pid()}") + + # Step 3: Send the request + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) + + # Step 4: Use polling to determine which worker received the request + worker, worker_name = determine_request_receiving_worker( + worker1, worker2, receiving_pattern="New Request ID: " + ) + + # Step 5: Kill the worker that has the request + logger.info( + f"Killing {worker_name} with PID {worker.get_pid()} processing the request" + ) + terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0) + + # Step 6: Validate the completion response + validate_completion_response(request_thread, response_list) + + # Step 7: Verify migration occurred + verify_migration_occurred(frontend) + + +@pytest.mark.skip(reason="TRT-LLM graceful shutdown not yet implemented") +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) +def test_request_migration_trtllm_graceful_shutdown( + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm +): + """ + End-to-end test for worker fault tolerance with graceful shutdown and migration support using TRT-LLM. + + This test verifies that when a worker receives a graceful shutdown signal (SIGTERM) + during request processing, the system can handle the shutdown gracefully and migrate + the request to another worker. Unlike the abrupt kill test, this simulates a more + controlled shutdown scenario where the worker has time to clean up and notify the + system about its shutdown. + + Timing (Last Run: 2025-12-09): ~95s total (2 workers at 45% GPU each) + - Engine initialization: ~52s (frontend: 2s, worker1: 25s, worker2: 25s sequential) + - Test execution (request + graceful migration): ~40s + - Teardown: ~3s + """ + + # Step 1: Start the frontend (allocates its own frontend_port) + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start 2 workers sequentially + with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1: + logger.info(f"Worker 1 PID: {worker1.get_pid()}") + + with DynamoWorkerProcess( + request, "worker2", frontend.frontend_port + ) as worker2: + logger.info(f"Worker 2 PID: {worker2.get_pid()}") + + # Step 3: Send the request + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) + + # Step 4: Use polling to determine which worker received the request + worker, worker_name = determine_request_receiving_worker( + worker1, worker2, receiving_pattern="New Request ID: " + ) + + # Step 5: Gracefully shutdown the worker that has the request + logger.info( + f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request" + ) + terminate_process_tree( + worker.get_pid(), immediate_kill=False, timeout=10 + ) + + # Step 6: Validate the completion response + validate_completion_response(request_thread, response_list) + + # Step 7: Verify migration occurred during graceful shutdown + verify_migration_occurred(frontend) + + +@pytest.mark.timeout(185) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) +def test_no_request_migration_trtllm_worker_failure( + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm +): + """ + End-to-end test for worker fault tolerance with migration disabled using TRT-LLM. + + This test verifies that when migration is disabled (migration_limit=0) and a worker + is killed during request processing, the request fails as expected without migration. + This is the opposite behavior of test_request_migration_trtllm_worker_failure. + + Timing (Last Run: 2025-12-09): ~60s total (2 workers at 45% GPU each) + - Engine initialization: ~52s (frontend: 2s, worker1: 25s, worker2: 25s sequential) + - Test execution (request failure): ~6s + - Teardown: ~2s + """ + + # Step 1: Start the frontend (allocates its own frontend_port) + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start 2 workers sequentially with migration disabled + with DynamoWorkerProcess( + request, + "worker1", + frontend.frontend_port, + migration_limit=0, + ) as worker1: + logger.info(f"Worker 1 PID: {worker1.get_pid()}") + + with DynamoWorkerProcess( + request, + "worker2", + frontend.frontend_port, + migration_limit=0, + ) as worker2: + logger.info(f"Worker 2 PID: {worker2.get_pid()}") + + # Step 3: Send the request + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) + + # Step 4: Use polling to determine which worker received the request + worker, worker_name = determine_request_receiving_worker( + worker1, worker2, receiving_pattern="New Request ID: " + ) + + # Step 5: Kill the worker that has the request + logger.info( + f"Killing {worker_name} with PID {worker.get_pid()} processing the request" + ) + terminate_process_tree(worker.get_pid(), immediate_kill=True, timeout=0) + + # Step 6: Validate the completion response - should fail without migration + try: + validate_completion_response(request_thread, response_list) + pytest.fail( + "Request succeeded unexpectedly when migration was disabled" + ) + except AssertionError as e: + assert "Request failed with status 500: " in str( + e + ), f"Unexpected request error message: {e}" + + # Step 7: Verify migration did NOT occur - should fail + try: + verify_migration_occurred(frontend) + pytest.fail( + "Migration verification unexpectedly passed when migration was disabled" + ) + except AssertionError as e: + assert "'Cannot recreate stream: ...' error found in logs" in str( + e + ), f"Unexpected migration message: {e}" + + +@pytest.mark.skip(reason="TRT-LLM graceful shutdown not yet implemented") +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) +def test_no_request_migration_trtllm_graceful_shutdown( + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm +): + """ + End-to-end test for worker fault tolerance with graceful shutdown and migration disabled using TRT-LLM. + + This test verifies that when migration is disabled (migration_limit=0) and a worker + receives a graceful shutdown signal (SIGTERM) during request processing, the request + fails as expected without migration. This is the opposite behavior of + test_request_migration_trtllm_graceful_shutdown. + + Timing (Last Run: 2025-12-09): ~60s total (2 workers at 45% GPU each) + - Engine initialization: ~52s (frontend: 2s, worker1: 25s, worker2: 25s sequential) + - Test execution (graceful shutdown failure): ~6s + - Teardown: ~2s + """ + + # Step 1: Start the frontend (allocates its own frontend_port) + with DynamoFrontendProcess(request) as frontend: + logger.info("Frontend started successfully") + + # Step 2: Start 2 workers sequentially with migration disabled + with DynamoWorkerProcess( + request, + "worker1", + frontend.frontend_port, + migration_limit=0, + ) as worker1: + logger.info(f"Worker 1 PID: {worker1.get_pid()}") + + with DynamoWorkerProcess( + request, + "worker2", + frontend.frontend_port, + migration_limit=0, + ) as worker2: + logger.info(f"Worker 2 PID: {worker2.get_pid()}") + + # Step 3: Send the request + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) + + # Step 4: Use polling to determine which worker received the request + worker, worker_name = determine_request_receiving_worker( + worker1, worker2, receiving_pattern="New Request ID: " + ) + + # Step 5: Gracefully shutdown the worker that has the request + logger.info( + f"Gracefully shutting down {worker_name} with PID {worker.get_pid()} processing the request" + ) + terminate_process_tree( + worker.get_pid(), immediate_kill=False, timeout=10 + ) + + # Step 6: Validate the completion response - should fail without migration + try: + validate_completion_response(request_thread, response_list) + pytest.fail( + "Request succeeded unexpectedly when migration was disabled" + ) + except AssertionError as e: + assert "Request failed with status 500: " in str( + e + ), f"Unexpected request error message: {e}" + + # Step 7: Verify migration did NOT occur - should fail + try: + verify_migration_occurred(frontend) + pytest.fail( + "Migration verification unexpectedly passed when migration was disabled" + ) + except AssertionError as e: + assert "'Cannot recreate stream: ...' error found in logs" in str( + e + ), f"Unexpected migration message: {e}" diff --git a/tests/fault_tolerance/migration/test_vllm.py b/tests/fault_tolerance/migration/test_vllm.py index b797059415..710b79bc70 100644 --- a/tests/fault_tolerance/migration/test_vllm.py +++ b/tests/fault_tolerance/migration/test_vllm.py @@ -1,6 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +""" +Test Execution Times (Last Run: 2025-12-09): +- test_request_migration_vllm_worker_failure: ~90s (gpu_1) +- test_request_migration_vllm_graceful_shutdown: ~80s (gpu_1) +- test_no_request_migration_vllm_worker_failure: ~75s (gpu_1) +- test_no_request_migration_vllm_graceful_shutdown: ~75s (gpu_1) +- Total: 318.73s (0:05:18) +""" + import logging import os import shutil @@ -8,9 +17,9 @@ import pytest from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME -from tests.utils.engine_process import FRONTEND_PORT from tests.utils.managed_process import ManagedProcess, terminate_process_tree from tests.utils.payloads import check_models_api +from tests.utils.port_utils import allocate_port, deallocate_port # Import utilities from the refactored utils module from .utils import ( @@ -28,15 +37,26 @@ pytest.mark.gpu_1, pytest.mark.e2e, pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME), - pytest.mark.nightly, + pytest.mark.post_merge, # post_merge to pinpoint failure commit ] class DynamoWorkerProcess(ManagedProcess): """Process manager for Dynamo worker with vLLM backend""" - def __init__(self, request, worker_id: str, migration_limit: int = 3): + def __init__( + self, + request, + worker_id: str, + frontend_port: int, + migration_limit: int = 3, + ): self.worker_id = worker_id + self.frontend_port = frontend_port + + # Allocate system port for this worker + system_port = allocate_port(9100) + self.system_port = system_port command = [ "python3", @@ -53,14 +73,26 @@ def __init__(self, request, worker_id: str, migration_limit: int = 3): str(migration_limit), ] - # Set debug logging environment + # Set environment variables env = os.environ.copy() - env["DYN_VLLM_KV_EVENT_PORT"] = f"2008{worker_id[-1]}" - env["VLLM_NIXL_SIDE_CHANNEL_PORT"] = f"560{worker_id[-1]}" + env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") + + env[ + "DYN_VLLM_KV_EVENT_PORT" + ] = f"2008{worker_id[-1]}" # TODO: use dynamic port allocation + env[ + "VLLM_NIXL_SIDE_CHANNEL_PORT" + ] = f"560{worker_id[-1]}" # TODO: use dynamic port allocation env["DYN_LOG"] = "debug" + # Disable canary health check - these tests expect full control over requests + # sent to the workers where canary health check intermittently sends dummy + # requests to workers interfering with the test process which may cause + # intermittent failures + env["DYN_HEALTH_CHECK_ENABLED"] = "false" env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]' - env["DYN_SYSTEM_PORT"] = f"808{worker_id[-1]}" + env["DYN_SYSTEM_PORT"] = str(system_port) + env["DYN_HTTP_PORT"] = str(frontend_port) # TODO: Have the managed process take a command name explicitly to distinguish # between processes started with the same command. @@ -78,8 +110,8 @@ def __init__(self, request, worker_id: str, migration_limit: int = 3): command=command, env=env, health_check_urls=[ - (f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api), - (f"http://localhost:808{worker_id[-1]}/health", self.is_ready), + (f"http://localhost:{frontend_port}/v1/models", check_models_api), + (f"http://localhost:{system_port}/health", self.is_ready), ], timeout=300, display_output=True, @@ -89,9 +121,15 @@ def __init__(self, request, worker_id: str, migration_limit: int = 3): log_dir=log_dir, ) - def get_pid(self): - """Get the PID of the worker process""" - return self.proc.pid if self.proc else None + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated port when worker exits.""" + try: + # system_port is always allocated in __init__ + deallocate_port(self.system_port) + except Exception as e: + logging.warning(f"Failed to release vLLM worker port: {e}") + + return super().__exit__(exc_type, exc_val, exc_tb) def is_ready(self, response) -> bool: """Check the health of the worker process""" @@ -108,8 +146,20 @@ def is_ready(self, response) -> bool: return False +@pytest.mark.timeout(290) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) def test_request_migration_vllm_worker_failure( - request, runtime_services, predownload_models, set_ucx_tls_no_mm + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm ): """ End-to-end test for worker fault tolerance with migration support. @@ -117,24 +167,30 @@ def test_request_migration_vllm_worker_failure( This test verifies that when a worker is killed during request processing, the system can handle the failure gracefully and migrate the request to another worker. + + Timing (Last Run: 2025-12-09): ~90s total + - Engine initialization: ~40s (Worker1: 20s, Worker2: 20s) + - Test execution (request + migration): ~48s + - Teardown: ~2s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start 2 workers sequentially - - # Start worker1 first and wait for it to be ready - logger.info("Starting worker 1...") - with DynamoWorkerProcess(request, "worker1") as worker1: + # Step 2: Start 2 workers sequentially (each allocates its own system_port) + with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1: logger.info(f"Worker 1 PID: {worker1.get_pid()}") - with DynamoWorkerProcess(request, "worker2") as worker2: + with DynamoWorkerProcess( + request, "worker2", frontend.frontend_port + ) as worker2: logger.info(f"Worker 2 PID: {worker2.get_pid()}") # Step 3: Send the request - request_thread, response_list = start_completion_request() + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) # Step 4: Use polling to determine which worker received the request worker, worker_name = determine_request_receiving_worker( @@ -154,8 +210,20 @@ def test_request_migration_vllm_worker_failure( verify_migration_occurred(frontend) +@pytest.mark.timeout(280) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) def test_request_migration_vllm_graceful_shutdown( - request, runtime_services, predownload_models, set_ucx_tls_no_mm + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm ): """ End-to-end test for worker fault tolerance with graceful shutdown and migration support. @@ -165,21 +233,30 @@ def test_request_migration_vllm_graceful_shutdown( the request to another worker. Unlike the abrupt kill test, this simulates a more controlled shutdown scenario where the worker has time to clean up and notify the system about its shutdown. + + Timing (Last Run: 2025-12-09): ~80s total + - Engine initialization: ~40s (Worker1: 20s, Worker2: 20s) + - Test execution (graceful shutdown + migration): ~38s + - Teardown: ~2s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start 2 workers sequentially - with DynamoWorkerProcess(request, "worker1") as worker1: + # Step 2: Start 2 workers sequentially (each allocates its own system_port) + with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1: logger.info(f"Worker 1 PID: {worker1.get_pid()}") - with DynamoWorkerProcess(request, "worker2") as worker2: + with DynamoWorkerProcess( + request, "worker2", frontend.frontend_port + ) as worker2: logger.info(f"Worker 2 PID: {worker2.get_pid()}") # Step 3: Send the request - request_thread, response_list = start_completion_request() + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) # Step 4: Use polling to determine which worker received the request worker, worker_name = determine_request_receiving_worker( @@ -201,8 +278,20 @@ def test_request_migration_vllm_graceful_shutdown( verify_migration_occurred(frontend) +@pytest.mark.timeout(150) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) def test_no_request_migration_vllm_worker_failure( - request, runtime_services, predownload_models, set_ucx_tls_no_mm + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm ): """ End-to-end test for worker fault tolerance with migration disabled. @@ -210,22 +299,32 @@ def test_no_request_migration_vllm_worker_failure( This test verifies that when migration is disabled (migration_limit=0) and a worker is killed during request processing, the request fails as expected without migration. This is the opposite behavior of test_request_migration_vllm_worker_failure. + + Timing (Last Run: 2025-12-09): ~75s total + - Engine initialization: ~40s (Worker1: 20s, Worker2: 20s) + - Test execution (failure validation): ~33s + - Teardown: ~2s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start 2 workers sequentially with migration disabled - logger.info("Starting worker 1 with migration disabled...") - with DynamoWorkerProcess(request, "worker1", migration_limit=0) as worker1: + # Step 2: Start 2 workers sequentially with migration disabled (each allocates its own system_port) + with DynamoWorkerProcess( + request, "worker1", frontend.frontend_port, migration_limit=0 + ) as worker1: logger.info(f"Worker 1 PID: {worker1.get_pid()}") - with DynamoWorkerProcess(request, "worker2", migration_limit=0) as worker2: + with DynamoWorkerProcess( + request, "worker2", frontend.frontend_port, migration_limit=0 + ) as worker2: logger.info(f"Worker 2 PID: {worker2.get_pid()}") # Step 3: Send the request - request_thread, response_list = start_completion_request() + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) # Step 4: Use polling to determine which worker received the request worker, worker_name = determine_request_receiving_worker( @@ -261,8 +360,20 @@ def test_no_request_migration_vllm_worker_failure( ), f"Unexpected migration message: {e}" +@pytest.mark.timeout(140) # 3x average +@pytest.mark.parametrize( + "request_plane", + [ + "nats", + pytest.param( + "tcp", + marks=pytest.mark.xfail(reason="Multi-worker TCP unstable", strict=False), + ), + ], + indirect=True, +) def test_no_request_migration_vllm_graceful_shutdown( - request, runtime_services, predownload_models, set_ucx_tls_no_mm + request, runtime_services_dynamic_ports, set_ucx_tls_no_mm ): """ End-to-end test for worker fault tolerance with graceful shutdown and migration disabled. @@ -271,21 +382,32 @@ def test_no_request_migration_vllm_graceful_shutdown( receives a graceful shutdown signal (SIGTERM) during request processing, the request fails as expected without migration. This is the opposite behavior of test_request_migration_vllm_graceful_shutdown. + + Timing (Last Run: 2025-12-09): ~75s total + - Engine initialization: ~40s (Worker1: 20s, Worker2: 20s) + - Test execution (graceful shutdown validation): ~33s + - Teardown: ~2s """ - # Step 1: Start the frontend + # Step 1: Start the frontend (allocates its own frontend_port) with DynamoFrontendProcess(request) as frontend: logger.info("Frontend started successfully") - # Step 2: Start 2 workers sequentially with migration disabled - with DynamoWorkerProcess(request, "worker1", migration_limit=0) as worker1: + # Step 2: Start 2 workers sequentially with migration disabled (each allocates its own system_port) + with DynamoWorkerProcess( + request, "worker1", frontend.frontend_port, migration_limit=0 + ) as worker1: logger.info(f"Worker 1 PID: {worker1.get_pid()}") - with DynamoWorkerProcess(request, "worker2", migration_limit=0) as worker2: + with DynamoWorkerProcess( + request, "worker2", frontend.frontend_port, migration_limit=0 + ) as worker2: logger.info(f"Worker 2 PID: {worker2.get_pid()}") # Step 3: Send the request - request_thread, response_list = start_completion_request() + request_thread, response_list = start_completion_request( + frontend.frontend_port + ) # Step 4: Use polling to determine which worker received the request worker, worker_name = determine_request_receiving_worker( diff --git a/tests/fault_tolerance/migration/utils.py b/tests/fault_tolerance/migration/utils.py index be70b703c3..4c3e0a57a7 100644 --- a/tests/fault_tolerance/migration/utils.py +++ b/tests/fault_tolerance/migration/utils.py @@ -11,8 +11,8 @@ import requests from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME -from tests.utils.engine_process import FRONTEND_PORT from tests.utils.managed_process import ManagedProcess +from tests.utils.port_utils import allocate_port, deallocate_port logger = logging.getLogger(__name__) @@ -21,10 +21,28 @@ class DynamoFrontendProcess(ManagedProcess): """Process manager for Dynamo frontend""" def __init__(self, request): - command = ["python", "-m", "dynamo.frontend", "--router-mode", "round-robin"] + # Allocate frontend port + frontend_port = allocate_port(8100) + self.frontend_port = frontend_port + + command = [ + "python", + "-m", + "dynamo.frontend", + "--router-mode", + "round-robin", + "--http-port", + str(frontend_port), + ] - # Unset DYN_SYSTEM_PORT - frontend doesn't use system metrics server env = os.environ.copy() + env["DYN_REQUEST_PLANE"] = request.getfixturevalue("request_plane") + # Disable canary health check - these tests expect full control over requests + # sent to the workers where canary health check intermittently sends dummy + # requests to workers interfering with the test process which may cause + # intermittent failures + env["DYN_HEALTH_CHECK_ENABLED"] = "false" + # Unset DYN_SYSTEM_PORT - frontend doesn't use system metrics server env.pop("DYN_SYSTEM_PORT", None) log_dir = f"{request.node.name}_frontend" @@ -41,15 +59,28 @@ def __init__(self, request): command=command, env=env, display_output=True, - terminate_existing=True, + terminate_existing=False, # Don't terminate other processes of the same name, we'll only terminate our own PID log_dir=log_dir, ) + def __exit__(self, exc_type, exc_val, exc_tb): + """Release allocated port when frontend exits.""" + try: + # frontend_port is always allocated in __init__ + deallocate_port(self.frontend_port) + except Exception as e: + logger.warning(f"Failed to release frontend port: {e}") + + return super().__exit__(exc_type, exc_val, exc_tb) + -def start_completion_request() -> tuple: +def start_completion_request(frontend_port: int) -> tuple: """ Start a long-running completion request in a separate thread. + Args: + frontend_port: Port where the frontend is running + Returns: tuple: (request_thread, response_list) """ @@ -57,7 +88,7 @@ def start_completion_request() -> tuple: def send_request(): prompt = "Tell me a long long long story about yourself?" - max_tokens = 8192 + max_tokens = 8000 timeout = 240 # Extended timeout for long request payload = { @@ -74,7 +105,7 @@ def send_request(): try: response = requests.post( - f"http://localhost:{FRONTEND_PORT}/v1/completions", + f"http://localhost:{frontend_port}/v1/completions", headers=headers, json=payload, timeout=timeout, diff --git a/tests/frontend/grpc/echo_tensor_worker.py b/tests/frontend/grpc/echo_tensor_worker.py index db306c3d05..c498b23024 100644 --- a/tests/frontend/grpc/echo_tensor_worker.py +++ b/tests/frontend/grpc/echo_tensor_worker.py @@ -53,15 +53,12 @@ async def echo_tensor_worker(runtime: DistributedRuntime): ) assert model_config == retrieved_model_config - # [gluo FIXME] register_llm will attempt to load a LLM model, - # which is not well-defined for Tensor yet. Currently provide - # a valid model name to pass the registration. + # Use register_llm for tensor-based backends (skips HuggingFace downloads) await register_llm( ModelInput.Tensor, ModelType.TensorBased, endpoint, - "Qwen/Qwen3-0.6B", - "echo", + "echo", # model_path (used as display name for tensor-based models) runtime_config=runtime_config, ) diff --git a/tests/frontend/grpc/test_tensor_mocker_engine.py b/tests/frontend/grpc/test_tensor_mocker_engine.py index 2982f4f810..c07d9c7edc 100644 --- a/tests/frontend/grpc/test_tensor_mocker_engine.py +++ b/tests/frontend/grpc/test_tensor_mocker_engine.py @@ -125,5 +125,6 @@ def start_services(request, runtime_services): @pytest.mark.integration @pytest.mark.model(TEST_MODEL) def test_echo() -> None: + triton_echo_client.check_health() triton_echo_client.run_infer() triton_echo_client.get_config() diff --git a/tests/frontend/grpc/triton_echo_client.py b/tests/frontend/grpc/triton_echo_client.py index 3e94333ab5..b76c71a88f 100644 --- a/tests/frontend/grpc/triton_echo_client.py +++ b/tests/frontend/grpc/triton_echo_client.py @@ -1,19 +1,22 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import sys import numpy as np import tritonclient.grpc as grpcclient +SERVER_URL = "localhost:8000" + + +def check_health(): + triton_client = grpcclient.InferenceServerClient(url=SERVER_URL) + assert triton_client.is_server_live() + assert triton_client.is_server_ready() + assert triton_client.is_model_ready("echo") + def run_infer(): - server_url = "localhost:8000" - try: - triton_client = grpcclient.InferenceServerClient(url=server_url) - except Exception as e: - print("channel creation failed: " + str(e)) - sys.exit() + triton_client = grpcclient.InferenceServerClient(url=SERVER_URL) model_name = "echo" @@ -46,12 +49,7 @@ def run_infer(): def get_config(): - server_url = "localhost:8000" - try: - triton_client = grpcclient.InferenceServerClient(url=server_url) - except Exception as e: - print("channel creation failed: " + str(e)) - sys.exit() + triton_client = grpcclient.InferenceServerClient(url=SERVER_URL) model_name = "echo" response = triton_client.get_model_config(model_name=model_name) diff --git a/tests/planner/unit/kube.py b/tests/planner/unit/kube.py index 87a10713ba..b1ef79d741 100644 --- a/tests/planner/unit/kube.py +++ b/tests/planner/unit/kube.py @@ -76,11 +76,87 @@ def test_get_graph_deployment_from_name(k8s_api, mock_custom_api): ) -def test_update_graph_replicas(k8s_api, mock_custom_api): +def test_update_service_replicas_uses_dgdsa_scale(k8s_api, mock_custom_api): + """Test that update_service_replicas uses DGDSA Scale API when available""" + mock_custom_api.patch_namespaced_custom_object_scale.return_value = None + + k8s_api.update_service_replicas("test-deployment", "Frontend", 3) + + # Should use Scale subresource with lowercase adapter name + mock_custom_api.patch_namespaced_custom_object_scale.assert_called_once_with( + group="nvidia.com", + version="v1alpha1", + namespace=k8s_api.current_namespace, + plural="dynamographdeploymentscalingadapters", + name="test-deployment-frontend", # lowercase service name + body={"spec": {"replicas": 3}}, + ) + # Should NOT fall back to DGD patch + mock_custom_api.patch_namespaced_custom_object.assert_not_called() + + +def test_update_service_replicas_fallback_to_dgd(k8s_api, mock_custom_api): + """Test that update_service_replicas falls back to DGD when DGDSA not found""" + # DGDSA doesn't exist (404) + mock_custom_api.patch_namespaced_custom_object_scale.side_effect = ( + client.ApiException(status=404) + ) mock_custom_api.patch_namespaced_custom_object.return_value = None + k8s_api.update_service_replicas("test-deployment", "test-component", 1) + + # Should have tried DGDSA first + mock_custom_api.patch_namespaced_custom_object_scale.assert_called_once() + + # Should fall back to DGD patch + mock_custom_api.patch_namespaced_custom_object.assert_called_once_with( + group="nvidia.com", + version="v1alpha1", + namespace=k8s_api.current_namespace, + plural="dynamographdeployments", + name="test-deployment", + body={"spec": {"services": {"test-component": {"replicas": 1}}}}, + ) + + +def test_update_service_replicas_propagates_other_errors(k8s_api, mock_custom_api): + """Test that update_service_replicas propagates non-404 errors""" + mock_custom_api.patch_namespaced_custom_object_scale.side_effect = ( + client.ApiException(status=500, reason="Internal Server Error") + ) + + with pytest.raises(client.ApiException) as exc_info: + k8s_api.update_service_replicas("test-deployment", "test-component", 1) + + assert exc_info.value.status == 500 + # Should NOT fall back to DGD + mock_custom_api.patch_namespaced_custom_object.assert_not_called() + + +def test_update_graph_replicas_calls_update_service_replicas(k8s_api, mock_custom_api): + """Test that deprecated update_graph_replicas calls update_service_replicas""" + mock_custom_api.patch_namespaced_custom_object_scale.return_value = None + + # Use the deprecated method k8s_api.update_graph_replicas("test-deployment", "test-component", 1) + # Should delegate to update_service_replicas which uses Scale API + mock_custom_api.patch_namespaced_custom_object_scale.assert_called_once_with( + group="nvidia.com", + version="v1alpha1", + namespace=k8s_api.current_namespace, + plural="dynamographdeploymentscalingadapters", + name="test-deployment-test-component", + body={"spec": {"replicas": 1}}, + ) + + +def test_update_dgd_replicas_directly(k8s_api, mock_custom_api): + """Test the internal _update_dgd_replicas method""" + mock_custom_api.patch_namespaced_custom_object.return_value = None + + k8s_api._update_dgd_replicas("test-deployment", "test-component", 1) + mock_custom_api.patch_namespaced_custom_object.assert_called_once_with( group="nvidia.com", version="v1alpha1", diff --git a/tests/planner/unit/test_prometheus.py b/tests/planner/unit/test_prometheus.py index 996b5083b4..3d6ec52dc3 100644 --- a/tests/planner/unit/test_prometheus.py +++ b/tests/planner/unit/test_prometheus.py @@ -18,9 +18,12 @@ import pytest +from dynamo import prometheus_names from dynamo.planner.utils.prometheus import ( + METRIC_SOURCE_MAP, FrontendMetric, FrontendMetricContainer, + MetricSource, PrometheusAPIClient, ) @@ -33,8 +36,8 @@ @pytest.fixture -def mock_prometheus_result(): - """Fixture providing mock prometheus result data for testing""" +def mock_prometheus_sum_result(): + """Fixture providing mock prometheus sum metric data for testing""" return [ { "metric": { @@ -75,6 +78,49 @@ def mock_prometheus_result(): ] +@pytest.fixture +def mock_prometheus_count_result(): + """Fixture providing mock prometheus count metric data for testing""" + return [ + { + "metric": { + "container": "main", + "dynamo_namespace": "different_namespace", + "model": "different_model", + "namespace": "dynamo-system", + }, + "value": [1758857776.071, 1.0], + }, + { + "metric": { + "container": "main", + "dynamo_namespace": "target_namespace", + "model": "target_model", + "namespace": "dynamo-system", + }, + "value": [1758857776.071, 1.0], + }, + { + "metric": { + "container": "worker", + "dynamo_namespace": "target_namespace", + "model": "target_model", + "namespace": "dynamo-system", + }, + "value": [1758857776.071, 1.0], + }, + { + "metric": { + "container": "sidecar", + "dynamo_namespace": "target_namespace", + "model": "target_model", + "namespace": "dynamo-system", + }, + "value": [30.0, 1.0], + }, + ] + + def test_frontend_metric_container_with_nan_value(): test_data = { "metric": { @@ -140,7 +186,7 @@ def test_get_average_metric_none_result(): mock_query.return_value = None result = client._get_average_metric( - full_metric_name="test_metric", + full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS, interval="60s", operation_name="test operation", model_name="test_model", @@ -157,7 +203,7 @@ def test_get_average_metric_empty_result(): mock_query.return_value = [] result = client._get_average_metric( - full_metric_name="test_metric", + full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS, interval="60s", operation_name="test operation", model_name="test_model", @@ -166,16 +212,21 @@ def test_get_average_metric_empty_result(): assert result == 0 -def test_get_average_metric_no_matching_containers(mock_prometheus_result): +def test_get_average_metric_no_matching_containers( + mock_prometheus_sum_result, mock_prometheus_count_result +): """Test _get_average_metric with valid containers but no matches""" client = PrometheusAPIClient("http://localhost:9090", "test_namespace") with patch.object(client.prom, "custom_query") as mock_query: # Use only the first container which doesn't match target criteria - mock_query.return_value = [mock_prometheus_result[0]] + mock_query.side_effect = [ + [mock_prometheus_sum_result[0]], # sum_result + [mock_prometheus_count_result[0]], # count_result + ] result = client._get_average_metric( - full_metric_name="test_metric", + full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS, interval="60s", operation_name="test operation", model_name="target_model", @@ -184,21 +235,41 @@ def test_get_average_metric_no_matching_containers(mock_prometheus_result): assert result == 0 -def test_get_average_metric_one_matching_container(mock_prometheus_result): +def test_get_average_metric_one_matching_container( + mock_prometheus_sum_result, mock_prometheus_count_result +): """Test _get_average_metric with one matching container""" client = PrometheusAPIClient("http://localhost:9090", "target_namespace") with patch.object(client.prom, "custom_query") as mock_query: # Use first two containers - one doesn't match, one does - mock_query.return_value = mock_prometheus_result[:2] + mock_query.side_effect = [ + mock_prometheus_sum_result[:2], # sum_result + mock_prometheus_count_result[:2], # count_result + ] result = client._get_average_metric( - full_metric_name="test_metric", + full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS, interval="60s", operation_name="test operation", model_name="target_model", ) + # Verify the correct queries were made + assert mock_query.call_count == 2 + + sum_call = mock_query.call_args_list[0] + assert ( + sum_call.kwargs["query"] + == f"increase({METRIC_SOURCE_MAP[MetricSource.FRONTEND][prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS]}_sum[60s])" + ) + + count_call = mock_query.call_args_list[1] + assert ( + count_call.kwargs["query"] + == f"increase({METRIC_SOURCE_MAP[MetricSource.FRONTEND][prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS]}_count[60s])" + ) + assert result == 42.7 @@ -206,7 +277,7 @@ def test_get_average_metric_with_validation_error(): """Test _get_average_metric with one valid container and one that fails validation""" client = PrometheusAPIClient("http://localhost:9090", "target_namespace") - mock_result = [ + mock_sum_result = [ { "metric": { "container": "main", @@ -223,11 +294,28 @@ def test_get_average_metric_with_validation_error(): }, ] + mock_count_result = [ + { + "metric": { + "container": "main", + "dynamo_namespace": "target_namespace", + "model": "target_model", + "namespace": "dynamo-system", + }, + "value": [1758857776.071, 1.0], + }, + { + # Invalid structure - missing required fields that will cause validation error + "invalid_structure": "bad_data", + "value": "not_a_tuple", + }, + ] + with patch.object(client.prom, "custom_query") as mock_query: - mock_query.return_value = mock_result + mock_query.side_effect = [mock_sum_result, mock_count_result] result = client._get_average_metric( - full_metric_name="test_metric", + full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS, interval="60s", operation_name="test operation", model_name="target_model", @@ -236,21 +324,26 @@ def test_get_average_metric_with_validation_error(): assert result == 25.5 -def test_get_average_metric_multiple_matching_containers(mock_prometheus_result): +def test_get_average_metric_multiple_matching_containers( + mock_prometheus_sum_result, mock_prometheus_count_result +): """Test _get_average_metric with multiple matching containers returns average""" client = PrometheusAPIClient("http://localhost:9090", "target_namespace") with patch.object(client.prom, "custom_query") as mock_query: # Use containers 1, 2, 3 which all match target criteria - mock_query.return_value = mock_prometheus_result[1:] + mock_query.side_effect = [ + mock_prometheus_sum_result[1:], # sum_result + mock_prometheus_count_result[1:], # count_result + ] result = client._get_average_metric( - full_metric_name="test_metric", + full_metric_name=prometheus_names.frontend_service.TIME_TO_FIRST_TOKEN_SECONDS, interval="60s", operation_name="test operation", model_name="target_model", ) - # Average of 42.7, 35.5, and 15.5 (using value[1] from each container) + # Total sum: 42.7 + 35.5 + 15.5 = 93.7, Total count: 1.0 + 1.0 + 1.0 = 3.0 expected = (42.7 + 35.5 + 15.5) / 3 assert result == expected diff --git a/tests/profiler/test_profile_sla_aiconfigurator.py b/tests/profiler/test_profile_sla_aiconfigurator.py index ed9a326cc9..49755ba18a 100644 --- a/tests/profiler/test_profile_sla_aiconfigurator.py +++ b/tests/profiler/test_profile_sla_aiconfigurator.py @@ -66,6 +66,7 @@ def __init__(self): self.aic_backend_version = None self.num_gpus_per_node = 8 self.deploy_after_profile = False + self.pick_with_webui = False # Provide minimal model_info to avoid HF queries self.model_info = ModelInfo( model_size=16384.0, diff --git a/tests/profiler/test_profile_sla_dryrun.py b/tests/profiler/test_profile_sla_dryrun.py index ff6ae0fc89..a98813189b 100644 --- a/tests/profiler/test_profile_sla_dryrun.py +++ b/tests/profiler/test_profile_sla_dryrun.py @@ -73,6 +73,7 @@ def __init__(self): self.aic_backend_version = None self.num_gpus_per_node = 8 self.deploy_after_profile = False + self.pick_with_webui = False # Provide minimal model_info to avoid HF queries self.model_info = ModelInfo( model_size=16384.0, @@ -116,6 +117,7 @@ def __init__(self): self.aic_backend_version = None self.num_gpus_per_node = 8 self.deploy_after_profile = False + self.pick_with_webui = False self.model_info = ModelInfo( model_size=16384.0, architecture="TestArchitecture", @@ -180,6 +182,7 @@ def __init__(self): self.aic_backend_version = None self.num_gpus_per_node = 8 self.deploy_after_profile = False + self.pick_with_webui = False self.model_info = ModelInfo( model_size=16384.0, architecture="TestArchitecture", @@ -233,6 +236,7 @@ def __init__(self): self.aic_backend_version = None self.num_gpus_per_node = 8 self.deploy_after_profile = False + self.pick_with_webui = False self.model_info = ModelInfo( model_size=65536.0, architecture="TestMoEArchitecture", @@ -309,6 +313,7 @@ def __init__(self): # Set to 0 to trigger auto-generation path self.num_gpus_per_node = 0 self.deploy_after_profile = False + self.pick_with_webui = False self.enable_gpu_discovery = True return Args() @@ -376,6 +381,7 @@ def __init__(self): self.aic_backend_version = None self.num_gpus_per_node = 0 self.deploy_after_profile = False + self.pick_with_webui = False self.enable_gpu_discovery = True return Args() @@ -443,6 +449,7 @@ def __init__(self): self.aic_backend_version = None self.num_gpus_per_node = 0 self.deploy_after_profile = False + self.pick_with_webui = False self.enable_gpu_discovery = True return Args() diff --git a/tests/router/common.py b/tests/router/common.py index c39ce328e8..fd72c88251 100644 --- a/tests/router/common.py +++ b/tests/router/common.py @@ -4,6 +4,7 @@ import asyncio import json import logging +import os import random import string import time @@ -38,6 +39,7 @@ def __init__( store_backend: str = "etcd", enforce_disagg: bool = False, busy_threshold: float | None = None, + request_plane: str = "nats", ): command = [ "python3", @@ -61,8 +63,12 @@ def __init__( if busy_threshold is not None: command.extend(["--busy-threshold", str(busy_threshold)]) + env = os.environ.copy() + env["DYN_REQUEST_PLANE"] = request_plane + super().__init__( command=command, + env=env, timeout=60, display_output=True, health_check_ports=[frontend_port], @@ -87,6 +93,64 @@ def generate_random_suffix() -> str: return "".join(random.choices(string.ascii_lowercase, k=10)) # noqa: S311 +def verify_response_worker_ids( + response_worker_ids: list[dict[str, Optional[int]]], + key: str, + expected_worker_id: int, +) -> None: + """Verify that all responses have the same worker ID for a given key. + + Args: + response_worker_ids: List of dicts with worker ID info from responses. + key: The key to check (e.g., "decode_worker_id" or "prefill_worker_id"). + expected_worker_id: The expected worker ID value. + + Raises: + AssertionError: If any response is missing the key, values differ, or don't match expected. + """ + worker_ids = [r.get(key) for r in response_worker_ids] + logger.info(f"Response {key}s: {worker_ids}") + + # All responses should have the key + assert all( + wid is not None for wid in worker_ids + ), f"Expected all {len(response_worker_ids)} responses to have {key}, got: {worker_ids}" + + # All values should be the same (due to prefix reuse routing) + unique_ids = set(worker_ids) + assert len(unique_ids) == 1, ( + f"Expected all responses to have the same {key} (due to prefix reuse), " + f"but found {len(unique_ids)} unique values: {unique_ids}" + ) + + # The value should match the expected worker ID + actual_worker_id = worker_ids[0] + assert actual_worker_id == expected_worker_id, ( + f"Expected {key}={expected_worker_id} (forced in first request), " + f"but got {key}={actual_worker_id}" + ) + logger.info( + f"โœ“ Verified all {len(response_worker_ids)} responses have {key}={actual_worker_id}" + ) + + +def verify_response_timing(timing_info: dict[str, Any]) -> None: + """Verify timing info has valid values (ttft_ms > 0, total_time_ms > 0).""" + ttft_ms = timing_info.get("ttft_ms") + total_time_ms = timing_info.get("total_time_ms") + + assert ttft_ms is not None and ttft_ms > 0, f"Expected ttft_ms > 0, got: {ttft_ms}" + assert ( + total_time_ms is not None and total_time_ms > 0 + ), f"Expected total_time_ms > 0, got: {total_time_ms}" + assert ( + total_time_ms >= ttft_ms + ), f"Expected total_time_ms >= ttft_ms, got {total_time_ms} < {ttft_ms}" + logger.info( + f"โœ“ Verified timing: ttft_ms={ttft_ms:.2f}, total_time_ms={total_time_ms:.2f}" + ) + + ######################################################## # Utility functions ######################################################## @@ -420,9 +484,17 @@ async def send_request_via_python_kv_router( int ] = None, # If None, Router will select the best available worker dp_rank: Optional[int] = None, # Data parallel rank (defaults to 0) -) -> bool: + return_worker_ids: bool = False, # If True, return worker IDs from response +) -> bool | dict[str, Optional[int]]: """Send a request to the specified worker instance. - Returns True if workers respond, otherwise raises or returns False. + + Args: + return_worker_ids: If True, returns a dict with prefill_worker_id and decode_worker_id. + If False, returns True on success or False on failure. + + Returns: + If return_worker_ids=False: True if workers respond, otherwise raises or returns False. + If return_worker_ids=True: Dict with 'prefill_worker_id' and 'decode_worker_id' keys. """ wait_time = initial_wait @@ -463,8 +535,11 @@ async def send_request_via_python_kv_router( f"Failed to connect to workers after {max_retries + 1} attempts" ) from e - # Collect tokens from the SSE stream + # Collect tokens and worker IDs from the SSE stream generated_tokens = [] + prefill_worker_id: Optional[int] = None + decode_worker_id: Optional[int] = None + async for response in stream: if isinstance(response, dict): # Check if response has token_ids @@ -480,6 +555,17 @@ async def send_request_via_python_kv_router( f"Stream finished with reason: {response['finish_reason']}" ) + # Extract worker IDs from disaggregated_params if present + if return_worker_ids and "disaggregated_params" in response: + disagg_params = response["disaggregated_params"] + if isinstance(disagg_params, dict) and "worker_id" in disagg_params: + worker_id_info = disagg_params["worker_id"] + if isinstance(worker_id_info, dict): + if "prefill_worker_id" in worker_id_info: + prefill_worker_id = worker_id_info["prefill_worker_id"] + if "decode_worker_id" in worker_id_info: + decode_worker_id = worker_id_info["decode_worker_id"] + # Verify if expected number of tokens are generated if max_tokens specified and ignore_eos is True logger.debug(f"Total generated tokens: {len(generated_tokens)}") if ( @@ -497,9 +583,14 @@ async def send_request_via_python_kv_router( logger.debug( f"Successfully verified {max_tokens} tokens generated as expected via KvPushRouter with ignore_eos=True" ) - return True - return False + if return_worker_ids: + return { + "prefill_worker_id": prefill_worker_id, + "decode_worker_id": decode_worker_id, + } + + return True ######################################################## @@ -525,7 +616,7 @@ def _test_router_basic( Always waits for workers to be properly registered before sending requests to avoid flakiness. Args: - engine_workers: Backend workers (mocker/vllm) already initialized with __enter__() + engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__()) block_size: Block size for KV cache request: Pytest request fixture for managing resources frontend_port: Port to start the frontend HTTP server on @@ -907,7 +998,6 @@ def _test_router_query_instance_id( Raises: AssertionError: If annotation response structure is incorrect or contains generation content """ - import aiohttp try: # Start KV router (frontend) @@ -1084,9 +1174,6 @@ def _test_router_overload_503( Raises: AssertionError: If 503 response is not received when expected """ - import aiohttp - - from tests.utils.managed_process import ManagedProcess try: logger.info( @@ -1243,7 +1330,7 @@ def _test_router_indexers_sync( This validates that the snapshot mechanism works and routers can sync state from NATS. Args: - engine_workers: Backend workers (mocker/vllm) already initialized with __enter__() + engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__()) block_size: Block size for KV cache model_name: Model name to use for requests num_workers: Expected number of workers @@ -1252,7 +1339,6 @@ def _test_router_indexers_sync( Raises: AssertionError: If router states don't synchronize correctly or snapshot is missing """ - import nats # Use async to manage the test flow async def test_sync(): @@ -1498,7 +1584,7 @@ def sort_key(event): logger.info("Indexers sync test completed successfully") -def _test_router_disagg_decisions( +def _test_router_decisions_disagg( prefill_workers, decode_workers, block_size: int, @@ -1577,7 +1663,7 @@ async def send_progressive_requests(): # Each iteration adds more content to extend the prefix progressive_content = " ".join([base_content] * (i + 1)) - # Create payload with worker_id in extra_fields to get prefill/decode worker IDs + # Create payload with worker_id and timing in extra_fields payload = { **test_payload, "messages": [ @@ -1586,7 +1672,7 @@ async def send_progressive_requests(): "content": progressive_content, } ], - "nvext": {"extra_fields": ["worker_id"]}, + "nvext": {"extra_fields": ["worker_id", "timing"]}, "stream": True, } @@ -1600,9 +1686,10 @@ async def send_progressive_requests(): response.status == 200 ), f"Request {i + 1} failed with status {response.status}" - # Collect all chunks and look for nvext with worker_id + # Collect all chunks and look for nvext with worker_id and timing prefill_wid = None decode_wid = None + timing_info = None async for line in response.content: if not line: @@ -1618,24 +1705,29 @@ async def send_progressive_requests(): try: data = json.loads(data_str) - # Check for nvext.worker_id in the response + # Check for nvext in the response nvext = data.get("nvext", {}) - worker_id_info = nvext.get("worker_id", {}) - - if worker_id_info: - if "prefill_worker_id" in worker_id_info: - prefill_wid = worker_id_info[ - "prefill_worker_id" - ] - if "decode_worker_id" in worker_id_info: - decode_wid = worker_id_info["decode_worker_id"] + if nvext: + worker_id_info = nvext.get("worker_id", {}) + if worker_id_info: + if "prefill_worker_id" in worker_id_info: + prefill_wid = worker_id_info[ + "prefill_worker_id" + ] + if "decode_worker_id" in worker_id_info: + decode_wid = worker_id_info[ + "decode_worker_id" + ] + # Timing info appears in final chunk + if "timing" in nvext: + timing_info = nvext["timing"] except json.JSONDecodeError: continue logger.info( f"Request {i + 1}: prefill_worker_id={prefill_wid}, " - f"decode_worker_id={decode_wid}" + f"decode_worker_id={decode_wid}, timing={timing_info}" ) if prefill_wid is not None: @@ -1643,6 +1735,12 @@ async def send_progressive_requests(): if decode_wid is not None: decode_worker_ids.append(decode_wid) + # Verify timing info is present and valid + assert ( + timing_info is not None + ), f"Request {i + 1}: Expected timing info in final chunk, got None" + verify_response_timing(timing_info) + # Small delay between requests await asyncio.sleep(0.5) @@ -1694,15 +1792,16 @@ def _test_router_decisions( model_name: str, request, test_dp_rank: bool = False, + block_size: int = BLOCK_SIZE, ): """Validate KV cache prefix reuse and worker routing by sending progressive requests with overlapping prefixes. Assumes engine workers are already initialized. Sends 4 progressive requests where each extends - the previous tokens by BLOCK_SIZE. The first request is forced to a specific worker (and optionally + the previous tokens by `block_size`. The first request is forced to a specific worker (and optionally dp_rank), and subsequent requests should naturally route to the same worker due to prefix reuse. Args: - engine_workers: MockerProcess or VLLMProcess instance (already initialized with __enter__()) + engine_workers: Backend worker instance ({MockerProcess, VLLMProcess, TRTLLMProcess}) (already initialized with __enter__()) endpoint: Endpoint of the engine workers model_name: Name of the model request: Pytest request fixture @@ -1715,7 +1814,7 @@ def _test_router_decisions( kv_router_config = KvRouterConfig(router_snapshot_threshold=20) kv_push_router = KvPushRouter( endpoint=endpoint, - block_size=BLOCK_SIZE, + block_size=block_size, kv_router_config=kv_router_config, ) @@ -1743,10 +1842,11 @@ async def test_sync(): # Send 4 progressive requests with overlapping prefixes cumulative_tokens = [] + response_worker_ids: list[dict[str, Optional[int]]] = [] for i in range(4): - # Add BLOCK_SIZE new random tokens - new_tokens = [random.randint(1, 10000) for _ in range(BLOCK_SIZE)] + # Add `block_size` new random tokens + new_tokens = [random.randint(1, 10000) for _ in range(block_size)] cumulative_tokens.extend(new_tokens) # Force first request to specific worker_id (and dp_rank if testing DP), let subsequent requests follow naturally @@ -1764,7 +1864,7 @@ async def test_sync(): log_msg += f" - FORCING worker_id={worker_id_override}" logger.info(log_msg) - await send_request_via_python_kv_router( + result = await send_request_via_python_kv_router( kv_python_router=kv_push_router, model_name=model_name, token_ids=cumulative_tokens.copy(), @@ -1776,6 +1876,13 @@ async def test_sync(): }, worker_id=worker_id_override, dp_rank=dp_rank_override, + return_worker_ids=True, + ) + assert isinstance(result, dict), f"Expected dict result, got {type(result)}" + response_worker_ids.append(result) + logger.info( + f"Request {i + 1} response: prefill_worker_id={result.get('prefill_worker_id')}, " + f"decode_worker_id={result.get('decode_worker_id')}" ) # Wait a bit between requests @@ -1787,10 +1894,23 @@ async def test_sync(): # Dump events from the router events_json = await kv_push_router.dump_events() - return events_json, forced_worker_id, forced_dp_rank + return events_json, forced_worker_id, forced_dp_rank, response_worker_ids # Run the async test - events_json, expected_worker_id, expected_dp_rank = asyncio.run(test_sync()) + ( + events_json, + expected_worker_id, + expected_dp_rank, + response_worker_ids, + ) = asyncio.run(test_sync()) + + # Verify worker IDs from responses + verify_response_worker_ids( + response_worker_ids, "decode_worker_id", expected_worker_id + ) + verify_response_worker_ids( + response_worker_ids, "prefill_worker_id", expected_worker_id + ) # Parse events and count by worker routing key (worker_id or (worker_id, dp_rank)) events = json.loads(events_json) @@ -1895,6 +2015,7 @@ def _test_busy_threshold_endpoint( frontend_port: int, test_payload: dict, store_backend: str = "etcd", + request_plane: str = "nats", ): """Test that the /busy_threshold endpoint can be hit and responds correctly. @@ -1906,12 +2027,13 @@ def _test_busy_threshold_endpoint( For now, this test only verifies the endpoint is accessible and returns valid responses. Args: - engine_workers: Backend workers (mocker/vllm) already initialized with __enter__() + engine_workers: MockerProcess instance (already initialized with __enter__()) block_size: Block size for KV cache request: Pytest request fixture for managing resources frontend_port: Port for the frontend HTTP server test_payload: Base test payload (used to extract model name) store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd". + request_plane: Request plane to use ("nats" or "tcp"). Defaults to "nats". Raises: AssertionError: If endpoint responses are incorrect @@ -1929,6 +2051,7 @@ def _test_busy_threshold_endpoint( engine_workers.namespace, store_backend, busy_threshold=initial_threshold, + request_plane=request_plane, ) kv_router.__enter__() diff --git a/tests/router/test_router_e2e_with_mockers.py b/tests/router/test_router_e2e_with_mockers.py index fcff8428e9..6432ba550d 100644 --- a/tests/router/test_router_e2e_with_mockers.py +++ b/tests/router/test_router_e2e_with_mockers.py @@ -11,7 +11,7 @@ _test_python_router_bindings, _test_router_basic, _test_router_decisions, - _test_router_disagg_decisions, + _test_router_decisions_disagg, _test_router_indexers_sync, _test_router_overload_503, _test_router_query_instance_id, @@ -30,7 +30,6 @@ pytest.mark.pre_merge, pytest.mark.gpu_0, pytest.mark.integration, - pytest.mark.parallel, pytest.mark.model(MODEL_NAME), ] NUM_MOCKERS = 2 @@ -41,19 +40,23 @@ def get_unique_ports( - request, num_ports: int = 1, store_backend: str = "etcd" + request, + num_ports: int = 1, + store_backend: str = "etcd", + request_plane: str = "nats", ) -> list[int]: """Generate unique ports for parallel test execution. Ports are unique based on: - Test function name (each test gets a base offset) - - Parametrization value (etcd=0, file=50) + - Parametrization value (etcd=0, file=50; nats=0, tcp=25) - Port index (for multi-port tests) Args: request: Pytest request fixture num_ports: Number of ports needed (1 for single router, 2 for two routers) store_backend: Storage backend parameter ("etcd" or "file") + request_plane: Request plane parameter ("nats" or "tcp") Returns: List of unique port numbers @@ -67,17 +70,21 @@ def get_unique_ports( "test_mocker_two_kv_router": 100, "test_mocker_kv_router_overload_503": 200, "test_query_instance_id_returns_worker_and_tokens": 300, - "test_router_disagg_decisions": 400, + "test_router_decisions_disagg": 400, "test_busy_threshold_endpoint": 500, } base_offset = test_offsets.get(test_name, 0) - # Parametrization offset (etcd=0, file=50) - param_offset = 0 if store_backend == "etcd" else 50 + # Parametrization offset (etcd=0, file=50; nats=0, tcp=25) + store_offset = 0 if store_backend == "etcd" else 50 + plane_offset = 0 if request_plane == "nats" else 25 # Generate ports - ports = [BASE_PORT + base_offset + param_offset + i for i in range(num_ports)] + ports = [ + BASE_PORT + base_offset + store_offset + plane_offset + i + for i in range(num_ports) + ] return ports @@ -176,6 +183,7 @@ def __init__( mocker_args: Optional[Dict[str, Any]] = None, num_mockers: int = 1, store_backend: str = "etcd", + request_plane: str = "nats", ): namespace_suffix = generate_random_suffix() self.namespace = f"test-namespace-{namespace_suffix}" @@ -192,8 +200,12 @@ def __init__( mocker_args=mocker_args, ) + env = os.environ.copy() + env["DYN_REQUEST_PLANE"] = request_plane + self._process = ManagedProcess( command=command, + env=env, timeout=60, display_output=True, health_check_ports=[], @@ -287,6 +299,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._process.__exit__(exc_type, exc_val, exc_tb) +@pytest.mark.parallel def test_mocker_kv_router(request, runtime_services_session, predownload_tokenizers): """ Test KV router with multiple mocker engine instances. @@ -326,6 +339,7 @@ def test_mocker_kv_router(request, runtime_services_session, predownload_tokeniz mockers.__exit__(None, None, None) +@pytest.mark.parallel @pytest.mark.parametrize("store_backend", ["etcd", "file"]) def test_mocker_two_kv_router( request, @@ -381,6 +395,7 @@ def test_mocker_two_kv_router( mockers.__exit__(None, None, None) +@pytest.mark.parallel @pytest.mark.skip(reason="Flaky, temporarily disabled") def test_mocker_kv_router_overload_503( request, runtime_services_session, predownload_tokenizers @@ -419,6 +434,7 @@ def test_mocker_kv_router_overload_503( mockers.__exit__(None, None, None) +@pytest.mark.parallel def test_kv_push_router_bindings( request, runtime_services_session, predownload_tokenizers ): @@ -504,6 +520,7 @@ def test_indexers_sync( mockers.__exit__(None, None, None) +@pytest.mark.parallel def test_query_instance_id_returns_worker_and_tokens( request, runtime_services_session, predownload_tokenizers ): @@ -538,6 +555,7 @@ def test_query_instance_id_returns_worker_and_tokens( mockers.__exit__(None, None, None) +@pytest.mark.parallel def test_router_decisions(request, runtime_services_session, predownload_tokenizers): """Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.""" @@ -577,7 +595,8 @@ def test_router_decisions(request, runtime_services_session, predownload_tokeniz mockers.__exit__(None, None, None) -def test_router_disagg_decisions( +@pytest.mark.parallel +def test_router_decisions_disagg( request, runtime_services_session, predownload_tokenizers ): """Validate KV cache prefix reuse in disaggregated prefill-decode setup. @@ -626,7 +645,7 @@ def test_router_disagg_decisions( frontend_port = get_unique_ports(request, num_ports=1)[0] # Run disagg routing test - _test_router_disagg_decisions( + _test_router_decisions_disagg( prefill_workers=prefill_workers, decode_workers=decode_workers, block_size=BLOCK_SIZE, @@ -642,8 +661,10 @@ def test_router_disagg_decisions( prefill_workers.__exit__(None, None, None) +@pytest.mark.parallel +@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) def test_busy_threshold_endpoint( - request, runtime_services_session, predownload_tokenizers + request, runtime_services_session, predownload_tokenizers, request_plane ): """Test that the /busy_threshold endpoint can be hit and responds correctly. @@ -654,19 +675,26 @@ def test_busy_threshold_endpoint( For now, this test only verifies the endpoint is accessible and returns valid responses. """ - logger.info("Starting busy_threshold endpoint test") + logger.info( + f"Starting busy_threshold endpoint test with request_plane={request_plane}" + ) mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE} try: logger.info(f"Starting {NUM_MOCKERS} mocker instances") mockers = MockerProcess( - request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS + request, + mocker_args=mocker_args, + num_mockers=NUM_MOCKERS, + request_plane=request_plane, ) logger.info(f"All mockers using endpoint: {mockers.endpoint}") mockers.__enter__() - frontend_port = get_unique_ports(request, num_ports=1)[0] + frontend_port = get_unique_ports( + request, num_ports=1, request_plane=request_plane + )[0] _test_busy_threshold_endpoint( engine_workers=mockers, @@ -674,6 +702,7 @@ def test_busy_threshold_endpoint( request=request, frontend_port=frontend_port, test_payload=TEST_PAYLOAD, + request_plane=request_plane, ) finally: diff --git a/tests/router/test_router_e2e_with_sglang.py b/tests/router/test_router_e2e_with_sglang.py new file mode 100644 index 0000000000..f8895ea73e --- /dev/null +++ b/tests/router/test_router_e2e_with_sglang.py @@ -0,0 +1,454 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +import time +from typing import Any, Dict, Optional + +import pytest + +from tests.router.common import ( # utilities + _test_router_basic, + _test_router_decisions, + _test_router_indexers_sync, + generate_random_suffix, + get_runtime, +) +from tests.utils.managed_process import ManagedProcess + +logger = logging.getLogger(__name__) + +MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + +pytestmark = [ + pytest.mark.e2e, + pytest.mark.sglang, + pytest.mark.model(MODEL_NAME), +] +SPEEDUP_RATIO = 10.0 +PORTS = [ + 8011, + 8022, +] # Frontend ports: use PORTS[0] for single router, PORTS for multi-router +NUM_REQUESTS = 10 +PAGE_SIZE = 16 # SGLang uses "page_size" instead of "block_size" + +# Shared test payload for all tests +TEST_PAYLOAD: Dict[str, Any] = { + "model": MODEL_NAME, + "messages": [ + { + "role": "user", + "content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacksโ€”an endless search for the tastiest patch of greens and the softest spot to nap.", + } + ], + "stream": True, + "max_tokens": 10, +} + +# Shared SGLang configuration for all tests +# mem_fraction_static limits actual VRAM allocation (required for multi-worker on same GPU) +SGLANG_ARGS: Dict[str, Any] = { + "page_size": PAGE_SIZE, + "model": MODEL_NAME, + "mem_fraction_static": 0.4, # Limit VRAM allocation per worker (equivalent to vLLM's gpu_memory_utilization) + "context_length": 1024, # Limit context length to reduce KV cache size (equivalent to vLLM's max_model_len) + "disable_cuda_graph": True, # Disable CUDA graphs for faster startup & lower memory (equivalent to vLLM's enforce_eager) +} + + +class SGLangProcess: + """Manages SGLang workers using dynamo.sglang (HTTP API + KV events). + + This is a drop-in replacement for MockerProcess that uses real SGLang workers. + The key difference: dynamo.sglang automatically handles: + - HTTP API serving + - KV cache event publishing (ZMQ โ†’ NATS bridge) + - Integration with dynamo.frontend router + """ + + def __init__( + self, + request, + sglang_args: Optional[Dict[str, Any]] = None, + num_workers: int = 2, + single_gpu: bool = False, + data_parallel_size: Optional[int] = None, + ): + """Initialize SGLang workers with dynamo integration. + + Args: + request: pytest request fixture for log directory + sglang_args: Configuration dict with keys: + - page_size: KV cache page size (default: 16) + - model: Model name/path (default: TinyLlama-1.1B) + - mem_fraction_static: Fraction of GPU memory to allocate (optional) + - context_length: Maximum sequence length (optional) + - disable_cuda_graph: Disable CUDA graphs (default: False) + num_workers: Number of SGLang worker processes + single_gpu: If True, all workers share GPU 0 + data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size) + """ + # Generate unique namespace for isolation + namespace_suffix = generate_random_suffix() + self.namespace = f"test-namespace-{namespace_suffix}" + self.component_name = "backend" + self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" + self.num_workers = num_workers + self.worker_processes = [] + + if sglang_args is None: + sglang_args = {} + + page_size = sglang_args.get("page_size", PAGE_SIZE) + model = sglang_args.get("model", MODEL_NAME) + mem_fraction_static = sglang_args.get("mem_fraction_static") + context_length = sglang_args.get("context_length") + disable_cuda_graph = sglang_args.get("disable_cuda_graph", False) + + self.model_name = model + + for worker_idx in range(num_workers): + # Calculate GPU device for this process + if single_gpu: + # Force all processes to GPU 0 (for single-GPU testing) + gpu_device = "0" + elif data_parallel_size is not None: + # Worker sees dp_rank GPUs (each DP rank gets its own GPU) + worker_start_gpu = worker_idx * data_parallel_size + gpu_device = ",".join( + str(i) + for i in range( + worker_start_gpu, worker_start_gpu + data_parallel_size + ) + ) + else: + # No DP; worker sees one GPU + gpu_device = str(worker_idx) + + command = [ + "python3", + "-m", + "dynamo.sglang", + "--model-path", + model, + "--page-size", + str(page_size), + ] + + # Disable CUDA graphs for faster startup & lower memory + if disable_cuda_graph: + command.append("--disable-cuda-graph") + + # Limit VRAM allocation (required for multi-worker on same GPU) + if mem_fraction_static is not None: + command.extend(["--mem-fraction-static", str(mem_fraction_static)]) + + # Add optional context_length if specified + if context_length is not None: + command.extend(["--context-length", str(context_length)]) + + if data_parallel_size is not None: + # Add DP configuration + command.extend( + [ + "--dp-size", + str(data_parallel_size), + ] + ) + + # Add per-worker KV events config for ZMQ publishing + # Each worker needs a unique port to avoid conflicts + kv_events_port = 20080 + worker_idx + kv_events_config = f'{{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:{kv_events_port}"}}' + command.extend(["--kv-events-config", kv_events_config]) + + env = os.environ.copy() # Copy parent environment + env.update( + { + "CUDA_VISIBLE_DEVICES": gpu_device, + "DYN_NAMESPACE": self.namespace, + "PYTHONHASHSEED": "0", # for deterministic event id's + } + ) + + # Create managed process for the worker + process = ManagedProcess( + command=command, + env=env, + timeout=120, # Allow time for model loading + display_output=True, + health_check_ports=[], + health_check_urls=[], + log_dir=request.node.name, + terminate_existing=False, + ) + self.worker_processes.append(process) + if data_parallel_size is not None: + logger.info( + f"Created {data_parallel_size} DP ranks per worker on GPU(s) {gpu_device} " + f"(mem_frac={mem_fraction_static}, kv_port={kv_events_port}) " + f"with endpoint: {self.endpoint}" + ) + else: + logger.info( + f"Created SGLang worker {worker_idx} on GPU {gpu_device} " + f"(mem_frac={mem_fraction_static}, kv_port={kv_events_port}) " + f"with endpoint: {self.endpoint}" + ) + + def __enter__(self): + """Start all SGLang worker processes with sequential initialization. + + Workers are started sequentially with a delay between each to avoid + resource contention during initialization. This prevents + shared memory handle allocation failures when multiple workers + try to initialize simultaneously on the same GPU. + """ + logger.info( + f"[SGLangProcess] Starting {len(self.worker_processes)} worker processes sequentially..." + ) + + # Start each process sequentially, waiting for initialization before next + for i, process in enumerate(self.worker_processes): + logger.info(f"[SGLangProcess] Starting SGLang worker {i}...") + try: + # Manually initialize the process without blocking on health checks + process._logger = logging.getLogger(process.__class__.__name__) + process._command_name = process.command[0] + os.makedirs(process.log_dir, exist_ok=True) + log_name = f"{process._command_name}.log.txt" + process._log_path = os.path.join(process.log_dir, log_name) + + if process.data_dir: + process._remove_directory(process.data_dir) + + process._terminate_existing() + logger.info( + f"[SGLangProcess] Launching process {i} (pid will be assigned)..." + ) + process._start_process() # Start the process but don't wait + logger.info( + f"[SGLangProcess] Worker {i} launched with PID: {process.proc.pid if process.proc else 'unknown'}" + ) + time.sleep(process.delayed_start) + + # Wait for initialization before starting next worker + # This prevents shared memory contention + if i < len(self.worker_processes) - 1: + init_delay = 5 # seconds + logger.info( + f"[SGLangProcess] Waiting {init_delay}s for worker {i} to initialize before starting next worker..." + ) + time.sleep(init_delay) + + except Exception: + logger.exception(f"[SGLangProcess] Failed to start worker {i}") + # Clean up on failure + try: + process.__exit__(None, None, None) + except Exception as cleanup_err: + logger.warning( + f"[SGLangProcess] Error during cleanup: {cleanup_err}" + ) + raise + + logger.info( + f"[SGLangProcess] All {len(self.worker_processes)} workers launched with sequential initialization." + ) + logger.info("[SGLangProcess] Waiting for health checks to complete...") + + # Now wait for health checks for all processes + for i, process in enumerate(self.worker_processes): + logger.info(f"[SGLangProcess] Checking health for worker {i}...") + try: + elapsed = process._check_ports(process.timeout) + process._check_urls(process.timeout - elapsed) + process._check_funcs(process.timeout - elapsed) + logger.info(f"[SGLangProcess] Worker {i} health checks passed") + except Exception: + logger.error(f"[SGLangProcess] Worker {i} health check failed") + # Clean up all processes on failure + self.__exit__(None, None, None) + raise + + logger.info( + "[SGLangProcess] All workers started successfully and passed health checks!" + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop all SGLang worker processes gracefully.""" + for i, process in enumerate(self.worker_processes): + logger.info(f"Stopping SGLang worker {i}") + process.__exit__(exc_type, exc_val, exc_tb) + + # Add delay to ensure full cleanup of NATS/ETCD/ZMQ resources + # This prevents test isolation issues when running multiple tests + logger.info("Waiting for SGLang worker resources to fully clean up...") + time.sleep(2) + + +@pytest.mark.pre_merge +@pytest.mark.gpu_1 +def test_sglang_kv_router_basic( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): + """ + Quick e2e sanity test for KV router with SGLang engine instances. + """ + + # runtime_services starts etcd and nats + N_SGLANG_WORKERS = 2 + logger.info(f"Starting SGLang KV router test with {N_SGLANG_WORKERS} workers") + + try: + # Start SGLang workers + logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers") + sglang_workers = SGLangProcess( + request, + sglang_args=SGLANG_ARGS, + num_workers=N_SGLANG_WORKERS, + single_gpu=True, # fit workers into one GPU + ) + logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") + sglang_workers.__enter__() + + # Run basic router test (starts router internally and waits for workers to be ready) + _test_router_basic( + engine_workers=sglang_workers, + block_size=PAGE_SIZE, + request=request, + frontend_port=PORTS[0], + test_payload=TEST_PAYLOAD, + num_requests=NUM_REQUESTS, + frontend_timeout=180, # 3 minutes should be plenty for TinyLlama + store_backend="etcd", # Explicit for clarity + ) + + finally: + if "sglang_workers" in locals(): + sglang_workers.__exit__(None, None, None) + + +@pytest.mark.pre_merge +@pytest.mark.gpu_1 +def test_router_decisions_sglang_multiple_workers( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): + # runtime_services starts etcd and nats + logger.info("Starting SGLang router prefix reuse test with two workers") + N_WORKERS = 2 + + try: + # Start 2 worker processes on the same GPU + logger.info("Starting 2 SGLang worker processes on single GPU (mem_frac=0.4)") + sglang_workers = SGLangProcess( + request, + sglang_args=SGLANG_ARGS, + num_workers=N_WORKERS, + single_gpu=True, # Worker uses GPU 0 + ) + logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") + + # Initialize SGLang workers + sglang_workers.__enter__() + + # Get runtime and create endpoint + runtime = get_runtime() + namespace = runtime.namespace(sglang_workers.namespace) + component = namespace.component("backend") + endpoint = component.endpoint("generate") + + _test_router_decisions( + sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=False + ) + + finally: + # Clean up SGLang workers + if "sglang_workers" in locals(): + sglang_workers.__exit__(None, None, None) + + +@pytest.mark.gpu_2 +def test_router_decisions_sglang_dp( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): + """Validate KV cache prefix reuse with SGLang by sending progressive requests with overlapping prefixes. + Same flow as test_router_decisions_sglang_multiple_workers; force first request to (worker_id, dp_rank=1). + Dump events from router and verify: + * All but one (worker_id, dp_rank) should have no events (due to prefix reuse) + * The (worker_id, dp_rank) with events should have exactly 4 events (one per request) + * All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse) + """ + N_WORKERS = 1 + DP_SIZE = 2 + + try: + logger.info("Starting 2 SGLang DP ranks (dp_size=2) (mem_frac=0.4)") + sglang_workers = SGLangProcess( + request, + sglang_args=SGLANG_ARGS, + num_workers=N_WORKERS, # Ignored when data_parallel_size is set + single_gpu=False, + data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank) + ) + logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") + sglang_workers.__enter__() + + # Get runtime and create endpoint + runtime = get_runtime() + # Use the namespace from the SGLang workers + namespace = runtime.namespace(sglang_workers.namespace) + component = namespace.component("backend") # endpoint is backend.generate + endpoint = component.endpoint("generate") + + _test_router_decisions( + sglang_workers, endpoint, MODEL_NAME, request, test_dp_rank=True + ) + + finally: + # Clean up SGLang workers + if "sglang_workers" in locals(): + sglang_workers.__exit__(None, None, None) + + +@pytest.mark.pre_merge +@pytest.mark.gpu_1 +def test_sglang_indexers_sync( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): + """ + Test that two KV routers have synchronized indexer states after processing requests + with SGLang workers. This test verifies that both routers converge to the same internal state. + """ + logger.info("Starting SGLang indexers sync test") + N_SGLANG_WORKERS = 2 + + try: + # Start SGLang workers + logger.info(f"Starting {N_SGLANG_WORKERS} SGLang workers") + sglang_workers = SGLangProcess( + request, + sglang_args=SGLANG_ARGS, + num_workers=N_SGLANG_WORKERS, + single_gpu=True, # fit workers into one GPU + ) + logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") + sglang_workers.__enter__() + + # Use the common test implementation (creates its own runtimes for each router) + # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive + _test_router_indexers_sync( + engine_workers=sglang_workers, + block_size=PAGE_SIZE, + model_name=MODEL_NAME, + num_workers=N_SGLANG_WORKERS, + store_backend="etcd", + ) + + logger.info("SGLang indexers sync test completed successfully") + + finally: + if "sglang_workers" in locals(): + sglang_workers.__exit__(None, None, None) diff --git a/tests/router/test_router_e2e_with_trtllm.py b/tests/router/test_router_e2e_with_trtllm.py new file mode 100644 index 0000000000..4ee7ca1117 --- /dev/null +++ b/tests/router/test_router_e2e_with_trtllm.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +import time +from typing import Any, Dict, Optional + +import pytest + +from tests.router.common import ( # utilities + _test_router_basic, + _test_router_decisions, + _test_router_indexers_sync, + generate_random_suffix, + get_runtime, +) +from tests.utils.managed_process import ManagedProcess + +logger = logging.getLogger(__name__) + +MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +TRTLLM_BLOCK_SIZE = 32 # fixed internally to 32 + +pytestmark = [ + pytest.mark.e2e, + pytest.mark.trtllm, + pytest.mark.model(MODEL_NAME), +] +PORTS = [ + 8011, + 8022, +] # Frontend ports: use PORTS[0] for single router, PORTS for multi-router +NUM_REQUESTS = 10 + +# Shared test payload for all tests +TEST_PAYLOAD: Dict[str, Any] = { + "model": MODEL_NAME, + "messages": [ + { + "role": "user", + "content": "In a quiet meadow tucked between rolling hills, a plump gray rabbit nibbled on clover beneath the shade of a gnarled oak tree. Its ears twitched at the faint rustle of leaves, but it remained calm, confident in the safety of its burrow just a few hops away. The late afternoon sun warmed its fur, and tiny dust motes danced in the golden light as bees hummed lazily nearby. Though the rabbit lived a simple life, every day was an adventure of scents, shadows, and snacksโ€”an endless search for the tastiest patch of greens and the softest spot to nap.", + } + ], + "stream": True, + "max_tokens": 10, +} + +# Shared TRT-LLM configuration for all tests +# free_gpu_memory_fraction limits actual VRAM allocation (required for multi-worker on same GPU) +TRTLLM_ARGS: Dict[str, Any] = { + "kv_block_size": TRTLLM_BLOCK_SIZE, + "model": MODEL_NAME, + "free_gpu_memory_fraction": 0.4, # Limit VRAM allocation per worker + "max_seq_len": 1024, # Limit context length to reduce KV cache size +} + + +class TRTLLMProcess: + """Manages TRT-LLM workers using dynamo.trtllm (HTTP API + KV events). + + This is a drop-in replacement for MockerProcess that uses real TRT-LLM workers. + The key difference: dynamo.trtllm automatically handles: + - HTTP API serving + - KV cache event publishing + - Integration with dynamo.frontend router + """ + + def __init__( + self, + request, + trtllm_args: Optional[Dict[str, Any]] = None, + num_workers: int = 2, + single_gpu: bool = False, + ): + """Initialize TRT-LLM workers with dynamo integration. + + Args: + request: pytest request fixture for log directory + trtllm_args: Configuration dict with keys: + - kv_block_size: KV cache block size (default: 32) + - model: Model name/path (default: TinyLlama-1.1B) + - free_gpu_memory_fraction: Fraction of GPU memory to allocate (optional) + - max_seq_len: Maximum sequence length (optional) + num_workers: Number of TRT-LLM worker processes + single_gpu: If True, all workers share GPU 0 + + Note: TRT-LLM doesn't support data parallelism like vLLM (dp_rank is always 0). + Tensor parallelism (TP) is supported but creates 1 worker spanning multiple GPUs, + not multiple routing targets. + """ + # Generate unique namespace for isolation + namespace_suffix = generate_random_suffix() + self.namespace = f"test-namespace-{namespace_suffix}" + self.component_name = "tensorrt_llm" + self.endpoint = f"dyn://{self.namespace}.{self.component_name}.generate" + self.num_workers = num_workers + self.worker_processes = [] + + if trtllm_args is None: + trtllm_args = {} + + model = trtllm_args.get("model", MODEL_NAME) + free_gpu_memory_fraction = trtllm_args.get("free_gpu_memory_fraction") + max_seq_len = trtllm_args.get("max_seq_len") + + self.model_name = model + + for worker_idx in range(num_workers): + # Calculate GPU device for this process + if single_gpu: + # Force all processes to GPU 0 (for single-GPU testing) + gpu_device = "0" + else: + # Each worker sees one GPU + gpu_device = str(worker_idx) + + # Single-node TRT-LLM workers use python3 -m dynamo.trtllm directly + # (trtllm-llmapi-launch is only needed for multi-node MPI deployments) + command = [ + "python3", + "-m", + "dynamo.trtllm", + "--model-path", + model, + "--kv-block-size", + str(TRTLLM_BLOCK_SIZE), + # Enable KV events publishing for router integration + "--publish-events-and-metrics", + ] + + # Limit VRAM allocation (required for multi-worker on same GPU) + if free_gpu_memory_fraction is not None: + command.extend( + ["--free-gpu-memory-fraction", str(free_gpu_memory_fraction)] + ) + + # Add optional max_seq_len if specified + if max_seq_len is not None: + command.extend(["--max-seq-len", str(max_seq_len)]) + + # Each TRT-LLM worker needs a unique DYN_SYSTEM_PORT to avoid conflicts. + # See examples/backends/trtllm/launch/disagg_same_gpu.sh for reference. + system_port = 8081 + worker_idx + + env = os.environ.copy() # Copy parent environment + env.update( + { + "CUDA_VISIBLE_DEVICES": gpu_device, + "DYN_NAMESPACE": self.namespace, + "PYTHONHASHSEED": "0", # for deterministic event id's + # Set unique system port for each worker to avoid port conflicts + "DYN_SYSTEM_PORT": str(system_port), + } + ) + + # Create managed process for the worker + process = ManagedProcess( + command=command, + env=env, + timeout=180, # Allow time for model loading (TRT-LLM may take longer) + display_output=True, + health_check_ports=[], + health_check_urls=[], + log_dir=request.node.name, + terminate_existing=False, + ) + self.worker_processes.append(process) + logger.info( + f"Created TRT-LLM worker {worker_idx} on GPU {gpu_device} " + f"(gpu_mem_frac={free_gpu_memory_fraction}, system_port={system_port}) " + f"with endpoint: {self.endpoint}" + ) + + def __enter__(self): + """Start all TRT-LLM worker processes with sequential initialization. + + Workers are started sequentially with a delay between each to avoid + resource contention during initialization. This prevents + MPI initialization conflicts when multiple workers + try to initialize simultaneously on the same GPU. + """ + logger.info( + f"[TRTLLMProcess] Starting {len(self.worker_processes)} worker processes sequentially..." + ) + + # Start each process sequentially, waiting for initialization before next + for i, process in enumerate(self.worker_processes): + logger.info(f"[TRTLLMProcess] Starting TRT-LLM worker {i}...") + try: + # Manually initialize the process without blocking on health checks + process._logger = logging.getLogger(process.__class__.__name__) + process._command_name = process.command[0] + os.makedirs(process.log_dir, exist_ok=True) + log_name = f"{process._command_name}.log.txt" + process._log_path = os.path.join(process.log_dir, log_name) + + if process.data_dir: + process._remove_directory(process.data_dir) + + process._terminate_existing() + logger.info( + f"[TRTLLMProcess] Launching process {i} (pid will be assigned)..." + ) + process._start_process() # Start the process but don't wait + logger.info( + f"[TRTLLMProcess] Worker {i} launched with PID: {process.proc.pid if process.proc else 'unknown'}" + ) + time.sleep(process.delayed_start) + + # Wait for initialization before starting next worker + # This prevents MPI initialization conflicts + if i < len(self.worker_processes) - 1: + init_delay = 5 # seconds + logger.info( + f"[TRTLLMProcess] Waiting {init_delay}s for worker {i} to initialize before starting next worker..." + ) + time.sleep(init_delay) + + except Exception: + logger.exception(f"[TRTLLMProcess] Failed to start worker {i}") + # Clean up on failure + try: + process.__exit__(None, None, None) + except Exception as cleanup_err: + logger.warning( + f"[TRTLLMProcess] Error during cleanup: {cleanup_err}" + ) + raise + + logger.info( + f"[TRTLLMProcess] All {len(self.worker_processes)} workers launched with sequential initialization." + ) + logger.info("[TRTLLMProcess] Waiting for health checks to complete...") + + # Now wait for health checks for all processes + for i, process in enumerate(self.worker_processes): + logger.info(f"[TRTLLMProcess] Checking health for worker {i}...") + try: + elapsed = process._check_ports(process.timeout) + process._check_urls(process.timeout - elapsed) + process._check_funcs(process.timeout - elapsed) + logger.info(f"[TRTLLMProcess] Worker {i} health checks passed") + except Exception: + logger.error(f"[TRTLLMProcess] Worker {i} health check failed") + # Clean up all processes on failure + self.__exit__(None, None, None) + raise + + logger.info( + "[TRTLLMProcess] All workers started successfully and passed health checks!" + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop all TRT-LLM worker processes gracefully.""" + for i, process in enumerate(self.worker_processes): + logger.info(f"Stopping TRT-LLM worker {i}") + process.__exit__(exc_type, exc_val, exc_tb) + + # Add delay to ensure full cleanup of NATS/ETCD/MPI resources + # This prevents test isolation issues when running multiple tests + logger.info("Waiting for TRT-LLM worker resources to fully clean up...") + time.sleep(2) + + +@pytest.mark.pre_merge +@pytest.mark.gpu_1 +def test_trtllm_kv_router_basic( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): + """ + Quick e2e sanity test for KV router with TRT-LLM engine instances. + """ + + # runtime_services starts etcd and nats + N_TRTLLM_WORKERS = 2 + logger.info(f"Starting TRT-LLM KV router test with {N_TRTLLM_WORKERS} workers") + + try: + # Start TRT-LLM workers + logger.info(f"Starting {N_TRTLLM_WORKERS} TRT-LLM workers") + trtllm_workers = TRTLLMProcess( + request, + trtllm_args=TRTLLM_ARGS, + num_workers=N_TRTLLM_WORKERS, + single_gpu=True, # fit workers into one GPU + ) + logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") + trtllm_workers.__enter__() + + # Run basic router test (starts router internally and waits for workers to be ready) + _test_router_basic( + engine_workers=trtllm_workers, + block_size=TRTLLM_BLOCK_SIZE, + request=request, + frontend_port=PORTS[0], + test_payload=TEST_PAYLOAD, + num_requests=NUM_REQUESTS, + frontend_timeout=180, # 3 minutes should be plenty for TinyLlama + store_backend="etcd", # Explicit for clarity + ) + + finally: + if "trtllm_workers" in locals(): + trtllm_workers.__exit__(None, None, None) + + +@pytest.mark.pre_merge +@pytest.mark.gpu_1 +def test_router_decisions_trtllm_multiple_workers( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): + # runtime_services starts etcd and nats + logger.info("Starting TRT-LLM router prefix reuse test with two workers") + N_WORKERS = 2 + + try: + # Start 2 worker processes on the same GPU + logger.info( + "Starting 2 TRT-LLM worker processes on single GPU (gpu_mem_frac=0.4)" + ) + trtllm_workers = TRTLLMProcess( + request, + trtllm_args=TRTLLM_ARGS, + num_workers=N_WORKERS, + single_gpu=True, # Worker uses GPU 0 + ) + logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") + + # Initialize TRT-LLM workers + trtllm_workers.__enter__() + + # Get runtime and create endpoint + runtime = get_runtime() + namespace = runtime.namespace(trtllm_workers.namespace) + component = namespace.component("tensorrt_llm") + endpoint = component.endpoint("generate") + + _test_router_decisions( + trtllm_workers, + endpoint, + MODEL_NAME, + request, + test_dp_rank=False, + block_size=TRTLLM_BLOCK_SIZE, + ) + + finally: + # Clean up TRT-LLM workers + if "trtllm_workers" in locals(): + trtllm_workers.__exit__(None, None, None) + + +@pytest.mark.pre_merge +@pytest.mark.gpu_1 +def test_trtllm_indexers_sync( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): + """ + Test that two KV routers have synchronized indexer states after processing requests + with TRT-LLM workers. This test verifies that both routers converge to the same internal state. + """ + logger.info("Starting TRT-LLM indexers sync test") + N_TRTLLM_WORKERS = 2 + + try: + # Start TRT-LLM workers + logger.info(f"Starting {N_TRTLLM_WORKERS} TRT-LLM workers") + trtllm_workers = TRTLLMProcess( + request, + trtllm_args=TRTLLM_ARGS, + num_workers=N_TRTLLM_WORKERS, + single_gpu=True, # fit workers into one GPU + ) + logger.info(f"All TRT-LLM workers using namespace: {trtllm_workers.namespace}") + trtllm_workers.__enter__() + + # Use the common test implementation (creates its own runtimes for each router) + # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive + _test_router_indexers_sync( + engine_workers=trtllm_workers, + block_size=TRTLLM_BLOCK_SIZE, + model_name=MODEL_NAME, + num_workers=N_TRTLLM_WORKERS, + store_backend="etcd", + ) + + logger.info("TRT-LLM indexers sync test completed successfully") + + finally: + if "trtllm_workers" in locals(): + trtllm_workers.__exit__(None, None, None) diff --git a/tests/router/test_router_e2e_with_vllm.py b/tests/router/test_router_e2e_with_vllm.py index 4ef0e24835..284c9f4bcf 100644 --- a/tests/router/test_router_e2e_with_vllm.py +++ b/tests/router/test_router_e2e_with_vllm.py @@ -10,6 +10,7 @@ from tests.router.common import ( # utilities _test_router_basic, _test_router_decisions, + _test_router_indexers_sync, generate_random_suffix, get_runtime, ) @@ -20,7 +21,6 @@ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" pytestmark = [ - pytest.mark.pre_merge, pytest.mark.e2e, pytest.mark.vllm, pytest.mark.model(MODEL_NAME), @@ -46,6 +46,16 @@ "max_tokens": 10, } +# Shared vLLM configuration for all tests +# gpu_memory_utilization limits actual VRAM allocation (required for multi-worker on same GPU) +VLLM_ARGS: Dict[str, Any] = { + "block_size": BLOCK_SIZE, + "model": MODEL_NAME, + "gpu_memory_utilization": 0.4, # Limit VRAM allocation per worker + "max_model_len": 1024, # Limit context length to reduce KV cache size + "enforce_eager": True, # Disable CUDA graphs for faster startup & lower memory +} + class VLLMProcess: """Manages vLLM workers using dynamo.vllm (HTTP API + KV events). @@ -72,11 +82,12 @@ def __init__( vllm_args: Configuration dict with keys: - block_size: KV cache block size (default: 16) - model: Model name/path (default: TinyLlama-1.1B) - - gpu_memory_utilization: GPU memory fraction per worker (default: 0.9) + - gpu_memory_utilization: Fraction of GPU memory to allocate (optional) + - num_gpu_blocks_override: Cap on number of KV cache blocks (optional) - max_model_len: Maximum sequence length (optional) - - speedup_ratio: IGNORED (vLLM runs at real speed) + - enforce_eager: Disable CUDA graphs (default: False) num_workers: Number of vLLM worker processes - single_gpu: If True, all workers share GPU 0 (requires gpu_memory_utilization < 1.0/num_workers) + single_gpu: If True, all workers share GPU 0 data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size) """ # Generate unique namespace for isolation @@ -92,8 +103,10 @@ def __init__( block_size = vllm_args.get("block_size", BLOCK_SIZE) model = vllm_args.get("model", MODEL_NAME) - gpu_memory_utilization = vllm_args.get("gpu_memory_utilization", 0.9) + gpu_memory_utilization = vllm_args.get("gpu_memory_utilization") + num_gpu_blocks_override = vllm_args.get("num_gpu_blocks_override") max_model_len = vllm_args.get("max_model_len") + enforce_eager = vllm_args.get("enforce_eager", False) self.model_name = model @@ -130,15 +143,28 @@ def __init__( model, "--block-size", str(block_size), - "--enforce-eager", # Disable CUDA graphs for faster startup - "--gpu-memory-utilization", - str(gpu_memory_utilization), ] + # Disable CUDA graphs for faster startup & lower memory + if enforce_eager: + command.append("--enforce-eager") + + # Limit VRAM allocation (required for multi-worker on same GPU) + if gpu_memory_utilization is not None: + command.extend( + ["--gpu-memory-utilization", str(gpu_memory_utilization)] + ) + # Add optional max_model_len if specified if max_model_len is not None: command.extend(["--max-model-len", str(max_model_len)]) + # Cap block count for predictable KV cache behavior + if num_gpu_blocks_override is not None: + command.extend( + ["--num-gpu-blocks-override", str(num_gpu_blocks_override)] + ) + if data_parallel_size is not None: # Add DP configuration for external load balancing # See: https://docs.vllm.ai/en/v0.10.0/serving/data_parallel_deployment.html#external-load-balancing @@ -157,6 +183,8 @@ def __init__( { "CUDA_VISIBLE_DEVICES": gpu_device, "DYN_NAMESPACE": self.namespace, + "DYN_VLLM_KV_EVENT_PORT": str(20080 + worker_idx), + "VLLM_NIXL_SIDE_CHANNEL_PORT": str(20090 + worker_idx), "PYTHONHASHSEED": "0", # for deterministic event id's } ) @@ -176,13 +204,13 @@ def __init__( if data_parallel_size is not None: logger.info( f"Created {data_parallel_size} DP ranks per worker on GPU(s) {gpu_device} " - f"(gpu_memory_utilization={gpu_memory_utilization}) " + f"(gpu_mem={gpu_memory_utilization}) " f"with endpoint: {self.endpoint}" ) else: logger.info( f"Created vLLM worker {worker_idx} on GPU {gpu_device} " - f"(gpu_memory_utilization={gpu_memory_utilization}) " + f"(gpu_mem={gpu_memory_utilization}) " f"with endpoint: {self.endpoint}" ) @@ -276,9 +304,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): time.sleep(2) +@pytest.mark.pre_merge @pytest.mark.gpu_1 -@pytest.mark.skip(reason="All vLLM tests disabled for now") -def test_vllm_kv_router_basic(request, runtime_services, predownload_tokenizers): +def test_vllm_kv_router_basic( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): """ Quick e2e sanity test for KV router with vLLM engine instances. """ @@ -287,19 +317,12 @@ def test_vllm_kv_router_basic(request, runtime_services, predownload_tokenizers) N_VLLM_WORKERS = 2 logger.info(f"Starting vLLM KV router test with {N_VLLM_WORKERS} workers") - vllm_args = { - "block_size": BLOCK_SIZE, - "model": MODEL_NAME, - "gpu_memory_utilization": 0.35, - "max_model_len": 1024, # Limit context length to reduce KV cache size - } - try: # Start vLLM workers logger.info(f"Starting {N_VLLM_WORKERS} vLLM workers") vllm_workers = VLLMProcess( request, - vllm_args=vllm_args, + vllm_args=VLLM_ARGS, num_workers=N_VLLM_WORKERS, single_gpu=True, # fit workers into one GPU ) @@ -323,32 +346,22 @@ def test_vllm_kv_router_basic(request, runtime_services, predownload_tokenizers) vllm_workers.__exit__(None, None, None) +@pytest.mark.pre_merge @pytest.mark.gpu_1 -@pytest.mark.skip(reason="All vLLM tests disabled for now") def test_router_decisions_vllm_multiple_workers( - request, runtime_services, predownload_tokenizers + request, runtime_services, predownload_models, set_ucx_tls_no_mm ): # runtime_services starts etcd and nats logger.info("Starting vLLM router prefix reuse test with two workers") - - # Create vLLM args - one worker with dp_size=2, sharing GPU 0 - vllm_args = { - "block_size": BLOCK_SIZE, - "model": MODEL_NAME, - "gpu_memory_utilization": 0.35, - "max_model_len": 1024, # Limit context length to reduce KV cache size - } N_WORKERS = 2 try: - # Start 2 worker processes (dp_rank 0 and dp_rank 1) on the same GPU - logger.info( - "Starting 2 vLLM worker processes with dp_size=2 on single GPU (gpu_memory_utilization=0.35, max_model_len=1024)" - ) + # Start 2 worker processes on the same GPU + logger.info("Starting 2 vLLM worker processes on single GPU (gpu_mem=0.4)") vllm_workers = VLLMProcess( request, - vllm_args=vllm_args, - num_workers=N_WORKERS, # One worker process with dp_size=2 + vllm_args=VLLM_ARGS, + num_workers=N_WORKERS, single_gpu=True, # Worker uses GPU 0 ) logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") @@ -373,8 +386,9 @@ def test_router_decisions_vllm_multiple_workers( @pytest.mark.gpu_2 -@pytest.mark.skip(reason="All vLLM tests disabled for now") -def test_router_decisions_vllm_dp(request, runtime_services, predownload_tokenizers): +def test_router_decisions_vllm_dp( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): """Validate KV cache prefix reuse with vLLM by sending progressive requests with overlapping prefixes. Same flow as test_router_decisions_vllm_multiple_workers; force first request to (worker_id, dp_rank=1). Dump events from router and verify: @@ -382,23 +396,14 @@ def test_router_decisions_vllm_dp(request, runtime_services, predownload_tokeniz * The (worker_id, dp_rank) with events should have exactly 4 events (one per request) * All events should be on the forced (worker_id, dp_rank=1) (verifying forced routing and prefix reuse) """ - # Create vLLM args - one worker with dp_size=2, sharing GPU 0 - vllm_args = { - "block_size": BLOCK_SIZE, - "model": MODEL_NAME, - "gpu_memory_utilization": 0.35, - "max_model_len": 1024, # Limit context length to reduce KV cache size - } N_WORKERS = 1 DP_SIZE = 2 try: - logger.info( - "Starting 2 vLLM DP ranks (dp_size=2) on single GPU (gpu_memory_utilization=0.35, max_model_len=1024)" - ) + logger.info("Starting 2 vLLM DP ranks (dp_size=2) (gpu_mem=0.4)") vllm_workers = VLLMProcess( request, - vllm_args=vllm_args, + vllm_args=VLLM_ARGS, num_workers=N_WORKERS, # Ignored when data_parallel_size is set single_gpu=False, data_parallel_size=DP_SIZE, # Creates DP_SIZE processes (one per rank) @@ -421,3 +426,44 @@ def test_router_decisions_vllm_dp(request, runtime_services, predownload_tokeniz # Clean up vLLM workers if "vllm_workers" in locals(): vllm_workers.__exit__(None, None, None) + + +@pytest.mark.pre_merge +@pytest.mark.gpu_1 +def test_vllm_indexers_sync( + request, runtime_services, predownload_models, set_ucx_tls_no_mm +): + """ + Test that two KV routers have synchronized indexer states after processing requests + with vLLM workers. This test verifies that both routers converge to the same internal state. + """ + logger.info("Starting vLLM indexers sync test") + N_VLLM_WORKERS = 2 + + try: + # Start vLLM workers + logger.info(f"Starting {N_VLLM_WORKERS} vLLM workers") + vllm_workers = VLLMProcess( + request, + vllm_args=VLLM_ARGS, + num_workers=N_VLLM_WORKERS, + single_gpu=True, # fit workers into one GPU + ) + logger.info(f"All vLLM workers using namespace: {vllm_workers.namespace}") + vllm_workers.__enter__() + + # Use the common test implementation (creates its own runtimes for each router) + # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive + _test_router_indexers_sync( + engine_workers=vllm_workers, + block_size=BLOCK_SIZE, + model_name=MODEL_NAME, + num_workers=N_VLLM_WORKERS, + store_backend="etcd", + ) + + logger.info("vLLM indexers sync test completed successfully") + + finally: + if "vllm_workers" in locals(): + vllm_workers.__exit__(None, None, None) diff --git a/tests/serve/conftest.py b/tests/serve/conftest.py index 2456865a09..51fd2b00f2 100644 --- a/tests/serve/conftest.py +++ b/tests/serve/conftest.py @@ -8,6 +8,7 @@ from pytest_httpserver import HTTPServer from dynamo.common.utils.paths import WORKSPACE_DIR +from tests.serve.lora_utils import MinioLoraConfig, MinioService # Shared constants for multimodal testing IMAGE_SERVER_PORT = 8765 @@ -50,3 +51,47 @@ def test_multimodal(image_server): ) return httpserver + + +@pytest.fixture(scope="function") +def minio_lora_service(): + """ + Provide a MinIO service with a pre-uploaded LoRA adapter for testing. + + This fixture: + 1. Starts a MinIO Docker container + 2. Creates the required S3 bucket + 3. Downloads the LoRA adapter from Hugging Face Hub + 4. Uploads it to MinIO + 5. Yields the MinioLoraConfig with connection details + 6. Cleans up after the test + + Usage: + def test_lora(minio_lora_service): + config = minio_lora_service + # Use config.get_env_vars() for environment setup + # Use config.get_s3_uri() to get the S3 URI for loading LoRA + """ + config = MinioLoraConfig() + service = MinioService(config) + + try: + # Start MinIO + service.start() + + # Create bucket + service.create_bucket() + + # Download and upload LoRA + local_path = service.download_lora() + service.upload_lora(local_path) + + # Clean up downloaded files (keep MinIO data intact) + service.cleanup_download() + + yield config + + finally: + # Stop MinIO and clean up + service.stop() + service.cleanup_temp() diff --git a/tests/serve/lora_utils.py b/tests/serve/lora_utils.py new file mode 100644 index 0000000000..434a92c796 --- /dev/null +++ b/tests/serve/lora_utils.py @@ -0,0 +1,278 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +import os +import shutil +import subprocess +import tempfile +import time +from dataclasses import dataclass +from typing import Optional + +import requests + +logger = logging.getLogger(__name__) + +# LoRA testing constants +MINIO_ENDPOINT = "http://localhost:9000" +MINIO_ACCESS_KEY = "minioadmin" +MINIO_SECRET_KEY = "minioadmin" +MINIO_BUCKET = "my-loras" +DEFAULT_LORA_REPO = "codelion/Qwen3-0.6B-accuracy-recovery-lora" +DEFAULT_LORA_NAME = "codelion/Qwen3-0.6B-accuracy-recovery-lora" + + +@dataclass +class MinioLoraConfig: + """Configuration for MinIO and LoRA setup""" + + endpoint: str = MINIO_ENDPOINT + access_key: str = MINIO_ACCESS_KEY + secret_key: str = MINIO_SECRET_KEY + bucket: str = MINIO_BUCKET + lora_repo: str = DEFAULT_LORA_REPO + lora_name: str = DEFAULT_LORA_NAME + data_dir: Optional[str] = None + + def get_s3_uri(self) -> str: + """Get the S3 URI for the LoRA adapter""" + return f"s3://{self.bucket}/{self.lora_name}" + + def get_env_vars(self) -> dict: + """Get environment variables for AWS/MinIO access""" + return { + "AWS_ENDPOINT": self.endpoint, + "AWS_ACCESS_KEY_ID": self.access_key, + "AWS_SECRET_ACCESS_KEY": self.secret_key, + "AWS_REGION": "us-east-1", + "AWS_ALLOW_HTTP": "true", + "DYN_LORA_ENABLED": "true", + "DYN_LORA_PATH": "/tmp/dynamo_loras_minio_test", + } + + +class MinioService: + """Manages MinIO Docker container lifecycle for tests""" + + CONTAINER_NAME = "dynamo-minio-test" + + def __init__(self, config: MinioLoraConfig): + self.config = config + self._logger = logging.getLogger(self.__class__.__name__) + self._temp_download_dir: Optional[str] = None + + def start(self) -> None: + """Start MinIO container""" + self._logger.info("Starting MinIO container...") + + # Create data directory + if self.config.data_dir: + data_dir = self.config.data_dir + else: + data_dir = tempfile.mkdtemp(prefix="minio_test_") + self.config.data_dir = data_dir + + # Stop existing container if running + self.stop() + + # Start MinIO container + cmd = [ + "docker", + "run", + "-d", + "--name", + self.CONTAINER_NAME, + "-p", + "9000:9000", + "-p", + "9001:9001", + "-v", + f"{data_dir}:/data", + "quay.io/minio/minio", + "server", + "/data", + "--console-address", + ":9001", + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Failed to start MinIO: {result.stderr}") + + # Wait for MinIO to be ready + self._wait_for_ready() + self._logger.info("MinIO started successfully") + + def _wait_for_ready(self, timeout: int = 30) -> None: + """Wait for MinIO to be ready""" + health_url = f"{self.config.endpoint}/minio/health/live" + start_time = time.time() + + while time.time() - start_time < timeout: + try: + response = requests.get(health_url, timeout=2) + if response.status_code == 200: + return + except requests.RequestException: + pass + time.sleep(1) + + raise RuntimeError(f"MinIO did not become ready within {timeout}s") + + def stop(self) -> None: + """Stop and remove MinIO container""" + self._logger.info("Stopping MinIO container...") + + # Stop container + subprocess.run( + ["docker", "stop", self.CONTAINER_NAME], + capture_output=True, + ) + + # Remove container + subprocess.run( + ["docker", "rm", self.CONTAINER_NAME], + capture_output=True, + ) + + def create_bucket(self) -> None: + """Create the S3 bucket using AWS CLI""" + env = os.environ.copy() + env.update( + { + "AWS_ACCESS_KEY_ID": self.config.access_key, + "AWS_SECRET_ACCESS_KEY": self.config.secret_key, + } + ) + + # Check if bucket exists + result = subprocess.run( + [ + "aws", + "--endpoint-url", + self.config.endpoint, + "s3", + "ls", + f"s3://{self.config.bucket}", + ], + capture_output=True, + text=True, + env=env, + ) + + if result.returncode != 0: + # Create bucket + self._logger.info(f"Creating bucket: {self.config.bucket}") + result = subprocess.run( + [ + "aws", + "--endpoint-url", + self.config.endpoint, + "s3", + "mb", + f"s3://{self.config.bucket}", + ], + capture_output=True, + text=True, + env=env, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to create bucket: {result.stderr}") + + def download_lora(self) -> str: + """Download LoRA from Hugging Face Hub, returns temp directory path""" + self._temp_download_dir = tempfile.mkdtemp(prefix="lora_download_") + self._logger.info( + f"Downloading LoRA {self.config.lora_repo} to {self._temp_download_dir}" + ) + + result = subprocess.run( + [ + "huggingface-cli", + "download", + self.config.lora_repo, + "--local-dir", + self._temp_download_dir, + "--local-dir-use-symlinks", + "False", + ], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to download LoRA: {result.stderr}") + + # Clean up cache directory + cache_dir = os.path.join(self._temp_download_dir, ".cache") + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + + return self._temp_download_dir + + def upload_lora(self, local_path: str) -> None: + """Upload LoRA to MinIO""" + self._logger.info( + f"Uploading LoRA to s3://{self.config.bucket}/{self.config.lora_name}" + ) + + env = os.environ.copy() + env.update( + { + "AWS_ACCESS_KEY_ID": self.config.access_key, + "AWS_SECRET_ACCESS_KEY": self.config.secret_key, + } + ) + + result = subprocess.run( + [ + "aws", + "--endpoint-url", + self.config.endpoint, + "s3", + "sync", + local_path, + f"s3://{self.config.bucket}/{self.config.lora_name}", + "--exclude", + "*.git*", + ], + capture_output=True, + text=True, + env=env, + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to upload LoRA: {result.stderr}") + + def cleanup_download(self) -> None: + """Clean up temporary download directory only""" + if self._temp_download_dir and os.path.exists(self._temp_download_dir): + shutil.rmtree(self._temp_download_dir) + self._temp_download_dir = None + + def cleanup_temp(self) -> None: + """Clean up all temporary directories including MinIO data dir""" + self.cleanup_download() + + if self.config.data_dir and os.path.exists(self.config.data_dir): + shutil.rmtree(self.config.data_dir, ignore_errors=True) + + +def load_lora_adapter( + system_port: int, lora_name: str, s3_uri: str, timeout: int = 60 +) -> None: + """Load a LoRA adapter via the system API""" + url = f"http://localhost:{system_port}/v1/loras" + payload = {"lora_name": lora_name, "source": {"uri": s3_uri}} + + logger.info(f"Loading LoRA adapter: {lora_name} from {s3_uri}") + + response = requests.post(url, json=payload, timeout=timeout) + if response.status_code != 200: + raise RuntimeError( + f"Failed to load LoRA adapter: {response.status_code} - {response.text}" + ) + + logger.info(f"LoRA adapter loaded successfully: {response.json()}") diff --git a/tests/serve/test_sglang.py b/tests/serve/test_sglang.py index 894a569ed4..9591d13571 100644 --- a/tests/serve/test_sglang.py +++ b/tests/serve/test_sglang.py @@ -37,6 +37,9 @@ class SGLangConfig(EngineConfig): WORKSPACE_DIR, "examples/backends/sglang" ) +# SGLang test configurations +# NOTE: pytest.mark.gpu_1 tests take ~167s (2m 47s) total to run sequentially (with models pre-cached) +# TODO: Parallelize these tests to reduce total execution time sglang_configs = { "aggregated": SGLangConfig( # Uses backend agg.sh (with metrics enabled) for testing standard @@ -44,7 +47,11 @@ class SGLangConfig(EngineConfig): name="aggregated", directory=sglang_dir, script_name="agg.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.timeout(240), # 3x measured time (39s) + download time (120s) + ], model="Qwen/Qwen3-0.6B", env={}, models_port=8000, @@ -120,7 +127,12 @@ class SGLangConfig(EngineConfig): name="template_verification", directory=SERVE_TEST_DIR, # special directory for test-specific scripts script_name="template_verifier.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.nightly], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.nightly, + pytest.mark.timeout(240), # 3x measured time (20s) + download time (180s) + ], model="Qwen/Qwen3-0.6B", env={}, models_port=8000, @@ -163,10 +175,14 @@ class SGLangConfig(EngineConfig): name="embedding_agg", directory=sglang_dir, script_name="agg_embed.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.nightly], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.nightly, + pytest.mark.timeout(270), # 3x measured time (29s) + download time (180s) + ], model="Qwen/Qwen3-Embedding-4B", delayed_start=0, - timeout=180, models_port=8000, request_payloads=[ # Test default payload with multiple inputs @@ -196,7 +212,12 @@ class SGLangConfig(EngineConfig): name="completions_only", directory=sglang_dir, script_name="agg.sh", - marks=[pytest.mark.gpu_1], + marks=[ + pytest.mark.gpu_1, + pytest.mark.timeout( + 420 + ), # Total test timeout: 2x measured average (79.36s) + download time (240s) for 7B model + ], model="deepseek-ai/deepseek-llm-7b-base", script_args=[ "--model-path", diff --git a/tests/serve/test_trtllm.py b/tests/serve/test_trtllm.py index f3384dbdea..ccda5cca5f 100644 --- a/tests/serve/test_trtllm.py +++ b/tests/serve/test_trtllm.py @@ -14,7 +14,10 @@ ) from tests.utils.engine_process import EngineConfig from tests.utils.payload_builder import ( + TEXT_PROMPT, + chat_payload, chat_payload_default, + completion_payload, completion_payload_default, metric_payload_default, multimodal_payload_default, @@ -34,13 +37,22 @@ class TRTLLMConfig(EngineConfig): WORKSPACE_DIR, "examples/backends/trtllm" ) -# trtllm test configurations +# TensorRT-LLM test configurations +# NOTE: pytest.mark.gpu_1 tests take ~442s (7m 22s) total to run sequentially (with models pre-cached) +# TODO: Parallelize these tests to reduce total execution time trtllm_configs = { "aggregated": TRTLLMConfig( name="aggregated", directory=trtllm_dir, script_name="agg_metrics.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.trtllm], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.trtllm, + pytest.mark.timeout( + 300 + ), # 3x measured time (44.66s) + download time (150s) + ], model="Qwen/Qwen3-0.6B", models_port=8000, request_payloads=[ @@ -65,7 +77,15 @@ class TRTLLMConfig(EngineConfig): name="disaggregated_same_gpu", directory=trtllm_dir, script_name="disagg_same_gpu.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.trtllm], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.trtllm, + pytest.mark.skip(reason="unstable"), + pytest.mark.timeout( + 480 + ), # 3x measured time (103.66s) + download time (150s) + ], model="Qwen/Qwen3-0.6B", models_port=8000, request_payloads=[ @@ -75,11 +95,46 @@ class TRTLLMConfig(EngineConfig): metric_payload_default(port=8082, min_num_requests=6, backend="trtllm"), ], ), + "aggregated_logprobs": TRTLLMConfig( + name="aggregated_logprobs", + directory=trtllm_dir, + script_name="agg.sh", + marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.trtllm], + model="Qwen/Qwen3-0.6B", + models_port=8000, + request_payloads=[ + chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5), + chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5), + chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None), + chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0), + ], + ), + "disaggregated_logprobs": TRTLLMConfig( + name="disaggregated_logprobs", + directory=trtllm_dir, + script_name="disagg.sh", + marks=[pytest.mark.gpu_2, pytest.mark.post_merge, pytest.mark.trtllm], + model="Qwen/Qwen3-0.6B", + models_port=8000, + request_payloads=[ + chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=5), + chat_payload(content=TEXT_PROMPT, logprobs=False, top_logprobs=5), + chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=None), + chat_payload(content=TEXT_PROMPT, logprobs=True, top_logprobs=0), + ], + ), "aggregated_router": TRTLLMConfig( name="aggregated_router", directory=trtllm_dir, script_name="agg_router.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge, pytest.mark.trtllm], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.trtllm, + pytest.mark.timeout( + 300 + ), # 3x measured time (37.91s) + download time (180s) + ], model="Qwen/Qwen3-0.6B", models_port=8000, request_payloads=[ @@ -121,7 +176,13 @@ class TRTLLMConfig(EngineConfig): name="completions_only", directory=trtllm_dir, script_name="agg.sh", - marks=[pytest.mark.gpu_1, pytest.mark.trtllm], + marks=[ + pytest.mark.gpu_1, + pytest.mark.trtllm, + pytest.mark.timeout( + 480 + ), # 3x measured time (83.85s) + download time (210s) for 7B model + ], model="deepseek-ai/deepseek-llm-7b-base", script_args=["--dyn-endpoint-types", "completions"], env={ @@ -130,6 +191,7 @@ class TRTLLMConfig(EngineConfig): }, request_payloads=[ completion_payload_default(), + completion_payload(prompt=TEXT_PROMPT, logprobs=3), ], ), } @@ -156,6 +218,7 @@ def test_deployment(trtllm_config_test, request, runtime_services, predownload_m @pytest.mark.e2e @pytest.mark.gpu_1 @pytest.mark.trtllm +@pytest.mark.timeout(660) # 3x measured time (159.68s) + download time (180s) def test_chat_only_aggregated_with_test_logits_processor( request, runtime_services, predownload_models, monkeypatch ): diff --git a/tests/serve/test_vllm.py b/tests/serve/test_vllm.py index f1437fb83f..38fea121d7 100644 --- a/tests/serve/test_vllm.py +++ b/tests/serve/test_vllm.py @@ -6,6 +6,7 @@ import os import random from dataclasses import dataclass, field +from typing import Optional import pytest @@ -15,13 +16,17 @@ run_serve_deployment, ) from tests.serve.conftest import MULTIMODAL_IMG_PATH, MULTIMODAL_IMG_URL +from tests.serve.lora_utils import MinioLoraConfig from tests.utils.engine_process import EngineConfig from tests.utils.payload_builder import ( chat_payload, chat_payload_default, + chat_payload_with_logprobs, completion_payload_default, + completion_payload_with_logprobs, metric_payload_default, ) +from tests.utils.payloads import LoraTestChatPayload, ToolCallingChatPayload logger = logging.getLogger(__name__) @@ -39,12 +44,18 @@ class VLLMConfig(EngineConfig): # vLLM test configurations +# NOTE: pytest.mark.gpu_1 tests take ~5.5 minutes total to run sequentially (with models pre-cached) +# TODO: Parallelize these tests to reduce total execution time vllm_configs = { "aggregated": VLLMConfig( name="aggregated", directory=vllm_dir, script_name="agg.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.timeout(300), # 3x measured time (43s) + download time (150s) + ], model="Qwen/Qwen3-0.6B", request_payloads=[ chat_payload_default(), @@ -52,11 +63,38 @@ class VLLMConfig(EngineConfig): metric_payload_default(min_num_requests=6, backend="vllm"), ], ), + "aggregated_logprobs": VLLMConfig( + name="aggregated_logprobs", + directory=vllm_dir, + script_name="agg.sh", + marks=[pytest.mark.gpu_1], + model="Qwen/Qwen3-0.6B", + request_payloads=[ + chat_payload_with_logprobs( + repeat_count=2, + expected_response=["AI", "knock", "joke"], + max_tokens=30, + temperature=0.0, + top_logprobs=3, + ), + completion_payload_with_logprobs( + repeat_count=2, + expected_response=["AI", "knock", "joke"], + max_tokens=30, + temperature=0.0, + logprobs=5, + ), + ], + ), "aggregated_lmcache": VLLMConfig( name="aggregated_lmcache", directory=vllm_dir, script_name="agg_lmcache.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.timeout(360), # 3x estimated time (70s) + download time (150s) + ], model="Qwen/Qwen3-0.6B", request_payloads=[ chat_payload_default(), @@ -69,7 +107,10 @@ class VLLMConfig(EngineConfig): name="aggregated_lmcache_multiproc", directory=vllm_dir, script_name="agg_lmcache_multiproc.sh", - marks=[pytest.mark.gpu_1], + marks=[ + pytest.mark.gpu_1, + pytest.mark.timeout(360), # 3x estimated time (70s) + download time (150s) + ], model="Qwen/Qwen3-0.6B", env={ "PROMETHEUS_MULTIPROC_DIR": f"/tmp/prometheus_multiproc_test_{os.getpid()}_{random.randint(0, 10000)}" @@ -85,7 +126,11 @@ class VLLMConfig(EngineConfig): name="agg-request-plane-tcp", directory=vllm_dir, script_name="agg_request_planes.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.timeout(300), # 3x measured time (43s) + download time (150s) + ], model="Qwen/Qwen3-0.6B", script_args=["--tcp"], request_payloads=[ @@ -97,7 +142,11 @@ class VLLMConfig(EngineConfig): name="agg-request-plane-http", directory=vllm_dir, script_name="agg_request_planes.sh", - marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], + marks=[ + pytest.mark.gpu_1, + pytest.mark.pre_merge, + pytest.mark.timeout(300), # 3x measured time (43s) + download time (150s) + ], model="Qwen/Qwen3-0.6B", script_args=["--http"], request_payloads=[ @@ -333,6 +382,74 @@ class VLLMConfig(EngineConfig): ) ], ), + "aggregated_toolcalling": VLLMConfig( + name="aggregated_toolcalling", + directory=vllm_dir, + script_name="agg_multimodal.sh", + marks=[pytest.mark.gpu_2, pytest.mark.multimodal], + model="Qwen/Qwen3-VL-30B-A3B-Instruct-FP8", + script_args=[ + "--model", + "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8", + "--max-model-len", + "10000", + "--dyn-tool-call-parser", + "hermes", + ], + delayed_start=0, + timeout=600, + request_payloads=[ + ToolCallingChatPayload( + body={ + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe what you see in this image in detail.", + }, + { + "type": "image_url", + "image_url": {"url": MULTIMODAL_IMG_URL}, + }, + ], + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "describe_image", + "description": "Provides detailed description of objects and scenes in an image", + "parameters": { + "type": "object", + "properties": { + "objects": { + "type": "array", + "items": {"type": "string"}, + "description": "List of objects detected in the image", + }, + "scene": { + "type": "string", + "description": "Overall scene description", + }, + }, + "required": ["objects", "scene"], + }, + }, + } + ], + "tool_choice": "auto", + "max_tokens": 1024, + }, + repeat_count=1, + expected_response=["purple"], # Validate image understanding + expected_log=[], + expected_tool_name="describe_image", # Validate tool call happened + ) + ], + ), # TODO: Enable this test case when we have 4 GPUs runners. # "multimodal_disagg": VLLMConfig( # name="multimodal_disagg", @@ -347,7 +464,12 @@ class VLLMConfig(EngineConfig): name="completions_only", directory=vllm_dir, script_name="agg.sh", - marks=[pytest.mark.gpu_1], + marks=[ + pytest.mark.gpu_1, + pytest.mark.timeout( + 420 + ), # 3x estimated time (60s) + download time (240s) for 7B model + ], model="deepseek-ai/deepseek-llm-7b-base", script_args=[ "--model", @@ -359,6 +481,66 @@ class VLLMConfig(EngineConfig): completion_payload_default(), ], ), + "guided_decoding_json": VLLMConfig( + name="guided_decoding_json", + directory=vllm_dir, + script_name="agg.sh", + marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], + model="Qwen/Qwen3-0.6B", + request_payloads=[ + chat_payload( + "Generate a person with name and age", + repeat_count=1, + expected_response=['"name"', '"age"'], + temperature=0.0, + max_tokens=100, + extra_body={ + "guided_json": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + } + }, + ) + ], + ), + "guided_decoding_regex": VLLMConfig( + name="guided_decoding_regex", + directory=vllm_dir, + script_name="agg.sh", + marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], + model="Qwen/Qwen3-0.6B", + request_payloads=[ + chat_payload( + "Generate a color name (red, blue, or green)", + repeat_count=1, + expected_response=["red", "blue", "green"], + temperature=0.0, + max_tokens=20, + extra_body={"guided_regex": r"(red|blue|green)"}, + ) + ], + ), + "guided_decoding_choice": VLLMConfig( + name="guided_decoding_choice", + directory=vllm_dir, + script_name="agg.sh", + marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], + model="Qwen/Qwen3-0.6B", + request_payloads=[ + chat_payload( + "Generate a color name (red, blue, or green)", + repeat_count=1, + expected_response=["red", "blue", "green"], + temperature=0.0, + max_tokens=20, + extra_body={"guided_choice": ["red", "blue", "green"]}, + ) + ], + ), } @@ -426,3 +608,153 @@ def test_multimodal_b64(request, runtime_services, predownload_models): ) run_serve_deployment(config, request) + + +# LoRA Test Directory +lora_dir = os.path.join(vllm_dir, "launch/lora") + + +def lora_chat_payload( + lora_name: str, + s3_uri: str, + system_port: int = 8081, + repeat_count: int = 2, + expected_response: Optional[list] = None, + expected_log: Optional[list] = None, + max_tokens: int = 100, + temperature: float = 0.0, +) -> LoraTestChatPayload: + """Create a LoRA-enabled chat payload for testing""" + return LoraTestChatPayload( + body={ + "model": lora_name, + "messages": [ + { + "role": "user", + "content": "What is deep learning? Answer in one sentence.", + } + ], + "max_tokens": max_tokens, + "temperature": temperature, + "stream": False, + }, + lora_name=lora_name, + s3_uri=s3_uri, + system_port=system_port, + repeat_count=repeat_count, + expected_response=expected_response + or ["learning", "neural", "network", "AI", "model"], + expected_log=expected_log or [], + ) + + +@pytest.mark.vllm +@pytest.mark.e2e +@pytest.mark.gpu_1 +@pytest.mark.model("Qwen/Qwen3-0.6B") +@pytest.mark.timeout(600) +@pytest.mark.nightly +def test_lora_aggregated( + request, runtime_services, predownload_models, minio_lora_service +): + """ + Test LoRA inference with aggregated vLLM deployment. + + This test: + 1. Uses MinIO fixture to provide S3-compatible storage with uploaded LoRA + 2. Starts vLLM with LoRA support enabled + 3. Loads the LoRA adapter via system API + 4. Runs inference with the LoRA model + """ + minio_config: MinioLoraConfig = minio_lora_service + + # Create payload that loads LoRA and tests inference + lora_payload = lora_chat_payload( + lora_name=minio_config.lora_name, + s3_uri=minio_config.get_s3_uri(), + system_port=8081, + repeat_count=2, + ) + + # Create test config with MinIO environment variables + config = VLLMConfig( + name="test_lora_aggregated", + directory=vllm_dir, + script_name="lora/agg_lora.sh", + marks=[], # markers at function-level + model="Qwen/Qwen3-0.6B", + timeout=600, + env=minio_config.get_env_vars(), + request_payloads=[lora_payload], + ) + + run_serve_deployment(config, request, extra_env=minio_config.get_env_vars()) + + +@pytest.mark.vllm +@pytest.mark.e2e +@pytest.mark.gpu_2 +@pytest.mark.model("Qwen/Qwen3-0.6B") +@pytest.mark.timeout(600) +@pytest.mark.nightly +def test_lora_aggregated_router( + request, runtime_services, predownload_models, minio_lora_service +): + """ + Test LoRA inference with aggregated vLLM deployment using KV router. + + This test: + 1. Uses MinIO fixture to provide S3-compatible storage with uploaded LoRA + 2. Starts multiple vLLM workers with LoRA support and KV router + 3. Loads the LoRA adapter on both workers via system API + 4. Runs inference with the LoRA model, verifying KV cache routing + """ + minio_config: MinioLoraConfig = minio_lora_service + + # Create payloads that load LoRA on both workers and test inference + # Worker 1 (port 8081) + lora_payload_worker1 = lora_chat_payload( + lora_name=minio_config.lora_name, + s3_uri=minio_config.get_s3_uri(), + system_port=8081, + repeat_count=1, + ) + + # Worker 2 (port 8082) + lora_payload_worker2 = lora_chat_payload( + lora_name=minio_config.lora_name, + s3_uri=minio_config.get_s3_uri(), + system_port=8082, + repeat_count=1, + ) + + # Additional inference payload to test routing (LoRA already loaded) + inference_payload = chat_payload( + content="Explain machine learning in simple terms.", + repeat_count=2, + expected_response=["learn", "data", "algorithm", "model", "pattern"], + max_tokens=150, + temperature=0.0, + ).with_model(minio_config.lora_name) + + # Add env vars including PYTHONHASHSEED for deterministic KV event IDs + env_vars = minio_config.get_env_vars() + env_vars["PYTHONHASHSEED"] = "0" + + # Create test config with MinIO environment variables + config = VLLMConfig( + name="test_lora_aggregated_router", + directory=vllm_dir, + script_name="lora/agg_lora_router.sh", + marks=[], # markers at function-level + model="Qwen/Qwen3-0.6B", + timeout=600, + env=env_vars, + request_payloads=[ + lora_payload_worker1, + lora_payload_worker2, + inference_payload, + ], + ) + + run_serve_deployment(config, request, extra_env=env_vars) diff --git a/tests/utils/managed_deployment.py b/tests/utils/managed_deployment.py index 8dd008a61a..5ee541833d 100644 --- a/tests/utils/managed_deployment.py +++ b/tests/utils/managed_deployment.py @@ -5,18 +5,18 @@ import logging import os import re +import secrets import shlex import time from dataclasses import dataclass, field from typing import Any, List, Optional import kr8s -import kubernetes import requests import yaml -from kr8s.objects import Pod as kr8s_Pod -from kr8s.objects import Service as kr8s_Service +from kr8s.objects import Pod, Service from kubernetes_asyncio import client, config +from kubernetes_asyncio.client import exceptions def _get_workspace_dir() -> str: @@ -65,6 +65,15 @@ def image(self, value: str): self._spec["extraPodSpec"]["mainContainer"] = {} self._spec["extraPodSpec"]["mainContainer"]["image"] = value + @property + def envs(self) -> list[dict[str, str]]: + """Environment variables for the service""" + return self._spec.get("envs", []) + + @envs.setter + def envs(self, value: list[dict[str, str]]): + self._spec["envs"] = value + # ----- Replicas ----- @property def replicas(self) -> int: @@ -314,8 +323,36 @@ def get_logging_config(self) -> dict: return {"jsonl_enabled": jsonl_enabled, "log_level": log_level} + def set_service_env_var(self, service_name: str, name: str, value: str): + """ + Set an environment variable for a specific service + """ + service = self.get_service(service_name) + envs = service.envs if service.envs is not None else [] + + # if env var already exists, update it + for env in envs: + if env["name"] == name: + env["value"] = value + service.envs = envs # Save back to trigger the setter + return + + # if env var does not exist, add it + envs.append({"name": name, "value": value}) + service.envs = envs # Save back to trigger the setter + + def get_service_env_vars(self, service_name: str) -> list[dict]: + """ + Get all environment variables for a specific service + + Returns: + List of environment variable dicts (e.g., [{"name": "VAR", "value": "val"}]) + """ + service = self.get_service(service_name) + return service.envs + @property - def services(self) -> list: + def services(self) -> list[ServiceSpec]: """List of ServiceSpec objects""" return [ ServiceSpec(svc, spec) @@ -340,28 +377,25 @@ def add_arg_to_service(self, service_name: str, arg_name: str, arg_value: str): arg_name: Argument name (e.g., "--max-model-len", "--max-seq-len") arg_value: Argument value (e.g., "1024") """ - # Get the service - if service_name not in self._deployment_spec["spec"]["services"]: - raise ValueError(f"Service '{service_name}' not found in deployment spec") - - service = self._deployment_spec["spec"]["services"][service_name] + service = self.get_service(service_name) + service_spec = service._spec # Ensure args list exists - if "extraPodSpec" not in service: - service["extraPodSpec"] = {"mainContainer": {}} - if "mainContainer" not in service["extraPodSpec"]: - service["extraPodSpec"]["mainContainer"] = {} - if "args" not in service["extraPodSpec"]["mainContainer"]: - service["extraPodSpec"]["mainContainer"]["args"] = [] + if "extraPodSpec" not in service_spec: + service_spec["extraPodSpec"] = {"mainContainer": {}} + if "mainContainer" not in service_spec["extraPodSpec"]: + service_spec["extraPodSpec"]["mainContainer"] = {} + if "args" not in service_spec["extraPodSpec"]["mainContainer"]: + service_spec["extraPodSpec"]["mainContainer"]["args"] = [] - args_list = service["extraPodSpec"]["mainContainer"]["args"] + args_list = service_spec["extraPodSpec"]["mainContainer"]["args"] # Convert to list if needed (sometimes it's a single string) if isinstance(args_list, str): import shlex args_list = shlex.split(args_list) - service["extraPodSpec"]["mainContainer"]["args"] = args_list + service_spec["extraPodSpec"]["mainContainer"]["args"] = args_list # Find existing argument arg_index = None @@ -384,6 +418,24 @@ def add_arg_to_service(self, service_name: str, arg_name: str, arg_value: str): # Add new argument args_list.extend([arg_name, arg_value]) + def get_service(self, service_name: str) -> ServiceSpec: + """ + Get a specific service from the deployment spec + """ + if service_name not in self._deployment_spec["spec"]["services"]: + raise ValueError(f"Service '{service_name}' not found in deployment spec") + + return ServiceSpec( + service_name, self._deployment_spec["spec"]["services"][service_name] + ) + + def set_service_replicas(self, service_name: str, replicas: int): + """ + Set the number of replicas for a specific service + """ + service = self.get_service(service_name) + service.replicas = replicas + def save(self, out_file: str): """Save updated deployment to file""" with open(out_file, "w") as f: @@ -391,7 +443,7 @@ def save(self, out_file: str): class PodProcess: - def __init__(self, pod: kr8s_Pod, line: str): + def __init__(self, pod: Pod, line: str): self.pid = int(re.split(r"\s+", line)[1]) self.command = " ".join( re.split(r"\s+", line)[10:] @@ -439,10 +491,13 @@ class ManagedDeployment: log_dir: str deployment_spec: DeploymentSpec namespace: str - frontend_service_name: Optional[str] = "Frontend" + # TODO: this should be determined by the deployment_spec + # the service containing component_type: Frontend determines what is actually the frontend service + frontend_service_name: str = "Frontend" + skip_service_restart: bool = False - _custom_api: Optional[Any] = None - _core_api: Optional[Any] = None + _custom_api: Optional[client.CustomObjectsApi] = None + _core_api: Optional[client.CoreV1Api] = None _in_cluster: bool = False _logger: logging.Logger = logging.getLogger() _port_forward: Optional[Any] = None @@ -457,7 +512,7 @@ async def _init_kubernetes(self): """Initialize kubernetes client""" try: # Try in-cluster config first (for pods with service accounts) - await config.load_incluster_config() + config.load_incluster_config() self._in_cluster = True except Exception: # Fallback to kube config file (for local development) @@ -511,6 +566,17 @@ async def _restart_stateful(self, name, label): self._logger.info(f"Restarted {name} {label}") + async def wait_for_unready(self, timeout: int = 1800, sleep=1, log_interval=60): + """ + Wait for the custom resource to be unready. + + Args: + timeout: Maximum time to wait in seconds, default to 30 mins (image pulling can take a while) + """ + return await self._wait_for_condition( + timeout, sleep, log_interval, False, "pending" + ) + async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): """ Wait for the custom resource to be ready. @@ -518,9 +584,23 @@ async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): Args: timeout: Maximum time to wait in seconds, default to 30 mins (image pulling can take a while) """ + return await self._wait_for_condition( + timeout, sleep, log_interval, True, "successful" + ) + + async def _wait_for_condition( + self, + timeout: int = 1800, + sleep=1, + log_interval=60, + desired_ready_condition_val: bool = True, + desired_state_val: str = "successful", + ): start_time = time.time() - self._logger.info(f"Waiting for Deployment {self._deployment_name}") + self._logger.info( + f"Waiting for Deployment {self._deployment_name} to have Ready condition {desired_ready_condition_val} and state {desired_state_val}" + ) attempt = 0 @@ -528,7 +608,7 @@ async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): try: attempt += 1 assert self._custom_api is not None, "Kubernetes API not initialized" - status = await self._custom_api.get_namespaced_custom_object( + status = await self._custom_api.get_namespaced_custom_object( # type: ignore[awaitable-is-not-coroutine] group="nvidia.com", version="v1alpha1", namespace=self.namespace, @@ -538,29 +618,34 @@ async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): # Check both conditions: # 1. Ready condition is True # 2. State is successful - status_obj = status.get("status", {}) - conditions = status_obj.get("conditions", []) - current_state = status_obj.get("state", "unknown") + status_obj = status.get("status", {}) # type: ignore[attr-defined] + conditions = status_obj.get("conditions", []) # type: ignore[attr-defined] + current_state = status_obj.get("state", "unknown") # type: ignore[attr-defined] - ready_condition = False + observed_ready_condition_val = "" for condition in conditions: - if ( - condition.get("type") == "Ready" - and condition.get("status") == "True" - ): - ready_condition = True - break - - state_successful = status_obj.get("state") == "successful" - - if ready_condition and state_successful: + if condition.get("type") == "Ready": + observed_ready_condition_val = condition.get("status") + if observed_ready_condition_val == str( + desired_ready_condition_val + ): + break + + observed_state_val = status_obj.get("state") # type: ignore[attr-defined] + + if ( + observed_ready_condition_val == str(desired_ready_condition_val) + and observed_state_val == desired_state_val + ): self._logger.info(f"Current deployment state: {current_state}") self._logger.info(f"Current conditions: {conditions}") self._logger.info( f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s" ) - self._logger.info(f"Deployment {self._deployment_name} is ready") + self._logger.info( + f"Deployment {self._deployment_name} has Ready condition {desired_ready_condition_val} and state {desired_state_val}" + ) return True else: if attempt % log_interval == 0: @@ -570,10 +655,10 @@ async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60): f"Elapsed time: {time.time() - start_time:.1f}s / {timeout}s" ) self._logger.info( - f"Deployment not ready yet - Ready condition: {ready_condition}, State successful: {state_successful}" + f"Deployment has Ready condition {observed_ready_condition_val} and state {observed_state_val}, desired condition {desired_ready_condition_val} and state {desired_state_val}" ) - except kubernetes.client.rest.ApiException as e: + except exceptions.ApiException as e: self._logger.info( f"API Exception while checking deployment status: {e}" ) @@ -624,7 +709,7 @@ async def _create_deployment(self): ) self._logger.info(self.deployment_spec.spec()) self._logger.info(f"Deployment Started {self._deployment_name}") - except kubernetes.client.rest.ApiException as e: + except exceptions.ApiException as e: if e.status == 409: # Already exists self._logger.info(f"Deployment {self._deployment_name} already exists") else: @@ -633,7 +718,64 @@ async def _create_deployment(self): ) raise - def get_processes(self, pod) -> list: + async def trigger_rolling_upgrade(self, service_names: list[str]): + """ + Triggers a rolling update for a list of services + This is a dummy update - sets an env var on the service + """ + + if not service_names: + raise ValueError( + "service_names cannot be empty for trigger_rolling_upgrade" + ) + + patch_body: dict[str, Any] = {"spec": {"services": {}}} + + for service_name in service_names: + self.deployment_spec.set_service_env_var( + service_name, "TEST_ROLLING_UPDATE_TRIGGER", secrets.token_hex(8) + ) + + updated_envs = self.deployment_spec.get_service_env_vars(service_name) + patch_body["spec"]["services"][service_name] = {"envs": updated_envs} + + try: + assert self._custom_api is not None, "Kubernetes API not initialized" + await self._custom_api.patch_namespaced_custom_object( + group="nvidia.com", + version="v1alpha1", + namespace=self.namespace, + plural="dynamographdeployments", + name=self._deployment_name, + body=patch_body, + _content_type="application/merge-patch+json", + ) + except exceptions.ApiException as e: + self._logger.info( + f"Failed to patch deployment {self._deployment_name}: {e}" + ) + raise + + async def get_pod_names(self, service_names: list[str] | None = None) -> list[str]: + if not service_names: + service_names = [service.name for service in self.deployment_spec.services] + + pod_names: list[str] = [] + + for service_name in service_names: + label_selector = ( + f"nvidia.com/selector={self._deployment_name}-{service_name.lower()}" + ) + assert self._core_api is not None, "Kubernetes API not initialized" + pods: client.V1PodList = await self._core_api.list_namespaced_pod( + self.namespace, label_selector=label_selector + ) + for pod in pods.items: + pod_names.append(pod.metadata.name) + + return pod_names + + def get_processes(self, pod: Pod) -> list[PodProcess]: """Get list of processes in the given pod""" result = pod.exec(["ps", "-aux"]) lines = result.stdout.decode().splitlines() @@ -646,38 +788,34 @@ def get_service(self, service_name=None): service_name = "" full_service_name = f"{self._deployment_name}-{service_name.lower()}" - return kr8s_Service.get(full_service_name, namespace=self.namespace) + return Service.get(full_service_name, namespace=self.namespace) - def get_pods(self, service_name=None): - result = {} + def get_pods(self, service_names: list[str] | None = None) -> dict[str, list[Pod]]: + result: dict[str, list[Pod]] = {} - service_list = [] + if not service_names: + service_names = [service.name for service in self.deployment_spec.services] - if not service_name: - service_list = [service.name for service in self.deployment_spec.services] - else: - service_list = [service_name] - - for service in service_list: + for service_name in service_names: # List pods for this service using the selector label # nvidia.com/selector: deployment-name-service label_selector = ( - f"nvidia.com/selector={self._deployment_name}-{service.lower()}" + f"nvidia.com/selector={self._deployment_name}-{service_name.lower()}" ) - pods = [] + pods: list[Pod] = [] for pod in kr8s.get( "pods", namespace=self.namespace, label_selector=label_selector ): - pods.append(pod) + pods.append(pod) # type: ignore[arg-type] - result[service] = pods + result[service_name] = pods return result - def get_pod_logs(self, service, pod, suffix=""): - directory = os.path.join(self.log_dir, service) + def get_pod_manifest_logs_metrics(self, service_name: str, pod: Pod, suffix=""): + directory = os.path.join(self.log_dir, service_name) os.makedirs(directory, exist_ok=True) try: @@ -699,16 +837,20 @@ def get_pod_logs(self, service, pod, suffix=""): except Exception as e: self._logger.debug(e) - self._get_pod_metrics(pod, service, suffix) + self._get_pod_metrics(pod, service_name, suffix) def _get_service_logs(self, service_name=None, suffix=""): - service_pods = self.get_pods(service_name) + service_names = None + if service_name: + service_names = [service_name] + + service_pods = self.get_pods(service_names) for service, pods in service_pods.items(): - for i, pod in enumerate(pods): - self.get_pod_logs(service, pod, suffix) + for pod in pods: + self.get_pod_manifest_logs_metrics(service, pod, suffix) - def _get_pod_metrics(self, pod, service_name, suffix=""): + def _get_pod_metrics(self, pod: Pod, service_name: str, suffix=""): directory = os.path.join(self.log_dir, service_name) os.makedirs(directory, exist_ok=True) port = None @@ -757,11 +899,13 @@ async def _delete_deployment(self): plural="dynamographdeployments", name=self._deployment_name, ) - except client.exceptions.ApiException as e: + except exceptions.ApiException as e: if e.status != 404: # Ignore if already deleted raise - def port_forward(self, pod, remote_port, max_connection_attempts=3): + def port_forward( + self, pod: Pod, remote_port: int, max_connection_attempts: int = 3 + ): """Attempt to connect to a pod and return the port-forward object on success. Note: Port forwards run in background threads. When pods are terminated, @@ -866,9 +1010,13 @@ async def __aenter__(self): self._deployment_name = self.deployment_spec.name logging.getLogger("httpx").setLevel(logging.WARNING) await self._init_kubernetes() - await self._delete_deployment() - await self._restart_etcd() - await self._restart_nats() + + # Run delete deployment and service restarts in parallel + tasks = [self._delete_deployment()] + if not self.skip_service_restart: + tasks.extend([self._restart_etcd(), self._restart_nats()]) + await asyncio.gather(*tasks) + await self._create_deployment() await self._wait_for_ready() diff --git a/tests/utils/managed_process.py b/tests/utils/managed_process.py index 618858109e..1de018b978 100644 --- a/tests/utils/managed_process.py +++ b/tests/utils/managed_process.py @@ -448,9 +448,16 @@ def _check_url(self, url, timeout=30, sleep=1, log_interval=20): elapsed = time.time() - start_time self._logger.error( - "FAILED: Check URL: %s (attempts=%d, elapsed=%.1fs)", url, attempt, elapsed + "TIMEOUT: Check URL: %s failed after %.1fs (attempts=%d, timeout=%.1fs)", + url, + elapsed, + attempt, + timeout, + ) + raise RuntimeError( + "TIMEOUT: Check URL: %s failed after %.1fs (timeout=%.1fs)" + % (url, elapsed, timeout) ) - raise RuntimeError("FAILED: Check URL: %s" % url) def _check_funcs(self, timeout): elapsed = 0.0 @@ -552,6 +559,10 @@ def is_running(self) -> bool: hasattr(self, "proc") and self.proc is not None and self.proc.poll() is None ) + def get_pid(self) -> int | None: + """Get the PID of the managed process.""" + return self.proc.pid if self.proc else None + def subprocesses(self) -> list[psutil.Process]: """Find child processes of the current process.""" if ( @@ -598,10 +609,6 @@ def __init__(self, request): log_dir=log_dir, ) - def get_pid(self) -> int | None: - """Get the PID of the worker process""" - return self.proc.pid if self.proc else None - def main(): with ManagedProcess( diff --git a/tests/utils/payload_builder.py b/tests/utils/payload_builder.py index 1b2e8bf963..017a35a461 100644 --- a/tests/utils/payload_builder.py +++ b/tests/utils/payload_builder.py @@ -6,7 +6,9 @@ from tests.utils.client import send_request from tests.utils.payloads import ( ChatPayload, + ChatPayloadWithLogprobs, CompletionPayload, + CompletionPayloadWithLogprobs, EmbeddingPayload, MetricsPayload, ) @@ -134,6 +136,9 @@ def chat_payload( max_tokens: int = 300, temperature: Optional[float] = None, stream: bool = False, + logprobs: bool = False, + top_logprobs: Optional[int] = None, + extra_body: Optional[Dict[str, Any]] = None, ) -> ChatPayload: body: Dict[str, Any] = { "messages": [ @@ -144,16 +149,35 @@ def chat_payload( ], "max_tokens": max_tokens, "stream": stream, + "logprobs": logprobs, } if temperature is not None: body["temperature"] = temperature + if logprobs is not None: + body["logprobs"] = logprobs + if top_logprobs is not None: + body["top_logprobs"] = top_logprobs - return ChatPayload( - body=body, - repeat_count=repeat_count, - expected_log=expected_log or [], - expected_response=expected_response or [], - ) + if top_logprobs is not None: + body["top_logprobs"] = top_logprobs + + if extra_body: + body.update(extra_body) + + if logprobs: + return ChatPayloadWithLogprobs( + body=body, + repeat_count=repeat_count, + expected_log=expected_log or [], + expected_response=expected_response or [], + ) + else: + return ChatPayload( + body=body, + repeat_count=repeat_count, + expected_log=expected_log or [], + expected_response=expected_response or [], + ) def completion_payload( @@ -164,18 +188,29 @@ def completion_payload( max_tokens: int = 150, temperature: float = 0.1, stream: bool = False, + logprobs: Optional[int] = None, ) -> CompletionPayload: - return CompletionPayload( - body={ - "prompt": prompt, - "max_tokens": max_tokens, - "temperature": temperature, - "stream": stream, - }, - repeat_count=repeat_count, - expected_log=expected_log or [], - expected_response=expected_response or [], - ) + body: Dict[str, Any] = { + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + "stream": stream, + } + if logprobs is not None: + body["logprobs"] = logprobs + return CompletionPayloadWithLogprobs( + body=body, + repeat_count=repeat_count, + expected_log=expected_log or [], + expected_response=expected_response or [], + ) + else: + return CompletionPayload( + body=body, + repeat_count=repeat_count, + expected_log=expected_log or [], + expected_response=expected_response or [], + ) def embedding_payload_default( @@ -276,3 +311,83 @@ def _check_completions_endpoint(remaining_timeout: float = 30.0) -> bool: return False return _check_completions_endpoint + + +def chat_payload_with_logprobs( + content: Union[str, List[Dict[str, Any]]] = TEXT_PROMPT, + repeat_count: int = 1, + expected_response: Optional[List[str]] = None, + max_tokens: int = 50, + temperature: float = 0.0, + top_logprobs: int = 3, +) -> ChatPayloadWithLogprobs: + """ + Create a chat payload that requests and validates logprobs in the response. + + Args: + content: Message content (text or structured content list) + repeat_count: Number of times to repeat the request + expected_response: List of strings expected in the response text + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_logprobs: Number of top logprobs to return per token + + Returns: + ChatPayloadWithLogprobs that validates logprobs in response + """ + body: Dict[str, Any] = { + "messages": [ + { + "role": "user", + "content": content, + } + ], + "max_tokens": max_tokens, + "temperature": temperature, + "logprobs": True, + "top_logprobs": top_logprobs, + } + + return ChatPayloadWithLogprobs( + body=body, + repeat_count=repeat_count, + expected_log=[], + expected_response=expected_response or ["AI", "knock", "joke"], + ) + + +def completion_payload_with_logprobs( + prompt: str = TEXT_PROMPT, + repeat_count: int = 1, + expected_response: Optional[List[str]] = None, + max_tokens: int = 50, + temperature: float = 0.0, + logprobs: int = 5, +) -> CompletionPayloadWithLogprobs: + """ + Create a completion payload that requests and validates logprobs in the response. + + Args: + prompt: Text prompt + repeat_count: Number of times to repeat the request + expected_response: List of strings expected in the response text + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + logprobs: Number of logprobs to return per token + + Returns: + CompletionPayloadWithLogprobs that validates logprobs in response + """ + body: Dict[str, Any] = { + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": temperature, + "logprobs": logprobs, + } + + return CompletionPayloadWithLogprobs( + body=body, + repeat_count=repeat_count, + expected_log=[], + expected_response=expected_response or ["AI", "knock", "joke"], + ) diff --git a/tests/utils/payloads.py b/tests/utils/payloads.py index 3a18dfdf44..3eac19d5ba 100644 --- a/tests/utils/payloads.py +++ b/tests/utils/payloads.py @@ -14,13 +14,16 @@ # limitations under the License. import logging +import math import re import time from copy import deepcopy from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional -from dynamo import prometheus_names +import requests + +from dynamo import prometheus_names # type: ignore[attr-defined] logger = logging.getLogger(__name__) @@ -155,6 +158,177 @@ def response_handler(self, response: Any) -> str: return ChatPayload.extract_content(response) +@dataclass +class ChatPayloadWithLogprobs(ChatPayload): + """Chat payload that validates logprobs in response.""" + + def validate(self, response: Any, content: str) -> None: + """Validate response contains logprobs fields.""" + super().validate(response, content) + + result = response.json() + choice = result["choices"][0] + + # Validate logprobs field exists + assert "logprobs" in choice, "Missing 'logprobs' in choice" + + logprobs_data = choice["logprobs"] + if logprobs_data is not None: + assert "content" in logprobs_data, "Missing 'content' in logprobs" + content_logprobs = logprobs_data["content"] + + if content_logprobs: + # Validate structure of logprobs + for item in content_logprobs: + assert "token" in item, "Missing 'token' in logprobs content" + assert "logprob" in item, "Missing 'logprob' in logprobs content" + assert ( + "top_logprobs" in item + ), "Missing 'top_logprobs' in logprobs content" + + # Sanity check: logprob should be valid (not nan/inf/positive) + logprob_val = item["logprob"] + assert not math.isnan(logprob_val), "logprob is NaN" + assert not math.isinf(logprob_val), "logprob is infinite" + assert ( + logprob_val <= 0 + ), f"logprob should be <= 0, got {logprob_val}" + + logger.info( + f"โœ“ Logprobs validation passed: found {len(content_logprobs)} tokens with logprobs" + ) + + +@dataclass +class ToolCallingChatPayload(ChatPayload): + """ChatPayload that validates tool calls in the response.""" + + def __init__(self, *args, expected_tool_name: Optional[str] = None, **kwargs): + super().__init__(*args, **kwargs) + self.expected_tool_name = expected_tool_name + + def validate(self, response, content: str) -> None: + """Validate that tool calls exist in the response.""" + # First run the standard validation + super().validate(response, content) + + # Then validate tool calls specifically + response_data = response.json() + choices = response_data.get("choices", []) + assert choices, "Response missing choices" + + message = choices[0].get("message", {}) + tool_calls = message.get("tool_calls", []) + + assert tool_calls, "Expected model to generate tool calls but none found" + logger.info(f"Tool calls detected: {len(tool_calls)} call(s)") + + # Validate tool call structure + for i, tc in enumerate(tool_calls): + assert "function" in tc, f"Tool call {i} missing 'function' field" + function = tc.get("function", {}) + assert "name" in function, f"Tool call {i} missing function name" + assert "arguments" in function, f"Tool call {i} missing function arguments" + logger.info( + f" [{i}] Function: {function.get('name')}, Args: {function.get('arguments')[:100]}..." + ) + + # If expected tool name is provided, validate it + if self.expected_tool_name: + tool_names = [tc.get("function", {}).get("name") for tc in tool_calls] + assert ( + self.expected_tool_name in tool_names + ), f"Expected tool '{self.expected_tool_name}' not found. Available tools: {tool_names}" + logger.info(f"Expected tool '{self.expected_tool_name}' was called") + + +@dataclass +class LoraTestChatPayload(ChatPayload): + """ + Chat payload that loads a LoRA adapter before sending inference requests. + + This payload first loads the specified LoRA adapter via the system API, + then sends chat completion requests using the LoRA model. + """ + + def __init__( + self, + body: dict, + lora_name: str, + s3_uri: str, + system_port: int = 8081, + repeat_count: int = 1, + expected_response: Optional[list] = None, + expected_log: Optional[list] = None, + timeout: int = 60, + ): + super().__init__( + body=body, + repeat_count=repeat_count, + expected_response=expected_response or [], + expected_log=expected_log or [], + timeout=timeout, + ) + self.system_port = system_port + self.lora_name = lora_name + self.s3_uri = s3_uri + self._lora_loaded = False + + def _ensure_lora_loaded(self) -> None: + """Ensure the LoRA adapter is loaded before making inference requests""" + if not self._lora_loaded: + # Import the load_lora_adapter function + # Note: This import is done here to avoid circular dependencies + from tests.serve.lora_utils import load_lora_adapter + + load_lora_adapter( + system_port=self.system_port, + lora_name=self.lora_name, + s3_uri=self.s3_uri, + timeout=self.timeout, + ) + + # Wait for the LoRA model to appear in /v1/models + models_url = f"http://{self.host}:{self.port}/v1/models" + start_time = time.time() + + logger.info( + f"Waiting for LoRA model '{self.lora_name}' to appear in /v1/models..." + ) + + while time.time() - start_time < self.timeout: + try: + response = requests.get(models_url, timeout=5) + if response.status_code == 200: + data = response.json() + models = data.get("data", []) + model_ids = [m.get("id", "") for m in models] + + if self.lora_name in model_ids: + logger.info( + f"LoRA model '{self.lora_name}' is now available" + ) + self._lora_loaded = True + return + + logger.debug( + f"Available models: {model_ids}, waiting for '{self.lora_name}'..." + ) + except requests.RequestException as e: + logger.debug(f"Error checking /v1/models: {e}") + + time.sleep(1) + + raise RuntimeError( + f"Timeout: LoRA model '{self.lora_name}' did not appear in /v1/models within {self.timeout}s" + ) + + def url(self) -> str: + """Load LoRA before first request, then return URL""" + self._ensure_lora_loaded() + return super().url() + + @dataclass class CompletionPayload(BasePayload): """Payload for completions endpoint.""" @@ -177,6 +351,53 @@ def response_handler(self, response: Any) -> str: return CompletionPayload.extract_text(response) +@dataclass +class CompletionPayloadWithLogprobs(CompletionPayload): + """Completion payload that validates logprobs in response.""" + + def validate(self, response: Any, content: str) -> None: + """Validate response contains logprobs fields.""" + super().validate(response, content) + + result = response.json() + choice = result["choices"][0] + + # Validate logprobs field exists + assert "logprobs" in choice, "Missing 'logprobs' in choice" + + logprobs_data = choice["logprobs"] + if logprobs_data is not None: + assert ( + "token_logprobs" in logprobs_data + ), "Missing 'token_logprobs' in logprobs" + assert "tokens" in logprobs_data, "Missing 'tokens' in logprobs" + + token_logprobs = logprobs_data["token_logprobs"] + tokens = logprobs_data["tokens"] + + if token_logprobs: + assert len(token_logprobs) == len( + tokens + ), "Mismatch between token_logprobs and tokens length" + + # Sanity check: each logprob should be valid (not nan/inf/positive) + for i, logprob_val in enumerate(token_logprobs): + if logprob_val is not None: # First token can be None + assert not math.isnan( + logprob_val + ), f"logprob at index {i} is NaN" + assert not math.isinf( + logprob_val + ), f"logprob at index {i} is infinite" + assert ( + logprob_val <= 0 + ), f"logprob at index {i} should be <= 0, got {logprob_val}" + + logger.info( + f"โœ“ Logprobs validation passed: found {len(token_logprobs)} tokens with logprobs" + ) + + @dataclass class EmbeddingPayload(BasePayload): """Payload for embeddings endpoint.""" @@ -274,9 +495,9 @@ def metric_pattern(name): name=f"{prefix}_*", pattern=lambda name: rf"^{prefix}_\w+", validator=lambda value: len(set(value)) - >= 23, # 80% of typical ~29 metrics (excluding _bucket) as of 2025-10-22 (but will grow) - error_msg=lambda name, value: f"Expected at least 23 unique {prefix}_* metrics, but found only {len(set(value))}", - success_msg=lambda name, value: f"SUCCESS: Found {len(set(value))} unique {prefix}_* metrics (minimum required: 23)", + >= 11, # 80% of typical ~17 metrics (excluding _bucket) as of 2025-12-02 + error_msg=lambda name, value: f"Expected at least 11 unique {prefix}_* metrics, but found only {len(set(value))}", + success_msg=lambda name, value: f"SUCCESS: Found {len(set(value))} unique {prefix}_* metrics (minimum required: 11)", multiline=True, ), MetricCheck( diff --git a/tests/utils/port_utils.py b/tests/utils/port_utils.py new file mode 100644 index 0000000000..89ac77724d --- /dev/null +++ b/tests/utils/port_utils.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Port allocation utilities for tests. + +Port allocation with flock-based locking to prevent race conditions in parallel tests. +""" + +import fcntl +import inspect +import json +import os +import random +import socket +import tempfile +import time +from pathlib import Path + +# Port allocation lock file +_PORT_LOCK_FILE = Path(tempfile.gettempdir()) / "pytest_port_allocations.lock" +_PORT_REGISTRY_FILE = Path(tempfile.gettempdir()) / "pytest_port_allocations.json" + +# Port range for allocation (i16 range for Rust compatibility) +# TODO: Get Rust backend to use u16 instead of i16 so we can use full 1024-65535 range +_PORT_MIN = 1024 +_PORT_MAX = 32767 + + +def _load_port_registry() -> dict: + """Load the port registry from disk. + + Returns: + dict: Port registry mapping port numbers (as strings) to allocation info. + Example: { + "30001": { + "timestamp": 1732647123.456, + "caller_file": "/workspace/tests/test_foo.py", + "caller_function": "test_bar", + "caller_line": 42 + } + } + """ + if not _PORT_REGISTRY_FILE.exists(): + return {} + try: + with open(_PORT_REGISTRY_FILE, "r") as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + return {} + + +def _save_port_registry(registry: dict) -> None: + """Save the port registry to disk.""" + with open(_PORT_REGISTRY_FILE, "w") as f: + json.dump(registry, f) + + +def _cleanup_stale_allocations(registry: dict, max_age: float = 900.0) -> dict: + """Remove port allocations older than max_age seconds.""" + current_time = time.time() + cleaned = {} + for port, info in registry.items(): + # Handle both old format (timestamp only) and new format (dict with timestamp) + if isinstance(info, dict): + timestamp = info.get("timestamp", 0) + else: + timestamp = info + + if current_time - timestamp < max_age: + cleaned[str(port)] = info + + return cleaned + + +def allocate_ports(count: int, start_port: int) -> list[int]: + """Find and return available ports in i16 range with flock-based locking. + + Uses file locking (flock) to prevent race conditions when multiple test processes + allocate ports simultaneously. + + Port range is limited to i16 (1024-32767) due to Rust backend expecting i16. + + Searches from a random offset (start_port + random(100)) and walks up incrementally. + Wraps around to _PORT_MIN (1024) when exceeding _PORT_MAX. Retries up to 100 times. + + Args: + count: Number of unique ports to allocate + start_port: Starting port number for allocation (required) + + Returns: + list[int]: List of available port numbers + """ + # Get caller information for debugging + caller_file = "unknown" + caller_function = "unknown" + caller_line = 0 + + frame = inspect.currentframe() + if frame and frame.f_back: + caller_frame = frame.f_back + caller_info = inspect.getframeinfo(caller_frame) + caller_function = caller_frame.f_code.co_name + caller_file = caller_info.filename + caller_line = caller_info.lineno + + # Validate start_port is in valid i16 range. Note that <1024 is reserved for system services (root only) + if start_port < _PORT_MIN or start_port > _PORT_MAX: + raise ValueError( + f"start_port must be between {_PORT_MIN} and {_PORT_MAX}, got {start_port}" + ) + + # Ensure lock file exists and is writable + _PORT_LOCK_FILE.parent.mkdir(parents=True, exist_ok=True) + _PORT_LOCK_FILE.touch(exist_ok=True) + + if not os.access(_PORT_LOCK_FILE, os.W_OK): + raise PermissionError( + f"Port allocation lock file is not writable: {_PORT_LOCK_FILE}" + ) + + with open(_PORT_LOCK_FILE, "r+") as lock_file: + # Acquire exclusive lock + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + + try: + # Load registry and clean up stale allocations + registry = _load_port_registry() + registry = _cleanup_stale_allocations(registry) + + allocated_ports = set(int(p) for p in registry.keys()) + ports: list[int] = [] + + # Start searching from desired port + random offset + current_port = start_port + random.randint(0, 100) + if current_port > _PORT_MAX: + current_port = _PORT_MIN + (current_port - _PORT_MAX - 1) + + # Retry limit + max_retries = 100 + attempts = 0 + + while len(ports) < count and attempts < max_retries: + attempts += 1 + + # Try current port + port = current_port + + # Increment and wrap around to _PORT_MIN + current_port += 1 + if current_port > _PORT_MAX: + current_port = _PORT_MIN + + # Skip if already allocated or in our current list + if port in allocated_ports or port in ports: + continue + + # Try to bind to verify it's actually free + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", port)) + sock.close() + ports.append(port) + registry[str(port)] = { + "timestamp": time.time(), + "caller_file": caller_file, + "caller_function": caller_function, + "caller_line": caller_line, + } + except OSError: + continue + + if len(ports) < count: + raise RuntimeError( + f"Could not find {count} available ports after {max_retries} retries" + ) + + # Save updated registry + _save_port_registry(registry) + + return ports + + finally: + # Release lock + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + + +def allocate_port(start_port: int) -> int: + """Find and return a single available port in i16 range. + + Args: + start_port: Starting port number for allocation (required) + + Returns: + int: An available port number between start_port and 32767 (i16 max) + """ + return allocate_ports(1, start_port)[0] + + +def deallocate_ports(ports: list[int]) -> None: + """Release previously allocated ports back to the pool. + + Args: + ports: List of port numbers to release + """ + if not ports: + return + + # Ensure lock file exists + _PORT_LOCK_FILE.parent.mkdir(parents=True, exist_ok=True) + _PORT_LOCK_FILE.touch(exist_ok=True) + + with open(_PORT_LOCK_FILE, "r+") as lock_file: + # Acquire exclusive lock + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + + try: + # Load registry + registry = _load_port_registry() + + # Remove the specified ports + for port in ports: + registry.pop(str(port), None) + + # Save updated registry + _save_port_registry(registry) + + finally: + # Release lock + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + + +def deallocate_port(port: int) -> None: + """Release a previously allocated port back to the pool. + + Args: + port: Port number to release + """ + deallocate_ports([port])