diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 00000000000..7e5d0062bb4 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1 @@ +Refer to [AGENTS.MD](../AGENTS.md) for all repo instructions. diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cfdb3c15a97..e18d1e2e51c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,11 +3,18 @@ name: release on: schedule: - cron: '0 13 * * *' # This schedule runs every 13:00:00Z(21:00:00+08:00) + # https://github.com/orgs/community/discussions/26286?utm_source=chatgpt.com#discussioncomment-3251208 + # "The create event does not support branch filter and tag filter." # The "create tags" trigger is specifically focused on the creation of new tags, while the "push tags" trigger is activated when tags are pushed, including both new tag creations and updates to existing tags. - create: + push: tags: - "v*.*.*" # normal release - - "nightly" # the only one mutable tag + +permissions: + contents: write + actions: read + checks: read + statuses: read # https://docs.github.com/en/actions/using-jobs/using-concurrency concurrency: @@ -21,9 +28,9 @@ jobs: - name: Ensure workspace ownership run: echo "chown -R ${USER} ${GITHUB_WORKSPACE}" && sudo chown -R ${USER} ${GITHUB_WORKSPACE} - # https://github.com/actions/checkout/blob/v3/README.md + # https://github.com/actions/checkout/blob/v6/README.md - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: token: ${{ secrets.GITHUB_TOKEN }} # Use the secret as an environment variable fetch-depth: 0 @@ -31,12 +38,12 @@ jobs: - name: Prepare release body run: | - if [[ ${GITHUB_EVENT_NAME} == "create" ]]; then + if [[ ${GITHUB_EVENT_NAME} != "schedule" ]]; then RELEASE_TAG=${GITHUB_REF#refs/tags/} - if [[ ${RELEASE_TAG} == "nightly" ]]; then - PRERELEASE=true - else + if [[ ${RELEASE_TAG} == v* ]]; then PRERELEASE=false + else + PRERELEASE=true fi echo "Workflow triggered by create tag: ${RELEASE_TAG}" else @@ -55,7 +62,7 @@ jobs: git fetch --tags if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then # Determine if a given tag exists and matches a specific Git commit. - # actions/checkout@v4 fetch-tags doesn't work when triggered by schedule + # actions/checkout@v6 fetch-tags doesn't work when triggered by schedule if [ "$(git rev-parse -q --verify "refs/tags/${RELEASE_TAG}")" = "${GITHUB_SHA}" ]; then echo "mutable tag ${RELEASE_TAG} exists and matches ${GITHUB_SHA}" else @@ -75,6 +82,14 @@ jobs: # The body field does not support environment variable substitution directly. body_path: release_body.md + - name: Build and push image + run: | + sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }} + sudo docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile . + sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest + sudo docker push infiniflow/ragflow:${RELEASE_TAG} + sudo docker push infiniflow/ragflow:latest + - name: Build and push ragflow-sdk if: startsWith(github.ref, 'refs/tags/v') run: | @@ -84,11 +99,3 @@ jobs: if: startsWith(github.ref, 'refs/tags/v') run: | cd admin/client && uv build && uv publish --token ${{ secrets.PYPI_API_TOKEN }} - - - name: Build and push image - run: | - sudo docker login --username infiniflow --password-stdin <<< ${{ secrets.DOCKERHUB_TOKEN }} - sudo docker build --build-arg NEED_MIRROR=1 -t infiniflow/ragflow:${RELEASE_TAG} -f Dockerfile . - sudo docker tag infiniflow/ragflow:${RELEASE_TAG} infiniflow/ragflow:latest - sudo docker push infiniflow/ragflow:${RELEASE_TAG} - sudo docker push infiniflow/ragflow:latest diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4357bf98278..37c666173a4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,6 @@ name: tests +permissions: + contents: read on: push: @@ -12,7 +14,7 @@ on: # The only difference between pull_request and pull_request_target is the context in which the workflow runs: # — pull_request_target workflows use the workflow files from the default branch, and secrets are available. # — pull_request workflows use the workflow files from the pull request branch, and secrets are unavailable. - pull_request_target: + pull_request: types: [ synchronize, ready_for_review ] paths-ignore: - 'docs/**' @@ -31,12 +33,9 @@ jobs: name: ragflow_tests # https://docs.github.com/en/actions/using-jobs/using-conditions-to-control-job-execution # https://github.com/orgs/community/discussions/26261 - if: ${{ github.event_name != 'pull_request_target' || contains(github.event.pull_request.labels.*.name, 'ci') }} + if: ${{ github.event_name != 'pull_request' || (github.event.pull_request.draft == false && contains(github.event.pull_request.labels.*.name, 'ci')) }} runs-on: [ "self-hosted", "ragflow-test" ] steps: - # https://github.com/hmarr/debug-action - #- uses: hmarr/debug-action@v2 - - name: Ensure workspace ownership run: | echo "Workflow triggered by ${{ github.event_name }}" @@ -44,7 +43,7 @@ jobs: # https://github.com/actions/checkout/issues/1781 - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: ref: ${{ (github.event_name == 'pull_request' || github.event_name == 'pull_request_target') && format('refs/pull/{0}/merge', github.event.pull_request.number) || github.sha }} fetch-depth: 0 @@ -53,7 +52,7 @@ jobs: - name: Check workflow duplication if: ${{ !cancelled() && !failure() }} run: | - if [[ ${GITHUB_EVENT_NAME} != "pull_request_target" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then + if [[ ${GITHUB_EVENT_NAME} != "pull_request" && ${GITHUB_EVENT_NAME} != "schedule" ]]; then HEAD=$(git rev-parse HEAD) # Find a PR that introduced a given commit gh auth login --with-token <<< "${{ secrets.GITHUB_TOKEN }}" @@ -78,7 +77,7 @@ jobs: fi fi fi - elif [[ ${GITHUB_EVENT_NAME} == "pull_request_target" ]]; then + elif [[ ${GITHUB_EVENT_NAME} == "pull_request" ]]; then PR_NUMBER=${{ github.event.pull_request.number }} PR_SHA_FP=${RUNNER_WORKSPACE_PREFIX}/artifacts/${GITHUB_REPOSITORY}/PR_${PR_NUMBER} # Calculate the hash of the current workspace content @@ -95,13 +94,53 @@ jobs: version: ">=0.11.x" args: "check" + - name: Check comments of changed Python files + if: ${{ false }} + run: | + if [[ ${{ github.event_name }} == 'pull_request' || ${{ github.event_name }} == 'pull_request_target' ]]; then + CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \ + | grep -E '\.(py)$' || true) + + if [ -n "$CHANGED_FILES" ]; then + echo "Check comments of changed Python files with check_comment_ascii.py" + + readarray -t files <<< "$CHANGED_FILES" + HAS_ERROR=0 + + for file in "${files[@]}"; do + if [ -f "$file" ]; then + if python3 check_comment_ascii.py "$file"; then + echo "✅ $file" + else + echo "❌ $file" + HAS_ERROR=1 + fi + fi + done + + if [ $HAS_ERROR -ne 0 ]; then + exit 1 + fi + else + echo "No Python files changed" + fi + fi + + - name: Run unit test + run: | + uv sync --python 3.12 --group test --frozen + source .venv/bin/activate + which pytest || echo "pytest not in PATH" + echo "Start to run unit test" + python3 run_tests.py + - name: Build ragflow:nightly run: | RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}} RAGFLOW_IMAGE=infiniflow/ragflow:${GITHUB_RUN_ID} echo "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> ${GITHUB_ENV} sudo docker pull ubuntu:22.04 - sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 -f Dockerfile -t ${RAGFLOW_IMAGE} . + sudo DOCKER_BUILDKIT=1 docker build --build-arg NEED_MIRROR=1 --build-arg HTTPS_PROXY=${HTTPS_PROXY} --build-arg HTTP_PROXY=${HTTP_PROXY} -f Dockerfile -t ${RAGFLOW_IMAGE} . if [[ ${GITHUB_EVENT_NAME} == "schedule" ]]; then export HTTP_API_TEST_LEVEL=p3 else @@ -161,34 +200,34 @@ jobs: echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV} sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d - uv sync --python 3.10 --only-group test --no-default-groups --frozen && uv pip install sdk/python + uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test - name: Run sdk tests against Elasticsearch run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do + until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do echo "Waiting for service to be available..." sleep 5 done - source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api + source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log - name: Run frontend api tests against Elasticsearch run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do + until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do echo "Waiting for service to be available..." sleep 5 done - source .venv/bin/activate && pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py + source .venv/bin/activate && set -o pipefail; pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log - name: Run http api tests against Elasticsearch run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do + until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do echo "Waiting for service to be available..." sleep 5 done - source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api + source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log - name: Stop ragflow:nightly if: always() # always run this step even if previous steps failed @@ -204,29 +243,29 @@ jobs: - name: Run sdk tests against Infinity run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do + until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do echo "Waiting for service to be available..." sleep 5 done - source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api + source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log - name: Run frontend api tests against Infinity run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do + until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do echo "Waiting for service to be available..." sleep 5 done - source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py + source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log - name: Run http api tests against Infinity run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do + until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null; do echo "Waiting for service to be available..." sleep 5 done - source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api + source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log - name: Stop ragflow:nightly if: always() # always run this step even if previous steps failed diff --git a/.gitignore b/.gitignore index fbf80b3aabd..11aa5449312 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,6 @@ ragflow_cli.egg-info # Default backup dir backup + + +.hypothesis \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000000..82d23b99039 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,110 @@ +# RAGFlow Project Instructions for GitHub Copilot + +This file provides context, build instructions, and coding standards for the RAGFlow project. +It is structured to follow GitHub Copilot's [customization guidelines](https://docs.github.com/en/copilot/concepts/prompting/response-customization). + +## 1. Project Overview +RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding. It is a full-stack application with a Python backend and a React/TypeScript frontend. + +- **Backend**: Python 3.10+ (Flask/Quart) +- **Frontend**: TypeScript, React, UmiJS +- **Architecture**: Microservices based on Docker. + - `api/`: Backend API server. + - `rag/`: Core RAG logic (indexing, retrieval). + - `deepdoc/`: Document parsing and OCR. + - `web/`: Frontend application. + +## 2. Directory Structure +- `api/`: Backend API server (Flask/Quart). + - `apps/`: API Blueprints (Knowledge Base, Chat, etc.). + - `db/`: Database models and services. +- `rag/`: Core RAG logic. + - `llm/`: LLM, Embedding, and Rerank model abstractions. +- `deepdoc/`: Document parsing and OCR modules. +- `agent/`: Agentic reasoning components. +- `web/`: Frontend application (React + UmiJS). +- `docker/`: Docker deployment configurations. +- `sdk/`: Python SDK. +- `test/`: Backend tests. + +## 3. Build Instructions + +### Backend (Python) +The project uses **uv** for dependency management. + +1. **Setup Environment**: + ```bash + uv sync --python 3.12 --all-extras + uv run download_deps.py + ``` + +2. **Run Server**: + - **Pre-requisite**: Start dependent services (MySQL, ES/Infinity, Redis, MinIO). + ```bash + docker compose -f docker/docker-compose-base.yml up -d + ``` + - **Launch**: + ```bash + source .venv/bin/activate + export PYTHONPATH=$(pwd) + bash docker/launch_backend_service.sh + ``` + +### Frontend (TypeScript/React) +Located in `web/`. + +1. **Install Dependencies**: + ```bash + cd web + npm install + ``` + +2. **Run Dev Server**: + ```bash + npm run dev + ``` + Runs on port 8000 by default. + +### Docker Deployment +To run the full stack using Docker: +```bash +cd docker +docker compose -f docker-compose.yml up -d +``` + +## 4. Testing Instructions + +### Backend Tests +- **Run All Tests**: + ```bash + uv run pytest + ``` +- **Run Specific Test**: + ```bash + uv run pytest test/test_api.py + ``` + +### Frontend Tests +- **Run Tests**: + ```bash + cd web + npm run test + ``` + +## 5. Coding Standards & Guidelines +- **Python Formatting**: Use `ruff` for linting and formatting. + ```bash + ruff check + ruff format + ``` +- **Frontend Linting**: + ```bash + cd web + npm run lint + ``` +- **Pre-commit**: Ensure pre-commit hooks are installed. + ```bash + pre-commit install + pre-commit run --all-files + ``` + diff --git a/CLAUDE.md b/CLAUDE.md index 7e5d43f9d68..d774fc376c6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -45,7 +45,7 @@ RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on d ### Backend Development ```bash # Install Python dependencies -uv sync --python 3.10 --all-extras +uv sync --python 3.12 --all-extras uv run download_deps.py pre-commit install diff --git a/Dockerfile b/Dockerfile index b16a0d7d518..5f2c5f6cf8a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # base stage -FROM ubuntu:22.04 AS base +FROM ubuntu:24.04 AS base USER root SHELL ["/bin/bash", "-c"] @@ -10,11 +10,10 @@ WORKDIR /ragflow # Copy models downloaded via download_deps.py RUN mkdir -p /ragflow/rag/res/deepdoc /root/.ragflow RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/huggingface.co,target=/huggingface.co \ - cp /huggingface.co/InfiniFlow/huqie/huqie.txt.trie /ragflow/rag/res/ && \ tar --exclude='.*' -cf - \ /huggingface.co/InfiniFlow/text_concat_xgb_v1.0 \ /huggingface.co/InfiniFlow/deepdoc \ - | tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc + | tar -xf - --strip-components=3 -C /ragflow/rag/res/deepdoc # https://github.com/chrismattmann/tika-python # This is the only way to run python-tika without internet access. Without this set, the default is to check the tika version and pull latest every time from Apache. @@ -34,34 +33,41 @@ ENV DEBIAN_FRONTEND=noninteractive # selenium: libatk-bridge2.0-0 chrome-linux64-121-0-6167-85 # Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ + apt update && \ + apt --no-install-recommends install -y ca-certificates; \ if [ "$NEED_MIRROR" == "1" ]; then \ - sed -i 's|http://ports.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \ - sed -i 's|http://archive.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \ + sed -i 's|http://archive.ubuntu.com/ubuntu|https://mirrors.tuna.tsinghua.edu.cn/ubuntu|g' /etc/apt/sources.list.d/ubuntu.sources; \ + sed -i 's|http://security.ubuntu.com/ubuntu|https://mirrors.tuna.tsinghua.edu.cn/ubuntu|g' /etc/apt/sources.list.d/ubuntu.sources; \ fi; \ rm -f /etc/apt/apt.conf.d/docker-clean && \ echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \ chmod 1777 /tmp && \ apt update && \ - apt --no-install-recommends install -y ca-certificates && \ - apt update && \ apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \ apt install -y pkg-config libicu-dev libgdiplus && \ apt install -y default-jdk && \ apt install -y libatk-bridge2.0-0 && \ apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \ apt install -y libjemalloc-dev && \ - apt install -y python3-pip pipx nginx unzip curl wget git vim less && \ - apt install -y ghostscript + apt install -y nginx unzip curl wget git vim less && \ + apt install -y ghostscript && \ + apt install -y pandoc && \ + apt install -y texlive && \ + apt install -y fonts-freefont-ttf fonts-noto-cjk -RUN if [ "$NEED_MIRROR" == "1" ]; then \ - pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \ - pip3 config set global.trusted-host pypi.tuna.tsinghua.edu.cn; \ +# Install uv +RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps \ + if [ "$NEED_MIRROR" == "1" ]; then \ mkdir -p /etc/uv && \ - echo "[[index]]" > /etc/uv/uv.toml && \ + echo 'python-install-mirror = "https://registry.npmmirror.com/-/binary/python-build-standalone/"' > /etc/uv/uv.toml && \ + echo '[[index]]' >> /etc/uv/uv.toml && \ echo 'url = "https://pypi.tuna.tsinghua.edu.cn/simple"' >> /etc/uv/uv.toml && \ - echo "default = true" >> /etc/uv/uv.toml; \ + echo 'default = true' >> /etc/uv/uv.toml; \ fi; \ - pipx install uv + tar xzf /deps/uv-x86_64-unknown-linux-gnu.tar.gz \ + && cp uv-x86_64-unknown-linux-gnu/* /usr/local/bin/ \ + && rm -rf uv-x86_64-unknown-linux-gnu \ + && uv python install 3.11 ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 ENV PATH=/root/.local/bin:$PATH @@ -77,12 +83,12 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ # A modern version of cargo is needed for the latest version of the Rust compiler. RUN apt update && apt install -y curl build-essential \ && if [ "$NEED_MIRROR" == "1" ]; then \ - # Use TUNA mirrors for rustup/rust dist files + # Use TUNA mirrors for rustup/rust dist files \ export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \ export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \ echo "Using TUNA mirrors for Rustup."; \ fi; \ - # Force curl to use HTTP/1.1 + # Force curl to use HTTP/1.1 \ curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \ && echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc @@ -99,10 +105,10 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \ apt update && \ arch="$(uname -m)"; \ if [ "$arch" = "arm64" ] || [ "$arch" = "aarch64" ]; then \ - # ARM64 (macOS/Apple Silicon or Linux aarch64) + # ARM64 (macOS/Apple Silicon or Linux aarch64) \ ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql18; \ else \ - # x86_64 or others + # x86_64 or others \ ACCEPT_EULA=Y apt install -y unixodbc-dev msodbcsql17; \ fi || \ { echo "Failed to install ODBC driver"; exit 1; } @@ -146,7 +152,7 @@ RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \ else \ sed -i 's|pypi.tuna.tsinghua.edu.cn|pypi.org|g' uv.lock; \ fi; \ - uv sync --python 3.10 --frozen + uv sync --python 3.12 --frozen COPY web web COPY docs docs @@ -186,6 +192,7 @@ COPY pyproject.toml uv.lock ./ COPY mcp mcp COPY plugin plugin COPY common common +COPY memory memory COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template COPY docker/entrypoint.sh ./ diff --git a/Dockerfile.deps b/Dockerfile.deps index c16ad446201..c683ebf7cb7 100644 --- a/Dockerfile.deps +++ b/Dockerfile.deps @@ -3,7 +3,7 @@ FROM scratch # Copy resources downloaded via download_deps.py -COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb / +COPY chromedriver-linux64-121-0-6167-85 chrome-linux64-121-0-6167-85 cl100k_base.tiktoken libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb tika-server-standard-3.0.0.jar tika-server-standard-3.0.0.jar.md5 libssl*.deb uv-x86_64-unknown-linux-gnu.tar.gz / COPY nltk_data /nltk_data diff --git a/README.md b/README.md index 299bd67fd0a..4aa670b2e09 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.22.0 + docker pull infiniflow/ragflow:v0.23.1 Latest Release @@ -37,7 +37,7 @@

Document | - Roadmap | + Roadmap | Twitter | Discord | Demo @@ -85,7 +85,9 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io). ## 🔥 Latest Updates -- 2025-11-12 Supports data synchronization from Confluence, AWS S3, Discord, Google Drive. +- 2025-12-26 Supports 'Memory' for AI agent. +- 2025-11-19 Supports Gemini 3 Pro. +- 2025-11-12 Supports data synchronization from Confluence, S3, Notion, Discord, Google Drive. - 2025-10-23 Supports MinerU & Docling as document parsing methods. - 2025-10-15 Supports orchestrable ingestion pipeline. - 2025-08-08 Supports OpenAI's latest GPT-5 series models. @@ -93,8 +95,6 @@ Try our demo at [https://demo.ragflow.io](https://demo.ragflow.io). - 2025-05-23 Adds a Python/JavaScript code executor component to Agent. - 2025-05-05 Supports cross-language query. - 2025-03-19 Supports using a multi-modal model to make sense of images within PDF or DOCX files. -- 2024-12-18 Upgrades Document Layout Analysis model in DeepDoc. -- 2024-08-22 Support text to SQL statements through RAG. ## 🎉 Stay Tuned @@ -188,13 +188,15 @@ releases! 🌟 > All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64. > If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system. -> The command below downloads the `v0.22.0` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.22.0`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. +> The command below downloads the `v0.23.1` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.23.1`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. ```bash $ cd ragflow/docker + + # git checkout v0.23.1 + # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) + # This step ensures the **entrypoint.sh** file in the code matches the Docker image version. - # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0 - # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d @@ -205,10 +207,10 @@ releases! 🌟 > Note: Prior to `v0.22.0`, we provided both images with embedding models and slim images without embedding models. Details as follows: -| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | -| ----------------- | --------------- | --------------------- | ------------------------ | -| v0.21.1 | ≈9 | ✔️ | Stable release | -| v0.21.1-slim | ≈2 | ❌ | Stable release | +| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | +|-------------------|-----------------|-----------------------|----------------| +| v0.21.1 | ≈9 | ✔️ | Stable release | +| v0.21.1-slim | ≈2 | ❌ | Stable release | > Starting with `v0.22.0`, we ship only the slim edition and no longer append the **-slim** suffix to the image tag. @@ -231,7 +233,7 @@ releases! 🌟 * Running on all addresses (0.0.0.0) ``` - > If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network anormal` + > If you skip this confirmation step and directly log in to RAGFlow, your browser may prompt a `network abnormal` > error because, at that moment, your RAGFlow may not be fully initialized. > 5. In your web browser, enter the IP address of your server and log in to RAGFlow. @@ -301,6 +303,15 @@ cd ragflow/ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . ``` +Or if you are behind a proxy, you can pass proxy arguments: + +```bash +docker build --platform linux/amd64 \ + --build-arg http_proxy=http://YOUR_PROXY:PORT \ + --build-arg https_proxy=http://YOUR_PROXY:PORT \ + -f Dockerfile -t infiniflow/ragflow:nightly . +``` + ## 🔨 Launch service from source for development 1. Install `uv` and `pre-commit`, or skip this step if they are already installed: @@ -313,7 +324,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ```bash git clone https://github.com/infiniflow/ragflow.git cd ragflow/ - uv sync --python 3.10 # install RAGFlow dependent python modules + uv sync --python 3.12 # install RAGFlow dependent python modules uv run download_deps.py pre-commit install ``` @@ -385,7 +396,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ## 📜 Roadmap -See the [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) +See the [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) ## 🏄 Community diff --git a/README_id.md b/README_id.md index c9017ddd126..51fe841175a 100644 --- a/README_id.md +++ b/README_id.md @@ -22,7 +22,7 @@ Lencana Daring - docker pull infiniflow/ragflow:v0.22.0 + docker pull infiniflow/ragflow:v0.23.1 Rilis Terbaru @@ -37,7 +37,7 @@

Dokumentasi | - Peta Jalan | + Peta Jalan | Twitter | Discord | Demo @@ -85,7 +85,9 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io). ## 🔥 Pembaruan Terbaru -- 2025-11-12 Mendukung sinkronisasi data dari Confluence, AWS S3, Discord, Google Drive. +- 2025-12-26 Mendukung 'Memori' untuk agen AI. +- 2025-11-19 Mendukung Gemini 3 Pro. +- 2025-11-12 Mendukung sinkronisasi data dari Confluence, S3, Notion, Discord, Google Drive. - 2025-10-23 Mendukung MinerU & Docling sebagai metode penguraian dokumen. - 2025-10-15 Dukungan untuk jalur data yang terorkestrasi. - 2025-08-08 Mendukung model seri GPT-5 terbaru dari OpenAI. @@ -186,12 +188,14 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io). > Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64. > Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image). -> Perintah di bawah ini mengunduh edisi v0.22.0 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.22.0, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. +> Perintah di bawah ini mengunduh edisi v0.23.1 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.23.1, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. ```bash $ cd ragflow/docker - # Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0 + # git checkout v0.23.1 + # Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases) + # This steps ensures the **entrypoint.sh** file in the code matches the Docker image version. # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d @@ -203,10 +207,10 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io). > Catatan: Sebelum `v0.22.0`, kami menyediakan image dengan model embedding dan image slim tanpa model embedding. Detailnya sebagai berikut: -| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | -| ----------------- | --------------- | --------------------- | ------------------------ | -| v0.21.1 | ≈9 | ✔️ | Stable release | -| v0.21.1-slim | ≈2 | ❌ | Stable release | +| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | +|-------------------|-----------------|-----------------------|----------------| +| v0.21.1 | ≈9 | ✔️ | Stable release | +| v0.21.1-slim | ≈2 | ❌ | Stable release | > Mulai dari `v0.22.0`, kami hanya menyediakan edisi slim dan tidak lagi menambahkan akhiran **-slim** pada tag image. @@ -229,7 +233,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io). * Running on all addresses (0.0.0.0) ``` - > Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network anormal` + > Jika Anda melewatkan langkah ini dan langsung login ke RAGFlow, browser Anda mungkin menampilkan error `network abnormal` > karena RAGFlow mungkin belum sepenuhnya siap. > 2. Buka browser web Anda, masukkan alamat IP server Anda, dan login ke RAGFlow. @@ -273,6 +277,15 @@ cd ragflow/ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . ``` +Jika berada di belakang proxy, Anda dapat melewatkan argumen proxy: + +```bash +docker build --platform linux/amd64 \ + --build-arg http_proxy=http://YOUR_PROXY:PORT \ + --build-arg https_proxy=http://YOUR_PROXY:PORT \ + -f Dockerfile -t infiniflow/ragflow:nightly . +``` + ## 🔨 Menjalankan Aplikasi dari untuk Pengembangan 1. Instal `uv` dan `pre-commit`, atau lewati langkah ini jika sudah terinstal: @@ -285,7 +298,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ```bash git clone https://github.com/infiniflow/ragflow.git cd ragflow/ - uv sync --python 3.10 # install RAGFlow dependent python modules + uv sync --python 3.12 # install RAGFlow dependent python modules uv run download_deps.py pre-commit install ``` @@ -355,7 +368,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ## 📜 Roadmap -Lihat [Roadmap RAGFlow 2025](https://github.com/infiniflow/ragflow/issues/4214) +Lihat [Roadmap RAGFlow 2026](https://github.com/infiniflow/ragflow/issues/12241) ## 🏄 Komunitas diff --git a/README_ja.md b/README_ja.md index 24bce0874c3..cd65acffddb 100644 --- a/README_ja.md +++ b/README_ja.md @@ -22,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.22.0 + docker pull infiniflow/ragflow:v0.23.1 Latest Release @@ -37,7 +37,7 @@

Document | - Roadmap | + Roadmap | Twitter | Discord | Demo @@ -66,7 +66,9 @@ ## 🔥 最新情報 -- 2025-11-12 Confluence、AWS S3、Discord、Google Drive からのデータ同期をサポートします。 +- 2025-12-26 AIエージェントの「メモリ」機能をサポート。 +- 2025-11-19 Gemini 3 Proをサポートしています。 +- 2025-11-12 Confluence、S3、Notion、Discord、Google Drive からのデータ同期をサポートします。 - 2025-10-23 ドキュメント解析方法として MinerU と Docling をサポートします。 - 2025-10-15 オーケストレーションされたデータパイプラインのサポート。 - 2025-08-08 OpenAI の最新 GPT-5 シリーズモデルをサポートします。 @@ -166,12 +168,14 @@ > 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。 > ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。 -> 以下のコマンドは、RAGFlow Docker イメージの v0.22.0 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.22.0 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。 +> 以下のコマンドは、RAGFlow Docker イメージの v0.23.1 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.23.1 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。 ```bash $ cd ragflow/docker - # 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0 + # git checkout v0.23.1 + # 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) + # この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。 # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d @@ -183,10 +187,10 @@ > 注意:`v0.22.0` より前のバージョンでは、embedding モデルを含むイメージと、embedding モデルを含まない slim イメージの両方を提供していました。詳細は以下の通りです: -| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | -| ----------------- | --------------- | --------------------- | ------------------------ | -| v0.21.1 | ≈9 | ✔️ | Stable release | -| v0.21.1-slim | ≈2 | ❌ | Stable release | +| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | +|-------------------|-----------------|-----------------------|----------------| +| v0.21.1 | ≈9 | ✔️ | Stable release | +| v0.21.1-slim | ≈2 | ❌ | Stable release | > `v0.22.0` 以降、当プロジェクトでは slim エディションのみを提供し、イメージタグに **-slim** サフィックスを付けなくなりました。 @@ -273,6 +277,15 @@ cd ragflow/ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . ``` +プロキシ環境下にいる場合は、プロキシ引数を指定できます: + +```bash +docker build --platform linux/amd64 \ + --build-arg http_proxy=http://YOUR_PROXY:PORT \ + --build-arg https_proxy=http://YOUR_PROXY:PORT \ + -f Dockerfile -t infiniflow/ragflow:nightly . +``` + ## 🔨 ソースコードからサービスを起動する方法 1. `uv` と `pre-commit` をインストールする。すでにインストールされている場合は、このステップをスキップしてください: @@ -285,7 +298,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ```bash git clone https://github.com/infiniflow/ragflow.git cd ragflow/ - uv sync --python 3.10 # install RAGFlow dependent python modules + uv sync --python 3.12 # install RAGFlow dependent python modules uv run download_deps.py pre-commit install ``` @@ -355,7 +368,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ## 📜 ロードマップ -[RAGFlow ロードマップ 2025](https://github.com/infiniflow/ragflow/issues/4214) を参照 +[RAGFlow ロードマップ 2026](https://github.com/infiniflow/ragflow/issues/12241) を参照 ## 🏄 コミュニティ diff --git a/README_ko.md b/README_ko.md index bd5acf82d4d..b6551fa264b 100644 --- a/README_ko.md +++ b/README_ko.md @@ -22,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.22.0 + docker pull infiniflow/ragflow:v0.23.1 Latest Release @@ -37,7 +37,7 @@

Document | - Roadmap | + Roadmap | Twitter | Discord | Demo @@ -67,7 +67,9 @@ ## 🔥 업데이트 -- 2025-11-12 Confluence, AWS S3, Discord, Google Drive에서 데이터 동기화를 지원합니다. +- 2025-12-26 AI 에이전트의 '메모리' 기능 지원. +- 2025-11-19 Gemini 3 Pro를 지원합니다. +- 2025-11-12 Confluence, S3, Notion, Discord, Google Drive에서 데이터 동기화를 지원합니다. - 2025-10-23 문서 파싱 방법으로 MinerU 및 Docling을 지원합니다. - 2025-10-15 조정된 데이터 파이프라인 지원. - 2025-08-08 OpenAI의 최신 GPT-5 시리즈 모델을 지원합니다. @@ -168,12 +170,14 @@ > 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다. > ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image). - > 아래 명령어는 RAGFlow Docker 이미지의 v0.22.0 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.22.0과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. + > 아래 명령어는 RAGFlow Docker 이미지의 v0.23.1 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.23.1과 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. ```bash $ cd ragflow/docker - # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0 + # git checkout v0.23.1 + # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) + # 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다. # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d @@ -185,10 +189,10 @@ > 참고: `v0.22.0` 이전 버전에서는 embedding 모델이 포함된 이미지와 embedding 모델이 포함되지 않은 slim 이미지를 모두 제공했습니다. 자세한 내용은 다음과 같습니다: -| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | -| ----------------- | --------------- | --------------------- | ------------------------ | -| v0.21.1 | ≈9 | ✔️ | Stable release | -| v0.21.1-slim | ≈2 | ❌ | Stable release | +| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | +|-------------------|-----------------|-----------------------|----------------| +| v0.21.1 | ≈9 | ✔️ | Stable release | +| v0.21.1-slim | ≈2 | ❌ | Stable release | > `v0.22.0`부터는 slim 에디션만 배포하며 이미지 태그에 **-slim** 접미사를 더 이상 붙이지 않습니다. @@ -210,7 +214,7 @@ * Running on all addresses (0.0.0.0) ``` - > 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network anormal` 오류가 발생할 수 있습니다. + > 만약 확인 단계를 건너뛰고 바로 RAGFlow에 로그인하면, RAGFlow가 완전히 초기화되지 않았기 때문에 브라우저에서 `network abnormal` 오류가 발생할 수 있습니다. 2. 웹 브라우저에 서버의 IP 주소를 입력하고 RAGFlow에 로그인하세요. > 기본 설정을 사용할 경우, `http://IP_OF_YOUR_MACHINE`만 입력하면 됩니다 (포트 번호는 제외). 기본 HTTP 서비스 포트 `80`은 기본 구성으로 사용할 때 생략할 수 있습니다. @@ -267,6 +271,15 @@ cd ragflow/ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . ``` +프록시 환경인 경우, 프록시 인수를 전달할 수 있습니다: + +```bash +docker build --platform linux/amd64 \ + --build-arg http_proxy=http://YOUR_PROXY:PORT \ + --build-arg https_proxy=http://YOUR_PROXY:PORT \ + -f Dockerfile -t infiniflow/ragflow:nightly . +``` + ## 🔨 소스 코드로 서비스를 시작합니다. 1. `uv` 와 `pre-commit` 을 설치하거나, 이미 설치된 경우 이 단계를 건너뜁니다: @@ -280,7 +293,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ```bash git clone https://github.com/infiniflow/ragflow.git cd ragflow/ - uv sync --python 3.10 # install RAGFlow dependent python modules + uv sync --python 3.12 # install RAGFlow dependent python modules uv run download_deps.py pre-commit install ``` @@ -359,7 +372,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ## 📜 로드맵 -[RAGFlow 로드맵 2025](https://github.com/infiniflow/ragflow/issues/4214)을 확인하세요. +[RAGFlow 로드맵 2026](https://github.com/infiniflow/ragflow/issues/12241)을 확인하세요. ## 🏄 커뮤니티 diff --git a/README_pt_br.md b/README_pt_br.md index 0769ea5e5ae..bd196bf6dae 100644 --- a/README_pt_br.md +++ b/README_pt_br.md @@ -22,7 +22,7 @@ Badge Estático - docker pull infiniflow/ragflow:v0.22.0 + docker pull infiniflow/ragflow:v0.23.1 Última Versão @@ -37,7 +37,7 @@

Documentação | - Roadmap | + Roadmap | Twitter | Discord | Demo @@ -86,7 +86,9 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io). ## 🔥 Últimas Atualizações -- 12-11-2025 Suporta a sincronização de dados do Confluence, AWS S3, Discord e Google Drive. +- 26-12-2025 Suporte à função 'Memória' para agentes de IA. +- 19-11-2025 Suporta Gemini 3 Pro. +- 12-11-2025 Suporta a sincronização de dados do Confluence, S3, Notion, Discord e Google Drive. - 23-10-2025 Suporta MinerU e Docling como métodos de análise de documentos. - 15-10-2025 Suporte para pipelines de dados orquestrados. - 08-08-2025 Suporta a mais recente série GPT-5 da OpenAI. @@ -186,12 +188,14 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io). > Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64. > Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema. - > O comando abaixo baixa a edição`v0.22.0` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.22.0`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. + > O comando abaixo baixa a edição`v0.23.1` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.23.1`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. ```bash $ cd ragflow/docker - # Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0 + # git checkout v0.23.1 + # Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases) + # Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker. # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d @@ -203,10 +207,10 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io). > Nota: Antes da `v0.22.0`, fornecíamos imagens com modelos de embedding e imagens slim sem modelos de embedding. Detalhes a seguir: -| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | -| ----------------- | --------------- | --------------------- | ------------------------ | -| v0.21.1 | ≈9 | ✔️ | Stable release | -| v0.21.1-slim | ≈2 | ❌ | Stable release | +| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | +|-------------------|-----------------|-----------------------|----------------| +| v0.21.1 | ≈9 | ✔️ | Stable release | +| v0.21.1-slim | ≈2 | ❌ | Stable release | > A partir da `v0.22.0`, distribuímos apenas a edição slim e não adicionamos mais o sufixo **-slim** às tags das imagens. @@ -228,7 +232,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io). * Rodando em todos os endereços (0.0.0.0) ``` - > Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network anormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado. + > Se você pular essa etapa de confirmação e acessar diretamente o RAGFlow, seu navegador pode exibir um erro `network abnormal`, pois, nesse momento, seu RAGFlow pode não estar totalmente inicializado. > 5. No seu navegador, insira o endereço IP do seu servidor e faça login no RAGFlow. @@ -290,6 +294,15 @@ cd ragflow/ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . ``` +Se você estiver atrás de um proxy, pode passar argumentos de proxy: + +```bash +docker build --platform linux/amd64 \ + --build-arg http_proxy=http://YOUR_PROXY:PORT \ + --build-arg https_proxy=http://YOUR_PROXY:PORT \ + -f Dockerfile -t infiniflow/ragflow:nightly . +``` + ## 🔨 Lançar o serviço a partir do código-fonte para desenvolvimento 1. Instale o `uv` e o `pre-commit`, ou pule esta etapa se eles já estiverem instalados: @@ -302,7 +315,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ```bash git clone https://github.com/infiniflow/ragflow.git cd ragflow/ - uv sync --python 3.10 # instala os módulos Python dependentes do RAGFlow + uv sync --python 3.12 # instala os módulos Python dependentes do RAGFlow uv run download_deps.py pre-commit install ``` @@ -372,7 +385,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ## 📜 Roadmap -Veja o [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) +Veja o [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) ## 🏄 Comunidade diff --git a/README_tzh.md b/README_tzh.md index a788964538f..a33f6f8f80c 100644 --- a/README_tzh.md +++ b/README_tzh.md @@ -22,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.22.0 + docker pull infiniflow/ragflow:v0.23.1 Latest Release @@ -37,7 +37,7 @@

Document | - Roadmap | + Roadmap | Twitter | Discord | Demo @@ -85,14 +85,16 @@ ## 🔥 近期更新 -- 2025-11-12 支援從 Confluence、AWS S3、Discord、Google Drive 進行資料同步。 +- 2025-12-26 支援AI代理的「記憶」功能。 +- 2025-11-19 支援 Gemini 3 Pro。 +- 2025-11-12 支援從 Confluence、S3、Notion、Discord、Google Drive 進行資料同步。 - 2025-10-23 支援 MinerU 和 Docling 作為文件解析方法。 - 2025-10-15 支援可編排的資料管道。 - 2025-08-08 支援 OpenAI 最新的 GPT-5 系列模型。 -- 2025-08-01 支援 agentic workflow 和 MCP +- 2025-08-01 支援 agentic workflow 和 MCP。 - 2025-05-23 為 Agent 新增 Python/JS 程式碼執行器元件。 - 2025-05-05 支援跨語言查詢。 -- 2025-03-19 PDF和DOCX中的圖支持用多模態大模型去解析得到描述. +- 2025-03-19 PDF和DOCX中的圖支持用多模態大模型去解析得到描述。 - 2024-12-18 升級了 DeepDoc 的文檔佈局分析模型。 - 2024-08-22 支援用 RAG 技術實現從自然語言到 SQL 語句的轉換。 @@ -123,7 +125,7 @@ ### 🍔 **相容各類異質資料來源** -- 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、影印件、結構化資料、網頁等。 +- 支援豐富的文件類型,包括 Word 文件、PPT、excel 表格、txt 檔案、圖片、PDF、影印件、複印件、結構化資料、網頁等。 ### 🛀 **全程無憂、自動化的 RAG 工作流程** @@ -185,12 +187,14 @@ > 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。 > 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。 -> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.22.0`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.22.0` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。 +> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.23.1`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.23.1` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。 ```bash $ cd ragflow/docker - # 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases),例:git checkout v0.22.0 + # git checkout v0.23.1 + # 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases) + # 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。 # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d @@ -202,10 +206,10 @@ > 注意:在 `v0.22.0` 之前的版本,我們會同時提供包含 embedding 模型的映像和不含 embedding 模型的 slim 映像。具體如下: -| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | -| ----------------- | --------------- | --------------------- | ------------------------ | -| v0.21.1 | ≈9 | ✔️ | Stable release | -| v0.21.1-slim | ≈2 | ❌ | Stable release | +| RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | +|-------------------|-----------------|-----------------------|----------------| +| v0.21.1 | ≈9 | ✔️ | Stable release | +| v0.21.1-slim | ≈2 | ❌ | Stable release | > 從 `v0.22.0` 開始,我們只發佈 slim 版本,並且不再在映像標籤後附加 **-slim** 後綴。 @@ -233,7 +237,7 @@ * Running on all addresses (0.0.0.0) ``` - > 如果您跳過這一步驟系統確認步驟就登入 RAGFlow,你的瀏覽器有可能會提示 `network anormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。 + > 如果您跳過這一步驟系統確認步驟就登入 RAGFlow,你的瀏覽器有可能會提示 `network abnormal` 或 `網路異常`,因為 RAGFlow 可能並未完全啟動成功。 > 5. 在你的瀏覽器中輸入你的伺服器對應的 IP 位址並登入 RAGFlow。 @@ -299,6 +303,15 @@ cd ragflow/ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . ``` +若您位於代理環境,可傳遞代理參數: + +```bash +docker build --platform linux/amd64 \ + --build-arg http_proxy=http://YOUR_PROXY:PORT \ + --build-arg https_proxy=http://YOUR_PROXY:PORT \ + -f Dockerfile -t infiniflow/ragflow:nightly . +``` + ## 🔨 以原始碼啟動服務 1. 安裝 `uv` 和 `pre-commit`。如已安裝,可跳過此步驟: @@ -312,7 +325,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ```bash git clone https://github.com/infiniflow/ragflow.git cd ragflow/ - uv sync --python 3.10 # install RAGFlow dependent python modules + uv sync --python 3.12 # install RAGFlow dependent python modules uv run download_deps.py pre-commit install ``` @@ -386,7 +399,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ## 📜 路線圖 -詳見 [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) 。 +詳見 [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) 。 ## 🏄 開源社群 diff --git a/README_zh.md b/README_zh.md index c70073a3e0f..2aa34a788eb 100644 --- a/README_zh.md +++ b/README_zh.md @@ -22,7 +22,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.22.0 + docker pull infiniflow/ragflow:v0.23.1 Latest Release @@ -37,7 +37,7 @@

Document | - Roadmap | + Roadmap | Twitter | Discord | Demo @@ -85,14 +85,16 @@ ## 🔥 近期更新 -- 2025-11-12 支持从 Confluence、AWS S3、Discord、Google Drive 进行数据同步。 +- 2025-12-26 支持AI代理的“记忆”功能。 +- 2025-11-19 支持 Gemini 3 Pro。 +- 2025-11-12 支持从 Confluence、S3、Notion、Discord、Google Drive 进行数据同步。 - 2025-10-23 支持 MinerU 和 Docling 作为文档解析方法。 - 2025-10-15 支持可编排的数据管道。 - 2025-08-08 支持 OpenAI 最新的 GPT-5 系列模型。 - 2025-08-01 支持 agentic workflow 和 MCP。 - 2025-05-23 Agent 新增 Python/JS 代码执行器组件。 - 2025-05-05 支持跨语言查询。 -- 2025-03-19 PDF 和 DOCX 中的图支持用多模态大模型去解析得到描述. +- 2025-03-19 PDF 和 DOCX 中的图支持用多模态大模型去解析得到描述。 - 2024-12-18 升级了 DeepDoc 的文档布局分析模型。 - 2024-08-22 支持用 RAG 技术实现从自然语言到 SQL 语句的转换。 @@ -186,12 +188,14 @@ > 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。 > 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。 - > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.22.0`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.22.0` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。 + > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.23.1`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.23.1` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。 ```bash $ cd ragflow/docker - # 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases),例如:git checkout v0.22.0 + # git checkout v0.23.1 + # 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases) + # 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。 # Use CPU for DeepDoc tasks: $ docker compose -f docker-compose.yml up -d @@ -203,10 +207,10 @@ > 注意:在 `v0.22.0` 之前的版本,我们会同时提供包含 embedding 模型的镜像和不含 embedding 模型的 slim 镜像。具体如下: - | RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | - | ----------------- | --------------- | --------------------- | ------------------------ | - | v0.21.1 | ≈9 | ✔️ | Stable release | - | v0.21.1-slim | ≈2 | ❌ | Stable release | + | RAGFlow image tag | Image size (GB) | Has embedding models? | Stable? | + |-------------------|-----------------|-----------------------|----------------| + | v0.21.1 | ≈9 | ✔️ | Stable release | + | v0.21.1-slim | ≈2 | ❌ | Stable release | > 从 `v0.22.0` 开始,我们只发布 slim 版本,并且不再在镜像标签后附加 **-slim** 后缀。 @@ -234,7 +238,7 @@ * Running on all addresses (0.0.0.0) ``` - > 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow,你的浏览器有可能会提示 `network anormal` 或 `网络异常`。 + > 如果您在没有看到上面的提示信息出来之前,就尝试登录 RAGFlow,你的浏览器有可能会提示 `network abnormal` 或 `网络异常`。 5. 在你的浏览器中输入你的服务器对应的 IP 地址并登录 RAGFlow。 > 上面这个例子中,您只需输入 http://IP_OF_YOUR_MACHINE 即可:未改动过配置则无需输入端口(默认的 HTTP 服务端口 80)。 @@ -298,6 +302,15 @@ cd ragflow/ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly . ``` +如果您处在代理环境下,可以传递代理参数: + +```bash +docker build --platform linux/amd64 \ + --build-arg http_proxy=http://YOUR_PROXY:PORT \ + --build-arg https_proxy=http://YOUR_PROXY:PORT \ + -f Dockerfile -t infiniflow/ragflow:nightly . +``` + ## 🔨 以源代码启动服务 1. 安装 `uv` 和 `pre-commit`。如已经安装,可跳过本步骤: @@ -312,7 +325,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ```bash git clone https://github.com/infiniflow/ragflow.git cd ragflow/ - uv sync --python 3.10 # install RAGFlow dependent python modules + uv sync --python 3.12 # install RAGFlow dependent python modules uv run download_deps.py pre-commit install ``` @@ -389,7 +402,7 @@ docker build --platform linux/amd64 -f Dockerfile -t infiniflow/ragflow:nightly ## 📜 路线图 -详见 [RAGFlow Roadmap 2025](https://github.com/infiniflow/ragflow/issues/4214) 。 +详见 [RAGFlow Roadmap 2026](https://github.com/infiniflow/ragflow/issues/12241) 。 ## 🏄 开源社区 diff --git a/SECURITY.md b/SECURITY.md index 3ccc48b67bd..7b95ba4cc70 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -6,8 +6,8 @@ Use this section to tell people about which versions of your project are currently being supported with security updates. | Version | Supported | -| ------- | ------------------ | -| <=0.7.0 | :white_check_mark: | +|---------|--------------------| +| <=0.7.0 | :white_check_mark: | ## Reporting a Vulnerability diff --git a/admin/client/README.md b/admin/client/README.md index 1964a41d40e..1f77a45d696 100644 --- a/admin/client/README.md +++ b/admin/client/README.md @@ -4,7 +4,7 @@ Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently. -The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime. +The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Infinity, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime. For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents. @@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple 1. Ensure the Admin Service is running. 2. Install ragflow-cli. ```bash - pip install ragflow-cli==0.22.0 + pip install ragflow-cli==0.23.1 ``` 3. Launch the CLI client: ```bash diff --git a/admin/client/admin_client.py b/admin/client/admin_client.py index b52e6749454..f70e1624e1b 100644 --- a/admin/client/admin_client.py +++ b/admin/client/admin_client.py @@ -16,14 +16,14 @@ import argparse import base64 +import getpass from cmd import Cmd +from typing import Any, Dict, List -from Cryptodome.PublicKey import RSA +import requests from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 -from typing import Dict, List, Any +from Cryptodome.PublicKey import RSA from lark import Lark, Transformer, Tree -import requests -import getpass GRAMMAR = r""" start: command @@ -141,7 +141,6 @@ class AdminTransformer(Transformer): - def start(self, items): return items[0] @@ -149,7 +148,7 @@ def command(self, items): return items[0] def list_services(self, items): - result = {'type': 'list_services'} + result = {"type": "list_services"} return result def show_service(self, items): @@ -236,11 +235,7 @@ def revoke_permission(self, items): action_list = items[1] resource = items[3] role_name = items[6] - return { - "type": "revoke_permission", - "role_name": role_name, - "resource": resource, "actions": action_list - } + return {"type": "revoke_permission", "role_name": role_name, "resource": resource, "actions": action_list} def alter_user_role(self, items): user_name = items[2] @@ -264,12 +259,12 @@ def meta_command(self, items): # handle quoted parameter parsed_args = [] for arg in args: - if hasattr(arg, 'value'): + if hasattr(arg, "value"): parsed_args.append(arg.value) else: parsed_args.append(str(arg)) - return {'type': 'meta', 'command': command_name, 'args': parsed_args} + return {"type": "meta", "command": command_name, "args": parsed_args} def meta_command_name(self, items): return items[0] @@ -279,22 +274,22 @@ def meta_args(self, items): def encrypt(input_string): - pub = '-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----' + pub = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEArq9XTUSeYr2+N1h3Afl/z8Dse/2yD0ZGrKwx+EEEcdsBLca9Ynmx3nIB5obmLlSfmskLpBo0UACBmB5rEjBp2Q2f3AG3Hjd4B+gNCG6BDaawuDlgANIhGnaTLrIqWrrcm4EMzJOnAOI1fgzJRsOOUEfaS318Eq9OVO3apEyCCt0lOQK6PuksduOjVxtltDav+guVAA068NrPYmRNabVKRNLJpL8w4D44sfth5RvZ3q9t+6RTArpEtc5sh5ChzvqPOzKGMXW83C95TxmXqpbK6olN4RevSfVjEAgCydH6HN6OhtOQEcnrU97r9H0iZOWwbw3pVrZiUkuRD1R56Wzs2wIDAQAB\n-----END PUBLIC KEY-----" pub_key = RSA.importKey(pub) cipher = Cipher_pkcs1_v1_5.new(pub_key) - cipher_text = cipher.encrypt(base64.b64encode(input_string.encode('utf-8'))) + cipher_text = cipher.encrypt(base64.b64encode(input_string.encode("utf-8"))) return base64.b64encode(cipher_text).decode("utf-8") def encode_to_base64(input_string): - base64_encoded = base64.b64encode(input_string.encode('utf-8')) - return base64_encoded.decode('utf-8') + base64_encoded = base64.b64encode(input_string.encode("utf-8")) + return base64_encoded.decode("utf-8") class AdminCLI(Cmd): def __init__(self): super().__init__() - self.parser = Lark(GRAMMAR, start='start', parser='lalr', transformer=AdminTransformer()) + self.parser = Lark(GRAMMAR, start="start", parser="lalr", transformer=AdminTransformer()) self.command_history = [] self.is_interactive = False self.admin_account = "admin@ragflow.io" @@ -312,7 +307,7 @@ def onecmd(self, command: str) -> bool: result = self.parse_command(command) if isinstance(result, dict): - if 'type' in result and result.get('type') == 'empty': + if "type" in result and result.get("type") == "empty": return False self.execute_command(result) @@ -320,7 +315,7 @@ def onecmd(self, command: str) -> bool: if isinstance(result, Tree): return False - if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']: + if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]: return True except KeyboardInterrupt: @@ -338,7 +333,7 @@ def default(self, line: str) -> bool: def parse_command(self, command_str: str) -> dict[str, str]: if not command_str.strip(): - return {'type': 'empty'} + return {"type": "empty"} self.command_history.append(command_str) @@ -346,12 +341,12 @@ def parse_command(self, command_str: str) -> dict[str, str]: result = self.parser.parse(command_str) return result except Exception as e: - return {'type': 'error', 'message': f'Parse error: {str(e)}'} + return {"type": "error", "message": f"Parse error: {str(e)}"} def verify_admin(self, arguments: dict, single_command: bool): - self.host = arguments['host'] - self.port = arguments['port'] - print(f"Attempt to access ip: {self.host}, port: {self.port}") + self.host = arguments["host"] + self.port = arguments["port"] + print("Attempt to access server for admin login") url = f"http://{self.host}:{self.port}/api/v1/admin/login" attempt_count = 3 @@ -365,35 +360,33 @@ def verify_admin(self, arguments: dict, single_command: bool): return False if single_command: - admin_passwd = arguments['password'] + admin_passwd = arguments["password"] else: admin_passwd = getpass.getpass(f"password for {self.admin_account}: ").strip() try: self.admin_password = encrypt(admin_passwd) - response = self.session.post(url, json={'email': self.admin_account, 'password': self.admin_password}) + response = self.session.post(url, json={"email": self.admin_account, "password": self.admin_password}) if response.status_code == 200: res_json = response.json() - error_code = res_json.get('code', -1) + error_code = res_json.get("code", -1) if error_code == 0: - self.session.headers.update({ - 'Content-Type': 'application/json', - 'Authorization': response.headers['Authorization'], - 'User-Agent': 'RAGFlow-CLI/0.22.0' - }) + self.session.headers.update({"Content-Type": "application/json", "Authorization": response.headers["Authorization"], "User-Agent": "RAGFlow-CLI/0.23.1"}) print("Authentication successful.") return True else: - error_message = res_json.get('message', 'Unknown error') + error_message = res_json.get("message", "Unknown error") print(f"Authentication failed: {error_message}, try again") continue else: print(f"Bad response,status: {response.status_code}, password is wrong") except Exception as e: print(str(e)) - print(f"Can't access {self.host}, port: {self.port}") + print("Can't access server for admin login (connection failed)") def _format_service_detail_table(self, data): - if not any([isinstance(v, list) for v in data.values()]): + if isinstance(data, list): + return data + if not all([isinstance(v, list) for v in data.values()]): # normal table return data # handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}] @@ -401,10 +394,14 @@ def _format_service_detail_table(self, data): for k, v in data.items(): # display latest status heartbeats = sorted(v, key=lambda x: x["now"], reverse=True) - task_executor_list.append({ - "task_executor_name": k, - **heartbeats[0], - }) + task_executor_list.append( + { + "task_executor_name": k, + **heartbeats[0], + } + if heartbeats + else {"task_executor_name": k} + ) return task_executor_list def _print_table_simple(self, data): @@ -415,16 +412,12 @@ def _print_table_simple(self, data): # handle single row data data = [data] - columns = list(data[0].keys()) + columns = list(set().union(*(d.keys() for d in data))) + columns.sort() col_widths = {} def get_string_width(text): - half_width_chars = ( - " !\"#$%&'()*+,-./0123456789:;<=>?@" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`" - "abcdefghijklmnopqrstuvwxyz{|}~" - "\t\n\r" - ) + half_width_chars = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\t\n\r" width = 0 for char in text: if char in half_width_chars: @@ -436,7 +429,7 @@ def get_string_width(text): for col in columns: max_width = get_string_width(str(col)) for item in data: - value_len = get_string_width(str(item.get(col, ''))) + value_len = get_string_width(str(item.get(col, ""))) if value_len > max_width: max_width = value_len col_widths[col] = max(2, max_width) @@ -454,16 +447,15 @@ def get_string_width(text): for item in data: row = "|" for col in columns: - value = str(item.get(col, '')) + value = str(item.get(col, "")) if get_string_width(value) > col_widths[col]: - value = value[:col_widths[col] - 3] + "..." + value = value[: col_widths[col] - 3] + "..." row += f" {value:<{col_widths[col] - (get_string_width(value) - len(value))}} |" print(row) print(separator) def run_interactive(self): - self.is_interactive = True print("RAGFlow Admin command line interface - Type '\\?' for help, '\\q' to quit") @@ -480,7 +472,7 @@ def run_interactive(self): if isinstance(result, Tree): continue - if result.get('type') == 'meta' and result.get('command') in ['q', 'quit', 'exit']: + if result.get("type") == "meta" and result.get("command") in ["q", "quit", "exit"]: break except KeyboardInterrupt: @@ -494,36 +486,30 @@ def run_single_command(self, command: str): self.execute_command(result) def parse_connection_args(self, args: List[str]) -> Dict[str, Any]: - parser = argparse.ArgumentParser(description='Admin CLI Client', add_help=False) - parser.add_argument('-h', '--host', default='localhost', help='Admin service host') - parser.add_argument('-p', '--port', type=int, default=9381, help='Admin service port') - parser.add_argument('-w', '--password', default='admin', type=str, help='Superuser password') - parser.add_argument('command', nargs='?', help='Single command') + parser = argparse.ArgumentParser(description="Admin CLI Client", add_help=False) + parser.add_argument("-h", "--host", default="localhost", help="Admin service host") + parser.add_argument("-p", "--port", type=int, default=9381, help="Admin service port") + parser.add_argument("-w", "--password", default="admin", type=str, help="Superuser password") + parser.add_argument("command", nargs="?", help="Single command") try: parsed_args, remaining_args = parser.parse_known_args(args) if remaining_args: command = remaining_args[0] - return { - 'host': parsed_args.host, - 'port': parsed_args.port, - 'password': parsed_args.password, - 'command': command - } + return {"host": parsed_args.host, "port": parsed_args.port, "password": parsed_args.password, "command": command} else: return { - 'host': parsed_args.host, - 'port': parsed_args.port, + "host": parsed_args.host, + "port": parsed_args.port, } except SystemExit: - return {'error': 'Invalid connection arguments'} + return {"error": "Invalid connection arguments"} def execute_command(self, parsed_command: Dict[str, Any]): - command_dict: dict if isinstance(parsed_command, Tree): command_dict = parsed_command.children[0] else: - if parsed_command['type'] == 'error': + if parsed_command["type"] == "error": print(f"Error: {parsed_command['message']}") return else: @@ -531,56 +517,56 @@ def execute_command(self, parsed_command: Dict[str, Any]): # print(f"Parsed command: {command_dict}") - command_type = command_dict['type'] + command_type = command_dict["type"] match command_type: - case 'list_services': + case "list_services": self._handle_list_services(command_dict) - case 'show_service': + case "show_service": self._handle_show_service(command_dict) - case 'restart_service': + case "restart_service": self._handle_restart_service(command_dict) - case 'shutdown_service': + case "shutdown_service": self._handle_shutdown_service(command_dict) - case 'startup_service': + case "startup_service": self._handle_startup_service(command_dict) - case 'list_users': + case "list_users": self._handle_list_users(command_dict) - case 'show_user': + case "show_user": self._handle_show_user(command_dict) - case 'drop_user': + case "drop_user": self._handle_drop_user(command_dict) - case 'alter_user': + case "alter_user": self._handle_alter_user(command_dict) - case 'create_user': + case "create_user": self._handle_create_user(command_dict) - case 'activate_user': + case "activate_user": self._handle_activate_user(command_dict) - case 'list_datasets': + case "list_datasets": self._handle_list_datasets(command_dict) - case 'list_agents': + case "list_agents": self._handle_list_agents(command_dict) - case 'create_role': + case "create_role": self._create_role(command_dict) - case 'drop_role': + case "drop_role": self._drop_role(command_dict) - case 'alter_role': + case "alter_role": self._alter_role(command_dict) - case 'list_roles': + case "list_roles": self._list_roles(command_dict) - case 'show_role': + case "show_role": self._show_role(command_dict) - case 'grant_permission': + case "grant_permission": self._grant_permission(command_dict) - case 'revoke_permission': + case "revoke_permission": self._revoke_permission(command_dict) - case 'alter_user_role': + case "alter_user_role": self._alter_user_role(command_dict) - case 'show_user_permission': + case "show_user_permission": self._show_user_permission(command_dict) - case 'show_version': + case "show_version": self._show_version(command_dict) - case 'meta': + case "meta": self._handle_meta_command(command_dict) case _: print(f"Command '{command_type}' would be executed with API") @@ -588,29 +574,29 @@ def execute_command(self, parsed_command: Dict[str, Any]): def _handle_list_services(self, command): print("Listing all services") - url = f'http://{self.host}:{self.port}/api/v1/admin/services' + url = f"http://{self.host}:{self.port}/api/v1/admin/services" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: print(f"Fail to get all services, code: {res_json['code']}, message: {res_json['message']}") def _handle_show_service(self, command): - service_id: int = command['number'] + service_id: int = command["number"] print(f"Showing service: {service_id}") - url = f'http://{self.host}:{self.port}/api/v1/admin/services/{service_id}' + url = f"http://{self.host}:{self.port}/api/v1/admin/services/{service_id}" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - res_data = res_json['data'] - if 'status' in res_data and res_data['status'] == 'alive': + res_data = res_json["data"] + if "status" in res_data and res_data["status"] == "alive": print(f"Service {res_data['service_name']} is alive, ") - if isinstance(res_data['message'], str): - print(res_data['message']) + if isinstance(res_data["message"], str): + print(res_data["message"]) else: - data = self._format_service_detail_table(res_data['message']) + data = self._format_service_detail_table(res_data["message"]) self._print_table_simple(data) else: print(f"Service {res_data['service_name']} is down, {res_data['message']}") @@ -618,47 +604,47 @@ def _handle_show_service(self, command): print(f"Fail to show service, code: {res_json['code']}, message: {res_json['message']}") def _handle_restart_service(self, command): - service_id: int = command['number'] + service_id: int = command["number"] print(f"Restart service {service_id}") def _handle_shutdown_service(self, command): - service_id: int = command['number'] + service_id: int = command["number"] print(f"Shutdown service {service_id}") def _handle_startup_service(self, command): - service_id: int = command['number'] + service_id: int = command["number"] print(f"Startup service {service_id}") def _handle_list_users(self, command): print("Listing all users") - url = f'http://{self.host}:{self.port}/api/v1/admin/users' + url = f"http://{self.host}:{self.port}/api/v1/admin/users" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: print(f"Fail to get all users, code: {res_json['code']}, message: {res_json['message']}") def _handle_show_user(self, command): - username_tree: Tree = command['user_name'] + username_tree: Tree = command["user_name"] user_name: str = username_tree.children[0].strip("'\"") print(f"Showing user: {user_name}") - url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}' + url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - table_data = res_json['data'] - table_data.pop('avatar') + table_data = res_json["data"] + table_data.pop("avatar") self._print_table_simple(table_data) else: print(f"Fail to get user {user_name}, code: {res_json['code']}, message: {res_json['message']}") def _handle_drop_user(self, command): - username_tree: Tree = command['user_name'] + username_tree: Tree = command["user_name"] user_name: str = username_tree.children[0].strip("'\"") print(f"Drop user: {user_name}") - url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}' + url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}" response = self.session.delete(url) res_json = response.json() if response.status_code == 200: @@ -667,13 +653,13 @@ def _handle_drop_user(self, command): print(f"Fail to drop user, code: {res_json['code']}, message: {res_json['message']}") def _handle_alter_user(self, command): - user_name_tree: Tree = command['user_name'] + user_name_tree: Tree = command["user_name"] user_name: str = user_name_tree.children[0].strip("'\"") - password_tree: Tree = command['password'] + password_tree: Tree = command["password"] password: str = password_tree.children[0].strip("'\"") - print(f"Alter user: {user_name}, password: {password}") - url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password' - response = self.session.put(url, json={'new_password': encrypt(password)}) + print(f"Alter user: {user_name}, password: ******") + url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/password" + response = self.session.put(url, json={"new_password": encrypt(password)}) res_json = response.json() if response.status_code == 200: print(res_json["message"]) @@ -681,32 +667,29 @@ def _handle_alter_user(self, command): print(f"Fail to alter password, code: {res_json['code']}, message: {res_json['message']}") def _handle_create_user(self, command): - user_name_tree: Tree = command['user_name'] + user_name_tree: Tree = command["user_name"] user_name: str = user_name_tree.children[0].strip("'\"") - password_tree: Tree = command['password'] + password_tree: Tree = command["password"] password: str = password_tree.children[0].strip("'\"") - role: str = command['role'] - print(f"Create user: {user_name}, password: {password}, role: {role}") - url = f'http://{self.host}:{self.port}/api/v1/admin/users' - response = self.session.post( - url, - json={'user_name': user_name, 'password': encrypt(password), 'role': role} - ) + role: str = command["role"] + print(f"Create user: {user_name}, password: ******, role: {role}") + url = f"http://{self.host}:{self.port}/api/v1/admin/users" + response = self.session.post(url, json={"user_name": user_name, "password": encrypt(password), "role": role}) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: print(f"Fail to create user {user_name}, code: {res_json['code']}, message: {res_json['message']}") def _handle_activate_user(self, command): - user_name_tree: Tree = command['user_name'] + user_name_tree: Tree = command["user_name"] user_name: str = user_name_tree.children[0].strip("'\"") - activate_tree: Tree = command['activate_status'] + activate_tree: Tree = command["activate_status"] activate_status: str = activate_tree.children[0].strip("'\"") - if activate_status.lower() in ['on', 'off']: + if activate_status.lower() in ["on", "off"]: print(f"Alter user {user_name} activate status, turn {activate_status.lower()}.") - url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/activate' - response = self.session.put(url, json={'activate_status': activate_status}) + url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/activate" + response = self.session.put(url, json={"activate_status": activate_status}) res_json = response.json() if response.status_code == 200: print(res_json["message"]) @@ -716,202 +699,182 @@ def _handle_activate_user(self, command): print(f"Unknown activate status: {activate_status}.") def _handle_list_datasets(self, command): - username_tree: Tree = command['user_name'] + username_tree: Tree = command["user_name"] user_name: str = username_tree.children[0].strip("'\"") print(f"Listing all datasets of user: {user_name}") - url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/datasets' + url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/datasets" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - table_data = res_json['data'] + table_data = res_json["data"] for t in table_data: - t.pop('avatar') + t.pop("avatar") self._print_table_simple(table_data) else: print(f"Fail to get all datasets of {user_name}, code: {res_json['code']}, message: {res_json['message']}") def _handle_list_agents(self, command): - username_tree: Tree = command['user_name'] + username_tree: Tree = command["user_name"] user_name: str = username_tree.children[0].strip("'\"") print(f"Listing all agents of user: {user_name}") - url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/agents' + url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name}/agents" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - table_data = res_json['data'] + table_data = res_json["data"] for t in table_data: - t.pop('avatar') + t.pop("avatar") self._print_table_simple(table_data) else: print(f"Fail to get all agents of {user_name}, code: {res_json['code']}, message: {res_json['message']}") def _create_role(self, command): - role_name_tree: Tree = command['role_name'] + role_name_tree: Tree = command["role_name"] role_name: str = role_name_tree.children[0].strip("'\"") - desc_str: str = '' - if 'description' in command: - desc_tree: Tree = command['description'] + desc_str: str = "" + if "description" in command: + desc_tree: Tree = command["description"] desc_str = desc_tree.children[0].strip("'\"") print(f"create role name: {role_name}, description: {desc_str}") - url = f'http://{self.host}:{self.port}/api/v1/admin/roles' - response = self.session.post( - url, - json={'role_name': role_name, 'description': desc_str} - ) + url = f"http://{self.host}:{self.port}/api/v1/admin/roles" + response = self.session.post(url, json={"role_name": role_name, "description": desc_str}) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: print(f"Fail to create role {role_name}, code: {res_json['code']}, message: {res_json['message']}") def _drop_role(self, command): - role_name_tree: Tree = command['role_name'] + role_name_tree: Tree = command["role_name"] role_name: str = role_name_tree.children[0].strip("'\"") print(f"drop role name: {role_name}") - url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}' + url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}" response = self.session.delete(url) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: print(f"Fail to drop role {role_name}, code: {res_json['code']}, message: {res_json['message']}") def _alter_role(self, command): - role_name_tree: Tree = command['role_name'] + role_name_tree: Tree = command["role_name"] role_name: str = role_name_tree.children[0].strip("'\"") - desc_tree: Tree = command['description'] + desc_tree: Tree = command["description"] desc_str: str = desc_tree.children[0].strip("'\"") print(f"alter role name: {role_name}, description: {desc_str}") - url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}' - response = self.session.put( - url, - json={'description': desc_str} - ) + url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}" + response = self.session.put(url, json={"description": desc_str}) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: - print( - f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to update role {role_name} with description: {desc_str}, code: {res_json['code']}, message: {res_json['message']}") def _list_roles(self, command): print("Listing all roles") - url = f'http://{self.host}:{self.port}/api/v1/admin/roles' + url = f"http://{self.host}:{self.port}/api/v1/admin/roles" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}") def _show_role(self, command): - role_name_tree: Tree = command['role_name'] + role_name_tree: Tree = command["role_name"] role_name: str = role_name_tree.children[0].strip("'\"") print(f"show role: {role_name}") - url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}/permission' + url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name}/permission" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: print(f"Fail to list roles, code: {res_json['code']}, message: {res_json['message']}") def _grant_permission(self, command): - role_name_tree: Tree = command['role_name'] + role_name_tree: Tree = command["role_name"] role_name_str: str = role_name_tree.children[0].strip("'\"") - resource_tree: Tree = command['resource'] + resource_tree: Tree = command["resource"] resource_str: str = resource_tree.children[0].strip("'\"") - action_tree_list: list = command['actions'] + action_tree_list: list = command["actions"] actions: list = [] for action_tree in action_tree_list: action_str: str = action_tree.children[0].strip("'\"") actions.append(action_str) print(f"grant role_name: {role_name_str}, resource: {resource_str}, actions: {actions}") - url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission' - response = self.session.post( - url, - json={'actions': actions, 'resource': resource_str} - ) + url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission" + response = self.session.post(url, json={"actions": actions, "resource": resource_str}) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: - print( - f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to grant role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}") def _revoke_permission(self, command): - role_name_tree: Tree = command['role_name'] + role_name_tree: Tree = command["role_name"] role_name_str: str = role_name_tree.children[0].strip("'\"") - resource_tree: Tree = command['resource'] + resource_tree: Tree = command["resource"] resource_str: str = resource_tree.children[0].strip("'\"") - action_tree_list: list = command['actions'] + action_tree_list: list = command["actions"] actions: list = [] for action_tree in action_tree_list: action_str: str = action_tree.children[0].strip("'\"") actions.append(action_str) print(f"revoke role_name: {role_name_str}, resource: {resource_str}, actions: {actions}") - url = f'http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission' - response = self.session.delete( - url, - json={'actions': actions, 'resource': resource_str} - ) + url = f"http://{self.host}:{self.port}/api/v1/admin/roles/{role_name_str}/permission" + response = self.session.delete(url, json={"actions": actions, "resource": resource_str}) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: - print( - f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to revoke role {role_name_str} with {actions} on {resource_str}, code: {res_json['code']}, message: {res_json['message']}") def _alter_user_role(self, command): - role_name_tree: Tree = command['role_name'] + role_name_tree: Tree = command["role_name"] role_name_str: str = role_name_tree.children[0].strip("'\"") - user_name_tree: Tree = command['user_name'] + user_name_tree: Tree = command["user_name"] user_name_str: str = user_name_tree.children[0].strip("'\"") print(f"alter_user_role user_name: {user_name_str}, role_name: {role_name_str}") - url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/role' - response = self.session.put( - url, - json={'role_name': role_name_str} - ) + url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/role" + response = self.session.put(url, json={"role_name": role_name_str}) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: - print( - f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to alter user: {user_name_str} to role {role_name_str}, code: {res_json['code']}, message: {res_json['message']}") def _show_user_permission(self, command): - user_name_tree: Tree = command['user_name'] + user_name_tree: Tree = command["user_name"] user_name_str: str = user_name_tree.children[0].strip("'\"") print(f"show_user_permission user_name: {user_name_str}") - url = f'http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/permission' + url = f"http://{self.host}:{self.port}/api/v1/admin/users/{user_name_str}/permission" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: - print( - f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}") + print(f"Fail to show user: {user_name_str} permission, code: {res_json['code']}, message: {res_json['message']}") def _show_version(self, command): print("show_version") - url = f'http://{self.host}:{self.port}/api/v1/admin/version' + url = f"http://{self.host}:{self.port}/api/v1/admin/version" response = self.session.get(url) res_json = response.json() if response.status_code == 200: - self._print_table_simple(res_json['data']) + self._print_table_simple(res_json["data"]) else: print(f"Fail to show version, code: {res_json['code']}, message: {res_json['message']}") def _handle_meta_command(self, command): - meta_command = command['command'] - args = command.get('args', []) + meta_command = command["command"] + args = command.get("args", []) - if meta_command in ['?', 'h', 'help']: + if meta_command in ["?", "h", "help"]: self.show_help() - elif meta_command in ['q', 'quit', 'exit']: + elif meta_command in ["q", "quit", "exit"]: print("Goodbye!") else: print(f"Meta command '{meta_command}' with args {args}") @@ -947,17 +910,17 @@ def main(): cli = AdminCLI() args = cli.parse_connection_args(sys.argv) - if 'error' in args: - print(f"Error: {args['error']}") + if "error" in args: + print("Error: Invalid connection arguments") return - if 'command' in args: - if 'password' not in args: + if "command" in args: + if "password" not in args: print("Error: password is missing") return if cli.verify_admin(args, single_command=True): - command: str = args['command'] - print(f"Run single command: {command}") + command: str = args["command"] + # print(f"Run single command: {command}") cli.run_single_command(command) else: if cli.verify_admin(args, single_command=False): @@ -971,5 +934,5 @@ def main(): cli.cmdloop() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/admin/client/pyproject.toml b/admin/client/pyproject.toml index 6dad77a2b8a..de6bf7bc348 100644 --- a/admin/client/pyproject.toml +++ b/admin/client/pyproject.toml @@ -1,14 +1,14 @@ [project] name = "ragflow-cli" -version = "0.22.0" +version = "0.23.1" description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. " authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }] license = { text = "Apache License, Version 2.0" } readme = "README.md" -requires-python = ">=3.10,<3.13" +requires-python = ">=3.12,<3.15" dependencies = [ "requests>=2.30.0,<3.0.0", - "beartype>=0.18.5,<0.19.0", + "beartype>=0.20.0,<1.0.0", "pycryptodomex>=3.10.0", "lark>=1.1.0", ] diff --git a/admin/client/uv.lock b/admin/client/uv.lock new file mode 100644 index 00000000000..7e38b7144c0 --- /dev/null +++ b/admin/client/uv.lock @@ -0,0 +1,298 @@ +version = 1 +revision = 3 +requires-python = ">=3.10, <3.13" + +[[package]] +name = "beartype" +version = "0.22.6" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/88/e2/105ceb1704cb80fe4ab3872529ab7b6f365cf7c74f725e6132d0efcf1560/beartype-0.22.6.tar.gz", hash = "sha256:97fbda69c20b48c5780ac2ca60ce3c1bb9af29b3a1a0216898ffabdd523e48f4", size = 1588975, upload-time = "2025-11-20T04:47:14.736Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/98/c9/ceecc71fe2c9495a1d8e08d44f5f31f5bca1350d5b2e27a4b6265424f59e/beartype-0.22.6-py3-none-any.whl", hash = "sha256:0584bc46a2ea2a871509679278cda992eadde676c01356ab0ac77421f3c9a093", size = 1324807, upload-time = "2025-11-20T04:47:11.837Z" }, +] + +[[package]] +name = "certifi" +version = "2025.11.12" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a2/8c/58f469717fa48465e4a50c014a0400602d3c437d7c0c468e17ada824da3a/certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316", size = 160538, upload-time = "2025-11-12T02:54:51.517Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b", size = 159438, upload-time = "2025-11-12T02:54:49.735Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.4" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/b8/6d51fc1d52cbd52cd4ccedd5b5b2f0f6a11bbf6765c782298b0f3e808541/charset_normalizer-3.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d", size = 209709, upload-time = "2025-10-14T04:40:11.385Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/5c/af/1f9d7f7faafe2ddfb6f72a2e07a548a629c61ad510fe60f9630309908fef/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8", size = 148814, upload-time = "2025-10-14T04:40:13.135Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/79/3d/f2e3ac2bbc056ca0c204298ea4e3d9db9b4afe437812638759db2c976b5f/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad", size = 144467, upload-time = "2025-10-14T04:40:14.728Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ec/85/1bf997003815e60d57de7bd972c57dc6950446a3e4ccac43bc3070721856/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8", size = 162280, upload-time = "2025-10-14T04:40:16.14Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3e/8e/6aa1952f56b192f54921c436b87f2aaf7c7a7c3d0d1a765547d64fd83c13/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d", size = 159454, upload-time = "2025-10-14T04:40:17.567Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/36/3b/60cbd1f8e93aa25d1c669c649b7a655b0b5fb4c571858910ea9332678558/charset_normalizer-3.4.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313", size = 153609, upload-time = "2025-10-14T04:40:19.08Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/64/91/6a13396948b8fd3c4b4fd5bc74d045f5637d78c9675585e8e9fbe5636554/charset_normalizer-3.4.4-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e", size = 151849, upload-time = "2025-10-14T04:40:20.607Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b7/7a/59482e28b9981d105691e968c544cc0df3b7d6133152fb3dcdc8f135da7a/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93", size = 151586, upload-time = "2025-10-14T04:40:21.719Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/92/59/f64ef6a1c4bdd2baf892b04cd78792ed8684fbc48d4c2afe467d96b4df57/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0", size = 145290, upload-time = "2025-10-14T04:40:23.069Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6b/63/3bf9f279ddfa641ffa1962b0db6a57a9c294361cc2f5fcac997049a00e9c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84", size = 163663, upload-time = "2025-10-14T04:40:24.17Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/09/c9e38fc8fa9e0849b172b581fd9803bdf6e694041127933934184e19f8c3/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e", size = 151964, upload-time = "2025-10-14T04:40:25.368Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d2/d1/d28b747e512d0da79d8b6a1ac18b7ab2ecfd81b2944c4c710e166d8dd09c/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db", size = 161064, upload-time = "2025-10-14T04:40:26.806Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/bb/9a/31d62b611d901c3b9e5500c36aab0ff5eb442043fb3a1c254200d3d397d9/charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6", size = 155015, upload-time = "2025-10-14T04:40:28.284Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/f3/107e008fa2bff0c8b9319584174418e5e5285fef32f79d8ee6a430d0039c/charset_normalizer-3.4.4-cp310-cp310-win32.whl", hash = "sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f", size = 99792, upload-time = "2025-10-14T04:40:29.613Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/eb/66/e396e8a408843337d7315bab30dbf106c38966f1819f123257f5520f8a96/charset_normalizer-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d", size = 107198, upload-time = "2025-10-14T04:40:30.644Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b5/58/01b4f815bf0312704c267f2ccb6e5d42bcc7752340cd487bc9f8c3710597/charset_normalizer-3.4.4-cp310-cp310-win_arm64.whl", hash = "sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69", size = 100262, upload-time = "2025-10-14T04:40:32.108Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ed/27/c6491ff4954e58a10f69ad90aca8a1b6fe9c5d3c6f380907af3c37435b59/charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8", size = 206988, upload-time = "2025-10-14T04:40:33.79Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/94/59/2e87300fe67ab820b5428580a53cad894272dbb97f38a7a814a2a1ac1011/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0", size = 147324, upload-time = "2025-10-14T04:40:34.961Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/fb/0cf61dc84b2b088391830f6274cb57c82e4da8bbc2efeac8c025edb88772/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3", size = 142742, upload-time = "2025-10-14T04:40:36.105Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/62/8b/171935adf2312cd745d290ed93cf16cf0dfe320863ab7cbeeae1dcd6535f/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc", size = 160863, upload-time = "2025-10-14T04:40:37.188Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/09/73/ad875b192bda14f2173bfc1bc9a55e009808484a4b256748d931b6948442/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897", size = 157837, upload-time = "2025-10-14T04:40:38.435Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6d/fc/de9cce525b2c5b94b47c70a4b4fb19f871b24995c728e957ee68ab1671ea/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381", size = 151550, upload-time = "2025-10-14T04:40:40.053Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/55/c2/43edd615fdfba8c6f2dfbd459b25a6b3b551f24ea21981e23fb768503ce1/charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815", size = 149162, upload-time = "2025-10-14T04:40:41.163Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/86/bde4ad8b4d0e9429a4e82c1e8f5c659993a9a863ad62c7df05cf7b678d75/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0", size = 150019, upload-time = "2025-10-14T04:40:42.276Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1f/86/a151eb2af293a7e7bac3a739b81072585ce36ccfb4493039f49f1d3cae8c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161", size = 143310, upload-time = "2025-10-14T04:40:43.439Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b5/fe/43dae6144a7e07b87478fdfc4dbe9efd5defb0e7ec29f5f58a55aeef7bf7/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4", size = 162022, upload-time = "2025-10-14T04:40:44.547Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/80/e6/7aab83774f5d2bca81f42ac58d04caf44f0cc2b65fc6db2b3b2e8a05f3b3/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89", size = 149383, upload-time = "2025-10-14T04:40:46.018Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/4f/e8/b289173b4edae05c0dde07f69f8db476a0b511eac556dfe0d6bda3c43384/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569", size = 159098, upload-time = "2025-10-14T04:40:47.081Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/df/fe699727754cae3f8478493c7f45f777b17c3ef0600e28abfec8619eb49c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224", size = 152991, upload-time = "2025-10-14T04:40:48.246Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1a/86/584869fe4ddb6ffa3bd9f491b87a01568797fb9bd8933f557dba9771beaf/charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a", size = 99456, upload-time = "2025-10-14T04:40:49.376Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/65/f6/62fdd5feb60530f50f7e38b4f6a1d5203f4d16ff4f9f0952962c044e919a/charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016", size = 106978, upload-time = "2025-10-14T04:40:50.844Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/9d/0710916e6c82948b3be62d9d398cb4fcf4e97b56d6a6aeccd66c4b2f2bd5/charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1", size = 99969, upload-time = "2025-10-14T04:40:52.272Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "exceptiongroup" +version = "1.3.1" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598", size = 16740, upload-time = "2025-11-21T23:01:53.443Z" }, +] + +[[package]] +name = "idna" +version = "3.11" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + +[[package]] +name = "lark" +version = "1.3.1" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/da/34/28fff3ab31ccff1fd4f6c7c7b0ceb2b6968d8ea4950663eadcb5720591a0/lark-1.3.1.tar.gz", hash = "sha256:b426a7a6d6d53189d318f2b6236ab5d6429eaf09259f1ca33eb716eed10d2905", size = 382732, upload-time = "2025-10-27T18:25:56.653Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/82/3d/14ce75ef66813643812f3093ab17e46d3a206942ce7376d31ec2d36229e7/lark-1.3.1-py3-none-any.whl", hash = "sha256:c629b661023a014c37da873b4ff58a817398d12635d3bbb2c5a03be7fe5d1e12", size = 113151, upload-time = "2025-10-27T18:25:54.882Z" }, +] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "pycryptodomex" +version = "3.23.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/85/e24bf90972a30b0fcd16c73009add1d7d7cd9140c2498a68252028899e41/pycryptodomex-3.23.0.tar.gz", hash = "sha256:71909758f010c82bc99b0abf4ea12012c98962fbf0583c2164f8b84533c2e4da", size = 4922157, upload-time = "2025-05-17T17:23:41.434Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/dd/9c/1a8f35daa39784ed8adf93a694e7e5dc15c23c741bbda06e1d45f8979e9e/pycryptodomex-3.23.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:06698f957fe1ab229a99ba2defeeae1c09af185baa909a31a5d1f9d42b1aaed6", size = 2499240, upload-time = "2025-05-17T17:22:46.953Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7a/62/f5221a191a97157d240cf6643747558759126c76ee92f29a3f4aee3197a5/pycryptodomex-3.23.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b2c2537863eccef2d41061e82a881dcabb04944c5c06c5aa7110b577cc487545", size = 1644042, upload-time = "2025-05-17T17:22:49.098Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8c/fd/5a054543c8988d4ed7b612721d7e78a4b9bf36bc3c5ad45ef45c22d0060e/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:43c446e2ba8df8889e0e16f02211c25b4934898384c1ec1ec04d7889c0333587", size = 2186227, upload-time = "2025-05-17T17:22:51.139Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c8/a9/8862616a85cf450d2822dbd4fff1fcaba90877907a6ff5bc2672cafe42f8/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f489c4765093fb60e2edafdf223397bc716491b2b69fe74367b70d6999257a5c", size = 2272578, upload-time = "2025-05-17T17:22:53.676Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/46/9f/bda9c49a7c1842820de674ab36c79f4fbeeee03f8ff0e4f3546c3889076b/pycryptodomex-3.23.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bdc69d0d3d989a1029df0eed67cc5e8e5d968f3724f4519bd03e0ec68df7543c", size = 2312166, upload-time = "2025-05-17T17:22:56.585Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/cc/870b9bf8ca92866ca0186534801cf8d20554ad2a76ca959538041b7a7cf4/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:6bbcb1dd0f646484939e142462d9e532482bc74475cecf9c4903d4e1cd21f003", size = 2185467, upload-time = "2025-05-17T17:22:59.237Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/96/e3/ce9348236d8e669fea5dd82a90e86be48b9c341210f44e25443162aba187/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:8a4fcd42ccb04c31268d1efeecfccfd1249612b4de6374205376b8f280321744", size = 2346104, upload-time = "2025-05-17T17:23:02.112Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a5/e9/e869bcee87beb89040263c416a8a50204f7f7a83ac11897646c9e71e0daf/pycryptodomex-3.23.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:55ccbe27f049743a4caf4f4221b166560d3438d0b1e5ab929e07ae1702a4d6fd", size = 2271038, upload-time = "2025-05-17T17:23:04.872Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/8d/67/09ee8500dd22614af5fbaa51a4aee6e342b5fa8aecf0a6cb9cbf52fa6d45/pycryptodomex-3.23.0-cp37-abi3-win32.whl", hash = "sha256:189afbc87f0b9f158386bf051f720e20fa6145975f1e76369303d0f31d1a8d7c", size = 1771969, upload-time = "2025-05-17T17:23:07.115Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/69/96/11f36f71a865dd6df03716d33bd07a67e9d20f6b8d39820470b766af323c/pycryptodomex-3.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:52e5ca58c3a0b0bd5e100a9fbc8015059b05cffc6c66ce9d98b4b45e023443b9", size = 1803124, upload-time = "2025-05-17T17:23:09.267Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/93/45c1cdcbeb182ccd2e144c693eaa097763b08b38cded279f0053ed53c553/pycryptodomex-3.23.0-cp37-abi3-win_arm64.whl", hash = "sha256:02d87b80778c171445d67e23d1caef279bf4b25c3597050ccd2e13970b57fd51", size = 1707161, upload-time = "2025-05-17T17:23:11.414Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/b8/3e76d948c3c4ac71335bbe75dac53e154b40b0f8f1f022dfa295257a0c96/pycryptodomex-3.23.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ebfff755c360d674306e5891c564a274a47953562b42fb74a5c25b8fc1fb1cb5", size = 1627695, upload-time = "2025-05-17T17:23:17.38Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6a/cf/80f4297a4820dfdfd1c88cf6c4666a200f204b3488103d027b5edd9176ec/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eca54f4bb349d45afc17e3011ed4264ef1cc9e266699874cdd1349c504e64798", size = 1675772, upload-time = "2025-05-17T17:23:19.202Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/d1/42/1e969ee0ad19fe3134b0e1b856c39bd0b70d47a4d0e81c2a8b05727394c9/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f2596e643d4365e14d0879dc5aafe6355616c61c2176009270f3048f6d9a61f", size = 1668083, upload-time = "2025-05-17T17:23:21.867Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/c3/1de4f7631fea8a992a44ba632aa40e0008764c0fb9bf2854b0acf78c2cf2/pycryptodomex-3.23.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fdfac7cda115bca3a5abb2f9e43bc2fb66c2b65ab074913643803ca7083a79ea", size = 1706056, upload-time = "2025-05-17T17:23:24.031Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f2/5f/af7da8e6f1e42b52f44a24d08b8e4c726207434e2593732d39e7af5e7256/pycryptodomex-3.23.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:14c37aaece158d0ace436f76a7bb19093db3b4deade9797abfc39ec6cd6cc2fe", size = 1806478, upload-time = "2025-05-17T17:23:26.066Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pytest" +version = "9.0.1" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/07/56/f013048ac4bc4c1d9be45afd4ab209ea62822fb1598f40687e6bf45dcea4/pytest-9.0.1.tar.gz", hash = "sha256:3e9c069ea73583e255c3b21cf46b8d3c56f6e3a1a8f6da94ccb0fcf57b9d73c8", size = 1564125, upload-time = "2025-11-12T13:05:09.333Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0b/8b/6300fb80f858cda1c51ffa17075df5d846757081d11ab4aa35cef9e6258b/pytest-9.0.1-py3-none-any.whl", hash = "sha256:67be0030d194df2dfa7b556f2e56fb3c3315bd5c8822c6951162b92b32ce7dad", size = 373668, upload-time = "2025-11-12T13:05:07.379Z" }, +] + +[[package]] +name = "ragflow-cli" +version = "0.23.1" +source = { virtual = "." } +dependencies = [ + { name = "beartype" }, + { name = "lark" }, + { name = "pycryptodomex" }, + { name = "requests" }, +] + +[package.dev-dependencies] +test = [ + { name = "pytest" }, + { name = "requests" }, + { name = "requests-toolbelt" }, +] + +[package.metadata] +requires-dist = [ + { name = "beartype", specifier = ">=0.20.0,<1.0.0" }, + { name = "lark", specifier = ">=1.1.0" }, + { name = "pycryptodomex", specifier = ">=3.10.0" }, + { name = "requests", specifier = ">=2.30.0,<3.0.0" }, +] + +[package.metadata.requires-dev] +test = [ + { name = "pytest", specifier = ">=8.3.5" }, + { name = "requests", specifier = ">=2.32.3" }, + { name = "requests-toolbelt", specifier = ">=1.0.0" }, +] + +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, +] + +[[package]] +name = "requests-toolbelt" +version = "1.0.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +dependencies = [ + { name = "requests" }, +] +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f3/61/d7545dafb7ac2230c70d38d31cbfe4cc64f7144dc41f6e4e4b78ecd9f5bb/requests-toolbelt-1.0.0.tar.gz", hash = "sha256:7681a0a3d047012b5bdc0ee37d7f8f07ebe76ab08caeccfc3921ce23c88d5bc6", size = 206888, upload-time = "2023-05-01T04:11:33.229Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" }, +] + +[[package]] +name = "tomli" +version = "2.3.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "urllib3" +version = "2.5.0" +source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +wheels = [ + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, +] diff --git a/admin/server/admin_server.py b/admin/server/admin_server.py index cfc5c4bee55..b8c96a62c45 100644 --- a/admin/server/admin_server.py +++ b/admin/server/admin_server.py @@ -20,8 +20,11 @@ import time import threading import traceback -from werkzeug.serving import run_simple +import faulthandler + from flask import Flask +from flask_login import LoginManager +from werkzeug.serving import run_simple from routes import admin_bp from common.log_utils import init_root_logger from common.constants import SERVICE_CONF @@ -30,12 +33,12 @@ from config import load_configurations, SERVICE_CONFIGS from auth import init_default_admin, setup_auth from flask_session import Session -from flask_login import LoginManager from common.versions import get_ragflow_version stop_event = threading.Event() if __name__ == '__main__': + faulthandler.enable() init_root_logger("admin_service") logging.info(r""" ____ ___ ______________ ___ __ _ diff --git a/admin/server/auth.py b/admin/server/auth.py index 564c348e3f6..486b9a4fbf7 100644 --- a/admin/server/auth.py +++ b/admin/server/auth.py @@ -19,7 +19,8 @@ import uuid from functools import wraps from datetime import datetime -from flask import request, jsonify + +from flask import jsonify, request from flask_login import current_user, login_user from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer @@ -30,7 +31,7 @@ from api.utils.crypt import decrypt from common.misc_utils import get_uuid from common.time_utils import current_timestamp, datetime_format, get_format_time -from common.connection_utils import construct_response +from common.connection_utils import sync_construct_response from common import settings @@ -129,7 +130,7 @@ def login_admin(email: str, password: str): user.last_login_time = get_format_time() user.save() msg = "Welcome back!" - return construct_response(data=resp, auth=user.get_id(), message=msg) + return sync_construct_response(data=resp, auth=user.get_id(), message=msg) def check_admin(username: str, password: str): @@ -169,17 +170,17 @@ def decorated(*args, **kwargs): username = auth.parameters['username'] password = auth.parameters['password'] try: - if check_admin(username, password) is False: + if not check_admin(username, password): return jsonify({ "code": 500, "message": "Access denied", "data": None }), 200 - except Exception as e: - error_msg = str(e) + except Exception: + logging.exception("An error occurred during admin login verification.") return jsonify({ "code": 500, - "message": error_msg + "message": "An internal server error occurred." }), 200 return f(*args, **kwargs) diff --git a/admin/server/config.py b/admin/server/config.py index e2c7d11ef90..43f079d4f2b 100644 --- a/admin/server/config.py +++ b/admin/server/config.py @@ -25,8 +25,21 @@ from urllib.parse import urlparse +class BaseConfig(BaseModel): + id: int + name: str + host: str + port: int + service_type: str + detail_func_name: str + + def to_dict(self) -> dict[str, Any]: + return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, + 'service_type': self.service_type} + + class ServiceConfigs: - configs = dict + configs = list[BaseConfig] def __init__(self): self.configs = [] @@ -45,19 +58,6 @@ class ServiceType(Enum): FILE_STORE = "file_store" -class BaseConfig(BaseModel): - id: int - name: str - host: str - port: int - service_type: str - detail_func_name: str - - def to_dict(self) -> dict[str, Any]: - return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port, - 'service_type': self.service_type} - - class MetaConfig(BaseConfig): meta_type: str @@ -227,7 +227,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]: ragflow_count = 0 id_count = 0 for k, v in raw_configs.items(): - match (k): + match k: case "ragflow": name: str = f'ragflow_{ragflow_count}' host: str = v['host'] diff --git a/admin/server/responses.py b/admin/server/responses.py index 54f841a8307..c41c4512eb5 100644 --- a/admin/server/responses.py +++ b/admin/server/responses.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - - from flask import jsonify diff --git a/admin/server/routes.py b/admin/server/routes.py index 2c70fbd7af6..e83f3ff08e1 100644 --- a/admin/server/routes.py +++ b/admin/server/routes.py @@ -17,7 +17,7 @@ import secrets from flask import Blueprint, request -from flask_login import current_user, logout_user, login_required +from flask_login import current_user, login_required, logout_user from auth import login_verify, login_admin, check_admin_auth from responses import success_response, error_response @@ -29,6 +29,11 @@ admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin') +@admin_bp.route('/ping', methods=['GET']) +def ping(): + return success_response('PONG') + + @admin_bp.route('/login', methods=['POST']) def login(): if not request.json: diff --git a/admin/server/services.py b/admin/server/services.py index e8cf4eb5d44..c394dae3a65 100644 --- a/admin/server/services.py +++ b/admin/server/services.py @@ -14,7 +14,8 @@ # limitations under the License. # - +import os +import logging import re from werkzeug.security import check_password_hash from common.constants import ActiveEnum @@ -180,17 +181,22 @@ class ServiceMgr: @staticmethod def get_all_services(): + doc_engine = os.getenv('DOC_ENGINE', 'elasticsearch') result = [] configs = SERVICE_CONFIGS.configs for service_id, config in enumerate(configs): config_dict = config.to_dict() + if config_dict['service_type'] == 'retrieval': + if config_dict['extra']['retrieval_type'] != doc_engine: + continue try: service_detail = ServiceMgr.get_service_details(service_id) if "status" in service_detail: config_dict['status'] = service_detail['status'] else: config_dict['status'] = 'timeout' - except Exception: + except Exception as e: + logging.warning(f"Can't get service details, error: {e}") config_dict['status'] = 'timeout' if not config_dict['host']: config_dict['host'] = '-' @@ -205,17 +211,13 @@ def get_services_by_type(service_type_str: str): @staticmethod def get_service_details(service_id: int): - service_id = int(service_id) + service_idx = int(service_id) configs = SERVICE_CONFIGS.configs - service_config_mapping = { - c.id: { - 'name': c.name, - 'detail_func_name': c.detail_func_name - } for c in configs - } - service_info = service_config_mapping.get(service_id, {}) - if not service_info: - raise AdminException(f"invalid service_id: {service_id}") + if service_idx < 0 or service_idx >= len(configs): + raise AdminException(f"invalid service_index: {service_idx}") + + service_config = configs[service_idx] + service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name} detail_func = getattr(health_utils, service_info.get('detail_func_name')) res = detail_func() diff --git a/agent/__init__.py b/agent/__init__.py index 643f79713c8..177b91dd051 100644 --- a/agent/__init__.py +++ b/agent/__init__.py @@ -13,6 +13,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from beartype.claw import beartype_this_package -beartype_this_package() diff --git a/agent/canvas.py b/agent/canvas.py index bc7a45e3e60..6368e10e355 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -13,7 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import base64 +import inspect +import binascii import json import logging import re @@ -26,7 +29,9 @@ from agent.component import component_class from agent.component.base import ComponentBase from api.db.services.file_service import FileService +from api.db.services.llm_service import LLMBundle from api.db.services.task_service import has_canceled +from common.constants import LLMType from common.misc_utils import get_uuid, hash_str2int from common.exceptions import TaskCanceledException from rag.prompts.generator import chunks_format @@ -80,14 +85,12 @@ def __init__(self, dsl: str, tenant_id=None, task_id=None): self.dsl = json.loads(dsl) self._tenant_id = tenant_id self.task_id = task_id if task_id else get_uuid() + self._thread_pool = ThreadPoolExecutor(max_workers=5) self.load() def load(self): self.components = self.dsl["components"] cpn_nms = set([]) - for k, cpn in self.components.items(): - cpn_nms.add(cpn["obj"]["component_name"]) - for k, cpn in self.components.items(): cpn_nms.add(cpn["obj"]["component_name"]) param = component_class(cpn["obj"]["component_name"] + "Param")() @@ -157,7 +160,7 @@ def get_tenant_id(self): return self._tenant_id def get_value_with_variable(self,value: str) -> Any: - pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*") + pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*") out_parts = [] last = 0 @@ -207,17 +210,60 @@ def get_variable_param_value(self, obj: Any, path: str) -> Any: for key in path.split('.'): if cur is None: return None + if isinstance(cur, str): try: cur = json.loads(cur) except Exception: return None + if isinstance(cur, dict): cur = cur.get(key) - else: - cur = getattr(cur, key, None) + continue + + if isinstance(cur, (list, tuple)): + try: + idx = int(key) + cur = cur[idx] + except Exception: + return None + continue + + cur = getattr(cur, key, None) return cur + def set_variable_value(self, exp: str,value): + exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}") + if exp.find("@") < 0: + self.globals[exp] = value + return + cpn_id, var_nm = exp.split("@") + cpn = self.get_component(cpn_id) + if not cpn: + raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'") + parts = var_nm.split(".", 1) + root_key = parts[0] + rest = parts[1] if len(parts) > 1 else "" + if not rest: + cpn["obj"].set_output(root_key, value) + return + root_val = cpn["obj"].output(root_key) + if not root_val: + root_val = {} + cpn["obj"].set_output(root_key, self.set_variable_param_value(root_val,rest,value)) + + def set_variable_param_value(self, obj: Any, path: str, value) -> Any: + cur = obj + keys = path.split('.') + if not path: + return value + for key in keys: + if key not in cur or not isinstance(cur[key], dict): + cur[key] = {} + cur = cur[key] + cur[keys[-1]] = value + return obj + def is_canceled(self) -> bool: return has_canceled(self.task_id) @@ -232,14 +278,16 @@ def cancel_task(self) -> bool: class Canvas(Graph): - def __init__(self, dsl: str, tenant_id=None, task_id=None): + def __init__(self, dsl: str, tenant_id=None, task_id=None, canvas_id=None): self.globals = { "sys.query": "", "sys.user_id": tenant_id, "sys.conversation_turns": 0, "sys.files": [] } + self.variables = {} super().__init__(dsl, tenant_id, task_id) + self._id = canvas_id def load(self): super().load() @@ -253,6 +301,10 @@ def load(self): "sys.conversation_turns": 0, "sys.files": [] } + if "variables" in self.dsl: + self.variables = self.dsl["variables"] + else: + self.variables = {} self.retrieval = self.dsl["retrieval"] self.memory = self.dsl.get("memory", []) @@ -269,6 +321,7 @@ def reset(self, mem=False): self.history = [] self.retrieval = [] self.memory = [] + print(self.variables) for k in self.globals.keys(): if k.startswith("sys."): if isinstance(self.globals[k], str): @@ -283,9 +336,31 @@ def reset(self, mem=False): self.globals[k] = {} else: self.globals[k] = None + if k.startswith("env."): + key = k[4:] + if key in self.variables: + variable = self.variables[key] + if variable["value"]: + self.globals[k] = variable["value"] + else: + if variable["type"] == "string": + self.globals[k] = "" + elif variable["type"] == "number": + self.globals[k] = 0 + elif variable["type"] == "boolean": + self.globals[k] = False + elif variable["type"] == "object": + self.globals[k] = {} + elif variable["type"].startswith("array"): + self.globals[k] = [] + else: + self.globals[k] = "" + else: + self.globals[k] = "" - def run(self, **kwargs): + async def run(self, **kwargs): st = time.perf_counter() + self._loop = asyncio.get_running_loop() self.message_id = get_uuid() created_at = int(time.time()) self.add_user_input(kwargs.get("query")) @@ -294,16 +369,19 @@ def run(self, **kwargs): if kwargs.get("webhook_payload"): for k, cpn in self.components.items(): - if self.components[k]["obj"].component_name.lower() == "webhook": - for kk, vv in kwargs["webhook_payload"].items(): + if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook": + payload = kwargs.get("webhook_payload", {}) + if "input" in payload: + self.components[k]["obj"].set_input_value("request", payload["input"]) + for kk, vv in payload.items(): + if kk == "input": + continue self.components[k]["obj"].set_output(kk, vv) - self.components[k]["obj"].reset(True) - for k in kwargs.keys(): if k in ["query", "user_id", "files"] and kwargs[k]: if k == "files": - self.globals[f"sys.{k}"] = self.get_files(kwargs[k]) + self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k]) else: self.globals[f"sys.{k}"] = kwargs[k] if not self.globals["sys.conversation_turns"] : @@ -333,31 +411,50 @@ def decorate(event, dt): yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) self.retrieval.append({"chunks": {}, "doc_aggs": {}}) - def _run_batch(f, t): + async def _run_batch(f, t): if self.is_canceled(): msg = f"Task {self.task_id} has been canceled during batch execution." logging.info(msg) raise TaskCanceledException(msg) - with ThreadPoolExecutor(max_workers=5) as executor: - thr = [] - i = f - while i < t: - cpn = self.get_component_obj(self.path[i]) - if cpn.component_name.lower() in ["begin", "userfillup"]: - thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {}))) - i += 1 + loop = asyncio.get_running_loop() + tasks = [] + + def _run_async_in_thread(coro_func, **call_kwargs): + return asyncio.run(coro_func(**call_kwargs)) + + i = f + while i < t: + cpn = self.get_component_obj(self.path[i]) + task_fn = None + call_kwargs = None + + if cpn.component_name.lower() in ["begin", "userfillup"]: + call_kwargs = {"inputs": kwargs.get("inputs", {})} + task_fn = cpn.invoke + i += 1 + else: + for _, ele in cpn.get_input_elements().items(): + if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0: + self.path.pop(i) + t -= 1 + break else: - for _, ele in cpn.get_input_elements().items(): - if isinstance(ele, dict) and ele.get("_cpn_id") and ele.get("_cpn_id") not in self.path[:i] and self.path[0].lower().find("userfillup") < 0: - self.path.pop(i) - t -= 1 - break - else: - thr.append(executor.submit(cpn.invoke, **cpn.get_input())) - i += 1 - for t in thr: - t.result() + call_kwargs = cpn.get_input() + task_fn = cpn.invoke + i += 1 + + if task_fn is None: + continue + + invoke_async = getattr(cpn, "invoke_async", None) + if invoke_async and asyncio.iscoroutinefunction(invoke_async): + tasks.append(loop.run_in_executor(self._thread_pool, partial(_run_async_in_thread, invoke_async, **(call_kwargs or {})))) + else: + tasks.append(loop.run_in_executor(self._thread_pool, partial(task_fn, **(call_kwargs or {})))) + + if tasks: + await asyncio.gather(*tasks) def _node_finished(cpn_obj): return decorate("node_finished",{ @@ -374,6 +471,7 @@ def _node_finished(cpn_obj): self.error = "" idx = len(self.path) - 1 partials = [] + tts_mdl = None while idx < len(self.path): to = len(self.path) for i in range(idx, to): @@ -384,31 +482,72 @@ def _node_finished(cpn_obj): "component_type": self.get_component_type(self.path[i]), "thoughts": self.get_component_thoughts(self.path[i]) }) - _run_batch(idx, to) + await _run_batch(idx, to) to = len(self.path) - # post processing of components invocation + # post-processing of components invocation for i in range(idx, to): cpn = self.get_component(self.path[i]) cpn_obj = self.get_component_obj(self.path[i]) if cpn_obj.component_name.lower() == "message": + if cpn_obj.get_param("auto_play"): + tts_mdl = LLMBundle(self._tenant_id, LLMType.TTS) if isinstance(cpn_obj.output("content"), partial): _m = "" - for m in cpn_obj.output("content")(): + buff_m = "" + stream = cpn_obj.output("content")() + async def _process_stream(m): + nonlocal buff_m, _m, tts_mdl if not m: - continue + return if m == "": - yield decorate("message", {"content": "", "start_to_think": True}) + return decorate("message", {"content": "", "start_to_think": True}) + elif m == "": - yield decorate("message", {"content": "", "end_to_think": True}) - else: - yield decorate("message", {"content": m}) - _m += m + return decorate("message", {"content": "", "end_to_think": True}) + + buff_m += m + _m += m + + if len(buff_m) > 16: + ev = decorate( + "message", + { + "content": m, + "audio_binary": self.tts(tts_mdl, buff_m) + } + ) + buff_m = "" + return ev + + return decorate("message", {"content": m}) + + if inspect.isasyncgen(stream): + async for m in stream: + ev= await _process_stream(m) + if ev: + yield ev + else: + for m in stream: + ev= await _process_stream(m) + if ev: + yield ev + if buff_m: + yield decorate("message", {"content": "", "audio_binary": self.tts(tts_mdl, buff_m)}) + buff_m = "" cpn_obj.set_output("content", _m) cite = re.search(r"\[ID:[ 0-9]+\]", _m) else: yield decorate("message", {"content": cpn_obj.output("content")}) cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content")) - yield decorate("message_end", {"reference": self.get_reference() if cite else None}) + + message_end = {} + if cpn_obj.get_param("status"): + message_end["status"] = cpn_obj.get_param("status") + if isinstance(cpn_obj.output("attachment"), dict): + message_end["attachment"] = cpn_obj.output("attachment") + if cite: + message_end["reference"] = self.get_reference() + yield decorate("message_end", message_end) while partials: _cpn_obj = self.get_component_obj(partials[0]) @@ -429,7 +568,7 @@ def _node_finished(cpn_obj): else: self.error = cpn_obj.error() - if cpn_obj.component_name.lower() != "iteration": + if cpn_obj.component_name.lower() not in ("iteration","loop"): if isinstance(cpn_obj.output("content"), partial): if self.error: cpn_obj.set_output("content", None) @@ -454,14 +593,16 @@ def _extend_path(cpn_ids): for cpn_id in cpn_ids: _append_path(cpn_id) - if cpn_obj.component_name.lower() == "iterationitem" and cpn_obj.end(): + if cpn_obj.component_name.lower() in ("iterationitem","loopitem") and cpn_obj.end(): iter = cpn_obj.get_parent() yield _node_finished(iter) _extend_path(self.get_component(cpn["parent_id"])["downstream"]) elif cpn_obj.component_name.lower() in ["categorize", "switch"]: _extend_path(cpn_obj.output("_next")) - elif cpn_obj.component_name.lower() == "iteration": + elif cpn_obj.component_name.lower() in ("iteration", "loop"): _append_path(cpn_obj.get_start()) + elif cpn_obj.component_name.lower() == "exitloop" and cpn_obj.get_parent().component_name.lower() == "loop": + _extend_path(self.get_component(cpn["parent_id"])["downstream"]) elif not cpn["downstream"] and cpn_obj.get_parent(): _append_path(cpn_obj.get_parent().get_start()) else: @@ -517,6 +658,50 @@ def is_reff(self, exp: str) -> bool: return False return True + + def tts(self,tts_mdl, text): + def clean_tts_text(text: str) -> str: + if not text: + return "" + + text = text.encode("utf-8", "ignore").decode("utf-8", "ignore") + + text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) + + emoji_pattern = re.compile( + "[\U0001F600-\U0001F64F" + "\U0001F300-\U0001F5FF" + "\U0001F680-\U0001F6FF" + "\U0001F1E0-\U0001F1FF" + "\U00002700-\U000027BF" + "\U0001F900-\U0001F9FF" + "\U0001FA70-\U0001FAFF" + "\U0001FAD0-\U0001FAFF]+", + flags=re.UNICODE + ) + text = emoji_pattern.sub("", text) + + text = re.sub(r"\s+", " ", text).strip() + + MAX_LEN = 500 + if len(text) > MAX_LEN: + text = text[:MAX_LEN] + + return text + if not tts_mdl or not text: + return None + text = clean_tts_text(text) + if not text: + return None + bin = b"" + try: + for chunk in tts_mdl.tts(text): + bin += chunk + except Exception as e: + logging.error(f"TTS failed: {e}, text={text!r}") + return None + return binascii.hexlify(bin).decode("utf-8") + def get_history(self, window_size): convs = [] if window_size <= 0: @@ -537,6 +722,9 @@ def get_prologue(self): def get_mode(self): return self.components["begin"]["obj"]._param.mode + def get_sys_query(self): + return self.globals.get("sys.query", "") + def set_global_param(self, **kwargs): self.globals.update(kwargs) @@ -546,20 +734,30 @@ def get_preset_param(self): def get_component_input_elements(self, cpnnm): return self.components[cpnnm]["obj"].get_input_elements() - def get_files(self, files: Union[None, list[dict]]) -> list[str]: + async def get_files_async(self, files: Union[None, list[dict]]) -> list[str]: if not files: return [] def image_to_base64(file): return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) - exe = ThreadPoolExecutor(max_workers=5) - threads = [] + loop = asyncio.get_running_loop() + tasks = [] for file in files: if file["mime_type"].find("image") >=0: - threads.append(exe.submit(image_to_base64, file)) + tasks.append(loop.run_in_executor(self._thread_pool, image_to_base64, file)) continue - threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) - return [th.result() for th in threads] + tasks.append(loop.run_in_executor(self._thread_pool, FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) + return await asyncio.gather(*tasks) + + def get_files(self, files: Union[None, list[dict]]) -> list[str]: + """ + Synchronous wrapper for get_files_async, used by sync component invoke paths. + """ + loop = getattr(self, "_loop", None) + if loop and loop.is_running(): + return asyncio.run_coroutine_threadsafe(self.get_files_async(files), loop).result() + + return asyncio.run(self.get_files_async(files)) def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any, elapsed_time=None): agent_ids = agent_id.split("-->") @@ -613,4 +811,3 @@ def get_memory(self) -> list[Tuple]: def get_component_thoughts(self, cpn_id) -> str: return self.components.get(cpn_id)["obj"].thoughts() - diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 98dfbc92fe8..5ff55adf93e 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +import json import logging import os import re -from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from functools import partial from typing import Any @@ -28,9 +29,9 @@ from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.mcp_server_service import MCPServerService from common.connection_utils import timeout -from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \ - citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in -from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool +from rag.prompts.generator import next_step_async, COMPLETE_TASK, \ + citation_prompt, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt +from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool from agent.component.llm import LLMParam, LLM @@ -83,9 +84,11 @@ class Agent(LLM, ToolBase): def __init__(self, canvas, id, param: LLMParam): LLM.__init__(self, canvas, id, param) self.tools = {} - for cpn in self._param.tools: + for idx, cpn in enumerate(self._param.tools): cpn = self._load_tool_obj(cpn) - self.tools[cpn.get_meta()["function"]["name"]] = cpn + original_name = cpn.get_meta()["function"]["name"] + indexed_name = f"{original_name}_{idx}" + self.tools[indexed_name] = cpn self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id, max_retries=self._param.max_retries, @@ -93,7 +96,12 @@ def __init__(self, canvas, id, param: LLMParam): max_rounds=self._param.max_rounds, verbose_tool_use=True ) - self.tool_meta = [v.get_meta() for _,v in self.tools.items()] + self.tool_meta = [] + for indexed_name, tool_obj in self.tools.items(): + original_meta = tool_obj.get_meta() + indexed_meta = deepcopy(original_meta) + indexed_meta["function"]["name"] = indexed_name + self.tool_meta.append(indexed_meta) for mcp in self._param.mcp: _, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"]) @@ -107,7 +115,8 @@ def __init__(self, canvas, id, param: LLMParam): def _load_tool_obj(self, cpn: dict) -> object: from agent.component import component_class - param = component_class(cpn["component_name"] + "Param")() + tool_name = cpn["component_name"] + param = component_class(tool_name + "Param")() param.update(cpn["params"]) try: param.check() @@ -137,8 +146,34 @@ def get_input_form(self) -> dict[str, dict]: res.update(cpn.get_input_form()) return res - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))) + def _get_output_schema(self): + try: + cand = self._param.outputs.get("structured") + except Exception: + return None + + if isinstance(cand, dict): + if isinstance(cand.get("properties"), dict) and len(cand["properties"]) > 0: + return cand + for k in ("schema", "structured"): + if isinstance(cand.get(k), dict) and isinstance(cand[k].get("properties"), dict) and len(cand[k]["properties"]) > 0: + return cand[k] + + return None + + async def _force_format_to_schema_async(self, text: str, schema_prompt: str) -> str: + fmt_msgs = [ + {"role": "system", "content": schema_prompt + "\nIMPORTANT: Output ONLY valid JSON. No markdown, no extra text."}, + {"role": "user", "content": text}, + ] + _, fmt_msgs = message_fit_in(fmt_msgs, int(self.chat_mdl.max_length * 0.97)) + return await self._generate_async(fmt_msgs) + def _invoke(self, **kwargs): + return asyncio.run(self._invoke_async(**kwargs)) + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))) + async def _invoke_async(self, **kwargs): if self.check_if_canceled("Agent processing"): return @@ -157,25 +192,25 @@ def _invoke(self, **kwargs): if not self.tools: if self.check_if_canceled("Agent processing"): return - return LLM._invoke(self, **kwargs) + return await LLM._invoke_async(self, **kwargs) prompt, msg, user_defined_prompt = self._prepare_prompt_variables() + output_schema = self._get_output_schema() + schema_prompt = "" + if output_schema: + schema = json.dumps(output_schema, ensure_ascii=False, indent=2) + schema_prompt = structured_output_prompt(schema) downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] ex = self.exception_handler() - output_structure=None - try: - output_structure=self._param.outputs['structured'] - except Exception: - pass - if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]): - self.set_output("content", partial(self.stream_output_with_tools, prompt, msg, user_defined_prompt)) + if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not (ex and ex["goto"]) and not output_schema: + self.set_output("content", partial(self.stream_output_with_tools_async, prompt, deepcopy(msg), user_defined_prompt)) return _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) use_tools = [] ans = "" - for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): + async for delta_ans, _tk in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt): if self.check_if_canceled("Agent processing"): return ans += delta_ans @@ -188,16 +223,38 @@ def _invoke(self, **kwargs): self.set_output("_ERROR", ans) return + if output_schema: + error = "" + for _ in range(self._param.max_retries + 1): + try: + def clean_formated_answer(ans: str) -> str: + ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) + ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL) + return re.sub(r"```\n*$", "", ans, flags=re.DOTALL) + obj = json_repair.loads(clean_formated_answer(ans)) + self.set_output("structured", obj) + if use_tools: + self.set_output("use_tools", use_tools) + return obj + except Exception: + error = "The answer cannot be parsed as JSON" + ans = await self._force_format_to_schema_async(ans, schema_prompt) + if ans.find("**ERROR**") >= 0: + continue + + self.set_output("_ERROR", error) + return + self.set_output("content", ans) if use_tools: self.set_output("use_tools", use_tools) return ans - def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}): + async def stream_output_with_tools_async(self, prompt, msg, user_defined_prompt={}): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) answer_without_toolcall = "" use_tools = [] - for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): + async for delta_ans, _ in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt): if self.check_if_canceled("Agent streaming"): return @@ -215,55 +272,58 @@ def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}): if use_tools: self.set_output("use_tools", use_tools) - def _gen_citations(self, text): - retrievals = self._canvas.get_reference() - retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} - formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True) - for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, - {"role": "user", "content": text} - ]): - yield delta_ans - - def _react_with_tools_streamly(self, prompt, history: list[dict], use_tools, user_defined_prompt={}): + async def _react_with_tools_streamly_async_simple(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""): token_count = 0 tool_metas = self.tool_meta hist = deepcopy(history) last_calling = "" if len(hist) > 3: st = timer() - user_request = full_question(messages=history, chat_mdl=self.chat_mdl) + user_request = await full_question(messages=history, chat_mdl=self.chat_mdl) self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) else: user_request = history[-1]["content"] - def use_tool(name, args): - nonlocal hist, use_tools, token_count,last_calling,user_request + def build_task_desc(prompt: str, user_request: str, user_defined_prompt: dict | None = None) -> str: + """Build a minimal task_desc by concatenating prompt, query, and tool schemas.""" + user_defined_prompt = user_defined_prompt or {} + + task_desc = ( + "### Agent Prompt\n" + f"{prompt}\n\n" + "### User Request\n" + f"{user_request}\n\n" + ) + + if user_defined_prompt: + udp_json = json.dumps(user_defined_prompt, ensure_ascii=False, indent=2) + task_desc += "\n### User Defined Prompts\n" + udp_json + "\n" + + return task_desc + + + async def use_tool_async(name, args): + nonlocal hist, use_tools, last_calling logging.info(f"{last_calling=} == {name=}") - # Summarize of function calling - #if all([ - # isinstance(self.toolcall_session.get_tool_obj(name), Agent), - # last_calling, - # last_calling != name - #]): - # self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"]),user_defined_prompt)) last_calling = name - tool_response = self.toolcall_session.tool_call(name, args) + tool_response = await self.toolcall_session.tool_call_async(name, args) use_tools.append({ "name": name, "arguments": args, "results": tool_response }) - # self.callback("add_memory", {}, "...") - #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt) - return name, tool_response - def complete(): + async def complete(): nonlocal hist need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 + if schema_prompt: + need2cite = False cited = False - if hist[0]["role"] == "system" and need2cite: - if len(hist) < 7: + if hist and hist[0]["role"] == "system": + if schema_prompt: + hist[0]["content"] += "\n" + schema_prompt + if need2cite and len(hist) < 7: hist[0]["content"] += citation_prompt() cited = True yield "", token_count @@ -272,7 +332,7 @@ def complete(): if len(hist) > 12: _hist = [hist[0], hist[1], *hist[-10:]] entire_txt = "" - for delta_ans in self._generate_streamly(_hist): + async for delta_ans in self._generate_streamly(_hist): if not need2cite or cited: yield delta_ans, 0 entire_txt += delta_ans @@ -281,7 +341,7 @@ def complete(): st = timer() txt = "" - for delta_ans in self._gen_citations(entire_txt): + async for delta_ans in self._gen_citations_async(entire_txt): if self.check_if_canceled("Agent streaming"): return yield delta_ans, 0 @@ -289,6 +349,21 @@ def complete(): self.callback("gen_citations", {}, txt, elapsed_time=timer()-st) + def build_observation(tool_call_res: list[tuple]) -> str: + """ + Build a Observation from tool call results. + No LLM involved. + """ + if not tool_call_res: + return "" + + lines = ["Observation:"] + for name, result in tool_call_res: + lines.append(f"[{name} result]") + lines.append(str(result)) + + return "\n".join(lines) + def append_user_content(hist, content): if hist[-1]["role"] == "user": hist[-1]["content"] += content @@ -296,14 +371,14 @@ def append_user_content(hist, content): hist.append({"role": "user", "content": content}) st = timer() - task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt) + task_desc = build_task_desc(prompt, user_request, user_defined_prompt) self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st) for _ in range(self._param.max_rounds + 1): if self.check_if_canceled("Agent streaming"): return - response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt) + response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt) # self.callback("next_step", {}, str(response)[:256]+"...") - token_count += tk + token_count += tk or 0 hist.append({"role": "assistant", "content": response}) try: functions = json_repair.loads(re.sub(r"```.*", "", response)) @@ -312,23 +387,24 @@ def append_user_content(hist, content): for f in functions: if not isinstance(f, dict): raise TypeError(f"An object type should be returned, but `{f}`") - with ThreadPoolExecutor(max_workers=5) as executor: - thr = [] - for func in functions: - name = func["name"] - args = func["arguments"] - if name == COMPLETE_TASK: - append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n") - for txt, tkcnt in complete(): - yield txt, tkcnt - return - - thr.append(executor.submit(use_tool, name, args)) - - st = timer() - reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr], user_defined_prompt) - append_user_content(hist, reflection) - self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st) + + tool_tasks = [] + for func in functions: + name = func["name"] + args = func["arguments"] + if name == COMPLETE_TASK: + append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n") + async for txt, tkcnt in complete(): + yield txt, tkcnt + return + + tool_tasks.append(asyncio.create_task(use_tool_async(name, args))) + + results = await asyncio.gather(*tool_tasks) if tool_tasks else [] + st = timer() + reflection = build_observation(results) + append_user_content(hist, reflection) + self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st) except Exception as e: logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}") @@ -352,27 +428,159 @@ def append_user_content(hist, content): return append_user_content(hist, final_instruction) - for txt, tkcnt in complete(): + async for txt, tkcnt in complete(): yield txt, tkcnt - def get_useful_memory(self, goal: str, sub_goal:str, topn=3, user_defined_prompt:dict={}) -> str: - # self.callback("get_useful_memory", {"topn": 3}, "...") - mems = self._canvas.get_memory() - rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems], user_defined_prompt) - try: - rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn] - mems = [mems[r] for r in rank] - return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems]) - except Exception as e: - logging.exception(e) - - return "Error occurred." +# async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""): +# token_count = 0 +# tool_metas = self.tool_meta +# hist = deepcopy(history) +# last_calling = "" +# if len(hist) > 3: +# st = timer() +# user_request = await full_question(messages=history, chat_mdl=self.chat_mdl) +# self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) +# else: +# user_request = history[-1]["content"] + +# async def use_tool_async(name, args): +# nonlocal hist, use_tools, last_calling +# logging.info(f"{last_calling=} == {name=}") +# last_calling = name +# tool_response = await self.toolcall_session.tool_call_async(name, args) +# use_tools.append({ +# "name": name, +# "arguments": args, +# "results": tool_response +# }) +# # self.callback("add_memory", {}, "...") +# #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt) + +# return name, tool_response + +# async def complete(): +# nonlocal hist +# need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0 +# if schema_prompt: +# need2cite = False +# cited = False +# if hist and hist[0]["role"] == "system": +# if schema_prompt: +# hist[0]["content"] += "\n" + schema_prompt +# if need2cite and len(hist) < 7: +# hist[0]["content"] += citation_prompt() +# cited = True +# yield "", token_count + +# _hist = hist +# if len(hist) > 12: +# _hist = [hist[0], hist[1], *hist[-10:]] +# entire_txt = "" +# async for delta_ans in self._generate_streamly(_hist): +# if not need2cite or cited: +# yield delta_ans, 0 +# entire_txt += delta_ans +# if not need2cite or cited: +# return + +# st = timer() +# txt = "" +# async for delta_ans in self._gen_citations_async(entire_txt): +# if self.check_if_canceled("Agent streaming"): +# return +# yield delta_ans, 0 +# txt += delta_ans + +# self.callback("gen_citations", {}, txt, elapsed_time=timer()-st) + +# def append_user_content(hist, content): +# if hist[-1]["role"] == "user": +# hist[-1]["content"] += content +# else: +# hist.append({"role": "user", "content": content}) + +# st = timer() +# task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt) +# self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st) +# for _ in range(self._param.max_rounds + 1): +# if self.check_if_canceled("Agent streaming"): +# return +# response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt) +# # self.callback("next_step", {}, str(response)[:256]+"...") +# token_count += tk or 0 +# hist.append({"role": "assistant", "content": response}) +# try: +# functions = json_repair.loads(re.sub(r"```.*", "", response)) +# if not isinstance(functions, list): +# raise TypeError(f"List should be returned, but `{functions}`") +# for f in functions: +# if not isinstance(f, dict): +# raise TypeError(f"An object type should be returned, but `{f}`") + +# tool_tasks = [] +# for func in functions: +# name = func["name"] +# args = func["arguments"] +# if name == COMPLETE_TASK: +# append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n") +# async for txt, tkcnt in complete(): +# yield txt, tkcnt +# return + +# tool_tasks.append(asyncio.create_task(use_tool_async(name, args))) + +# results = await asyncio.gather(*tool_tasks) if tool_tasks else [] +# st = timer() +# reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt) +# append_user_content(hist, reflection) +# self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st) + +# except Exception as e: +# logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}") +# e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}" +# append_user_content(hist, str(e)) + +# logging.warning( f"Exceed max rounds: {self._param.max_rounds}") +# final_instruction = f""" +# {user_request} +# IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request. +# Instructions: +# 1. SYNTHESIZE all information collected during this conversation +# 2. Provide a COMPLETE response using existing data - do not suggest additional research +# 3. Structure your response as a FINAL DELIVERABLE, not a plan +# 4. If information is incomplete, state what you found and provide the best analysis possible with available data +# 5. DO NOT mention conversation limits or suggest further steps +# 6. Focus on delivering VALUE with the information already gathered +# Respond immediately with your final comprehensive answer. +# """ +# if self.check_if_canceled("Agent final instruction"): +# return +# append_user_content(hist, final_instruction) + +# async for txt, tkcnt in complete(): +# yield txt, tkcnt + + async def _gen_citations_async(self, text): + retrievals = self._canvas.get_reference() + retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} + formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True) + async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, + {"role": "user", "content": text} + ]): + yield delta_ans - def reset(self, temp=False): + def reset(self, only_output=False): """ Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession. """ + for k in self._param.outputs.keys(): + self._param.outputs[k]["value"] = None + for k, cpn in self.tools.items(): if hasattr(cpn, "reset") and callable(cpn.reset): cpn.reset() - + if only_output: + return + for k in self._param.inputs.keys(): + self._param.inputs[k]["value"] = None + self._param.debug_inputs = {} diff --git a/agent/component/base.py b/agent/component/base.py index 31ad46820b7..264f3972a34 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -14,6 +14,7 @@ # limitations under the License. # +import asyncio import re import time from abc import ABC @@ -23,11 +24,9 @@ import logging from typing import Any, List, Union import pandas as pd -import trio from agent import settings from common.connection_utils import timeout - _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params" _USER_FEEDED_PARAMS = "_user_feeded_params" @@ -97,7 +96,7 @@ def as_dict(self): def _recursive_convert_obj_to_dict(obj): ret_dict = {} if isinstance(obj, dict): - for k,v in obj.items(): + for k, v in obj.items(): if isinstance(v, dict) or (v and type(v).__name__ not in dir(builtins)): ret_dict[k] = _recursive_convert_obj_to_dict(v) else: @@ -253,96 +252,65 @@ def _validate_param(self, param_obj, validation_json): self._validate_param(attr, validation_json) @staticmethod - def check_string(param, descr): + def check_string(param, description): if type(param).__name__ not in ["str"]: - raise ValueError( - descr + " {} not supported, should be string type".format(param) - ) + raise ValueError(description + " {} not supported, should be string type".format(param)) @staticmethod - def check_empty(param, descr): + def check_empty(param, description): if not param: - raise ValueError( - descr + " does not support empty value." - ) + raise ValueError(description + " does not support empty value.") @staticmethod - def check_positive_integer(param, descr): + def check_positive_integer(param, description): if type(param).__name__ not in ["int", "long"] or param <= 0: - raise ValueError( - descr + " {} not supported, should be positive integer".format(param) - ) + raise ValueError(description + " {} not supported, should be positive integer".format(param)) @staticmethod - def check_positive_number(param, descr): + def check_positive_number(param, description): if type(param).__name__ not in ["float", "int", "long"] or param <= 0: - raise ValueError( - descr + " {} not supported, should be positive numeric".format(param) - ) + raise ValueError(description + " {} not supported, should be positive numeric".format(param)) @staticmethod - def check_nonnegative_number(param, descr): + def check_nonnegative_number(param, description): if type(param).__name__ not in ["float", "int", "long"] or param < 0: - raise ValueError( - descr - + " {} not supported, should be non-negative numeric".format(param) - ) + raise ValueError(description + " {} not supported, should be non-negative numeric".format(param)) @staticmethod - def check_decimal_float(param, descr): + def check_decimal_float(param, description): if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1: - raise ValueError( - descr - + " {} not supported, should be a float number in range [0, 1]".format( - param - ) - ) + raise ValueError(description + " {} not supported, should be a float number in range [0, 1]".format(param)) @staticmethod - def check_boolean(param, descr): + def check_boolean(param, description): if type(param).__name__ != "bool": - raise ValueError( - descr + " {} not supported, should be bool type".format(param) - ) + raise ValueError(description + " {} not supported, should be bool type".format(param)) @staticmethod - def check_open_unit_interval(param, descr): + def check_open_unit_interval(param, description): if type(param).__name__ not in ["float"] or param <= 0 or param >= 1: - raise ValueError( - descr + " should be a numeric number between 0 and 1 exclusively" - ) + raise ValueError(description + " should be a numeric number between 0 and 1 exclusively") @staticmethod - def check_valid_value(param, descr, valid_values): + def check_valid_value(param, description, valid_values): if param not in valid_values: - raise ValueError( - descr - + " {} is not supported, it should be in {}".format(param, valid_values) - ) + raise ValueError(description + " {} is not supported, it should be in {}".format(param, valid_values)) @staticmethod - def check_defined_type(param, descr, types): + def check_defined_type(param, description, types): if type(param).__name__ not in types: - raise ValueError( - descr + " {} not supported, should be one of {}".format(param, types) - ) + raise ValueError(description + " {} not supported, should be one of {}".format(param, types)) @staticmethod - def check_and_change_lower(param, valid_list, descr=""): + def check_and_change_lower(param, valid_list, description=""): if type(param).__name__ != "str": - raise ValueError( - descr - + " {} not supported, should be one of {}".format(param, valid_list) - ) + raise ValueError(description + " {} not supported, should be one of {}".format(param, valid_list)) lower_param = param.lower() if lower_param in valid_list: return lower_param else: - raise ValueError( - descr - + " {} not supported, should be one of {}".format(param, valid_list) - ) + raise ValueError(description + " {} not supported, should be one of {}".format(param, valid_list)) @staticmethod def _greater_equal_than(value, limit): @@ -374,16 +342,16 @@ def _in(value, right_value_list): def _not_in(value, wrong_value_list): return value not in wrong_value_list - def _warn_deprecated_param(self, param_name, descr): + def _warn_deprecated_param(self, param_name, description): if self._deprecated_params_set.get(param_name): logging.warning( - f"{descr} {param_name} is deprecated and ignored in this version." + f"{description} {param_name} is deprecated and ignored in this version." ) - def _warn_to_deprecate_param(self, param_name, descr, new_param): + def _warn_to_deprecate_param(self, param_name, description, new_param): if self._deprecated_params_set.get(param_name): logging.warning( - f"{descr} {param_name} will be deprecated in future release; " + f"{description} {param_name} will be deprecated in future release; " f"please use {new_param} instead." ) return True @@ -392,8 +360,8 @@ def _warn_to_deprecate_param(self, param_name, descr, new_param): class ComponentBase(ABC): component_name: str - thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10))) - variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*" + thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) + variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*" def __str__(self): """ @@ -407,7 +375,7 @@ def __str__(self): "params": {} }}""".format(self.component_name, self._param - ) + ) def __init__(self, canvas, id, param: ComponentParamBase): from agent.canvas import Graph # Local import to avoid cyclic dependency @@ -445,14 +413,42 @@ def invoke(self, **kwargs) -> dict[str, Any]: self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) return self.output() - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + async def invoke_async(self, **kwargs) -> dict[str, Any]: + """ + Async wrapper for component invocation. + Prefers coroutine `_invoke_async` if present; otherwise falls back to `_invoke`. + Handles timing and error recording consistently with `invoke`. + """ + self.set_output("_created_time", time.perf_counter()) + try: + if self.check_if_canceled("Component processing"): + return + + fn_async = getattr(self, "_invoke_async", None) + if fn_async and asyncio.iscoroutinefunction(fn_async): + await fn_async(**kwargs) + elif asyncio.iscoroutinefunction(self._invoke): + await self._invoke(**kwargs) + else: + await asyncio.to_thread(self._invoke, **kwargs) + except Exception as e: + if self.get_exception_default_value(): + self.set_exception_default_value() + else: + self.set_output("_ERROR", str(e)) + logging.exception(e) + self._param.debug_inputs = {} + self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) + return self.output() + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): raise NotImplementedError() - def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]: + def output(self, var_nm: str = None) -> Union[dict[str, Any], Any]: if var_nm: return self._param.outputs.get(var_nm, {}).get("value", "") - return {k: o.get("value") for k,o in self._param.outputs.items()} + return {k: o.get("value") for k, o in self._param.outputs.items()} def set_output(self, key: str, value: Any): if key not in self._param.outputs: @@ -463,15 +459,18 @@ def error(self): return self._param.outputs.get("_ERROR", {}).get("value") def reset(self, only_output=False): - for k in self._param.outputs.keys(): - self._param.outputs[k]["value"] = None + outputs: dict = self._param.outputs # for better performance + for k in outputs.keys(): + outputs[k]["value"] = None if only_output: return - for k in self._param.inputs.keys(): - self._param.inputs[k]["value"] = None + + inputs: dict = self._param.inputs # for better performance + for k in inputs.keys(): + inputs[k]["value"] = None self._param.debug_inputs = {} - def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]: + def get_input(self, key: str = None) -> Union[Any, dict[str, Any]]: if key: return self._param.inputs.get(key, {}).get("value") @@ -495,13 +494,13 @@ def get_input_values(self) -> Union[Any, dict[str, Any]]: def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]: res = {} - for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE|re.DOTALL): + for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE | re.DOTALL): exp = r.group(1) - cpn_id, var_nm = exp.split("@") if exp.find("@")>0 else ("", exp) + cpn_id, var_nm = exp.split("@") if exp.find("@") > 0 else ("", exp) res[exp] = { - "name": (self._canvas.get_component_name(cpn_id) +f"@{var_nm}") if cpn_id else exp, + "name": (self._canvas.get_component_name(cpn_id) + f"@{var_nm}") if cpn_id else exp, "value": self._canvas.get_variable_value(exp), - "_retrival": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None, + "_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None, "_cpn_id": cpn_id } return res @@ -552,6 +551,7 @@ def string_format(content: str, kv: dict[str, str]) -> str: for n, v in kv.items(): def repl(_match, val=v): return str(val) if val is not None else "" + content = re.sub( r"\{%s\}" % re.escape(n), repl, diff --git a/agent/component/begin.py b/agent/component/begin.py index b5985bb7a90..bcbfdbf24b7 100644 --- a/agent/component/begin.py +++ b/agent/component/begin.py @@ -14,6 +14,7 @@ # limitations under the License. # from agent.component.fillup import UserFillUpParam, UserFillUp +from api.db.services.file_service import FileService class BeginParam(UserFillUpParam): @@ -27,7 +28,7 @@ def __init__(self): self.prologue = "Hi! I'm your smart assistant. What can I do for you?" def check(self): - self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"]) + self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"]) def get_input_form(self) -> dict[str, dict]: return getattr(self, "inputs") @@ -48,7 +49,7 @@ def _invoke(self, **kwargs): if v.get("optional") and v.get("value", None) is None: v = None else: - v = self._canvas.get_files([v["value"]]) + v = FileService.get_files([v["value"]]) else: v = v.get("value") self.set_output(k, v) diff --git a/agent/component/categorize.py b/agent/component/categorize.py index 1333889bbdb..27cffb91c88 100644 --- a/agent/component/categorize.py +++ b/agent/component/categorize.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import os import re @@ -97,7 +98,7 @@ class Categorize(LLM, ABC): component_name = "Categorize" @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) - def _invoke(self, **kwargs): + async def _invoke_async(self, **kwargs): if self.check_if_canceled("Categorize processing"): return @@ -121,7 +122,7 @@ def _invoke(self, **kwargs): if self.check_if_canceled("Categorize processing"): return - ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) + ans = await chat_mdl.async_chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) logging.info(f"input: {user_prompt}, answer: {str(ans)}") if ERROR_PREFIX in ans: raise Exception(ans) @@ -144,5 +145,9 @@ def _invoke(self, **kwargs): self.set_output("category_name", max_category) self.set_output("_next", cpn_ids) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + def _invoke(self, **kwargs): + return asyncio.run(self._invoke_async(**kwargs)) + def thoughts(self) -> str: return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()])) diff --git a/agent/component/data_operations.py b/agent/component/data_operations.py index fab7d8c0fa7..cddd20996cd 100644 --- a/agent/component/data_operations.py +++ b/agent/component/data_operations.py @@ -1,3 +1,18 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from abc import ABC import ast import os diff --git a/agent/component/docs_generator.py b/agent/component/docs_generator.py new file mode 100644 index 00000000000..9c244295843 --- /dev/null +++ b/agent/component/docs_generator.py @@ -0,0 +1,1570 @@ +import json +import os +import re +import base64 +from datetime import datetime +from abc import ABC +from io import BytesIO +from typing import Optional +from functools import partial +from reportlab.lib.pagesizes import A4 +from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle +from reportlab.lib.units import inch +from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_JUSTIFY +from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image, TableStyle, LongTable +from reportlab.lib import colors +from reportlab.pdfbase import pdfmetrics +from reportlab.pdfbase.ttfonts import TTFont +from reportlab.pdfbase.cidfonts import UnicodeCIDFont + +from agent.component.base import ComponentParamBase +from api.utils.api_utils import timeout +from .message import Message + + +class PDFGeneratorParam(ComponentParamBase): + """ + Define the PDF Generator component parameters. + """ + + def __init__(self): + super().__init__() + # Output format + self.output_format = "pdf" # pdf, docx, txt + + # Content inputs + self.content = "" + self.title = "" + self.subtitle = "" + self.header_text = "" + self.footer_text = "" + + # Images + self.logo_image = "" # base64 or file path + self.logo_position = "left" # left, center, right + self.logo_width = 2.0 # inches + self.logo_height = 1.0 # inches + + # Styling + self.font_family = "Helvetica" # Helvetica, Times-Roman, Courier + self.font_size = 12 + self.title_font_size = 24 + self.heading1_font_size = 18 + self.heading2_font_size = 16 + self.heading3_font_size = 14 + self.text_color = "#000000" + self.title_color = "#000000" + + # Page settings + self.page_size = "A4" + self.orientation = "portrait" # portrait, landscape + self.margin_top = 1.0 # inches + self.margin_bottom = 1.0 + self.margin_left = 1.0 + self.margin_right = 1.0 + self.line_spacing = 1.2 + + # Output settings + self.filename = "" + self.output_directory = "/tmp/pdf_outputs" + self.add_page_numbers = True + self.add_timestamp = True + + # Advanced features + self.watermark_text = "" + self.enable_toc = False + + self.outputs = { + "file_path": {"value": "", "type": "string"}, + "pdf_base64": {"value": "", "type": "string"}, + "download": {"value": "", "type": "string"}, + "success": {"value": False, "type": "boolean"} + } + + def check(self): + self.check_empty(self.content, "[PDFGenerator] Content") + self.check_valid_value(self.output_format, "[PDFGenerator] Output format", ["pdf", "docx", "txt"]) + self.check_valid_value(self.logo_position, "[PDFGenerator] Logo position", ["left", "center", "right"]) + self.check_valid_value(self.font_family, "[PDFGenerator] Font family", + ["Helvetica", "Times-Roman", "Courier", "Helvetica-Bold", "Times-Bold"]) + self.check_valid_value(self.page_size, "[PDFGenerator] Page size", ["A4", "Letter"]) + self.check_valid_value(self.orientation, "[PDFGenerator] Orientation", ["portrait", "landscape"]) + self.check_positive_number(self.font_size, "[PDFGenerator] Font size") + self.check_positive_number(self.margin_top, "[PDFGenerator] Margin top") + + +class PDFGenerator(Message, ABC): + component_name = "PDFGenerator" + + # Track if Unicode fonts have been registered + _unicode_fonts_registered = False + _unicode_font_name = None + _unicode_font_bold_name = None + + @classmethod + def _reset_font_cache(cls): + """Reset font registration cache - useful for testing""" + cls._unicode_fonts_registered = False + cls._unicode_font_name = None + cls._unicode_font_bold_name = None + + @classmethod + def _register_unicode_fonts(cls): + """Register Unicode-compatible fonts for multi-language support. + + Uses CID fonts (STSong-Light) for reliable CJK rendering as TTF fonts + have issues with glyph mapping in some ReportLab versions. + """ + # If already registered successfully, return True + if cls._unicode_fonts_registered and cls._unicode_font_name is not None: + return True + + # Reset and try again if previous registration failed + cls._unicode_fonts_registered = True + cls._unicode_font_name = None + cls._unicode_font_bold_name = None + + # Use CID fonts for reliable CJK support + # These are built into ReportLab and work reliably across all platforms + cid_fonts = [ + 'STSong-Light', # Simplified Chinese + 'HeiseiMin-W3', # Japanese + 'HYSMyeongJo-Medium', # Korean + ] + + for cid_font in cid_fonts: + try: + pdfmetrics.registerFont(UnicodeCIDFont(cid_font)) + cls._unicode_font_name = cid_font + cls._unicode_font_bold_name = cid_font # CID fonts don't have bold variants + print(f"Registered CID font: {cid_font}") + break + except Exception as e: + print(f"Failed to register CID font {cid_font}: {e}") + continue + + # If CID fonts fail, try TTF fonts as fallback + if not cls._unicode_font_name: + font_paths = [ + '/usr/share/fonts/truetype/freefont/FreeSans.ttf', + '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', + ] + + for font_path in font_paths: + if os.path.exists(font_path): + try: + pdfmetrics.registerFont(TTFont('UnicodeFont', font_path)) + cls._unicode_font_name = 'UnicodeFont' + cls._unicode_font_bold_name = 'UnicodeFont' + print(f"Registered TTF font from: {font_path}") + + # Register font family + from reportlab.pdfbase.pdfmetrics import registerFontFamily + registerFontFamily('UnicodeFont', normal='UnicodeFont', bold='UnicodeFont') + break + except Exception as e: + print(f"Failed to register TTF font {font_path}: {e}") + continue + + return cls._unicode_font_name is not None + + @staticmethod + def _needs_unicode_font(text: str) -> bool: + """Check if text contains CJK or other complex scripts that need special fonts. + + Standard PDF fonts (Helvetica, Times, Courier) support: + - Basic Latin, Extended Latin, Cyrillic, Greek + + CID fonts are needed for: + - CJK (Chinese, Japanese, Korean) + - Arabic, Hebrew (RTL scripts) + - Thai, Hindi, and other Indic scripts + """ + if not text: + return False + + for char in text: + code = ord(char) + + # CJK Unified Ideographs and related ranges + if 0x4E00 <= code <= 0x9FFF: # CJK Unified Ideographs + return True + if 0x3400 <= code <= 0x4DBF: # CJK Extension A + return True + if 0x3000 <= code <= 0x303F: # CJK Symbols and Punctuation + return True + if 0x3040 <= code <= 0x309F: # Hiragana + return True + if 0x30A0 <= code <= 0x30FF: # Katakana + return True + if 0xAC00 <= code <= 0xD7AF: # Hangul Syllables + return True + if 0x1100 <= code <= 0x11FF: # Hangul Jamo + return True + + # Arabic and Hebrew (RTL scripts) + if 0x0600 <= code <= 0x06FF: # Arabic + return True + if 0x0590 <= code <= 0x05FF: # Hebrew + return True + + # Indic scripts + if 0x0900 <= code <= 0x097F: # Devanagari (Hindi) + return True + if 0x0E00 <= code <= 0x0E7F: # Thai + return True + + return False + + def _get_font_for_content(self, content: str) -> tuple: + """Get appropriate font based on content, returns (regular_font, bold_font)""" + if self._needs_unicode_font(content): + if self._register_unicode_fonts() and self._unicode_font_name: + return (self._unicode_font_name, self._unicode_font_bold_name or self._unicode_font_name) + else: + print("Warning: Content contains non-Latin characters but no Unicode font available") + + # Fall back to configured font + return (self._param.font_family, self._get_bold_font_name()) + + def _get_active_font(self) -> str: + """Get the currently active font (Unicode or configured)""" + return getattr(self, '_active_font', self._param.font_family) + + def _get_active_bold_font(self) -> str: + """Get the currently active bold font (Unicode or configured)""" + return getattr(self, '_active_bold_font', self._get_bold_font_name()) + + def _get_bold_font_name(self) -> str: + """Get the correct bold variant of the current font family""" + font_map = { + 'Helvetica': 'Helvetica-Bold', + 'Times-Roman': 'Times-Bold', + 'Courier': 'Courier-Bold', + } + font_family = getattr(self._param, 'font_family', 'Helvetica') + if 'Bold' in font_family: + return font_family + return font_map.get(font_family, 'Helvetica-Bold') + + def get_input_form(self) -> dict[str, dict]: + return { + "content": { + "name": "Content", + "type": "text" + }, + "title": { + "name": "Title", + "type": "line" + }, + "subtitle": { + "name": "Subtitle", + "type": "line" + } + } + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + def _invoke(self, **kwargs): + import traceback + + try: + # Get content from parameters (which may contain variable references) + content = self._param.content or "" + title = self._param.title or "" + subtitle = self._param.subtitle or "" + + # Log PDF generation start + print(f"Starting PDF generation for title: {title}, content length: {len(content)} chars") + + # Resolve variable references in content using canvas + if content and self._canvas.is_reff(content): + # Extract the variable reference and get its value + import re + matches = re.findall(self.variable_ref_patt, content, flags=re.DOTALL) + for match in matches: + try: + var_value = self._canvas.get_variable_value(match) + if var_value: + # Handle partial (streaming) content + if isinstance(var_value, partial): + resolved_content = "" + for chunk in var_value(): + resolved_content += chunk + content = content.replace("{" + match + "}", resolved_content) + else: + content = content.replace("{" + match + "}", str(var_value)) + except Exception as e: + print(f"Error resolving variable {match}: {str(e)}") + content = content.replace("{" + match + "}", f"[ERROR: {str(e)}]") + + # Also process with get_kwargs for any remaining variables + if content: + try: + content, _ = self.get_kwargs(content, kwargs) + except Exception as e: + print(f"Error processing content with get_kwargs: {str(e)}") + + # Process template variables in title + if title and self._canvas.is_reff(title): + try: + matches = re.findall(self.variable_ref_patt, title, flags=re.DOTALL) + for match in matches: + var_value = self._canvas.get_variable_value(match) + if var_value: + title = title.replace("{" + match + "}", str(var_value)) + except Exception as e: + print(f"Error processing title variables: {str(e)}") + + if title: + try: + title, _ = self.get_kwargs(title, kwargs) + except Exception: + pass + + # Process template variables in subtitle + if subtitle and self._canvas.is_reff(subtitle): + try: + matches = re.findall(self.variable_ref_patt, subtitle, flags=re.DOTALL) + for match in matches: + var_value = self._canvas.get_variable_value(match) + if var_value: + subtitle = subtitle.replace("{" + match + "}", str(var_value)) + except Exception as e: + print(f"Error processing subtitle variables: {str(e)}") + + if subtitle: + try: + subtitle, _ = self.get_kwargs(subtitle, kwargs) + except Exception: + pass + + # If content is still empty, check if it was passed directly + if not content: + content = kwargs.get("content", "") + + # Generate document based on format + try: + output_format = self._param.output_format or "pdf" + + if output_format == "pdf": + file_path, doc_base64 = self._generate_pdf(content, title, subtitle) + mime_type = "application/pdf" + elif output_format == "docx": + file_path, doc_base64 = self._generate_docx(content, title, subtitle) + mime_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + elif output_format == "txt": + file_path, doc_base64 = self._generate_txt(content, title, subtitle) + mime_type = "text/plain" + else: + raise Exception(f"Unsupported output format: {output_format}") + + filename = os.path.basename(file_path) + + # Verify the file was created and has content + if not os.path.exists(file_path): + raise Exception(f"Document file was not created: {file_path}") + + file_size = os.path.getsize(file_path) + if file_size == 0: + raise Exception(f"Document file is empty: {file_path}") + + print(f"Successfully generated {output_format.upper()}: {file_path} (Size: {file_size} bytes)") + + # Set outputs + self.set_output("file_path", file_path) + self.set_output("pdf_base64", doc_base64) # Keep same output name for compatibility + self.set_output("success", True) + + # Create download info object + download_info = { + "filename": filename, + "path": file_path, + "base64": doc_base64, + "mime_type": mime_type, + "size": file_size + } + # Output download info as JSON string so it can be used in Message block + download_json = json.dumps(download_info) + self.set_output("download", download_json) + + return download_info + + except Exception as e: + error_msg = f"Error in _generate_pdf: {str(e)}\n{traceback.format_exc()}" + print(error_msg) + self.set_output("success", False) + self.set_output("_ERROR", f"PDF generation failed: {str(e)}") + raise + + except Exception as e: + error_msg = f"Error in PDFGenerator._invoke: {str(e)}\n{traceback.format_exc()}" + print(error_msg) + self.set_output("success", False) + self.set_output("_ERROR", f"PDF generation failed: {str(e)}") + raise + + def _generate_pdf(self, content: str, title: str = "", subtitle: str = "") -> tuple[str, str]: + """Generate PDF from markdown-style content with improved error handling and concurrency support""" + import uuid + import traceback + + # Create output directory if it doesn't exist + os.makedirs(self._param.output_directory, exist_ok=True) + + # Initialize variables that need cleanup + buffer = None + temp_file_path = None + file_path = None + + try: + # Generate a unique filename to prevent conflicts + if self._param.filename: + base_name = os.path.splitext(self._param.filename)[0] + filename = f"{base_name}_{uuid.uuid4().hex[:8]}.pdf" + else: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"document_{timestamp}_{uuid.uuid4().hex[:8]}.pdf" + + file_path = os.path.join(self._param.output_directory, filename) + temp_file_path = f"{file_path}.tmp" + + # Setup page size + page_size = A4 + if self._param.orientation == "landscape": + page_size = (A4[1], A4[0]) + + # Create PDF buffer and document + buffer = BytesIO() + doc = SimpleDocTemplate( + buffer, + pagesize=page_size, + topMargin=self._param.margin_top * inch, + bottomMargin=self._param.margin_bottom * inch, + leftMargin=self._param.margin_left * inch, + rightMargin=self._param.margin_right * inch + ) + + # Build story (content elements) + story = [] + # Combine all text content for Unicode font detection + all_text = f"{title} {subtitle} {content}" + + # IMPORTANT: Register Unicode fonts BEFORE creating any styles or Paragraphs + # This ensures the font family is available for ReportLab's HTML parser + if self._needs_unicode_font(all_text): + self._register_unicode_fonts() + + styles = self._create_styles(all_text) + + # Add logo if provided + if self._param.logo_image: + logo = self._add_logo() + if logo: + story.append(logo) + story.append(Spacer(1, 0.3 * inch)) + + # Add title + if title: + title_para = Paragraph(self._escape_html(title), styles['PDFTitle']) + story.append(title_para) + story.append(Spacer(1, 0.2 * inch)) + + # Add subtitle + if subtitle: + subtitle_para = Paragraph(self._escape_html(subtitle), styles['PDFSubtitle']) + story.append(subtitle_para) + story.append(Spacer(1, 0.3 * inch)) + + # Add timestamp if enabled + if self._param.add_timestamp: + timestamp_text = f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + timestamp_para = Paragraph(timestamp_text, styles['Italic']) + story.append(timestamp_para) + story.append(Spacer(1, 0.2 * inch)) + + # Parse and add content + content_elements = self._parse_markdown_content(content, styles) + story.extend(content_elements) + + # Build PDF + doc.build(story, onFirstPage=self._add_page_decorations, onLaterPages=self._add_page_decorations) + + # Get PDF bytes + pdf_bytes = buffer.getvalue() + + # Write to temporary file first + with open(temp_file_path, 'wb') as f: + f.write(pdf_bytes) + + # Atomic rename to final filename (works across different filesystems) + if os.path.exists(file_path): + os.remove(file_path) + os.rename(temp_file_path, file_path) + + # Verify the file was created and has content + if not os.path.exists(file_path): + raise Exception(f"Failed to create output file: {file_path}") + + file_size = os.path.getsize(file_path) + if file_size == 0: + raise Exception(f"Generated PDF is empty: {file_path}") + + # Convert to base64 + pdf_base64 = base64.b64encode(pdf_bytes).decode('utf-8') + + return file_path, pdf_base64 + + except Exception as e: + # Clean up any temporary files on error + if temp_file_path and os.path.exists(temp_file_path): + try: + os.remove(temp_file_path) + except Exception as cleanup_error: + print(f"Error cleaning up temporary file: {cleanup_error}") + + error_msg = f"Error generating PDF: {str(e)}\n{traceback.format_exc()}" + print(error_msg) + raise Exception(f"PDF generation failed: {str(e)}") + + finally: + # Ensure buffer is always closed + if buffer is not None: + try: + buffer.close() + except Exception as close_error: + print(f"Error closing buffer: {close_error}") + + def _create_styles(self, content: str = ""): + """Create custom paragraph styles with Unicode font support if needed""" + # Check if content contains CJK characters that need special fonts + needs_cjk = self._needs_unicode_font(content) + + if needs_cjk: + # Use CID fonts for CJK content + if self._register_unicode_fonts() and self._unicode_font_name: + regular_font = self._unicode_font_name + bold_font = self._unicode_font_bold_name or self._unicode_font_name + print(f"Using CID font for CJK content: {regular_font}") + else: + # Fall back to configured font if CID fonts unavailable + regular_font = self._param.font_family + bold_font = self._get_bold_font_name() + print(f"Warning: CJK content detected but no CID font available, using {regular_font}") + else: + # Use user-selected font for Latin-only content + regular_font = self._param.font_family + bold_font = self._get_bold_font_name() + print(f"Using configured font: {regular_font}") + + # Store active fonts as instance variables for use in other methods + self._active_font = regular_font + self._active_bold_font = bold_font + + # Get fresh style sheet + styles = getSampleStyleSheet() + + # Helper function to get the correct bold font name + def get_bold_font(font_family): + """Get the correct bold variant of a font family""" + # If using Unicode font, return the Unicode bold + if font_family in ('UnicodeFont', self._unicode_font_name): + return bold_font + font_map = { + 'Helvetica': 'Helvetica-Bold', + 'Times-Roman': 'Times-Bold', + 'Courier': 'Courier-Bold', + } + if 'Bold' in font_family: + return font_family + return font_map.get(font_family, 'Helvetica-Bold') + + # Use detected font instead of configured font for non-Latin content + active_font = regular_font + active_bold_font = bold_font + + # Helper function to add or update style + def add_or_update_style(name, **kwargs): + if name in styles: + # Update existing style + style = styles[name] + for key, value in kwargs.items(): + setattr(style, key, value) + else: + # Add new style + styles.add(ParagraphStyle(name=name, **kwargs)) + + # IMPORTANT: Update base styles to use Unicode font for non-Latin content + # This ensures ALL text uses the correct font, not just our custom styles + add_or_update_style('Normal', fontName=active_font) + add_or_update_style('BodyText', fontName=active_font) + add_or_update_style('Bullet', fontName=active_font) + add_or_update_style('Heading1', fontName=active_bold_font) + add_or_update_style('Heading2', fontName=active_bold_font) + add_or_update_style('Heading3', fontName=active_bold_font) + add_or_update_style('Title', fontName=active_bold_font) + + # Title style + add_or_update_style( + 'PDFTitle', + parent=styles['Heading1'], + fontSize=self._param.title_font_size, + textColor=colors.HexColor(self._param.title_color), + fontName=active_bold_font, + alignment=TA_CENTER, + spaceAfter=12 + ) + + # Subtitle style + add_or_update_style( + 'PDFSubtitle', + parent=styles['Heading2'], + fontSize=self._param.heading2_font_size, + textColor=colors.HexColor(self._param.text_color), + fontName=active_font, + alignment=TA_CENTER, + spaceAfter=12 + ) + + # Custom heading styles + add_or_update_style( + 'CustomHeading1', + parent=styles['Heading1'], + fontSize=self._param.heading1_font_size, + fontName=active_bold_font, + textColor=colors.HexColor(self._param.text_color), + spaceAfter=12, + spaceBefore=12 + ) + + add_or_update_style( + 'CustomHeading2', + parent=styles['Heading2'], + fontSize=self._param.heading2_font_size, + fontName=active_bold_font, + textColor=colors.HexColor(self._param.text_color), + spaceAfter=10, + spaceBefore=10 + ) + + add_or_update_style( + 'CustomHeading3', + parent=styles['Heading3'], + fontSize=self._param.heading3_font_size, + fontName=active_bold_font, + textColor=colors.HexColor(self._param.text_color), + spaceAfter=8, + spaceBefore=8 + ) + + # Body text style + add_or_update_style( + 'CustomBody', + parent=styles['BodyText'], + fontSize=self._param.font_size, + fontName=active_font, + textColor=colors.HexColor(self._param.text_color), + leading=self._param.font_size * self._param.line_spacing, + alignment=TA_JUSTIFY + ) + + # Bullet style + add_or_update_style( + 'CustomBullet', + parent=styles['BodyText'], + fontSize=self._param.font_size, + fontName=active_font, + textColor=colors.HexColor(self._param.text_color), + leftIndent=20, + bulletIndent=10 + ) + + # Code style (keep Courier for code blocks) + add_or_update_style( + 'PDFCode', + parent=styles.get('Code', styles['Normal']), + fontSize=self._param.font_size - 1, + fontName='Courier', + textColor=colors.HexColor('#333333'), + backColor=colors.HexColor('#f5f5f5'), + leftIndent=20, + rightIndent=20 + ) + + # Italic style + add_or_update_style( + 'Italic', + parent=styles['Normal'], + fontSize=self._param.font_size, + fontName=active_font, + textColor=colors.HexColor(self._param.text_color) + ) + + return styles + + def _parse_markdown_content(self, content: str, styles): + """Parse markdown-style content and convert to PDF elements""" + elements = [] + lines = content.split('\n') + + i = 0 + while i < len(lines): + line = lines[i].strip() + + # Skip empty lines + if not line: + elements.append(Spacer(1, 0.1 * inch)) + i += 1 + continue + + # Horizontal rule + if line == '---' or line == '___': + elements.append(Spacer(1, 0.1 * inch)) + elements.append(self._create_horizontal_line()) + elements.append(Spacer(1, 0.1 * inch)) + i += 1 + continue + + # Heading 1 + if line.startswith('# ') and not line.startswith('## '): + text = line[2:].strip() + elements.append(Paragraph(self._format_inline(text), styles['CustomHeading1'])) + i += 1 + continue + + # Heading 2 + if line.startswith('## ') and not line.startswith('### '): + text = line[3:].strip() + elements.append(Paragraph(self._format_inline(text), styles['CustomHeading2'])) + i += 1 + continue + + # Heading 3 + if line.startswith('### '): + text = line[4:].strip() + elements.append(Paragraph(self._format_inline(text), styles['CustomHeading3'])) + i += 1 + continue + + # Bullet list + if line.startswith('- ') or line.startswith('* '): + bullet_items = [] + while i < len(lines) and (lines[i].strip().startswith('- ') or lines[i].strip().startswith('* ')): + item_text = lines[i].strip()[2:].strip() + formatted = self._format_inline(item_text) + bullet_items.append(f"• {formatted}") + i += 1 + for item in bullet_items: + elements.append(Paragraph(item, styles['CustomBullet'])) + continue + + # Numbered list + if re.match(r'^\d+\.\s', line): + numbered_items = [] + counter = 1 + while i < len(lines) and re.match(r'^\d+\.\s', lines[i].strip()): + item_text = re.sub(r'^\d+\.\s', '', lines[i].strip()) + numbered_items.append(f"{counter}. {self._format_inline(item_text)}") + counter += 1 + i += 1 + for item in numbered_items: + elements.append(Paragraph(item, styles['CustomBullet'])) + continue + + # Table detection (markdown table must start with |) + if line.startswith('|') and '|' in line: + table_lines = [] + # Collect all consecutive lines that look like table rows + while i < len(lines) and lines[i].strip() and '|' in lines[i]: + table_lines.append(lines[i].strip()) + i += 1 + + # Only process if we have at least 2 lines (header + separator or header + data) + if len(table_lines) >= 2: + table_elements = self._create_table(table_lines) + if table_elements: + # _create_table now returns a list of elements + elements.extend(table_elements) + elements.append(Spacer(1, 0.2 * inch)) + continue + else: + # Not a valid table, treat as regular text + i -= len(table_lines) # Reset position + + # Code block + if line.startswith('```'): + code_lines = [] + i += 1 + while i < len(lines) and not lines[i].strip().startswith('```'): + code_lines.append(lines[i]) + i += 1 + if i < len(lines): + i += 1 + code_text = '\n'.join(code_lines) + elements.append(Paragraph(self._escape_html(code_text), styles['PDFCode'])) + elements.append(Spacer(1, 0.1 * inch)) + continue + + # Regular paragraph + paragraph_lines = [line] + i += 1 + while i < len(lines) and lines[i].strip() and not self._is_special_line(lines[i]): + paragraph_lines.append(lines[i].strip()) + i += 1 + + paragraph_text = ' '.join(paragraph_lines) + formatted_text = self._format_inline(paragraph_text) + elements.append(Paragraph(formatted_text, styles['CustomBody'])) + elements.append(Spacer(1, 0.1 * inch)) + + return elements + + def _is_special_line(self, line: str) -> bool: + """Check if line is a special markdown element""" + line = line.strip() + return (line.startswith('#') or + line.startswith('- ') or + line.startswith('* ') or + re.match(r'^\d+\.\s', line) or + line in ['---', '___'] or + line.startswith('```') or + '|' in line) + + def _format_inline(self, text: str) -> str: + """Format inline markdown (bold, italic, code)""" + # First, escape the existing HTML to not conflict with our tags. + text = self._escape_html(text) + + # IMPORTANT: Process inline code FIRST to protect underscores inside code blocks + # Use a placeholder to protect code blocks from italic/bold processing + code_blocks = [] + def save_code(match): + code_blocks.append(match.group(1)) + return f"__CODE_BLOCK_{len(code_blocks)-1}__" + + text = re.sub(r'`(.+?)`', save_code, text) + + # Then, apply markdown formatting. + # The order is important: from most specific to least specific. + + # Bold and italic combined: ***text*** or ___text___ + text = re.sub(r'\*\*\*(.+?)\*\*\*', r'\1', text) + text = re.sub(r'___(.+?)___', r'\1', text) + + # Bold: **text** or __text__ + text = re.sub(r'\*\*(.+?)\*\*', r'\1', text) + text = re.sub(r'__([^_]+?)__', r'\1', text) # More restrictive to avoid matching placeholders + + # Italic: *text* or _text_ (but not underscores in words like variable_name) + text = re.sub(r'\*([^*]+?)\*', r'\1', text) + # Only match _text_ when surrounded by spaces or at start/end, not mid-word underscores + text = re.sub(r'(?\1', text) + + # Restore code blocks with proper formatting + for i, code in enumerate(code_blocks): + text = text.replace(f"__CODE_BLOCK_{i}__", f'{code}') + + return text + + def _escape_html(self, text: str) -> str: + """Escape HTML special characters and clean up markdown. + + Args: + text: Input text that may contain HTML or markdown + + Returns: + str: Cleaned and escaped text + """ + if not text: + return "" + + # Ensure we're working with a string + text = str(text) + + # Remove HTML form elements and tags + text = re.sub(r']*>', '', text, flags=re.IGNORECASE) # Remove input tags + text = re.sub(r']*>.*?', '', text, flags=re.IGNORECASE | re.DOTALL) # Remove textarea + text = re.sub(r']*>.*?', '', text, flags=re.IGNORECASE | re.DOTALL) # Remove select + text = re.sub(r']*>.*?', '', text, flags=re.IGNORECASE | re.DOTALL) # Remove buttons + text = re.sub(r']*>.*?', '', text, flags=re.IGNORECASE | re.DOTALL) # Remove forms + + # Remove other common HTML tags (but preserve content) + text = re.sub(r']*>', '', text, flags=re.IGNORECASE) + text = re.sub(r'', '', text, flags=re.IGNORECASE) + text = re.sub(r']*>', '', text, flags=re.IGNORECASE) + text = re.sub(r'', '', text, flags=re.IGNORECASE) + text = re.sub(r']*>', '', text, flags=re.IGNORECASE) + text = re.sub(r'

', '\n', text, flags=re.IGNORECASE) + + # First, handle common markdown table artifacts + text = re.sub(r'^[|\-\s:]+$', '', text, flags=re.MULTILINE) # Remove separator lines + text = re.sub(r'^\s*\|\s*|\s*\|\s*$', '', text) # Remove leading/trailing pipes + text = re.sub(r'\s*\|\s*', ' | ', text) # Normalize pipes + + # Remove markdown links, but keep other formatting characters for _format_inline + text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text) # Remove markdown links + + # Escape HTML special characters + text = text.replace('&', '&') + text = text.replace('<', '<') + text = text.replace('>', '>') + + # Clean up excessive whitespace + text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text) # Multiple blank lines to double + text = re.sub(r' +', ' ', text) # Multiple spaces to single + + return text.strip() + + def _get_cell_style(self, row_idx: int, is_header: bool = False, font_size: int = None) -> 'ParagraphStyle': + """Get the appropriate style for a table cell.""" + styles = getSampleStyleSheet() + + # Helper function to get the correct bold font name + def get_bold_font(font_family): + font_map = { + 'Helvetica': 'Helvetica-Bold', + 'Times-Roman': 'Times-Bold', + 'Courier': 'Courier-Bold', + } + if 'Bold' in font_family: + return font_family + return font_map.get(font_family, 'Helvetica-Bold') + + if is_header: + return ParagraphStyle( + 'TableHeader', + parent=styles['Normal'], + fontSize=self._param.font_size, + fontName=self._get_active_bold_font(), + textColor=colors.whitesmoke, + alignment=TA_CENTER, + leading=self._param.font_size * 1.2, + wordWrap='CJK' + ) + else: + font_size = font_size or (self._param.font_size - 1) + return ParagraphStyle( + 'TableCell', + parent=styles['Normal'], + fontSize=font_size, + fontName=self._get_active_font(), + textColor=colors.black, + alignment=TA_LEFT, + leading=font_size * 1.15, + wordWrap='CJK' + ) + + def _convert_table_to_definition_list(self, data: list[list[str]]) -> list: + """Convert a table to a definition list format for better handling of large content. + + This method handles both simple and complex tables, including those with nested content. + It ensures that large cell content is properly wrapped and paginated. + """ + elements = [] + styles = getSampleStyleSheet() + + # Base styles + base_font_size = getattr(self._param, 'font_size', 10) + + # Body style + body_style = ParagraphStyle( + 'TableBody', + parent=styles['Normal'], + fontSize=base_font_size, + fontName=self._get_active_font(), + textColor=colors.HexColor(getattr(self._param, 'text_color', '#000000')), + spaceAfter=6, + leading=base_font_size * 1.2 + ) + + # Label style (for field names) + label_style = ParagraphStyle( + 'LabelStyle', + parent=body_style, + fontName=self._get_active_bold_font(), + textColor=colors.HexColor('#2c3e50'), + fontSize=base_font_size, + spaceAfter=4, + leftIndent=0, + leading=base_font_size * 1.3 + ) + + # Value style (for cell content) - clean, no borders + value_style = ParagraphStyle( + 'ValueStyle', + parent=body_style, + leftIndent=15, + rightIndent=0, + spaceAfter=8, + spaceBefore=2, + fontSize=base_font_size, + textColor=colors.HexColor('#333333'), + alignment=TA_JUSTIFY, + leading=base_font_size * 1.4, + # No borders or background - clean text only + ) + + try: + # If we have no data, return empty list + if not data or not any(data): + return elements + + # Get column headers or generate them + headers = [] + if data and len(data) > 0: + headers = [str(h).strip() for h in data[0]] + + # If no headers or empty headers, generate them + if not any(headers): + headers = [f"Column {i+1}" for i in range(len(data[0]) if data and len(data) > 0 else 0)] + + # Process each data row (skip header if it exists) + start_row = 1 if len(data) > 1 and any(data[0]) else 0 + + for row_idx in range(start_row, len(data)): + row = data[row_idx] if row_idx < len(data) else [] + if not row: + continue + + # Create a container for the row + row_elements = [] + + # Process each cell in the row + for col_idx in range(len(headers)): + if col_idx >= len(headers): + continue + + # Get cell content + cell_text = str(row[col_idx]).strip() if col_idx < len(row) and row[col_idx] is not None else "" + + # Skip empty cells + if not cell_text or cell_text.isspace(): + continue + + # Clean up markdown artifacts for regular text content + cell_text = str(cell_text) # Ensure it's a string + + # Remove markdown table formatting + cell_text = re.sub(r'^[|\-\s:]+$', '', cell_text, flags=re.MULTILINE) # Remove separator lines + cell_text = re.sub(r'^\s*\|\s*|\s*\|\s*$', '', cell_text) # Remove leading/trailing pipes + cell_text = re.sub(r'\s*\|\s*', ' | ', cell_text) # Normalize pipes + cell_text = re.sub(r'\s+', ' ', cell_text).strip() # Normalize whitespace + + # Remove any remaining markdown formatting + cell_text = re.sub(r'`(.*?)`', r'\1', cell_text) # Remove code ticks + cell_text = re.sub(r'\*\*(.*?)\*\*', r'\1', cell_text) # Remove bold + cell_text = re.sub(r'\*(.*?)\*', r'\1', cell_text) # Remove italic + + # Clean up any HTML entities or special characters + cell_text = self._escape_html(cell_text) + + # If content still looks like a table, convert it to plain text + if '|' in cell_text and ('--' in cell_text or any(cell_text.count('|') > 2 for line in cell_text.split('\n') if line.strip())): + # Convert to a simple text format + lines = [line.strip() for line in cell_text.split('\n') if line.strip()] + cell_text = ' | '.join(lines[:5]) # Join first 5 lines with pipe + if len(lines) > 5: + cell_text += '...' + + # Process long content with better wrapping + max_chars_per_line = 100 # Reduced for better readability + max_paragraphs = 3 # Maximum number of paragraphs to show initially + + # Split into paragraphs + paragraphs = [p for p in cell_text.split('\n\n') if p.strip()] + + # If content is too long, truncate with "show more" indicator + if len(paragraphs) > max_paragraphs or any(len(p) > max_chars_per_line * 3 for p in paragraphs): + wrapped_paragraphs = [] + + for i, para in enumerate(paragraphs[:max_paragraphs]): + if len(para) > max_chars_per_line * 3: + # Split long paragraphs + words = para.split() + current_line = [] + current_length = 0 + + for word in words: + if current_line and current_length + len(word) + 1 > max_chars_per_line: + wrapped_paragraphs.append(' '.join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += len(word) + (1 if current_line else 0) + + if current_line: + wrapped_paragraphs.append(' '.join(current_line)) + else: + wrapped_paragraphs.append(para) + + # Add "show more" indicator if there are more paragraphs + if len(paragraphs) > max_paragraphs: + wrapped_paragraphs.append(f"... and {len(paragraphs) - max_paragraphs} more paragraphs") + + cell_text = '\n\n'.join(wrapped_paragraphs) + + # Add label and content with clean formatting (no borders) + label_para = Paragraph(f"{self._escape_html(headers[col_idx])}:", label_style) + value_para = Paragraph(self._escape_html(cell_text), value_style) + + # Add elements with proper spacing + row_elements.append(label_para) + row_elements.append(Spacer(1, 0.03 * 72)) # Tiny space between label and value + row_elements.append(value_para) + + # Add spacing between rows + if row_elements and row_idx < len(data) - 1: + # Add a subtle horizontal line as separator + row_elements.append(Spacer(1, 0.1 * 72)) + row_elements.append(self._create_horizontal_line(width=0.5, color='#e0e0e0')) + row_elements.append(Spacer(1, 0.15 * 72)) + + elements.extend(row_elements) + + # Add some space after the table + if elements: + elements.append(Spacer(1, 0.3 * 72)) # 0.3 inches in points + + except Exception as e: + # Fallback to simple text representation if something goes wrong + error_style = ParagraphStyle( + 'ErrorStyle', + parent=styles['Normal'], + fontSize=base_font_size - 1, + textColor=colors.red, + backColor=colors.HexColor('#fff0f0'), + borderWidth=1, + borderColor=colors.red, + borderPadding=5 + ) + + error_msg = [ + Paragraph("Error processing table:", error_style), + Paragraph(str(e), error_style), + Spacer(1, 0.2 * 72) + ] + + # Add a simplified version of the table + try: + for row in data[:10]: # Limit to first 10 rows to avoid huge error output + error_msg.append(Paragraph(" | ".join(str(cell) for cell in row), body_style)) + if len(data) > 10: + error_msg.append(Paragraph(f"... and {len(data) - 10} more rows", body_style)) + except Exception: + pass + + elements.extend(error_msg) + + return elements + + def _create_table(self, table_lines: list[str]) -> Optional[list]: + """Create a table from markdown table syntax with robust error handling. + + This method handles simple tables and falls back to a list format for complex cases. + + Returns: + A list of flowables (could be a table or alternative representation) + Returns None if the table cannot be created. + """ + if not table_lines or len(table_lines) < 2: + return None + + try: + # Parse table data + data = [] + max_columns = 0 + + for line in table_lines: + # Skip separator lines (e.g., |---|---|) + if re.match(r'^\|[\s\-:]+\|$', line): + continue + + # Handle empty lines within tables + if not line.strip(): + continue + + # Split by | and clean up cells + cells = [] + in_quotes = False + current_cell = "" + + # Custom split to handle escaped pipes and quoted content + for char in line[1:]: # Skip initial | + if char == '|' and not in_quotes: + cells.append(current_cell.strip()) + current_cell = "" + elif char == '"': + in_quotes = not in_quotes + current_cell += char + elif char == '\\' and not in_quotes: + # Handle escaped characters + pass + else: + current_cell += char + + # Add the last cell + if current_cell.strip() or len(cells) > 0: + cells.append(current_cell.strip()) + + # Remove empty first/last elements if they're empty (from leading/trailing |) + if cells and not cells[0]: + cells = cells[1:] + if cells and not cells[-1]: + cells = cells[:-1] + + if cells: + data.append(cells) + max_columns = max(max_columns, len(cells)) + + if not data or max_columns == 0: + return None + + # Ensure all rows have the same number of columns + for row in data: + while len(row) < max_columns: + row.append('') + + # Calculate available width for table + from reportlab.lib.pagesizes import A4 + page_width = A4[0] if self._param.orientation == 'portrait' else A4[1] + available_width = page_width - (self._param.margin_left + self._param.margin_right) * inch + + # Check if we should use definition list format + max_cell_length = max((len(str(cell)) for row in data for cell in row), default=0) + total_rows = len(data) + + # Use definition list format if: + # - Any cell is too large (> 300 chars), OR + # - More than 6 columns, OR + # - More than 20 rows, OR + # - Contains nested tables or complex structures + has_nested_tables = any('|' in cell and '---' in cell for row in data for cell in row) + has_complex_cells = any(len(str(cell)) > 150 for row in data for cell in row) + + should_use_list_format = ( + max_cell_length > 300 or + max_columns > 6 or + total_rows > 20 or + has_nested_tables or + has_complex_cells + ) + + if should_use_list_format: + return self._convert_table_to_definition_list(data) + + # Process cells for normal table + processed_data = [] + for row_idx, row in enumerate(data): + processed_row = [] + for cell_idx, cell in enumerate(row): + cell_text = str(cell).strip() if cell is not None else "" + + # Handle empty cells + if not cell_text: + processed_row.append("") + continue + + # Clean up markdown table artifacts + cell_text = re.sub(r'\\\|', '|', cell_text) # Unescape pipes + cell_text = re.sub(r'\\n', '\n', cell_text) # Handle explicit newlines + + # Check for nested tables + if '|' in cell_text and '---' in cell_text: + # This cell contains a nested table + nested_lines = [line.strip() for line in cell_text.split('\n') if line.strip()] + nested_table = self._create_table(nested_lines) + if nested_table: + processed_row.append(nested_table[0]) # Add the nested table + continue + + # Process as regular text + font_size = self._param.font_size - 1 if row_idx > 0 else self._param.font_size + try: + style = self._get_cell_style(row_idx, is_header=(row_idx == 0), font_size=font_size) + escaped_text = self._escape_html(cell_text) + processed_row.append(Paragraph(escaped_text, style)) + except Exception: + processed_row.append(self._escape_html(cell_text)) + + processed_data.append(processed_row) + + # Calculate column widths + min_col_width = 0.5 * inch + max_cols = int(available_width / min_col_width) + + if max_columns > max_cols: + return self._convert_table_to_definition_list(data) + + col_width = max(min_col_width, available_width / max_columns) + col_widths = [col_width] * max_columns + + # Create the table + try: + table = LongTable(processed_data, colWidths=col_widths, repeatRows=1) + + # Define table style + table_style = [ + ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c3e50')), # Darker header + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), + ('ALIGN', (0, 0), (-1, 0), 'CENTER'), + ('FONTNAME', (0, 0), (-1, 0), self._get_active_bold_font()), + ('FONTSIZE', (0, 0), (-1, -1), self._param.font_size - 1), + ('BOTTOMPADDING', (0, 0), (-1, 0), 12), + ('BACKGROUND', (0, 1), (-1, -1), colors.HexColor('#f8f9fa')), # Lighter background + ('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#dee2e6')), # Lighter grid + ('VALIGN', (0, 0), (-1, -1), 'TOP'), + ('TOPPADDING', (0, 0), (-1, -1), 8), + ('BOTTOMPADDING', (0, 0), (-1, -1), 8), + ('LEFTPADDING', (0, 0), (-1, -1), 8), + ('RIGHTPADDING', (0, 0), (-1, -1), 8), + ] + + # Add zebra striping for better readability + for i in range(1, len(processed_data)): + if i % 2 == 0: + table_style.append(('BACKGROUND', (0, i), (-1, i), colors.HexColor('#f1f3f5'))) + + table.setStyle(TableStyle(table_style)) + + # Add a small spacer after the table + return [table, Spacer(1, 0.2 * inch)] + + except Exception as table_error: + print(f"Error creating table: {table_error}") + return self._convert_table_to_definition_list(data) + + except Exception as e: + print(f"Error processing table: {e}") + # Return a simple text representation of the table + try: + text_content = [] + for row in data: + text_content.append(" | ".join(str(cell) for cell in row)) + return [Paragraph("
".join(text_content), self._get_cell_style(0))] + except Exception: + return None + + def _create_horizontal_line(self, width: float = 1, color: str = None): + """Create a horizontal line with customizable width and color + + Args: + width: Line thickness in points (default: 1) + color: Hex color string (default: grey) + + Returns: + HRFlowable: Horizontal line element + """ + from reportlab.platypus import HRFlowable + line_color = colors.HexColor(color) if color else colors.grey + return HRFlowable(width="100%", thickness=width, color=line_color, spaceBefore=0, spaceAfter=0) + + def _add_logo(self) -> Optional[Image]: + """Add logo image to PDF""" + try: + # Check if it's base64 or file path + if self._param.logo_image.startswith('data:image'): + # Extract base64 data + base64_data = self._param.logo_image.split(',')[1] + image_data = base64.b64decode(base64_data) + img = Image(BytesIO(image_data)) + elif os.path.exists(self._param.logo_image): + img = Image(self._param.logo_image) + else: + return None + + # Set size + img.drawWidth = self._param.logo_width * inch + img.drawHeight = self._param.logo_height * inch + + # Set alignment + if self._param.logo_position == 'center': + img.hAlign = 'CENTER' + elif self._param.logo_position == 'right': + img.hAlign = 'RIGHT' + else: + img.hAlign = 'LEFT' + + return img + except Exception as e: + print(f"Error adding logo: {e}") + return None + + def _add_page_decorations(self, canvas, doc): + """Add header, footer, page numbers, watermark""" + canvas.saveState() + + # Get active font for decorations + active_font = self._get_active_font() + + # Add watermark + if self._param.watermark_text: + canvas.setFont(active_font, 60) + canvas.setFillColorRGB(0.9, 0.9, 0.9, alpha=0.3) + canvas.saveState() + canvas.translate(doc.pagesize[0] / 2, doc.pagesize[1] / 2) + canvas.rotate(45) + canvas.drawCentredString(0, 0, self._param.watermark_text) + canvas.restoreState() + + # Add header + if self._param.header_text: + canvas.setFont(active_font, 9) + canvas.setFillColorRGB(0.5, 0.5, 0.5) + canvas.drawString(doc.leftMargin, doc.pagesize[1] - 0.5 * inch, self._param.header_text) + + # Add footer + if self._param.footer_text: + canvas.setFont(active_font, 9) + canvas.setFillColorRGB(0.5, 0.5, 0.5) + canvas.drawString(doc.leftMargin, 0.5 * inch, self._param.footer_text) + + # Add page numbers + if self._param.add_page_numbers: + page_num = canvas.getPageNumber() + text = f"Page {page_num}" + canvas.setFont(active_font, 9) + canvas.setFillColorRGB(0.5, 0.5, 0.5) + canvas.drawRightString(doc.pagesize[0] - doc.rightMargin, 0.5 * inch, text) + + canvas.restoreState() + + def thoughts(self) -> str: + return "Generating PDF document with formatted content..." + + def _generate_docx(self, content: str, title: str = "", subtitle: str = "") -> tuple[str, str]: + """Generate DOCX from markdown-style content""" + import uuid + from docx import Document + from docx.shared import Pt + from docx.enum.text import WD_ALIGN_PARAGRAPH + + # Create output directory if it doesn't exist + os.makedirs(self._param.output_directory, exist_ok=True) + + try: + # Generate filename + if self._param.filename: + base_name = os.path.splitext(self._param.filename)[0] + filename = f"{base_name}_{uuid.uuid4().hex[:8]}.docx" + else: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"document_{timestamp}_{uuid.uuid4().hex[:8]}.docx" + + file_path = os.path.join(self._param.output_directory, filename) + + # Create document + doc = Document() + + # Add title + if title: + title_para = doc.add_heading(title, level=0) + title_para.alignment = WD_ALIGN_PARAGRAPH.CENTER + + # Add subtitle + if subtitle: + subtitle_para = doc.add_heading(subtitle, level=1) + subtitle_para.alignment = WD_ALIGN_PARAGRAPH.CENTER + + # Add timestamp if enabled + if self._param.add_timestamp: + timestamp_text = f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + ts_para = doc.add_paragraph(timestamp_text) + ts_para.runs[0].italic = True + ts_para.runs[0].font.size = Pt(9) + + # Parse and add content + lines = content.split('\n') + i = 0 + while i < len(lines): + line = lines[i].strip() + + if not line: + i += 1 + continue + + # Headings + if line.startswith('# ') and not line.startswith('## '): + doc.add_heading(line[2:].strip(), level=1) + elif line.startswith('## ') and not line.startswith('### '): + doc.add_heading(line[3:].strip(), level=2) + elif line.startswith('### '): + doc.add_heading(line[4:].strip(), level=3) + # Bullet list + elif line.startswith('- ') or line.startswith('* '): + doc.add_paragraph(line[2:].strip(), style='List Bullet') + # Numbered list + elif re.match(r'^\d+\.\s', line): + text = re.sub(r'^\d+\.\s', '', line) + doc.add_paragraph(text, style='List Number') + # Regular paragraph + else: + para = doc.add_paragraph(line) + para.runs[0].font.size = Pt(self._param.font_size) + + i += 1 + + # Save document + doc.save(file_path) + + # Read and encode to base64 + with open(file_path, 'rb') as f: + doc_bytes = f.read() + doc_base64 = base64.b64encode(doc_bytes).decode('utf-8') + + return file_path, doc_base64 + + except Exception as e: + raise Exception(f"DOCX generation failed: {str(e)}") + + def _generate_txt(self, content: str, title: str = "", subtitle: str = "") -> tuple[str, str]: + """Generate TXT from markdown-style content""" + import uuid + + # Create output directory if it doesn't exist + os.makedirs(self._param.output_directory, exist_ok=True) + + try: + # Generate filename + if self._param.filename: + base_name = os.path.splitext(self._param.filename)[0] + filename = f"{base_name}_{uuid.uuid4().hex[:8]}.txt" + else: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"document_{timestamp}_{uuid.uuid4().hex[:8]}.txt" + + file_path = os.path.join(self._param.output_directory, filename) + + # Build text content + text_content = [] + + if title: + text_content.append(title.upper()) + text_content.append("=" * len(title)) + text_content.append("") + + if subtitle: + text_content.append(subtitle) + text_content.append("-" * len(subtitle)) + text_content.append("") + + if self._param.add_timestamp: + timestamp_text = f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + text_content.append(timestamp_text) + text_content.append("") + + # Add content (keep markdown formatting for readability) + text_content.append(content) + + # Join and save + final_text = '\n'.join(text_content) + + with open(file_path, 'w', encoding='utf-8') as f: + f.write(final_text) + + # Encode to base64 + txt_base64 = base64.b64encode(final_text.encode('utf-8')).decode('utf-8') + + return file_path, txt_base64 + + except Exception as e: + raise Exception(f"TXT generation failed: {str(e)}") diff --git a/agent/component/excel_processor.py b/agent/component/excel_processor.py new file mode 100644 index 00000000000..65b3a9bd202 --- /dev/null +++ b/agent/component/excel_processor.py @@ -0,0 +1,401 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +ExcelProcessor Component + +A component for reading, processing, and generating Excel files in RAGFlow agents. +Supports multiple Excel file inputs, data transformation, and Excel output generation. +""" + +import logging +import os +from abc import ABC +from io import BytesIO + +import pandas as pd + +from agent.component.base import ComponentBase, ComponentParamBase +from api.db.services.file_service import FileService +from api.utils.api_utils import timeout +from common import settings +from common.misc_utils import get_uuid + + +class ExcelProcessorParam(ComponentParamBase): + """ + Define the ExcelProcessor component parameters. + """ + def __init__(self): + super().__init__() + # Input configuration + self.input_files = [] # Variable references to uploaded files + self.operation = "read" # read, merge, transform, output + + # Processing options + self.sheet_selection = "all" # all, first, or comma-separated sheet names + self.merge_strategy = "concat" # concat, join + self.join_on = "" # Column name for join operations + + # Transform options (for LLM-guided transformations) + self.transform_instructions = "" + self.transform_data = "" # Variable reference to transformation data + + # Output options + self.output_format = "xlsx" # xlsx, csv + self.output_filename = "output" + + # Component outputs + self.outputs = { + "data": { + "type": "object", + "value": {} + }, + "summary": { + "type": "str", + "value": "" + }, + "markdown": { + "type": "str", + "value": "" + } + } + + def check(self): + self.check_valid_value( + self.operation, + "[ExcelProcessor] Operation", + ["read", "merge", "transform", "output"] + ) + self.check_valid_value( + self.output_format, + "[ExcelProcessor] Output format", + ["xlsx", "csv"] + ) + return True + + +class ExcelProcessor(ComponentBase, ABC): + """ + Excel processing component for RAGFlow agents. + + Operations: + - read: Parse Excel files into structured data + - merge: Combine multiple Excel files + - transform: Apply data transformations based on instructions + - output: Generate Excel file output + """ + component_name = "ExcelProcessor" + + def get_input_form(self) -> dict[str, dict]: + """Define input form for the component.""" + res = {} + for ref in (self._param.input_files or []): + for k, o in self.get_input_elements_from_text(ref).items(): + res[k] = {"name": o.get("name", ""), "type": "file"} + if self._param.transform_data: + for k, o in self.get_input_elements_from_text(self._param.transform_data).items(): + res[k] = {"name": o.get("name", ""), "type": "object"} + return res + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + def _invoke(self, **kwargs): + if self.check_if_canceled("ExcelProcessor processing"): + return + + operation = self._param.operation.lower() + + if operation == "read": + self._read_excels() + elif operation == "merge": + self._merge_excels() + elif operation == "transform": + self._transform_data() + elif operation == "output": + self._output_excel() + else: + self.set_output("summary", f"Unknown operation: {operation}") + + def _get_file_content(self, file_ref: str) -> tuple[bytes, str]: + """ + Get file content from a variable reference. + Returns (content_bytes, filename). + """ + value = self._canvas.get_variable_value(file_ref) + if value is None: + return None, None + + # Handle different value formats + if isinstance(value, dict): + # File reference from Begin/UserFillUp component + file_id = value.get("id") or value.get("file_id") + created_by = value.get("created_by") or self._canvas.get_tenant_id() + filename = value.get("name") or value.get("filename", "unknown.xlsx") + if file_id: + content = FileService.get_blob(created_by, file_id) + return content, filename + elif isinstance(value, list) and len(value) > 0: + # List of file references - return first + return self._get_file_content_from_list(value[0]) + elif isinstance(value, str): + # Could be base64 encoded or a path + if value.startswith("data:"): + import base64 + # Extract base64 content + _, encoded = value.split(",", 1) + return base64.b64decode(encoded), "uploaded.xlsx" + + return None, None + + def _get_file_content_from_list(self, item) -> tuple[bytes, str]: + """Extract file content from a list item.""" + if isinstance(item, dict): + return self._get_file_content(item) + return None, None + + def _parse_excel_to_dataframes(self, content: bytes, filename: str) -> dict[str, pd.DataFrame]: + """Parse Excel content into a dictionary of DataFrames (one per sheet).""" + try: + excel_file = BytesIO(content) + + if filename.lower().endswith(".csv"): + df = pd.read_csv(excel_file) + return {"Sheet1": df} + else: + # Read all sheets + xlsx = pd.ExcelFile(excel_file, engine='openpyxl') + sheet_selection = self._param.sheet_selection + + if sheet_selection == "all": + sheets_to_read = xlsx.sheet_names + elif sheet_selection == "first": + sheets_to_read = [xlsx.sheet_names[0]] if xlsx.sheet_names else [] + else: + # Comma-separated sheet names + requested = [s.strip() for s in sheet_selection.split(",")] + sheets_to_read = [s for s in requested if s in xlsx.sheet_names] + + dfs = {} + for sheet in sheets_to_read: + dfs[sheet] = pd.read_excel(xlsx, sheet_name=sheet) + return dfs + + except Exception as e: + logging.error(f"Error parsing Excel file {filename}: {e}") + return {} + + def _read_excels(self): + """Read and parse Excel files into structured data.""" + all_data = {} + summaries = [] + markdown_parts = [] + + for file_ref in (self._param.input_files or []): + if self.check_if_canceled("ExcelProcessor reading"): + return + + # Get variable value + value = self._canvas.get_variable_value(file_ref) + self.set_input_value(file_ref, str(value)[:200] if value else "") + + if value is None: + continue + + # Handle file content + content, filename = self._get_file_content(file_ref) + if content is None: + continue + + # Parse Excel + dfs = self._parse_excel_to_dataframes(content, filename) + + for sheet_name, df in dfs.items(): + key = f"{filename}_{sheet_name}" if len(dfs) > 1 else filename + all_data[key] = df.to_dict(orient="records") + + # Build summary + summaries.append(f"**{key}**: {len(df)} rows, {len(df.columns)} columns ({', '.join(df.columns.tolist()[:5])}{'...' if len(df.columns) > 5 else ''})") + + # Build markdown table + markdown_parts.append(f"### {key}\n\n{df.head(10).to_markdown(index=False)}\n") + + # Set outputs + self.set_output("data", all_data) + self.set_output("summary", "\n".join(summaries) if summaries else "No Excel files found") + self.set_output("markdown", "\n\n".join(markdown_parts) if markdown_parts else "No data") + + def _merge_excels(self): + """Merge multiple Excel files/sheets into one.""" + all_dfs = [] + + for file_ref in (self._param.input_files or []): + if self.check_if_canceled("ExcelProcessor merging"): + return + + value = self._canvas.get_variable_value(file_ref) + self.set_input_value(file_ref, str(value)[:200] if value else "") + + if value is None: + continue + + content, filename = self._get_file_content(file_ref) + if content is None: + continue + + dfs = self._parse_excel_to_dataframes(content, filename) + all_dfs.extend(dfs.values()) + + if not all_dfs: + self.set_output("data", {}) + self.set_output("summary", "No data to merge") + return + + # Merge strategy + if self._param.merge_strategy == "concat": + merged_df = pd.concat(all_dfs, ignore_index=True) + elif self._param.merge_strategy == "join" and self._param.join_on: + # Join on specified column + merged_df = all_dfs[0] + for df in all_dfs[1:]: + merged_df = merged_df.merge(df, on=self._param.join_on, how="outer") + else: + merged_df = pd.concat(all_dfs, ignore_index=True) + + self.set_output("data", {"merged": merged_df.to_dict(orient="records")}) + self.set_output("summary", f"Merged {len(all_dfs)} sources into {len(merged_df)} rows, {len(merged_df.columns)} columns") + self.set_output("markdown", merged_df.head(20).to_markdown(index=False)) + + def _transform_data(self): + """Apply transformations to data based on instructions or input data.""" + # Get the data to transform + transform_ref = self._param.transform_data + if not transform_ref: + self.set_output("summary", "No transform data reference provided") + return + + data = self._canvas.get_variable_value(transform_ref) + self.set_input_value(transform_ref, str(data)[:300] if data else "") + + if data is None: + self.set_output("summary", "Transform data is empty") + return + + # Convert to DataFrame + if isinstance(data, dict): + # Could be {"sheet": [rows]} format + if all(isinstance(v, list) for v in data.values()): + # Multiple sheets + all_markdown = [] + for sheet_name, rows in data.items(): + df = pd.DataFrame(rows) + all_markdown.append(f"### {sheet_name}\n\n{df.to_markdown(index=False)}") + self.set_output("data", data) + self.set_output("markdown", "\n\n".join(all_markdown)) + else: + df = pd.DataFrame([data]) + self.set_output("data", df.to_dict(orient="records")) + self.set_output("markdown", df.to_markdown(index=False)) + elif isinstance(data, list): + df = pd.DataFrame(data) + self.set_output("data", df.to_dict(orient="records")) + self.set_output("markdown", df.to_markdown(index=False)) + else: + self.set_output("data", {"raw": str(data)}) + self.set_output("markdown", str(data)) + + self.set_output("summary", "Transformed data ready for processing") + + def _output_excel(self): + """Generate Excel file output from data.""" + # Get data from transform_data reference + transform_ref = self._param.transform_data + if not transform_ref: + self.set_output("summary", "No data reference for output") + return + + data = self._canvas.get_variable_value(transform_ref) + self.set_input_value(transform_ref, str(data)[:300] if data else "") + + if data is None: + self.set_output("summary", "No data to output") + return + + try: + # Prepare DataFrames + if isinstance(data, dict): + if all(isinstance(v, list) for v in data.values()): + # Multi-sheet format + dfs = {k: pd.DataFrame(v) for k, v in data.items()} + else: + dfs = {"Sheet1": pd.DataFrame([data])} + elif isinstance(data, list): + dfs = {"Sheet1": pd.DataFrame(data)} + else: + self.set_output("summary", "Invalid data format for Excel output") + return + + # Generate output + doc_id = get_uuid() + + if self._param.output_format == "csv": + # For CSV, only output first sheet + first_df = list(dfs.values())[0] + binary_content = first_df.to_csv(index=False).encode("utf-8") + filename = f"{self._param.output_filename}.csv" + else: + # Excel output + excel_io = BytesIO() + with pd.ExcelWriter(excel_io, engine='openpyxl') as writer: + for sheet_name, df in dfs.items(): + # Sanitize sheet name (max 31 chars, no special chars) + safe_name = sheet_name[:31].replace("/", "_").replace("\\", "_") + df.to_excel(writer, sheet_name=safe_name, index=False) + excel_io.seek(0) + binary_content = excel_io.read() + filename = f"{self._param.output_filename}.xlsx" + + # Store file + settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content) + + # Set attachment output + self.set_output("attachment", { + "doc_id": doc_id, + "format": self._param.output_format, + "file_name": filename + }) + + total_rows = sum(len(df) for df in dfs.values()) + self.set_output("summary", f"Generated {filename} with {len(dfs)} sheet(s), {total_rows} total rows") + self.set_output("data", {k: v.to_dict(orient="records") for k, v in dfs.items()}) + + logging.info(f"ExcelProcessor: Generated {filename} as {doc_id}") + + except Exception as e: + logging.error(f"ExcelProcessor output error: {e}") + self.set_output("summary", f"Error generating output: {str(e)}") + + def thoughts(self) -> str: + """Return component thoughts for UI display.""" + op = self._param.operation + if op == "read": + return "Reading Excel files..." + elif op == "merge": + return "Merging Excel data..." + elif op == "transform": + return "Transforming data..." + elif op == "output": + return "Generating Excel output..." + return "Processing Excel..." diff --git a/agent/component/webhook.py b/agent/component/exit_loop.py similarity index 64% rename from agent/component/webhook.py rename to agent/component/exit_loop.py index c707d455626..9dc04491293 100644 --- a/agent/component/webhook.py +++ b/agent/component/exit_loop.py @@ -13,26 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from agent.component.base import ComponentParamBase, ComponentBase +from abc import ABC +from agent.component.base import ComponentBase, ComponentParamBase -class WebhookParam(ComponentParamBase): +class ExitLoopParam(ComponentParamBase, ABC): + def check(self): + return True - """ - Define the Begin component parameters. - """ - def __init__(self): - super().__init__() - def get_input_form(self) -> dict[str, dict]: - return getattr(self, "inputs") - - -class Webhook(ComponentBase): - component_name = "Webhook" +class ExitLoop(ComponentBase, ABC): + component_name = "ExitLoop" def _invoke(self, **kwargs): pass def thoughts(self) -> str: - return "" + return "" \ No newline at end of file diff --git a/agent/component/fillup.py b/agent/component/fillup.py index 7428912d490..10163d10c0b 100644 --- a/agent/component/fillup.py +++ b/agent/component/fillup.py @@ -18,6 +18,7 @@ from functools import partial from agent.component.base import ComponentParamBase, ComponentBase +from api.db.services.file_service import FileService class UserFillUpParam(ComponentParamBase): @@ -63,6 +64,13 @@ def _invoke(self, **kwargs): for k, v in kwargs.get("inputs", {}).items(): if self.check_if_canceled("UserFillUp processing"): return + if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0: + if v.get("optional") and v.get("value", None) is None: + v = None + else: + v = FileService.get_files([v["value"]]) + else: + v = v.get("value") self.set_output(k, v) def thoughts(self) -> str: diff --git a/agent/component/iteration.py b/agent/component/iteration.py index a39147d8f81..ae5c0b6772d 100644 --- a/agent/component/iteration.py +++ b/agent/component/iteration.py @@ -32,6 +32,7 @@ class IterationParam(ComponentParamBase): def __init__(self): super().__init__() self.items_ref = "" + self.variable={} def get_input_form(self) -> dict[str, dict]: return { diff --git a/agent/component/list_operations.py b/agent/component/list_operations.py new file mode 100644 index 00000000000..6016f758507 --- /dev/null +++ b/agent/component/list_operations.py @@ -0,0 +1,168 @@ +from abc import ABC +import os +from agent.component.base import ComponentBase, ComponentParamBase +from api.utils.api_utils import timeout + +class ListOperationsParam(ComponentParamBase): + """ + Define the List Operations component parameters. + """ + def __init__(self): + super().__init__() + self.query = "" + self.operations = "topN" + self.n=0 + self.sort_method = "asc" + self.filter = { + "operator": "=", + "value": "" + } + self.outputs = { + "result": { + "value": [], + "type": "Array of ?" + }, + "first": { + "value": "", + "type": "?" + }, + "last": { + "value": "", + "type": "?" + } + } + + def check(self): + self.check_empty(self.query, "query") + self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"]) + + def get_input_form(self) -> dict[str, dict]: + return {} + + +class ListOperations(ComponentBase,ABC): + component_name = "ListOperations" + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + def _invoke(self, **kwargs): + self.input_objects=[] + inputs = getattr(self._param, "query", None) + self.inputs = self._canvas.get_variable_value(inputs) + if not isinstance(self.inputs, list): + raise TypeError("The input of List Operations should be an array.") + self.set_input_value(inputs, self.inputs) + if self._param.operations == "topN": + self._topN() + elif self._param.operations == "head": + self._head() + elif self._param.operations == "tail": + self._tail() + elif self._param.operations == "filter": + self._filter() + elif self._param.operations == "sort": + self._sort() + elif self._param.operations == "drop_duplicates": + self._drop_duplicates() + + + def _coerce_n(self): + try: + return int(getattr(self._param, "n", 0)) + except Exception: + return 0 + + def _set_outputs(self, outputs): + self._param.outputs["result"]["value"] = outputs + self._param.outputs["first"]["value"] = outputs[0] if outputs else None + self._param.outputs["last"]["value"] = outputs[-1] if outputs else None + + def _topN(self): + n = self._coerce_n() + if n < 1: + outputs = [] + else: + n = min(n, len(self.inputs)) + outputs = self.inputs[:n] + self._set_outputs(outputs) + + def _head(self): + n = self._coerce_n() + if 1 <= n <= len(self.inputs): + outputs = [self.inputs[n - 1]] + else: + outputs = [] + self._set_outputs(outputs) + + def _tail(self): + n = self._coerce_n() + if 1 <= n <= len(self.inputs): + outputs = [self.inputs[-n]] + else: + outputs = [] + self._set_outputs(outputs) + + def _filter(self): + self._set_outputs([i for i in self.inputs if self._eval(self._norm(i),self._param.filter["operator"],self._param.filter["value"])]) + + def _norm(self,v): + s = "" if v is None else str(v) + return s + + def _eval(self, v, operator, value): + if operator == "=": + return v == value + elif operator == "≠": + return v != value + elif operator == "contains": + return value in v + elif operator == "start with": + return v.startswith(value) + elif operator == "end with": + return v.endswith(value) + else: + return False + + def _sort(self): + items = self.inputs or [] + method = getattr(self._param, "sort_method", "asc") or "asc" + reverse = method == "desc" + + if not items: + self._set_outputs([]) + return + + first = items[0] + + if isinstance(first, dict): + outputs = sorted( + items, + key=lambda x: self._hashable(x), + reverse=reverse, + ) + else: + outputs = sorted(items, reverse=reverse) + + self._set_outputs(outputs) + + def _drop_duplicates(self): + seen = set() + outs = [] + for item in self.inputs: + k = self._hashable(item) + if k in seen: + continue + seen.add(k) + outs.append(item) + self._set_outputs(outs) + + def _hashable(self,x): + if isinstance(x, dict): + return tuple(sorted((k, self._hashable(v)) for k, v in x.items())) + if isinstance(x, (list, tuple)): + return tuple(self._hashable(v) for v in x) + if isinstance(x, set): + return tuple(sorted(self._hashable(v) for v in x)) + return x + + def thoughts(self) -> str: + return "ListOperation in progress" diff --git a/agent/component/llm.py b/agent/component/llm.py index 6ce0f65a551..e9d8770684c 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import os import re from copy import deepcopy -from typing import Any, Generator +from typing import Any, AsyncGenerator import json_repair from functools import partial from common.constants import LLMType @@ -55,7 +56,6 @@ def check(self): self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens") self.check_decimal_float(float(self.top_p), "[Agent] Top P") self.check_empty(self.llm_id, "[Agent] LLM") - self.check_empty(self.sys_prompt, "[Agent] System prompt") self.check_empty(self.prompts, "[Agent] User prompt") def gen_conf(self): @@ -166,25 +166,67 @@ def _extract_prompts(self, sys_prompt): sys_prompt = re.sub(rf"<{tag}>(.*?)", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE) return pts, sys_prompt - def _generate(self, msg:list[dict], **kwargs) -> str: + async def _generate_async(self, msg: list[dict], **kwargs) -> str: if not self.imgs: - return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) - return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) + + async def _generate_streamly(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]: + async def delta_wrapper(txt_iter): + ans = "" + last_idx = 0 + endswith_think = False + + def delta(txt): + nonlocal ans, last_idx, endswith_think + delta_ans = txt[last_idx:] + ans = txt + + if delta_ans.find("") == 0: + last_idx += len("") + return "" + elif delta_ans.find("") > 0: + delta_ans = txt[last_idx:last_idx + delta_ans.find("")] + last_idx += delta_ans.find("") + return delta_ans + elif delta_ans.endswith(""): + endswith_think = True + elif endswith_think: + endswith_think = False + return "" + + last_idx = len(ans) + if ans.endswith(""): + last_idx -= len("") + return re.sub(r"(|)", "", delta_ans) + + async for t in txt_iter: + yield delta(t) - def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]: - ans = "" + if not self.imgs: + async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)): + yield t + return + + async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)): + yield t + + async def _stream_output_async(self, prompt, msg): + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) + answer = "" last_idx = 0 endswith_think = False + def delta(txt): - nonlocal ans, last_idx, endswith_think + nonlocal answer, last_idx, endswith_think delta_ans = txt[last_idx:] - ans = txt + answer = txt if delta_ans.find("") == 0: last_idx += len("") return "" elif delta_ans.find("") > 0: - delta_ans = txt[last_idx:last_idx+delta_ans.find("")] + delta_ans = txt[last_idx:last_idx + delta_ans.find("")] last_idx += delta_ans.find("") return delta_ans elif delta_ans.endswith(""): @@ -193,20 +235,33 @@ def delta(txt): endswith_think = False return "" - last_idx = len(ans) - if ans.endswith(""): + last_idx = len(answer) + if answer.endswith(""): last_idx -= len("") return re.sub(r"(|)", "", delta_ans) - if not self.imgs: - for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs): - yield delta(txt) - else: - for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): - yield delta(txt) + stream_kwargs = {"images": self.imgs} if self.imgs else {} + async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs): + if self.check_if_canceled("LLM streaming"): + return + + if isinstance(ans, int): + continue + + if ans.find("**ERROR**") >= 0: + if self.get_exception_default_value(): + self.set_output("content", self.get_exception_default_value()) + yield self.get_exception_default_value() + else: + self.set_output("_ERROR", ans) + return + + yield delta(ans) + + self.set_output("content", answer) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) - def _invoke(self, **kwargs): + async def _invoke_async(self, **kwargs): if self.check_if_canceled("LLM processing"): return @@ -217,22 +272,25 @@ def clean_formated_answer(ans: str) -> str: prompt, msg, _ = self._prepare_prompt_variables() error: str = "" - output_structure=None + output_structure = None try: - output_structure = self._param.outputs['structured'] + output_structure = self._param.outputs["structured"] except Exception: pass - if output_structure: - schema=json.dumps(output_structure, ensure_ascii=False, indent=2) - prompt += structured_output_prompt(schema) - for _ in range(self._param.max_retries+1): + if output_structure and isinstance(output_structure, dict) and output_structure.get("properties") and len(output_structure["properties"]) > 0: + schema = json.dumps(output_structure, ensure_ascii=False, indent=2) + prompt_with_schema = prompt + structured_output_prompt(schema) + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("LLM processing"): return - _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) + _, msg_fit = message_fit_in( + [{"role": "system", "content": prompt_with_schema}, *deepcopy(msg)], + int(self.chat_mdl.max_length * 0.97), + ) error = "" - ans = self._generate(msg) - msg.pop(0) + ans = await self._generate_async(msg_fit) + msg_fit.pop(0) if ans.find("**ERROR**") >= 0: logging.error(f"LLM response error: {ans}") error = ans @@ -241,7 +299,7 @@ def clean_formated_answer(ans: str) -> str: self.set_output("structured", json_repair.loads(clean_formated_answer(ans))) return except Exception: - msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"}) + msg_fit.append({"role": "user", "content": "The answer can't not be parsed as JSON"}) error = "The answer can't not be parsed as JSON" if error: self.set_output("_ERROR", error) @@ -249,18 +307,23 @@ def clean_formated_answer(ans: str) -> str: downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else [] ex = self.exception_handler() - if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not output_structure and not (ex and ex["goto"]): - self.set_output("content", partial(self._stream_output, prompt, msg)) + if any([self._canvas.get_component_obj(cid).component_name.lower() == "message" for cid in downstreams]) and not ( + ex and ex["goto"] + ): + self.set_output("content", partial(self._stream_output_async, prompt, deepcopy(msg))) return - for _ in range(self._param.max_retries+1): + error = "" + for _ in range(self._param.max_retries + 1): if self.check_if_canceled("LLM processing"): return - _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) + _, msg_fit = message_fit_in( + [{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97) + ) error = "" - ans = self._generate(msg) - msg.pop(0) + ans = await self._generate_async(msg_fit) + msg_fit.pop(0) if ans.find("**ERROR**") >= 0: logging.error(f"LLM response error: {ans}") error = ans @@ -274,26 +337,12 @@ def clean_formated_answer(ans: str) -> str: else: self.set_output("_ERROR", error) - def _stream_output(self, prompt, msg): - _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) - answer = "" - for ans in self._generate_streamly(msg): - if self.check_if_canceled("LLM streaming"): - return - - if ans.find("**ERROR**") >= 0: - if self.get_exception_default_value(): - self.set_output("content", self.get_exception_default_value()) - yield self.get_exception_default_value() - else: - self.set_output("_ERROR", ans) - return - yield ans - answer += ans - self.set_output("content", answer) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + def _invoke(self, **kwargs): + return asyncio.run(self._invoke_async(**kwargs)) - def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): - summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) + async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): + summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) logging.info(f"[MEMORY]: {summ}") self._canvas.add_memory(user, assist, summ) diff --git a/agent/component/loop.py b/agent/component/loop.py new file mode 100644 index 00000000000..484dfae8256 --- /dev/null +++ b/agent/component/loop.py @@ -0,0 +1,80 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC +from agent.component.base import ComponentBase, ComponentParamBase + + +class LoopParam(ComponentParamBase): + """ + Define the Loop component parameters. + """ + + def __init__(self): + super().__init__() + self.loop_variables = [] + self.loop_termination_condition=[] + self.maximum_loop_count = 0 + + def get_input_form(self) -> dict[str, dict]: + return { + "items": { + "type": "json", + "name": "Items" + } + } + + def check(self): + return True + + +class Loop(ComponentBase, ABC): + component_name = "Loop" + + def get_start(self): + for cid in self._canvas.components.keys(): + if self._canvas.get_component(cid)["obj"].component_name.lower() != "loopitem": + continue + if self._canvas.get_component(cid)["parent_id"] == self._id: + return cid + + def _invoke(self, **kwargs): + if self.check_if_canceled("Loop processing"): + return + + for item in self._param.loop_variables: + if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]): + assert "Loop Variable is not complete." + if item["input_mode"]=="variable": + self.set_output(item["variable"],self._canvas.get_variable_value(item["value"])) + elif item["input_mode"]=="constant": + self.set_output(item["variable"],item["value"]) + else: + if item["type"] == "number": + self.set_output(item["variable"], 0) + elif item["type"] == "string": + self.set_output(item["variable"], "") + elif item["type"] == "boolean": + self.set_output(item["variable"], False) + elif item["type"].startswith("object"): + self.set_output(item["variable"], {}) + elif item["type"].startswith("array"): + self.set_output(item["variable"], []) + else: + self.set_output(item["variable"], "") + + + def thoughts(self) -> str: + return "Loop from canvas." \ No newline at end of file diff --git a/agent/component/loopitem.py b/agent/component/loopitem.py new file mode 100644 index 00000000000..b656ea78948 --- /dev/null +++ b/agent/component/loopitem.py @@ -0,0 +1,167 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC +from agent.component.base import ComponentBase, ComponentParamBase + + +class LoopItemParam(ComponentParamBase): + """ + Define the LoopItem component parameters. + """ + def check(self): + return True + +class LoopItem(ComponentBase, ABC): + component_name = "LoopItem" + + def __init__(self, canvas, id, param: ComponentParamBase): + super().__init__(canvas, id, param) + self._idx = 0 + + + def _invoke(self, **kwargs): + if self.check_if_canceled("LoopItem processing"): + return + parent = self.get_parent() + maximum_loop_count = parent._param.maximum_loop_count + if self._idx >= maximum_loop_count: + self._idx = -1 + return + if self._idx > 0: + if self.check_if_canceled("LoopItem processing"): + return + self._idx += 1 + + def evaluate_condition(self,var, operator, value): + if isinstance(var, str): + if operator == "contains": + return value in var + elif operator == "not contains": + return value not in var + elif operator == "start with": + return var.startswith(value) + elif operator == "end with": + return var.endswith(value) + elif operator == "is": + return var == value + elif operator == "is not": + return var != value + elif operator == "empty": + return var == "" + elif operator == "not empty": + return var != "" + + elif isinstance(var, (int, float)): + if operator == "=": + return var == value + elif operator == "≠": + return var != value + elif operator == ">": + return var > value + elif operator == "<": + return var < value + elif operator == "≥": + return var >= value + elif operator == "≤": + return var <= value + elif operator == "empty": + return var is None + elif operator == "not empty": + return var is not None + + elif isinstance(var, bool): + if operator == "is": + return var is value + elif operator == "is not": + return var is not value + elif operator == "empty": + return var is None + elif operator == "not empty": + return var is not None + + elif isinstance(var, dict): + if operator == "empty": + return len(var) == 0 + elif operator == "not empty": + return len(var) > 0 + + elif isinstance(var, list): + if operator == "contains": + return value in var + elif operator == "not contains": + return value not in var + + elif operator == "is": + return var == value + elif operator == "is not": + return var != value + + elif operator == "empty": + return len(var) == 0 + elif operator == "not empty": + return len(var) > 0 + elif var is None: + if operator == "empty": + return True + return False + + raise Exception(f"Invalid operator: {operator}") + + def end(self): + if self._idx == -1: + return True + parent = self.get_parent() + logical_operator = parent._param.logical_operator if hasattr(parent._param, "logical_operator") else "and" + conditions = [] + for item in parent._param.loop_termination_condition: + if not item.get("variable") or not item.get("operator"): + raise ValueError("Loop condition is incomplete.") + var = self._canvas.get_variable_value(item["variable"]) + operator = item["operator"] + input_mode = item.get("input_mode", "constant") + + if input_mode == "variable": + value = self._canvas.get_variable_value(item.get("value", "")) + elif input_mode == "constant": + value = item.get("value", "") + else: + raise ValueError("Invalid input mode.") + conditions.append(self.evaluate_condition(var, operator, value)) + should_end = ( + all(conditions) if logical_operator == "and" + else any(conditions) if logical_operator == "or" + else None + ) + if should_end is None: + raise ValueError("Invalid logical operator,should be 'and' or 'or'.") + + if should_end: + self._idx = -1 + return True + + return False + + def next(self): + if self._idx == -1: + self._idx = 0 + else: + self._idx += 1 + if self._idx >= len(self._items): + self._idx = -1 + return False + + def thoughts(self) -> str: + return "Next turn..." \ No newline at end of file diff --git a/agent/component/message.py b/agent/component/message.py index 641198083e9..bf393f541d6 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -13,10 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +import nest_asyncio +nest_asyncio.apply() +import inspect import json import os import random import re +import logging +import tempfile from functools import partial from typing import Any @@ -24,6 +30,10 @@ from jinja2 import Template as Jinja2Template from common.connection_utils import timeout +from common.misc_utils import get_uuid +from common import settings + +from api.db.joint_services.memory_message_service import queue_save_to_memory_task class MessageParam(ComponentParamBase): @@ -34,6 +44,8 @@ def __init__(self): super().__init__() self.content = [] self.stream = True + self.output_format = None # default output format + self.auto_play = False self.outputs = { "content": { "type": "str" @@ -61,8 +73,12 @@ def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[ v = "" ans = "" if isinstance(v, partial): - for t in v(): - ans += t + iter_obj = v() + if inspect.isasyncgen(iter_obj): + ans = asyncio.run(self._consume_async_gen(iter_obj)) + else: + for t in iter_obj: + ans += t elif isinstance(v, list) and delimiter: ans = delimiter.join([str(vv) for vv in v]) elif not isinstance(v, str): @@ -84,7 +100,13 @@ def get_kwargs(self, script:str, kwargs:dict = {}, delimiter:str=None) -> tuple[ _kwargs[_n] = v return script, _kwargs - def _stream(self, rand_cnt:str): + async def _consume_async_gen(self, agen): + buf = "" + async for t in agen: + buf += t + return buf + + async def _stream(self, rand_cnt:str): s = 0 all_content = "" cache = {} @@ -106,15 +128,27 @@ def _stream(self, rand_cnt:str): v = "" if isinstance(v, partial): cnt = "" - for t in v(): - if self.check_if_canceled("Message streaming"): - return + iter_obj = v() + if inspect.isasyncgen(iter_obj): + async for t in iter_obj: + if self.check_if_canceled("Message streaming"): + return - all_content += t - cnt += t - yield t + all_content += t + cnt += t + yield t + else: + for t in iter_obj: + if self.check_if_canceled("Message streaming"): + return + + all_content += t + cnt += t + yield t self.set_input_value(exp, cnt) continue + elif inspect.isawaitable(v): + v = await v elif not isinstance(v, str): try: v = json.dumps(v, ensure_ascii=False) @@ -133,6 +167,8 @@ def _stream(self, rand_cnt:str): yield rand_cnt[s: ] self.set_output("content", all_content) + self._convert_content(all_content) + await self._save_to_memory(all_content) def _is_jinjia2(self, content:str) -> bool: patt = [ @@ -164,6 +200,241 @@ def _invoke(self, **kwargs): content = re.sub(n, v, content) self.set_output("content", content) + self._convert_content(content) + self._save_to_memory(content) def thoughts(self) -> str: return "" + + def _parse_markdown_table_lines(self, table_lines: list): + """ + Parse a list of Markdown table lines into a pandas DataFrame. + + Args: + table_lines: List of strings, each representing a row in the Markdown table + (excluding separator lines like |---|---|) + + Returns: + pandas DataFrame with the table data, or None if parsing fails + """ + import pandas as pd + + if not table_lines: + return None + + rows = [] + headers = None + + for line in table_lines: + # Split by | and clean up + cells = [cell.strip() for cell in line.split('|')] + # Remove empty first and last elements from split (caused by leading/trailing |) + cells = [c for c in cells if c] + + if headers is None: + headers = cells + else: + rows.append(cells) + + if headers and rows: + # Ensure all rows have same number of columns as headers + normalized_rows = [] + for row in rows: + while len(row) < len(headers): + row.append('') + normalized_rows.append(row[:len(headers)]) + + return pd.DataFrame(normalized_rows, columns=headers) + + return None + + def _convert_content(self, content): + if not self._param.output_format: + return + + import pypandoc + doc_id = get_uuid() + + if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx", "xlsx"}: + self._param.output_format = "markdown" + + try: + if self._param.output_format in {"markdown", "html"}: + if isinstance(content, str): + converted = pypandoc.convert_text( + content, + to=self._param.output_format, + format="markdown", + ) + else: + converted = pypandoc.convert_file( + content, + to=self._param.output_format, + format="markdown", + ) + + binary_content = converted.encode("utf-8") + + elif self._param.output_format == "xlsx": + import pandas as pd + from io import BytesIO + + # Debug: log the content being parsed + logging.info(f"XLSX Parser: Content length={len(content) if content else 0}, first 500 chars: {content[:500] if content else 'None'}") + + # Try to parse ALL Markdown tables from the content + # Each table will be written to a separate sheet + tables = [] # List of (sheet_name, dataframe) + + if isinstance(content, str): + lines = content.strip().split('\n') + logging.info(f"XLSX Parser: Total lines={len(lines)}, lines starting with '|': {sum(1 for line in lines if line.strip().startswith('|'))}") + current_table_lines = [] + current_table_title = None + pending_title = None + in_table = False + table_count = 0 + + for i, line in enumerate(lines): + stripped = line.strip() + + # Check for potential table title (lines before a table) + # Look for patterns like "Table 1:", "## Table", or markdown headers + if not in_table and stripped and not stripped.startswith('|'): + # Check if this could be a table title + lower_stripped = stripped.lower() + if (lower_stripped.startswith('table') or + stripped.startswith('#') or + ':' in stripped): + pending_title = stripped.lstrip('#').strip() + + if stripped.startswith('|') and '|' in stripped[1:]: + # Check if this is a separator line (|---|---|) + cleaned = stripped.replace(' ', '').replace('|', '').replace('-', '').replace(':', '') + if cleaned == '': + continue # Skip separator line + + if not in_table: + # Starting a new table + in_table = True + current_table_lines = [] + current_table_title = pending_title + pending_title = None + + current_table_lines.append(stripped) + + elif in_table and not stripped.startswith('|'): + # End of current table - save it + if current_table_lines: + df = self._parse_markdown_table_lines(current_table_lines) + if df is not None and not df.empty: + table_count += 1 + # Generate sheet name + if current_table_title: + # Clean and truncate title for sheet name + sheet_name = current_table_title[:31] + sheet_name = sheet_name.replace('/', '_').replace('\\', '_').replace('*', '').replace('?', '').replace('[', '').replace(']', '').replace(':', '') + else: + sheet_name = f"Table_{table_count}" + tables.append((sheet_name, df)) + + # Reset for next table + in_table = False + current_table_lines = [] + current_table_title = None + + # Check if this line could be a title for the next table + if stripped: + lower_stripped = stripped.lower() + if (lower_stripped.startswith('table') or + stripped.startswith('#') or + ':' in stripped): + pending_title = stripped.lstrip('#').strip() + + # Don't forget the last table if content ends with a table + if in_table and current_table_lines: + df = self._parse_markdown_table_lines(current_table_lines) + if df is not None and not df.empty: + table_count += 1 + if current_table_title: + sheet_name = current_table_title[:31] + sheet_name = sheet_name.replace('/', '_').replace('\\', '_').replace('*', '').replace('?', '').replace('[', '').replace(']', '').replace(':', '') + else: + sheet_name = f"Table_{table_count}" + tables.append((sheet_name, df)) + + # Fallback: if no tables found, create single sheet with content + if not tables: + df = pd.DataFrame({"Content": [content if content else ""]}) + tables = [("Data", df)] + + # Write all tables to Excel, each in a separate sheet + excel_io = BytesIO() + with pd.ExcelWriter(excel_io, engine='openpyxl') as writer: + used_names = set() + for sheet_name, df in tables: + # Ensure unique sheet names + original_name = sheet_name + counter = 1 + while sheet_name in used_names: + suffix = f"_{counter}" + sheet_name = original_name[:31-len(suffix)] + suffix + counter += 1 + used_names.add(sheet_name) + df.to_excel(writer, sheet_name=sheet_name, index=False) + + excel_io.seek(0) + binary_content = excel_io.read() + + logging.info(f"Generated Excel with {len(tables)} sheet(s): {[t[0] for t in tables]}") + + else: # pdf, docx + with tempfile.NamedTemporaryFile(suffix=f".{self._param.output_format}", delete=False) as tmp: + tmp_name = tmp.name + + try: + if isinstance(content, str): + pypandoc.convert_text( + content, + to=self._param.output_format, + format="markdown", + outputfile=tmp_name, + ) + else: + pypandoc.convert_file( + content, + to=self._param.output_format, + format="markdown", + outputfile=tmp_name, + ) + + with open(tmp_name, "rb") as f: + binary_content = f.read() + + finally: + if os.path.exists(tmp_name): + os.remove(tmp_name) + + settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content) + self.set_output("attachment", { + "doc_id":doc_id, + "format":self._param.output_format, + "file_name":f"{doc_id[:8]}.{self._param.output_format}"}) + + logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})") + + except Exception as e: + logging.error(f"Error converting content to {self._param.output_format}: {e}") + + async def _save_to_memory(self, content): + if not hasattr(self._param, "memory_ids") or not self._param.memory_ids: + return True, "No memory selected." + + message_dict = { + "user_id": self._canvas._tenant_id, + "agent_id": self._canvas._id, + "session_id": self._canvas.task_id, + "user_input": self._canvas.get_sys_query(), + "agent_response": content + } + return await queue_save_to_memory_task(self._param.memory_ids, message_dict) diff --git a/agent/component/variable_assigner.py b/agent/component/variable_assigner.py new file mode 100644 index 00000000000..08b28334312 --- /dev/null +++ b/agent/component/variable_assigner.py @@ -0,0 +1,192 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC +import os +import numbers +from agent.component.base import ComponentBase, ComponentParamBase +from api.utils.api_utils import timeout + +class VariableAssignerParam(ComponentParamBase): + """ + Define the Variable Assigner component parameters. + """ + def __init__(self): + super().__init__() + self.variables=[] + + def check(self): + return True + + def get_input_form(self) -> dict[str, dict]: + return { + "items": { + "type": "json", + "name": "Items" + } + } + +class VariableAssigner(ComponentBase,ABC): + component_name = "VariableAssigner" + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + def _invoke(self, **kwargs): + if not isinstance(self._param.variables,list): + return + else: + for item in self._param.variables: + if any([not item.get("variable"), not item.get("operator"), not item.get("parameter")]): + assert "Variable is not complete." + variable=item["variable"] + operator=item["operator"] + parameter=item["parameter"] + variable_value=self._canvas.get_variable_value(variable) + new_variable=self._operate(variable_value,operator,parameter) + self._canvas.set_variable_value(variable, new_variable) + + def _operate(self,variable,operator,parameter): + if operator == "overwrite": + return self._overwrite(parameter) + elif operator == "clear": + return self._clear(variable) + elif operator == "set": + return self._set(variable,parameter) + elif operator == "append": + return self._append(variable,parameter) + elif operator == "extend": + return self._extend(variable,parameter) + elif operator == "remove_first": + return self._remove_first(variable) + elif operator == "remove_last": + return self._remove_last(variable) + elif operator == "+=": + return self._add(variable,parameter) + elif operator == "-=": + return self._subtract(variable,parameter) + elif operator == "*=": + return self._multiply(variable,parameter) + elif operator == "/=": + return self._divide(variable,parameter) + else: + return + + def _overwrite(self,parameter): + return self._canvas.get_variable_value(parameter) + + def _clear(self,variable): + if isinstance(variable,list): + return [] + elif isinstance(variable,str): + return "" + elif isinstance(variable,dict): + return {} + elif isinstance(variable,int): + return 0 + elif isinstance(variable,float): + return 0.0 + elif isinstance(variable,bool): + return False + else: + return None + + def _set(self,variable,parameter): + if variable is None: + return self._canvas.get_value_with_variable(parameter) + elif isinstance(variable,str): + return self._canvas.get_value_with_variable(parameter) + elif isinstance(variable,bool): + return parameter + elif isinstance(variable,int): + return parameter + elif isinstance(variable,float): + return parameter + else: + return parameter + + def _append(self,variable,parameter): + parameter=self._canvas.get_variable_value(parameter) + if variable is None: + variable=[] + if not isinstance(variable,list): + return "ERROR:VARIABLE_NOT_LIST" + elif len(variable)!=0 and not isinstance(parameter,type(variable[0])): + return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE" + else: + variable.append(parameter) + return variable + + def _extend(self,variable,parameter): + parameter=self._canvas.get_variable_value(parameter) + if variable is None: + variable=[] + if not isinstance(variable,list): + return "ERROR:VARIABLE_NOT_LIST" + elif not isinstance(parameter,list): + return "ERROR:PARAMETER_NOT_LIST" + elif len(variable)!=0 and len(parameter)!=0 and not isinstance(parameter[0],type(variable[0])): + return "ERROR:PARAMETER_NOT_LIST_ELEMENT_TYPE" + else: + return variable + parameter + + def _remove_first(self,variable): + if len(variable)==0: + return variable + if not isinstance(variable,list): + return "ERROR:VARIABLE_NOT_LIST" + else: + return variable[1:] + + def _remove_last(self,variable): + if len(variable)==0: + return variable + if not isinstance(variable,list): + return "ERROR:VARIABLE_NOT_LIST" + else: + return variable[:-1] + + def is_number(self, value): + if isinstance(value, bool): + return False + return isinstance(value, numbers.Number) + + def _add(self,variable,parameter): + if self.is_number(variable) and self.is_number(parameter): + return variable + parameter + else: + return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER" + + def _subtract(self,variable,parameter): + if self.is_number(variable) and self.is_number(parameter): + return variable - parameter + else: + return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER" + + def _multiply(self,variable,parameter): + if self.is_number(variable) and self.is_number(parameter): + return variable * parameter + else: + return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER" + + def _divide(self,variable,parameter): + if self.is_number(variable) and self.is_number(parameter): + if parameter==0: + return "ERROR:DIVIDE_BY_ZERO" + else: + return variable/parameter + else: + return "ERROR:VARIABLE_NOT_NUMBER or PARAMETER_NOT_NUMBER" + + def thoughts(self) -> str: + return "Assign variables from canvas." \ No newline at end of file diff --git a/agent/templates/advanced_ingestion_pipeline.json b/agent/templates/advanced_ingestion_pipeline.json index cfd211f4688..2e996e248be 100644 --- a/agent/templates/advanced_ingestion_pipeline.json +++ b/agent/templates/advanced_ingestion_pipeline.json @@ -193,7 +193,7 @@ "presence_penalty": 0.4, "prompts": [ { - "content": "Text Content:\n{Splitter:KindDingosJam@chunks}\n", + "content": "Text Content:\n{Splitter:NineTiesSin@chunks}\n", "role": "user" } ], @@ -226,7 +226,7 @@ "presence_penalty": 0.4, "prompts": [ { - "content": "Text Content:\n\n{Splitter:KindDingosJam@chunks}\n", + "content": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n", "role": "user" } ], @@ -259,7 +259,7 @@ "presence_penalty": 0.4, "prompts": [ { - "content": "Content: \n\n{Splitter:KindDingosJam@chunks}", + "content": "Content: \n\n{Splitter:CuteBusesBet@chunks}", "role": "user" } ], @@ -485,7 +485,7 @@ "outputs": {}, "presencePenaltyEnabled": false, "presence_penalty": 0.4, - "prompts": "Text Content:\n{Splitter:KindDingosJam@chunks}\n", + "prompts": "Text Content:\n{Splitter:NineTiesSin@chunks}\n", "sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nExtract the most important keywords/phrases of a given piece of text content.\n\nRequirements\n- Summarize the text content, and give the top 5 important keywords/phrases.\n- The keywords MUST be in the same language as the given piece of text content.\n- The keywords are delimited by ENGLISH COMMA.\n- Output keywords ONLY.", "temperature": 0.1, "temperatureEnabled": false, @@ -522,7 +522,7 @@ "outputs": {}, "presencePenaltyEnabled": false, "presence_penalty": 0.4, - "prompts": "Text Content:\n\n{Splitter:KindDingosJam@chunks}\n", + "prompts": "Text Content:\n\n{Splitter:TastyPointsLay@chunks}\n", "sys_prompt": "Role\nYou are a text analyzer.\n\nTask\nPropose 3 questions about a given piece of text content.\n\nRequirements\n- Understand and summarize the text content, and propose the top 3 important questions.\n- The questions SHOULD NOT have overlapping meanings.\n- The questions SHOULD cover the main content of the text as much as possible.\n- The questions MUST be in the same language as the given piece of text content.\n- One question per line.\n- Output questions ONLY.", "temperature": 0.1, "temperatureEnabled": false, @@ -559,7 +559,7 @@ "outputs": {}, "presencePenaltyEnabled": false, "presence_penalty": 0.4, - "prompts": "Content: \n\n{Splitter:KindDingosJam@chunks}", + "prompts": "Content: \n\n{Splitter:BlueResultsWink@chunks}", "sys_prompt": "Extract important structured information from the given content. Output ONLY a valid JSON string with no additional text. If no important structured information is found, output an empty JSON object: {}.\n\nImportant structured information may include: names, dates, locations, events, key facts, numerical data, or other extractable entities.", "temperature": 0.1, "temperatureEnabled": false, diff --git a/agent/templates/customer_service.json b/agent/templates/customer_service.json index 24e022feda6..fc3704e5353 100644 --- a/agent/templates/customer_service.json +++ b/agent/templates/customer_service.json @@ -11,707 +11,952 @@ "zh": "多智能体系统,用于智能客服场景。基于用户意图分类,使用主智能体识别用户需求类型,并将任务分配给子智能体进行处理。"}, "canvas_type": "Agent", "dsl": { - "components": { - "Agent:RottenRiversDo": { - "downstream": [ - "Message:PurpleCitiesSee" - ], - "obj": { - "component_name": "Agent", - "params": { - "delay_after_error": 1, - "description": "", - "exception_comment": "", - "exception_default_value": "", - "exception_goto": [], - "exception_method": null, - "frequencyPenaltyEnabled": false, - "frequency_penalty": 0.7, - "llm_id": "deepseek-chat@DeepSeek", - "maxTokensEnabled": false, - "max_retries": 3, - "max_rounds": 2, - "max_tokens": 256, - "mcp": [], - "message_history_window_size": 12, - "outputs": { - "content": { - "type": "string", - "value": "" - } - }, - "presencePenaltyEnabled": false, - "presence_penalty": 0.4, - "prompts": [ - { - "content": "The user query is {sys.query}", - "role": "user" - } - ], - "sys_prompt": "# Role \n\nYou are **Customer Server Agent**. Classify every user message; handle **contact** yourself. This is a multi-agent system.\n\n## Categories \n\n1. **contact** \u2013 user gives phone, e\u2011mail, WeChat, Line, Discord, etc. \n\n2. **casual** \u2013 small talk, not about the product. \n\n3. **complain** \u2013 complaints or profanity about the product/service. \n\n4. **product** \u2013 questions on product use, appearance, function, or errors.\n\n## If contact \n\nReply with one random item below\u2014do not change wording or call sub\u2011agents: \n\n1. Okay, I've already written this down. What else can I do for you? \n\n2. Got it. What else can I do for you? \n\n3. Thanks for your trust! Our expert will contact you ASAP. Anything else I can help with? \n\n4. Thanks! Anything else I can do for you?\n\n\n---\n\n\n## Otherwise (casual\u202f/\u202fcomplain\u202f/\u202fproduct) \n\nLet Sub\u2011Agent returns its answer\n\n## Sub\u2011Agent \n\n- casual \u2192 **Casual Agent** \nThis is an agent for handles casual conversationk.\n\n- complain \u2192 **Soothe Agent** \nThis is an agent for handles complaints or emotional input.\n\n- product \u2192 **Product Agent** \nThis is an agent for handles product-related queries and can use the `Retrieval` tool.\n\n## Importance\n\n- When the Sub\u2011Agent returns its answer, forward that answer to the user verbatim \u2014 do not add, edit, or reason further.\n ", - "temperature": 0.1, - "temperatureEnabled": true, - "tools": [ - { - "component_name": "Agent", - "id": "Agent:SlowKiwisBehave", - "name": "Casual Agent", - "params": { - "delay_after_error": 1, - "description": "This is an agent for handles casual conversationk.", - "exception_comment": "", - "exception_default_value": "", - "exception_goto": [], - "exception_method": null, - "frequencyPenaltyEnabled": false, - "frequency_penalty": 0.3, - "llm_id": "deepseek-chat@DeepSeek", - "maxTokensEnabled": false, - "max_retries": 1, - "max_rounds": 1, - "max_tokens": 4096, - "mcp": [], - "message_history_window_size": 12, - "outputs": { - "content": { - "type": "string", - "value": "" - } - }, - "parameter": "Balance", - "presencePenaltyEnabled": false, - "presence_penalty": 0.2, - "prompts": [ - { - "content": "{sys.query}", - "role": "user" - } - ], - "sys_prompt": "You are a friendly and casual conversational assistant. \n\nYour primary goal is to engage users in light and enjoyable daily conversation. \n\n- Keep a natural, relaxed, and positive tone. \n\n- Avoid sensitive, controversial, or negative topics. \n\n- You may gently guide the conversation by introducing related casual topics if the user shows interest. \n\n", - "temperature": 0.5, - "temperatureEnabled": true, - "tools": [], - "topPEnabled": false, - "top_p": 0.85, - "user_prompt": "This is the order you need to send to the agent.", - "visual_files_var": "" - } - }, - { - "component_name": "Agent", - "id": "Agent:PoorTaxesRescue", - "name": "Soothe Agent", - "params": { - "delay_after_error": 1, - "description": "This is an agent for handles complaints or emotional input.", - "exception_comment": "", - "exception_default_value": "", - "exception_goto": [], - "exception_method": null, - "frequencyPenaltyEnabled": false, - "frequency_penalty": 0.3, - "llm_id": "deepseek-chat@DeepSeek", - "maxTokensEnabled": false, - "max_retries": 1, - "max_rounds": 1, - "max_tokens": 4096, - "mcp": [], - "message_history_window_size": 12, - "outputs": { - "content": { - "type": "string", - "value": "" - } - }, - "parameter": "Balance", - "presencePenaltyEnabled": false, - "presence_penalty": 0.2, - "prompts": [ - { - "content": "{sys.query}", - "role": "user" - } - ], - "sys_prompt": "You are an empathetic mood-soothing assistant. \n\nYour role is to comfort and encourage users when they feel upset or frustrated. \n\n- Use a warm, kind, and understanding tone. \n\n- Focus on showing empathy and emotional support rather than solving the problem directly. \n\n- Always encourage users with positive and reassuring statements. ", - "temperature": 0.5, - "temperatureEnabled": true, - "tools": [], - "topPEnabled": false, - "top_p": 0.85, - "user_prompt": "This is the order you need to send to the agent.", - "visual_files_var": "" - } - }, - { - "component_name": "Agent", - "id": "Agent:SillyTurkeysRest", - "name": "Product Agent", - "params": { - "delay_after_error": 1, - "description": "This is an agent for handles product-related queries and can use the `Retrieval` tool.", - "exception_comment": "", - "exception_default_value": "", - "exception_goto": [], - "exception_method": null, - "frequencyPenaltyEnabled": false, - "frequency_penalty": 0.7, - "llm_id": "deepseek-chat@DeepSeek", - "maxTokensEnabled": false, - "max_retries": 3, - "max_rounds": 2, - "max_tokens": 256, - "mcp": [], - "message_history_window_size": 12, - "outputs": { - "content": { - "type": "string", - "value": "" - } - }, - "presencePenaltyEnabled": false, - "presence_penalty": 0.4, - "prompts": [ - { - "content": "{sys.query}", - "role": "user" - } - ], - "sys_prompt": "# Role \n\nYou are a Product Information Advisor with access to the **Retrieval** tool.\n\n# Workflow \n\n1. Run **Retrieval** with a focused query from the user\u2019s question. \n\n2. Draft the reply **strictly** from the returned passages. \n\n3. If nothing relevant is retrieved, reply: \n\n \u201cI cannot find relevant documents in the knowledge base.\u201d\n\n# Rules \n\n- No assumptions, guesses, or extra\u2011KB knowledge. \n\n- Factual, concise. Use bullets / numbers when helpful. \n\n", - "temperature": 0.1, - "temperatureEnabled": true, - "tools": [ - { - "component_name": "Retrieval", - "name": "Retrieval", - "params": { - "cross_languages": [], - "description": "This is a product knowledge base", - "empty_response": "", - "kb_ids": [], - "keywords_similarity_weight": 0.7, - "outputs": { - "formalized_content": { - "type": "string", - "value": "" - } - }, - "rerank_id": "", - "similarity_threshold": 0.2, - "top_k": 1024, - "top_n": 8, - "use_kg": false - } - } - ], - "topPEnabled": false, - "top_p": 0.3, - "user_prompt": "This is the order you need to send to the agent.", - "visual_files_var": "" - } - } - ], - "topPEnabled": false, - "top_p": 0.3, - "user_prompt": "", - "visual_files_var": "" - } - }, - "upstream": [ - "begin" - ] - }, - "Message:PurpleCitiesSee": { - "downstream": [], - "obj": { - "component_name": "Message", - "params": { - "content": [ - "{Agent:RottenRiversDo@content}" - ] - } - }, - "upstream": [ - "Agent:RottenRiversDo" - ] + "components": { + "Agent:DullTownsHope": { + "downstream": [ + "VariableAggregator:FuzzyBerriesFlow" + ], + "obj": { + "component_name": "Agent", + "params": { + "delay_after_error": 1, + "description": "", + "exception_comment": "", + "exception_default_value": "", + "exception_goto": [], + "exception_method": null, + "frequencyPenaltyEnabled": false, + "frequency_penalty": 0.3, + "llm_id": "deepseek-chat@DeepSeek", + "maxTokensEnabled": false, + "max_retries": 3, + "max_rounds": 5, + "max_tokens": 4096, + "mcp": [], + "message_history_window_size": 12, + "outputs": { + "content": { + "type": "string", + "value": "" + } }, - "begin": { - "downstream": [ - "Agent:RottenRiversDo" - ], - "obj": { - "component_name": "Begin", - "params": { - "enablePrologue": true, - "inputs": {}, - "mode": "conversational", - "prologue": "Hi! I'm an official AI customer service representative. How can I help you?" - } - }, - "upstream": [] + "parameter": "Balance", + "presencePenaltyEnabled": false, + "presence_penalty": 0.2, + "prompts": [ + { + "content": "The user query is {sys.query}", + "role": "user" + } + ], + "sys_prompt": "You are an empathetic mood-soothing assistant. \n\nYour role is to comfort and encourage users when they feel upset or frustrated. \n\n- Use a warm, kind, and understanding tone. \n\n- Focus on showing empathy and emotional support rather than solving the problem directly. \n\n- Always encourage users with positive and reassuring statements. ", + "temperature": 0.5, + "temperatureEnabled": true, + "tools": [], + "topPEnabled": false, + "top_p": 0.85, + "user_prompt": "", + "visual_files_var": "" } }, - "globals": { - "sys.conversation_turns": 0, - "sys.files": [], - "sys.query": "", - "sys.user_id": "" + "upstream": [ + "Categorize:DullFriendsThank" + ] }, - "graph": { - "edges": [ + "Agent:KhakiSunsJudge": { + "downstream": [ + "VariableAggregator:FuzzyBerriesFlow" + ], + "obj": { + "component_name": "Agent", + "params": { + "delay_after_error": 1, + "description": "", + "exception_comment": "", + "exception_default_value": "", + "exception_goto": [], + "exception_method": null, + "frequencyPenaltyEnabled": false, + "frequency_penalty": 0.7, + "llm_id": "deepseek-chat@DeepSeek", + "maxTokensEnabled": false, + "max_retries": 3, + "max_rounds": 5, + "max_tokens": 256, + "mcp": [], + "message_history_window_size": 12, + "outputs": { + "content": { + "type": "string", + "value": "" + } + }, + "presencePenaltyEnabled": false, + "presence_penalty": 0.4, + "prompts": [ { - "data": { - "isHovered": false - }, - "id": "xy-edge__beginstart-Agent:RottenRiversDoend", - "source": "begin", - "sourceHandle": "start", - "target": "Agent:RottenRiversDo", - "targetHandle": "end" - }, + "content": "The user query is {sys.query}\n\nThe relevant document are {Retrieval:ShyPumasJoke@formalized_content}", + "role": "user" + } + ], + "sys_prompt": "You are a highly professional product information advisor. \n\nYour only mission is to provide accurate, factual, and structured answers to all product-related queries.\n\nAbsolutely no assumptions, guesses, or fabricated content are allowed. \n\n**Key Principles:**\n\n1. **Strict Database Reliance:** \n\n - Every answer must be based solely on the verified product information stored in the relevant documen.\n\n - You are NOT allowed to invent, speculate, or infer details beyond what is retrieved. \n\n - If you cannot find relevant data, respond with: *\"I cannot find this information in our official product database. Please check back later or provide more details for further search.\"*\n\n2. **Information Accuracy and Structure:** \n\n - Provide information in a clear, concise, and professional way. \n\n - Use bullet points or numbered lists if there are multiple key points (e.g., features, price, warranty, technical specifications). \n\n - Always specify the version or model number when applicable to avoid confusion.\n\n3. **Tone and Style:** \n\n - Maintain a polite, professional, and helpful tone at all times. \n\n - Avoid marketing exaggeration or promotional language; stay strictly factual. \n\n - Do not express personal opinions; only cite official product data.\n\n4. **User Guidance:** \n\n - If the user’s query is unclear or too broad, politely request clarification or guide them to provide more specific product details (e.g., product name, model, version). \n\n - Example: *\"Could you please specify the product model or category so I can retrieve the most relevant information for you?\"*\n\n5. **Response Length and Formatting:** \n\n - Keep each answer within 100–150 words for general queries. \n\n - For complex or multi-step explanations, you may extend to 200–250 words, but always remain clear and well-structured.\n\n6. **Critical Reminder:** \n\nYour authority and reliability depend entirely on the relevant document responses. Any fabricated, speculative, or unverified content will be considered a critical failure of your role.\n\n\n", + "temperature": 0.1, + "temperatureEnabled": true, + "tools": [], + "topPEnabled": false, + "top_p": 0.3, + "user_prompt": "", + "visual_files_var": "" + } + }, + "upstream": [ + "Retrieval:ShyPumasJoke" + ] + }, + "Agent:TwelveOwlsWatch": { + "downstream": [ + "VariableAggregator:FuzzyBerriesFlow" + ], + "obj": { + "component_name": "Agent", + "params": { + "delay_after_error": 1, + "description": "", + "exception_comment": "", + "exception_default_value": "", + "exception_goto": [], + "exception_method": null, + "frequencyPenaltyEnabled": false, + "frequency_penalty": 0.3, + "llm_id": "deepseek-chat@DeepSeek", + "maxTokensEnabled": false, + "max_retries": 3, + "max_rounds": 5, + "max_tokens": 4096, + "mcp": [], + "message_history_window_size": 12, + "outputs": { + "content": { + "type": "string", + "value": "" + } + }, + "parameter": "Balance", + "presencePenaltyEnabled": false, + "presence_penalty": 0.2, + "prompts": [ { - "data": { - "isHovered": false - }, - "id": "xy-edge__Agent:RottenRiversDoagentBottom-Agent:SlowKiwisBehaveagentTop", - "source": "Agent:RottenRiversDo", - "sourceHandle": "agentBottom", - "target": "Agent:SlowKiwisBehave", - "targetHandle": "agentTop" + "content": "The user query is {sys.query}", + "role": "user" + } + ], + "sys_prompt": "You are a friendly and casual conversational assistant. \n\nYour primary goal is to engage users in light and enjoyable daily conversation. \n\n- Keep a natural, relaxed, and positive tone. \n\n- Avoid sensitive, controversial, or negative topics. \n\n- You may gently guide the conversation by introducing related casual topics if the user shows interest. \n\n", + "temperature": 0.5, + "temperatureEnabled": true, + "tools": [], + "topPEnabled": false, + "top_p": 0.85, + "user_prompt": "", + "visual_files_var": "" + } + }, + "upstream": [ + "Categorize:DullFriendsThank" + ] + }, + "Categorize:DullFriendsThank": { + "downstream": [ + "Message:BreezyDonutsHeal", + "Agent:TwelveOwlsWatch", + "Agent:DullTownsHope", + "Retrieval:ShyPumasJoke" + ], + "obj": { + "component_name": "Categorize", + "params": { + "category_description": { + "1. contact": { + "description": "This answer provide a specific contact information, like e-mail, phone number, wechat number, line number, twitter, discord, etc,.", + "examples": [ + "My phone number is 203921\nkevinhu.hk@gmail.com\nThis is my discord number: johndowson_29384\n13212123432\n8379829" + ], + "to": [ + "Message:BreezyDonutsHeal" + ] }, - { - "data": { - "isHovered": false - }, - "id": "xy-edge__Agent:RottenRiversDoagentBottom-Agent:PoorTaxesRescueagentTop", - "source": "Agent:RottenRiversDo", - "sourceHandle": "agentBottom", - "target": "Agent:PoorTaxesRescue", - "targetHandle": "agentTop" + "2. casual": { + "description": "The question is not about the product usage, appearance and how it works. Just casual chat.", + "examples": [ + "How are you doing?\nWhat is your name?\nAre you a robot?\nWhat's the weather?\nWill it rain?" + ], + "to": [ + "Agent:TwelveOwlsWatch" + ] }, - { - "data": { - "isHovered": false - }, - "id": "xy-edge__Agent:RottenRiversDoagentBottom-Agent:SillyTurkeysRestagentTop", - "source": "Agent:RottenRiversDo", - "sourceHandle": "agentBottom", - "target": "Agent:SillyTurkeysRest", - "targetHandle": "agentTop" + "3. complain": { + "description": "Complain even curse about the product or service you provide. But the comment is not specific enough.", + "examples": [ + "How bad is it.\nIt's really sucks.\nDamn, for God's sake, can it be more steady?\nShit, I just can't use this shit.\nI can't stand it anymore." + ], + "to": [ + "Agent:DullTownsHope" + ] }, + "4. product related": { + "description": "The question is about the product usage, appearance and how it works.", + "examples": [ + "Why it always beaming?\nHow to install it onto the wall?\nIt leaks, what to do?\nException: Can't connect to ES cluster\nHow to build the RAGFlow image from scratch" + ], + "to": [ + "Retrieval:ShyPumasJoke" + ] + } + }, + "llm_id": "deepseek-chat@DeepSeek", + "message_history_window_size": 1, + "outputs": { + "category_name": { + "type": "string" + } + }, + "query": "sys.query", + "temperature": "0.1" + } + }, + "upstream": [ + "begin" + ] + }, + "Message:BreezyDonutsHeal": { + "downstream": [], + "obj": { + "component_name": "Message", + "params": { + "content": [ + "Okay, I've already write this down. What else I can do for you?", + "Get it. What else I can do for you?", + "Thanks for your trust! Our expert will contact ASAP. So, anything else I can do for you?", + "Thanks! So, anything else I can do for you?" + ] + } + }, + "upstream": [ + "Categorize:DullFriendsThank" + ] + }, + "Message:DryBusesCarry": { + "downstream": [], + "obj": { + "component_name": "Message", + "params": { + "content": [ + "{VariableAggregator:FuzzyBerriesFlow@LLM_Response}" + ] + } + }, + "upstream": [ + "VariableAggregator:FuzzyBerriesFlow" + ] + }, + "Retrieval:ShyPumasJoke": { + "downstream": [ + "Agent:KhakiSunsJudge" + ], + "obj": { + "component_name": "Retrieval", + "params": { + "cross_languages": [], + "empty_response": "", + "kb_ids": [], + "keywords_similarity_weight": 0.7, + "outputs": { + "formalized_content": { + "type": "string", + "value": "" + } + }, + "query": "sys.query", + "rerank_id": "", + "similarity_threshold": 0.2, + "top_k": 1024, + "top_n": 8, + "use_kg": false + } + }, + "upstream": [ + "Categorize:DullFriendsThank" + ] + }, + "VariableAggregator:FuzzyBerriesFlow": { + "downstream": [ + "Message:DryBusesCarry" + ], + "obj": { + "component_name": "VariableAggregator", + "params": { + "groups": [ { - "data": { - "isHovered": false + "group_name": "LLM_Response", + "type": "string", + "variables": [ + { + "value": "Agent:TwelveOwlsWatch@content" }, - "id": "xy-edge__Agent:SillyTurkeysResttool-Tool:CrazyShirtsKissend", - "source": "Agent:SillyTurkeysRest", - "sourceHandle": "tool", - "target": "Tool:CrazyShirtsKiss", - "targetHandle": "end" - }, - { - "data": { - "isHovered": false + { + "value": "Agent:DullTownsHope@content" }, - "id": "xy-edge__Agent:RottenRiversDostart-Message:PurpleCitiesSeeend", - "source": "Agent:RottenRiversDo", - "sourceHandle": "start", - "target": "Message:PurpleCitiesSee", - "targetHandle": "end" + { + "value": "Agent:KhakiSunsJudge@content" + } + ] } ], - "nodes": [ + "outputs": { + "LLM_Response": { + "type": "string" + } + } + } + }, + "upstream": [ + "Agent:DullTownsHope", + "Agent:TwelveOwlsWatch", + "Agent:KhakiSunsJudge" + ] + }, + "begin": { + "downstream": [ + "Categorize:DullFriendsThank" + ], + "obj": { + "component_name": "Begin", + "params": { + "enablePrologue": true, + "inputs": {}, + "mode": "conversational", + "prologue": "Hi! I'm an official AI customer service representative. How can I help you?" + } + }, + "upstream": [] + } + }, + "globals": { + "sys.conversation_turns": 0, + "sys.files": [], + "sys.query": "", + "sys.user_id": "" + }, + "graph": { + "edges": [ + { + "data": { + "isHovered": false + }, + "id": "xy-edge__beginstart-Categorize:DullFriendsThankend", + "source": "begin", + "sourceHandle": "start", + "target": "Categorize:DullFriendsThank", + "targetHandle": "end" + }, + { + "data": { + "isHovered": false + }, + "id": "xy-edge__Categorize:DullFriendsThanke4d754a5-a33e-4096-8648-8688e5474a15-Message:BreezyDonutsHealend", + "source": "Categorize:DullFriendsThank", + "sourceHandle": "e4d754a5-a33e-4096-8648-8688e5474a15", + "target": "Message:BreezyDonutsHeal", + "targetHandle": "end" + }, + { + "data": { + "isHovered": false + }, + "id": "xy-edge__Categorize:DullFriendsThank8cbf6ea3-a176-490d-9f8c-86373c932583-Agent:TwelveOwlsWatchend", + "source": "Categorize:DullFriendsThank", + "sourceHandle": "8cbf6ea3-a176-490d-9f8c-86373c932583", + "target": "Agent:TwelveOwlsWatch", + "targetHandle": "end" + }, + { + "data": { + "isHovered": false + }, + "id": "xy-edge__Categorize:DullFriendsThankacc40a78-1b9e-4d2f-b5d6-64e01ab69269-Agent:DullTownsHopeend", + "source": "Categorize:DullFriendsThank", + "sourceHandle": "acc40a78-1b9e-4d2f-b5d6-64e01ab69269", + "target": "Agent:DullTownsHope", + "targetHandle": "end" + }, + { + "data": { + "isHovered": false + }, + "id": "xy-edge__Categorize:DullFriendsThankdfa5eead-9341-4f22-9236-068dbfb745e8-Retrieval:ShyPumasJokeend", + "source": "Categorize:DullFriendsThank", + "sourceHandle": "dfa5eead-9341-4f22-9236-068dbfb745e8", + "target": "Retrieval:ShyPumasJoke", + "targetHandle": "end" + }, + { + "data": { + "isHovered": false + }, + "id": "xy-edge__Retrieval:ShyPumasJokestart-Agent:KhakiSunsJudgeend", + "source": "Retrieval:ShyPumasJoke", + "sourceHandle": "start", + "target": "Agent:KhakiSunsJudge", + "targetHandle": "end" + }, + { + "data": { + "isHovered": false + }, + "id": "xy-edge__Agent:DullTownsHopestart-VariableAggregator:FuzzyBerriesFlowend", + "source": "Agent:DullTownsHope", + "sourceHandle": "start", + "target": "VariableAggregator:FuzzyBerriesFlow", + "targetHandle": "end" + }, + { + "id": "xy-edge__Agent:TwelveOwlsWatchstart-VariableAggregator:FuzzyBerriesFlowend", + "markerEnd": "logo", + "source": "Agent:TwelveOwlsWatch", + "sourceHandle": "start", + "target": "VariableAggregator:FuzzyBerriesFlow", + "targetHandle": "end", + "type": "buttonEdge", + "zIndex": 1001 + }, + { + "data": { + "isHovered": false + }, + "id": "xy-edge__Agent:KhakiSunsJudgestart-VariableAggregator:FuzzyBerriesFlowend", + "markerEnd": "logo", + "source": "Agent:KhakiSunsJudge", + "sourceHandle": "start", + "target": "VariableAggregator:FuzzyBerriesFlow", + "targetHandle": "end", + "type": "buttonEdge", + "zIndex": 1001 + }, + { + "id": "xy-edge__VariableAggregator:FuzzyBerriesFlowstart-Message:DryBusesCarryend", + "source": "VariableAggregator:FuzzyBerriesFlow", + "sourceHandle": "start", + "target": "Message:DryBusesCarry", + "targetHandle": "end" + } + ], + "nodes": [ + { + "data": { + "form": { + "enablePrologue": true, + "inputs": {}, + "mode": "conversational", + "prologue": "Hi! I'm an official AI customer service representative. How can I help you?" + }, + "label": "Begin", + "name": "begin" + }, + "id": "begin", + "measured": { + "height": 48, + "width": 200 + }, + "position": { + "x": 50, + "y": 200 + }, + "selected": false, + "sourcePosition": "left", + "targetPosition": "right", + "type": "beginNode" + }, + { + "data": { + "form": { + "frequencyPenaltyEnabled": false, + "frequency_penalty": 0.5, + "items": [ { - "data": { - "form": { - "enablePrologue": true, - "inputs": {}, - "mode": "conversational", - "prologue": "Hi! I'm an official AI customer service representative. How can I help you?" - }, - "label": "Begin", - "name": "begin" - }, - "id": "begin", - "measured": { - "height": 48, - "width": 200 - }, - "position": { - "x": 50, - "y": 200 - }, - "selected": false, - "sourcePosition": "left", - "targetPosition": "right", - "type": "beginNode" + "description": "This answer provide a specific contact information, like e-mail, phone number, wechat number, line number, twitter, discord, etc,.", + "examples": [ + { + "value": "My phone number is 203921\nkevinhu.hk@gmail.com\nThis is my discord number: johndowson_29384\n13212123432\n8379829" + } + ], + "name": "1. contact", + "uuid": "e4d754a5-a33e-4096-8648-8688e5474a15" }, { - "data": { - "form": { - "delay_after_error": 1, - "description": "", - "exception_comment": "", - "exception_default_value": "", - "exception_goto": [], - "exception_method": null, - "frequencyPenaltyEnabled": false, - "frequency_penalty": 0.7, - "llm_id": "deepseek-chat@DeepSeek", - "maxTokensEnabled": false, - "max_retries": 3, - "max_rounds": 2, - "max_tokens": 256, - "mcp": [], - "message_history_window_size": 12, - "outputs": { - "content": { - "type": "string", - "value": "" - } - }, - "presencePenaltyEnabled": false, - "presence_penalty": 0.4, - "prompts": [ - { - "content": "The user query is {sys.query}", - "role": "user" - } - ], - "sys_prompt": "# Role \n\nYou are **Customer Server Agent**. Classify every user message; handle **contact** yourself. This is a multi-agent system.\n\n## Categories \n\n1. **contact** \u2013 user gives phone, e\u2011mail, WeChat, Line, Discord, etc. \n\n2. **casual** \u2013 small talk, not about the product. \n\n3. **complain** \u2013 complaints or profanity about the product/service. \n\n4. **product** \u2013 questions on product use, appearance, function, or errors.\n\n## If contact \n\nReply with one random item below\u2014do not change wording or call sub\u2011agents: \n\n1. Okay, I've already written this down. What else can I do for you? \n\n2. Got it. What else can I do for you? \n\n3. Thanks for your trust! Our expert will contact you ASAP. Anything else I can help with? \n\n4. Thanks! Anything else I can do for you?\n\n\n---\n\n\n## Otherwise (casual\u202f/\u202fcomplain\u202f/\u202fproduct) \n\nLet Sub\u2011Agent returns its answer\n\n## Sub\u2011Agent \n\n- casual \u2192 **Casual Agent** \nThis is an agent for handles casual conversationk.\n\n- complain \u2192 **Soothe Agent** \nThis is an agent for handles complaints or emotional input.\n\n- product \u2192 **Product Agent** \nThis is an agent for handles product-related queries and can use the `Retrieval` tool.\n\n## Importance\n\n- When the Sub\u2011Agent returns its answer, forward that answer to the user verbatim \u2014 do not add, edit, or reason further.\n ", - "temperature": 0.1, - "temperatureEnabled": true, - "tools": [], - "topPEnabled": false, - "top_p": 0.3, - "user_prompt": "", - "visual_files_var": "" - }, - "label": "Agent", - "name": "Customer Server Agent" - }, - "dragging": false, - "id": "Agent:RottenRiversDo", - "measured": { - "height": 84, - "width": 200 - }, - "position": { - "x": 350, - "y": 198.88981333505626 - }, - "selected": false, - "sourcePosition": "right", - "targetPosition": "left", - "type": "agentNode" + "description": "The question is not about the product usage, appearance and how it works. Just casual chat.", + "examples": [ + { + "value": "How are you doing?\nWhat is your name?\nAre you a robot?\nWhat's the weather?\nWill it rain?" + } + ], + "name": "2. casual", + "uuid": "8cbf6ea3-a176-490d-9f8c-86373c932583" }, { - "data": { - "form": { - "delay_after_error": 1, - "description": "This is an agent for handles casual conversationk.", - "exception_comment": "", - "exception_default_value": "", - "exception_goto": [], - "exception_method": null, - "frequencyPenaltyEnabled": false, - "frequency_penalty": 0.3, - "llm_id": "deepseek-chat@DeepSeek", - "maxTokensEnabled": false, - "max_retries": 1, - "max_rounds": 1, - "max_tokens": 4096, - "mcp": [], - "message_history_window_size": 12, - "outputs": { - "content": { - "type": "string", - "value": "" - } - }, - "parameter": "Balance", - "presencePenaltyEnabled": false, - "presence_penalty": 0.2, - "prompts": [ - { - "content": "{sys.query}", - "role": "user" - } - ], - "sys_prompt": "You are a friendly and casual conversational assistant. \n\nYour primary goal is to engage users in light and enjoyable daily conversation. \n\n- Keep a natural, relaxed, and positive tone. \n\n- Avoid sensitive, controversial, or negative topics. \n\n- You may gently guide the conversation by introducing related casual topics if the user shows interest. \n\n", - "temperature": 0.5, - "temperatureEnabled": true, - "tools": [], - "topPEnabled": false, - "top_p": 0.85, - "user_prompt": "This is the order you need to send to the agent.", - "visual_files_var": "" - }, - "label": "Agent", - "name": "Casual Agent" - }, - "dragging": false, - "id": "Agent:SlowKiwisBehave", - "measured": { - "height": 84, - "width": 200 - }, - "position": { - "x": 124.4782938105834, - "y": 402.1704532368496 - }, - "selected": false, - "sourcePosition": "right", - "targetPosition": "left", - "type": "agentNode" + "description": "Complain even curse about the product or service you provide. But the comment is not specific enough.", + "examples": [ + { + "value": "How bad is it.\nIt's really sucks.\nDamn, for God's sake, can it be more steady?\nShit, I just can't use this shit.\nI can't stand it anymore." + } + ], + "name": "3. complain", + "uuid": "acc40a78-1b9e-4d2f-b5d6-64e01ab69269" }, { - "data": { - "form": { - "delay_after_error": 1, - "description": "This is an agent for handles complaints or emotional input.", - "exception_comment": "", - "exception_default_value": "", - "exception_goto": [], - "exception_method": null, - "frequencyPenaltyEnabled": false, - "frequency_penalty": 0.3, - "llm_id": "deepseek-chat@DeepSeek", - "maxTokensEnabled": false, - "max_retries": 1, - "max_rounds": 1, - "max_tokens": 4096, - "mcp": [], - "message_history_window_size": 12, - "outputs": { - "content": { - "type": "string", - "value": "" - } - }, - "parameter": "Balance", - "presencePenaltyEnabled": false, - "presence_penalty": 0.2, - "prompts": [ - { - "content": "{sys.query}", - "role": "user" - } - ], - "sys_prompt": "You are an empathetic mood-soothing assistant. \n\nYour role is to comfort and encourage users when they feel upset or frustrated. \n\n- Use a warm, kind, and understanding tone. \n\n- Focus on showing empathy and emotional support rather than solving the problem directly. \n\n- Always encourage users with positive and reassuring statements. ", - "temperature": 0.5, - "temperatureEnabled": true, - "tools": [], - "topPEnabled": false, - "top_p": 0.85, - "user_prompt": "This is the order you need to send to the agent.", - "visual_files_var": "" - }, - "label": "Agent", - "name": "Soothe Agent" - }, - "dragging": false, - "id": "Agent:PoorTaxesRescue", - "measured": { - "height": 84, - "width": 200 - }, - "position": { - "x": 402.02090711979577, - "y": 363.3139199638186 - }, - "selected": false, - "sourcePosition": "right", - "targetPosition": "left", - "type": "agentNode" + "description": "The question is about the product usage, appearance and how it works.", + "examples": [ + { + "value": "Why it always beaming?\nHow to install it onto the wall?\nIt leaks, what to do?\nException: Can't connect to ES cluster\nHow to build the RAGFlow image from scratch" + } + ], + "name": "4. product related", + "uuid": "dfa5eead-9341-4f22-9236-068dbfb745e8" + } + ], + "llm_id": "deepseek-chat@DeepSeek", + "maxTokensEnabled": false, + "max_tokens": 4096, + "message_history_window_size": 1, + "outputs": { + "category_name": { + "type": "string" + } }, - { - "data": { - "form": { - "delay_after_error": 1, - "description": "This is an agent for handles product-related queries and can use the `Retrieval` tool.", - "exception_comment": "", - "exception_default_value": "", - "exception_goto": [], - "exception_method": null, - "frequencyPenaltyEnabled": false, - "frequency_penalty": 0.7, - "llm_id": "deepseek-chat@DeepSeek", - "maxTokensEnabled": false, - "max_retries": 3, - "max_rounds": 2, - "max_tokens": 256, - "mcp": [], - "message_history_window_size": 12, - "outputs": { - "content": { - "type": "string", - "value": "" - } - }, - "presencePenaltyEnabled": false, - "presence_penalty": 0.4, - "prompts": [ - { - "content": "{sys.query}", - "role": "user" - } - ], - "sys_prompt": "# Role \n\nYou are a Product Information Advisor with access to the **Retrieval** tool.\n\n# Workflow \n\n1. Run **Retrieval** with a focused query from the user\u2019s question. \n\n2. Draft the reply **strictly** from the returned passages. \n\n3. If nothing relevant is retrieved, reply: \n\n \u201cI cannot find relevant documents in the knowledge base.\u201d\n\n# Rules \n\n- No assumptions, guesses, or extra\u2011KB knowledge. \n\n- Factual, concise. Use bullets / numbers when helpful. \n\n", - "temperature": 0.1, - "temperatureEnabled": true, - "tools": [ - { - "component_name": "Retrieval", - "name": "Retrieval", - "params": { - "cross_languages": [], - "description": "This is a product knowledge base", - "empty_response": "", - "kb_ids": [], - "keywords_similarity_weight": 0.7, - "outputs": { - "formalized_content": { - "type": "string", - "value": "" - } - }, - "rerank_id": "", - "similarity_threshold": 0.2, - "top_k": 1024, - "top_n": 8, - "use_kg": false - } - } - ], - "topPEnabled": false, - "top_p": 0.3, - "user_prompt": "This is the order you need to send to the agent.", - "visual_files_var": "" - }, - "label": "Agent", - "name": "Product Agent" - }, - "dragging": false, - "id": "Agent:SillyTurkeysRest", - "measured": { - "height": 84, - "width": 200 - }, - "position": { - "x": 684.0042670887832, - "y": 317.79626670112515 - }, - "selected": false, - "sourcePosition": "right", - "targetPosition": "left", - "type": "agentNode" + "parameter": "Precise", + "presencePenaltyEnabled": false, + "presence_penalty": 0.5, + "query": "sys.query", + "temperature": "0.1", + "temperatureEnabled": true, + "topPEnabled": false, + "top_p": 0.75 + }, + "label": "Categorize", + "name": "Categorize" + }, + "dragging": false, + "id": "Categorize:DullFriendsThank", + "measured": { + "height": 218, + "width": 200 + }, + "position": { + "x": 377.1140727959881, + "y": 138.1799140251472 + }, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "categorizeNode" + }, + { + "data": { + "form": { + "content": [ + "Okay, I've already write this down. What else I can do for you?", + "Get it. What else I can do for you?", + "Thanks for your trust! Our expert will contact ASAP. So, anything else I can do for you?", + "Thanks! So, anything else I can do for you?" + ] + }, + "label": "Message", + "name": "What else?" + }, + "dragging": false, + "id": "Message:BreezyDonutsHeal", + "measured": { + "height": 56, + "width": 200 + }, + "position": { + "x": 724.8348409169271, + "y": 60.09138437270154 + }, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "messageNode" + }, + { + "data": { + "form": { + "delay_after_error": 1, + "description": "", + "exception_comment": "", + "exception_default_value": "", + "exception_goto": [], + "exception_method": null, + "frequencyPenaltyEnabled": false, + "frequency_penalty": 0.3, + "llm_id": "deepseek-chat@DeepSeek", + "maxTokensEnabled": false, + "max_retries": 3, + "max_rounds": 5, + "max_tokens": 4096, + "mcp": [], + "message_history_window_size": 12, + "outputs": { + "content": { + "type": "string", + "value": "" + } }, + "parameter": "Balance", + "presencePenaltyEnabled": false, + "presence_penalty": 0.2, + "prompts": [ { - "data": { - "form": { - "description": "This is an agent for a specific task.", - "user_prompt": "This is the order you need to send to the agent." - }, - "label": "Tool", - "name": "flow.tool_0" - }, - "dragging": false, - "id": "Tool:CrazyShirtsKiss", - "measured": { - "height": 48, - "width": 200 - }, - "position": { - "x": 659.7339736658578, - "y": 443.3638400568565 - }, - "selected": false, - "sourcePosition": "right", - "targetPosition": "left", - "type": "toolNode" + "content": "The user query is {sys.query}", + "role": "user" + } + ], + "sys_prompt": "You are a friendly and casual conversational assistant. \n\nYour primary goal is to engage users in light and enjoyable daily conversation. \n\n- Keep a natural, relaxed, and positive tone. \n\n- Avoid sensitive, controversial, or negative topics. \n\n- You may gently guide the conversation by introducing related casual topics if the user shows interest. \n\n", + "temperature": 0.5, + "temperatureEnabled": true, + "tools": [], + "topPEnabled": false, + "top_p": 0.85, + "user_prompt": "", + "visual_files_var": "" + }, + "label": "Agent", + "name": "Causal chat" + }, + "dragging": false, + "id": "Agent:TwelveOwlsWatch", + "measured": { + "height": 84, + "width": 200 + }, + "position": { + "x": 720.4965892695689, + "y": 167.46311264481432 + }, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "agentNode" + }, + { + "data": { + "form": { + "delay_after_error": 1, + "description": "", + "exception_comment": "", + "exception_default_value": "", + "exception_goto": [], + "exception_method": null, + "frequencyPenaltyEnabled": false, + "frequency_penalty": 0.3, + "llm_id": "deepseek-chat@DeepSeek", + "maxTokensEnabled": false, + "max_retries": 3, + "max_rounds": 5, + "max_tokens": 4096, + "mcp": [], + "message_history_window_size": 12, + "outputs": { + "content": { + "type": "string", + "value": "" + } }, + "parameter": "Balance", + "presencePenaltyEnabled": false, + "presence_penalty": 0.2, + "prompts": [ { - "data": { - "form": { - "content": [ - "{Agent:RottenRiversDo@content}" - ] - }, - "label": "Message", - "name": "Response" - }, - "dragging": false, - "id": "Message:PurpleCitiesSee", - "measured": { - "height": 56, - "width": 200 - }, - "position": { - "x": 675.534293293706, - "y": 158.92309339708154 - }, - "selected": false, - "sourcePosition": "right", - "targetPosition": "left", - "type": "messageNode" + "content": "The user query is {sys.query}", + "role": "user" + } + ], + "sys_prompt": "You are an empathetic mood-soothing assistant. \n\nYour role is to comfort and encourage users when they feel upset or frustrated. \n\n- Use a warm, kind, and understanding tone. \n\n- Focus on showing empathy and emotional support rather than solving the problem directly. \n\n- Always encourage users with positive and reassuring statements. ", + "temperature": 0.5, + "temperatureEnabled": true, + "tools": [], + "topPEnabled": false, + "top_p": 0.85, + "user_prompt": "", + "visual_files_var": "" + }, + "label": "Agent", + "name": "Soothe mood" + }, + "dragging": false, + "id": "Agent:DullTownsHope", + "measured": { + "height": 84, + "width": 200 + }, + "position": { + "x": 722.665715093248, + "y": 281.3422183879642 + }, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "agentNode" + }, + { + "data": { + "form": { + "cross_languages": [], + "empty_response": "", + "kb_ids": [], + "keywords_similarity_weight": 0.7, + "outputs": { + "formalized_content": { + "type": "string", + "value": "" + } }, - { - "data": { - "form": { - "text": "This is a multi-agent system for intelligent customer service processing based on user intent classification. It uses the lead-agent to identify the type of user needs, assign tasks to sub-agents for processing, and finally the lead agent outputs the results." - }, - "label": "Note", - "name": "Workflow Overall Description" - }, - "dragHandle": ".note-drag-handle", - "dragging": false, - "height": 140, - "id": "Note:MoodyTurtlesCount", - "measured": { - "height": 140, - "width": 385 - }, - "position": { - "x": -59.311679338397, - "y": -2.2203733298874866 - }, - "resizing": false, - "selected": false, - "sourcePosition": "right", - "targetPosition": "left", - "type": "noteNode", - "width": 385 + "query": "sys.query", + "rerank_id": "", + "similarity_threshold": 0.2, + "top_k": 1024, + "top_n": 8, + "use_kg": false + }, + "label": "Retrieval", + "name": "Search product info" + }, + "dragging": false, + "id": "Retrieval:ShyPumasJoke", + "measured": { + "height": 50, + "width": 200 + }, + "position": { + "x": 645.6873721057459, + "y": 516.6923702571407 + }, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "retrievalNode" + }, + { + "data": { + "form": { + "delay_after_error": 1, + "description": "", + "exception_comment": "", + "exception_default_value": "", + "exception_goto": [], + "exception_method": null, + "frequencyPenaltyEnabled": false, + "frequency_penalty": 0.7, + "llm_id": "deepseek-chat@DeepSeek", + "maxTokensEnabled": false, + "max_retries": 3, + "max_rounds": 5, + "max_tokens": 256, + "mcp": [], + "message_history_window_size": 12, + "outputs": { + "content": { + "type": "string", + "value": "" + } }, + "presencePenaltyEnabled": false, + "presence_penalty": 0.4, + "prompts": [ { - "data": { - "form": { - "text": "Answers will be given strictly according to the content retrieved from the knowledge base." - }, - "label": "Note", - "name": "Product Agent " - }, - "dragHandle": ".note-drag-handle", - "dragging": false, - "id": "Note:ColdCoinsBathe", - "measured": { - "height": 136, - "width": 249 + "content": "The user query is {sys.query}\n\nThe relevant document are {Retrieval:ShyPumasJoke@formalized_content}", + "role": "user" + } + ], + "sys_prompt": "You are a highly professional product information advisor. \n\nYour only mission is to provide accurate, factual, and structured answers to all product-related queries.\n\nAbsolutely no assumptions, guesses, or fabricated content are allowed. \n\n**Key Principles:**\n\n1. **Strict Database Reliance:** \n\n - Every answer must be based solely on the verified product information stored in the relevant documen.\n\n - You are NOT allowed to invent, speculate, or infer details beyond what is retrieved. \n\n - If you cannot find relevant data, respond with: *\"I cannot find this information in our official product database. Please check back later or provide more details for further search.\"*\n\n2. **Information Accuracy and Structure:** \n\n - Provide information in a clear, concise, and professional way. \n\n - Use bullet points or numbered lists if there are multiple key points (e.g., features, price, warranty, technical specifications). \n\n - Always specify the version or model number when applicable to avoid confusion.\n\n3. **Tone and Style:** \n\n - Maintain a polite, professional, and helpful tone at all times. \n\n - Avoid marketing exaggeration or promotional language; stay strictly factual. \n\n - Do not express personal opinions; only cite official product data.\n\n4. **User Guidance:** \n\n - If the user’s query is unclear or too broad, politely request clarification or guide them to provide more specific product details (e.g., product name, model, version). \n\n - Example: *\"Could you please specify the product model or category so I can retrieve the most relevant information for you?\"*\n\n5. **Response Length and Formatting:** \n\n - Keep each answer within 100–150 words for general queries. \n\n - For complex or multi-step explanations, you may extend to 200–250 words, but always remain clear and well-structured.\n\n6. **Critical Reminder:** \n\nYour authority and reliability depend entirely on the relevant document responses. Any fabricated, speculative, or unverified content will be considered a critical failure of your role.\n\n\n", + "temperature": 0.1, + "temperatureEnabled": true, + "tools": [], + "topPEnabled": false, + "top_p": 0.3, + "user_prompt": "", + "visual_files_var": "" + }, + "label": "Agent", + "name": "Product info" + }, + "dragging": false, + "id": "Agent:KhakiSunsJudge", + "measured": { + "height": 84, + "width": 200 + }, + "position": { + "x": 726.580040161058, + "y": 386.5448208363979 + }, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "agentNode" + }, + { + "data": { + "form": { + "text": "This is an intelligent customer service processing system workflow based on user intent classification. It uses LLM to identify user demand types and transfers them to the corresponding professional agent for processing." + }, + "label": "Note", + "name": "Workflow Overall Description" + }, + "dragHandle": ".note-drag-handle", + "dragging": false, + "height": 171, + "id": "Note:AllGuestsShow", + "measured": { + "height": 171, + "width": 380 + }, + "position": { + "x": -283.6407251474677, + "y": 157.2943019466498 + }, + "resizing": false, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "noteNode", + "width": 380 + }, + { + "data": { + "form": { + "text": "Here, product document snippets related to the user's question will be retrieved from the knowledge base first, and the relevant document snippets will be passed to the LLM together with the user's question." + }, + "label": "Note", + "name": "Product info Agent" + }, + "dragHandle": ".note-drag-handle", + "dragging": false, + "height": 154, + "id": "Note:IcyBooksCough", + "measured": { + "height": 154, + "width": 370 + }, + "position": { + "x": 1014.0959071234828, + "y": 492.830874176321 + }, + "resizing": false, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "noteNode", + "width": 370 + }, + { + "data": { + "form": { + "text": "Here, a text will be randomly selected for answering" + }, + "label": "Note", + "name": "What else?" + }, + "dragHandle": ".note-drag-handle", + "dragging": false, + "id": "Note:AllThingsHide", + "measured": { + "height": 136, + "width": 249 + }, + "position": { + "x": 770.7060131788647, + "y": -123.23496705283817 + }, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "noteNode" + }, + { + "data": { + "form": { + "groups": [ + { + "group_name": "LLM_Response", + "type": "string", + "variables": [ + { + "value": "Agent:TwelveOwlsWatch@content" }, - "position": { - "x": 994.4238924667025, - "y": 329.08949370720796 + { + "value": "Agent:DullTownsHope@content" }, - "selected": false, - "sourcePosition": "right", - "targetPosition": "left", - "type": "noteNode" + { + "value": "Agent:KhakiSunsJudge@content" + } + ] } - ] + ], + "outputs": { + "LLM_Response": { + "type": "string" + } + } + }, + "label": "VariableAggregator", + "name": "Variable aggregator" + }, + "dragging": false, + "id": "VariableAggregator:FuzzyBerriesFlow", + "measured": { + "height": 150, + "width": 200 + }, + "position": { + "x": 1061.596672609154, + "y": 247.90496561846572 + }, + "selected": false, + "sourcePosition": "right", + "targetPosition": "left", + "type": "variableAggregatorNode" }, - "history": [], - "messages": [], - "path": [], - "retrieval": [] + { + "data": { + "form": { + "content": [ + "{VariableAggregator:FuzzyBerriesFlow@LLM_Response}" + ] + }, + "label": "Message", + "name": "Response" + }, + "dragging": false, + "id": "Message:DryBusesCarry", + "measured": { + "height": 50, + "width": 200 + }, + "position": { + "x": 1364.5500382017049, + "y": 296.59667260915404 + }, + "selected": true, + "sourcePosition": "right", + "targetPosition": "left", + "type": "messageNode" + } + ] }, + "history": [], + "messages": [], + "path": [], + "retrieval": [], + "variables": {} + }, "avatar": "" } \ No newline at end of file diff --git a/agent/templates/sql_assistant.json b/agent/templates/sql_assistant.json index 92804abc6ee..6c6030f67d7 100644 --- a/agent/templates/sql_assistant.json +++ b/agent/templates/sql_assistant.json @@ -83,10 +83,10 @@ "value": [] } }, - "password": "20010812Yy!", + "password": "", "port": 3306, "sql": "{Agent:WickedGoatsDivide@content}", - "username": "13637682833@163.com" + "username": "" } }, "upstream": [ @@ -527,10 +527,10 @@ "value": [] } }, - "password": "20010812Yy!", + "password": "", "port": 3306, "sql": "{Agent:WickedGoatsDivide@content}", - "username": "13637682833@163.com" + "username": "" }, "label": "ExeSQL", "name": "ExeSQL" @@ -578,7 +578,7 @@ { "data": { "form": { - "text": "Searches for relevant database creation statements.\n\nIt should label with a knowledgebase to which the schema is dumped in. You could use \" General \" as parsing method, \" 2 \" as chunk size and \" ; \" as delimiter." + "text": "Searches for relevant database creation statements.\n\nIt should label with a dataset to which the schema is dumped in. You could use \" General \" as parsing method, \" 2 \" as chunk size and \" ; \" as delimiter." }, "label": "Note", "name": "Note Schema" diff --git a/agent/test/dsl_examples/categorize_and_agent_with_tavily.json b/agent/test/dsl_examples/categorize_and_agent_with_tavily.json index 7d956744664..49738f14d93 100644 --- a/agent/test/dsl_examples/categorize_and_agent_with_tavily.json +++ b/agent/test/dsl_examples/categorize_and_agent_with_tavily.json @@ -75,7 +75,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/agent/test/dsl_examples/iteration.json b/agent/test/dsl_examples/iteration.json index dd44484239a..dc976aa8b1f 100644 --- a/agent/test/dsl_examples/iteration.json +++ b/agent/test/dsl_examples/iteration.json @@ -82,7 +82,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/agent/test/dsl_examples/retrieval_and_generate.json b/agent/test/dsl_examples/retrieval_and_generate.json index 9f9f9bac4f4..3897e877fde 100644 --- a/agent/test/dsl_examples/retrieval_and_generate.json +++ b/agent/test/dsl_examples/retrieval_and_generate.json @@ -31,7 +31,7 @@ "component_name": "LLM", "params": { "llm_id": "deepseek-chat", - "sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.", + "sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n Above is the knowledge base.", "temperature": 0.2 } }, @@ -51,7 +51,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/agent/test/dsl_examples/retrieval_categorize_and_generate.json b/agent/test/dsl_examples/retrieval_categorize_and_generate.json index c506b9a6bfc..2b8dfb779e4 100644 --- a/agent/test/dsl_examples/retrieval_categorize_and_generate.json +++ b/agent/test/dsl_examples/retrieval_categorize_and_generate.json @@ -65,7 +65,7 @@ "component_name": "Agent", "params": { "llm_id": "deepseek-chat", - "sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.", + "sys_prompt": "You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.", "temperature": 0.2 } }, @@ -85,7 +85,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/agent/test/dsl_examples/tavily_and_generate.json b/agent/test/dsl_examples/tavily_and_generate.json index f2f79b4b73a..95739224a08 100644 --- a/agent/test/dsl_examples/tavily_and_generate.json +++ b/agent/test/dsl_examples/tavily_and_generate.json @@ -25,7 +25,7 @@ "component_name": "LLM", "params": { "llm_id": "deepseek-chat", - "sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {tavily:0@formalized_content}\n The above is the knowledge base.", + "sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {tavily:0@formalized_content}\n Above is the knowledge base.", "temperature": 0.2 } }, @@ -45,7 +45,7 @@ }, "history": [], "path": [], - "retrival": {"chunks": [], "doc_aggs": []}, + "retrieval": {"chunks": [], "doc_aggs": []}, "globals": { "sys.query": "", "sys.user_id": "", diff --git a/agent/tools/base.py b/agent/tools/base.py index a3d569694a4..ac8336f5d32 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -17,13 +17,13 @@ import re import time from copy import deepcopy +import asyncio from functools import partial from typing import TypedDict, List, Any from agent.component.base import ComponentParamBase, ComponentBase from common.misc_utils import hash_str2int -from rag.llm.chat_model import ToolCallSession from rag.prompts.generator import kb_prompt -from rag.utils.mcp_tool_call_conn import MCPToolCallSession +from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession from timeit import default_timer as timer @@ -49,12 +49,19 @@ def __init__(self, tools_map: dict[str, object], callback: partial): self.callback = callback def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: + return asyncio.run(self.tool_call_async(name, arguments)) + + async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any: assert name in self.tools_map, f"LLM tool {name} does not exist" st = timer() - if isinstance(self.tools_map[name], MCPToolCallSession): - resp = self.tools_map[name].tool_call(name, arguments, 60) + tool_obj = self.tools_map[name] + if isinstance(tool_obj, MCPToolCallSession): + resp = await asyncio.to_thread(tool_obj.tool_call, name, arguments, 60) else: - resp = self.tools_map[name].invoke(**arguments) + if hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): + resp = await tool_obj.invoke_async(**arguments) + else: + resp = await asyncio.to_thread(tool_obj.invoke, **arguments) self.callback(name, arguments, resp, elapsed_time=timer()-st) return resp @@ -140,6 +147,33 @@ def invoke(self, **kwargs): self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) return res + async def invoke_async(self, **kwargs): + """ + Async wrapper for tool invocation. + If `_invoke` is a coroutine, await it directly; otherwise run in a thread to avoid blocking. + Mirrors the exception handling of `invoke`. + """ + if self.check_if_canceled("Tool processing"): + return + + self.set_output("_created_time", time.perf_counter()) + try: + fn_async = getattr(self, "_invoke_async", None) + if fn_async and asyncio.iscoroutinefunction(fn_async): + res = await fn_async(**kwargs) + elif asyncio.iscoroutinefunction(self._invoke): + res = await self._invoke(**kwargs) + else: + res = await asyncio.to_thread(self._invoke, **kwargs) + except Exception as e: + self._param.outputs["_ERROR"] = {"value": str(e)} + logging.exception(e) + res = str(e) + self._param.debug_inputs = [] + + self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time")) + return res + def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None): chunks = [] aggs = [] diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py index adba4168e28..678d56f020a 100644 --- a/agent/tools/code_exec.py +++ b/agent/tools/code_exec.py @@ -13,16 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import ast import base64 +import json import logging import os from abc import ABC -from strenum import StrEnum from typing import Optional + from pydantic import BaseModel, Field, field_validator -from agent.tools.base import ToolParamBase, ToolBase, ToolMeta -from common.connection_utils import timeout +from strenum import StrEnum + +from agent.tools.base import ToolBase, ToolMeta, ToolParamBase from common import settings +from common.connection_utils import timeout class Language(StrEnum): @@ -62,10 +66,10 @@ class CodeExecParam(ToolParamBase): """ def __init__(self): - self.meta:ToolMeta = { + self.meta: ToolMeta = { "name": "execute_code", "description": """ -This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string. +This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It receives a piece of code and return a Json string. Here's a code example for Python(`main` function MUST be included): def main() -> dict: \"\"\" @@ -99,16 +103,12 @@ def fibonacci_recursive(n): "enum": ["python", "javascript"], "required": True, }, - "script": { - "type": "string", - "description": "A piece of code in right format. There MUST be main function.", - "required": True - } - } + "script": {"type": "string", "description": "A piece of code in right format. There MUST be main function.", "required": True}, + }, } super().__init__() self.lang = Language.PYTHON.value - self.script = "def main(arg1: str, arg2: str) -> dict: return {\"result\": arg1 + arg2}" + self.script = 'def main(arg1: str, arg2: str) -> dict: return {"result": arg1 + arg2}' self.arguments = {} self.outputs = {"result": {"value": "", "type": "string"}} @@ -119,17 +119,14 @@ def check(self): def get_input_form(self) -> dict[str, dict]: res = {} for k, v in self.arguments.items(): - res[k] = { - "type": "line", - "name": k - } + res[k] = {"type": "line", "name": k} return res class CodeExec(ToolBase, ABC): component_name = "CodeExec" - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) def _invoke(self, **kwargs): if self.check_if_canceled("CodeExec processing"): return @@ -138,17 +135,12 @@ def _invoke(self, **kwargs): script = kwargs.get("script", self._param.script) arguments = {} for k, v in self._param.arguments.items(): - if kwargs.get(k): arguments[k] = kwargs[k] continue arguments[k] = self._canvas.get_variable_value(v) if v else None - self._execute_code( - language=lang, - code=script, - arguments=arguments - ) + self._execute_code(language=lang, code=script, arguments=arguments) def _execute_code(self, language: str, code: str, arguments: dict): import requests @@ -169,7 +161,7 @@ def _execute_code(self, language: str, code: str, arguments: dict): if self.check_if_canceled("CodeExec execution"): return "Task has been canceled" - resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:") if self.check_if_canceled("CodeExec execution"): @@ -183,35 +175,10 @@ def _execute_code(self, language: str, code: str, arguments: dict): if stderr: self.set_output("_ERROR", stderr) return - try: - rt = eval(body.get("stdout", "")) - except Exception: - rt = body.get("stdout", "") - logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}") - if isinstance(rt, tuple): - for i, (k, o) in enumerate(self._param.outputs.items()): - if self.check_if_canceled("CodeExec execution"): - return - - if k.find("_") == 0: - continue - o["value"] = rt[i] - elif isinstance(rt, dict): - for i, (k, o) in enumerate(self._param.outputs.items()): - if self.check_if_canceled("CodeExec execution"): - return - - if k not in rt or k.find("_") == 0: - continue - o["value"] = rt[k] - else: - for i, (k, o) in enumerate(self._param.outputs.items()): - if self.check_if_canceled("CodeExec execution"): - return - - if k.find("_") == 0: - continue - o["value"] = rt + raw_stdout = body.get("stdout", "") + parsed_stdout = self._deserialize_stdout(raw_stdout) + logging.info(f"[CodeExec]: http://{settings.SANDBOX_HOST}:9385/run -> {parsed_stdout}") + self._populate_outputs(parsed_stdout, raw_stdout) else: self.set_output("_ERROR", "There is no response from sandbox") @@ -228,3 +195,149 @@ def _encode_code(self, code: str) -> str: def thoughts(self) -> str: return "Running a short script to process data." + + def _deserialize_stdout(self, stdout: str): + text = str(stdout).strip() + if not text: + return "" + for loader in (json.loads, ast.literal_eval): + try: + return loader(text) + except Exception: + continue + return text + + def _coerce_output_value(self, value, expected_type: Optional[str]): + if expected_type is None: + return value + + etype = expected_type.strip().lower() + inner_type = None + if etype.startswith("array<") and etype.endswith(">"): + inner_type = etype[6:-1].strip() + etype = "array" + + try: + if etype == "string": + return "" if value is None else str(value) + + if etype == "number": + if value is None or value == "": + return None + if isinstance(value, (int, float)): + return value + if isinstance(value, str): + try: + return float(value) + except Exception: + return value + return float(value) + + if etype == "boolean": + if isinstance(value, bool): + return value + if isinstance(value, str): + lv = value.lower() + if lv in ("true", "1", "yes", "y", "on"): + return True + if lv in ("false", "0", "no", "n", "off"): + return False + return bool(value) + + if etype == "array": + candidate = value + if isinstance(candidate, str): + parsed = self._deserialize_stdout(candidate) + candidate = parsed + if isinstance(candidate, tuple): + candidate = list(candidate) + if not isinstance(candidate, list): + candidate = [] if candidate is None else [candidate] + + if inner_type == "string": + return ["" if v is None else str(v) for v in candidate] + if inner_type == "number": + coerced = [] + for v in candidate: + try: + if v is None or v == "": + coerced.append(None) + elif isinstance(v, (int, float)): + coerced.append(v) + else: + coerced.append(float(v)) + except Exception: + coerced.append(v) + return coerced + return candidate + + if etype == "object": + if isinstance(value, dict): + return value + if isinstance(value, str): + parsed = self._deserialize_stdout(value) + if isinstance(parsed, dict): + return parsed + return value + except Exception: + return value + + return value + + def _populate_outputs(self, parsed_stdout, raw_stdout: str): + outputs_items = list(self._param.outputs.items()) + logging.info(f"[CodeExec]: outputs schema keys: {[k for k, _ in outputs_items]}") + if not outputs_items: + return + + if isinstance(parsed_stdout, dict): + for key, meta in outputs_items: + if key.startswith("_"): + continue + val = self._get_by_path(parsed_stdout, key) + coerced = self._coerce_output_value(val, meta.get("type")) + logging.info(f"[CodeExec]: populate dict key='{key}' raw='{val}' coerced='{coerced}'") + self.set_output(key, coerced) + return + + if isinstance(parsed_stdout, (list, tuple)): + for idx, (key, meta) in enumerate(outputs_items): + if key.startswith("_"): + continue + val = parsed_stdout[idx] if idx < len(parsed_stdout) else None + coerced = self._coerce_output_value(val, meta.get("type")) + logging.info(f"[CodeExec]: populate list key='{key}' raw='{val}' coerced='{coerced}'") + self.set_output(key, coerced) + return + + default_val = parsed_stdout if parsed_stdout is not None else raw_stdout + for idx, (key, meta) in enumerate(outputs_items): + if key.startswith("_"): + continue + val = default_val if idx == 0 else None + coerced = self._coerce_output_value(val, meta.get("type")) + logging.info(f"[CodeExec]: populate scalar key='{key}' raw='{val}' coerced='{coerced}'") + self.set_output(key, coerced) + + def _get_by_path(self, data, path: str): + if not path: + return None + cur = data + for part in path.split("."): + part = part.strip() + if not part: + return None + if isinstance(cur, dict): + cur = cur.get(part) + elif isinstance(cur, list): + try: + idx = int(part) + cur = cur[idx] + except Exception: + return None + else: + return None + if cur is None: + return None + logging.info(f"[CodeExec]: resolve path '{path}' -> {cur}") + return cur diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index ab388a08ee3..21df960befb 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio from functools import partial import json import os @@ -21,13 +22,15 @@ from agent.tools.base import ToolParamBase, ToolBase, ToolMeta from common.constants import LLMType from api.db.services.document_service import DocumentService -from api.db.services.dialog_service import meta_filter +from common.metadata_utils import apply_meta_data_filter from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from api.db.services.memory_service import MemoryService +from api.db.joint_services import memory_message_service from common import settings from common.connection_utils import timeout from rag.app.tag import label_question -from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter +from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt class RetrievalParam(ToolParamBase): @@ -56,6 +59,7 @@ def __init__(self): self.top_n = 8 self.top_k = 1024 self.kb_ids = [] + self.memory_ids = [] self.kb_vars = [] self.rerank_id = "" self.empty_response = "" @@ -80,15 +84,7 @@ def get_input_form(self) -> dict[str, dict]: class Retrieval(ToolBase, ABC): component_name = "Retrieval" - @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) - def _invoke(self, **kwargs): - if self.check_if_canceled("Retrieval processing"): - return - - if not kwargs.get("query"): - self.set_output("formalized_content", self._param.empty_response) - return - + async def _retrieve_kb(self, query_text: str): kb_ids: list[str] = [] for id in self._param.kb_ids: if id.find("@") < 0: @@ -123,54 +119,58 @@ def _invoke(self, **kwargs): if self._param.rerank_id: rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) - vars = self.get_input_elements_from_text(kwargs["query"]) - vars = {k:o["value"] for k,o in vars.items()} - query = self.string_format(kwargs["query"], vars) + vars = self.get_input_elements_from_text(query_text) + vars = {k: o["value"] for k, o in vars.items()} + query = self.string_format(query_text, vars) - doc_ids=[] - if self._param.meta_data_filter!={}: + doc_ids = [] + if self._param.meta_data_filter != {}: metas = DocumentService.get_meta_by_kbs(kb_ids) - if self._param.meta_data_filter.get("method") == "auto": + + def _resolve_manual_filter(flt: dict) -> dict: + pat = re.compile(self.variable_ref_patt) + s = flt.get("value", "") + out_parts = [] + last = 0 + + for m in pat.finditer(s): + out_parts.append(s[last:m.start()]) + key = m.group(1) + v = self._canvas.get_variable_value(key) + if v is None: + rep = "" + elif isinstance(v, partial): + buf = [] + for chunk in v(): + buf.append(chunk) + rep = "".join(buf) + elif isinstance(v, str): + rep = v + else: + rep = json.dumps(v, ensure_ascii=False) + + out_parts.append(rep) + last = m.end() + + out_parts.append(s[last:]) + flt["value"] = "".join(out_parts) + return flt + + chat_mdl = None + if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT) - filters = gen_meta_filter(chat_mdl, metas, query) - doc_ids.extend(meta_filter(metas, filters)) - if not doc_ids: - doc_ids = None - elif self._param.meta_data_filter.get("method") == "manual": - filters=self._param.meta_data_filter["manual"] - for flt in filters: - pat = re.compile(self.variable_ref_patt) - s = flt["value"] - out_parts = [] - last = 0 - - for m in pat.finditer(s): - out_parts.append(s[last:m.start()]) - key = m.group(1) - v = self._canvas.get_variable_value(key) - if v is None: - rep = "" - elif isinstance(v, partial): - buf = [] - for chunk in v(): - buf.append(chunk) - rep = "".join(buf) - elif isinstance(v, str): - rep = v - else: - rep = json.dumps(v, ensure_ascii=False) - - out_parts.append(rep) - last = m.end() - - out_parts.append(s[last:]) - flt["value"] = "".join(out_parts) - doc_ids.extend(meta_filter(metas, filters)) - if not doc_ids: - doc_ids = None + + doc_ids = await apply_meta_data_filter( + self._param.meta_data_filter, + metas, + query, + chat_mdl, + doc_ids, + _resolve_manual_filter if self._param.meta_data_filter.get("method") == "manual" else None, + ) if self._param.cross_languages: - query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) + query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) if kbs: query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) @@ -193,17 +193,20 @@ def _invoke(self, **kwargs): if self._param.toc_enhance: chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) - cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) + cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], + chat_mdl, self._param.top_n) if self.check_if_canceled("Retrieval processing"): return if cks: kbinfos["chunks"] = cks + kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], + [kb.tenant_id for kb in kbs]) if self._param.use_kg: ck = settings.kg_retriever.retrieval(query, - [kb.tenant_id for kb in kbs], - kb_ids, - embd_mdl, - LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)) + [kb.tenant_id for kb in kbs], + kb_ids, + embd_mdl, + LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)) if self.check_if_canceled("Retrieval processing"): return if ck["content_with_weight"]: @@ -212,7 +215,8 @@ def _invoke(self, **kwargs): kbinfos = {"chunks": [], "doc_aggs": []} if self._param.use_kg and kbs: - ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) + ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, + LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) if self.check_if_canceled("Retrieval processing"): return if ck["content_with_weight"]: @@ -242,6 +246,58 @@ def _invoke(self, **kwargs): return form_cnt + async def _retrieve_memory(self, query_text: str): + memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids] + memory_list = MemoryService.get_by_ids(memory_ids) + if not memory_list: + raise Exception("No memory is selected.") + + embd_names = list({memory.embd_id for memory in memory_list}) + assert len(embd_names) == 1, "Memory use different embedding models." + + vars = self.get_input_elements_from_text(query_text) + vars = {k: o["value"] for k, o in vars.items()} + query = self.string_format(query_text, vars) + # query message + message_list = memory_message_service.query_message({"memory_id": memory_ids}, { + "query": query, + "similarity_threshold": self._param.similarity_threshold, + "keywords_similarity_weight": self._param.keywords_similarity_weight, + "top_n": self._param.top_n + }) + if not message_list: + self.set_output("formalized_content", self._param.empty_response) + return "" + formated_content = "\n".join(memory_prompt(message_list, 200000)) + # set formalized_content output + self.set_output("formalized_content", formated_content) + + return formated_content + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) + async def _invoke_async(self, **kwargs): + if self.check_if_canceled("Retrieval processing"): + return + if not kwargs.get("query"): + self.set_output("formalized_content", self._param.empty_response) + return + + if hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "dataset": + return await self._retrieve_kb(kwargs["query"]) + elif hasattr(self._param, "retrieval_from") and self._param.retrieval_from == "memory": + return await self._retrieve_memory(kwargs["query"]) + elif self._param.kb_ids: + return await self._retrieve_kb(kwargs["query"]) + elif hasattr(self._param, "memory_ids") and self._param.memory_ids: + return await self._retrieve_memory(kwargs["query"]) + else: + self.set_output("formalized_content", self._param.empty_response) + return + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) + def _invoke(self, **kwargs): + return asyncio.run(self._invoke_async(**kwargs)) + def thoughts(self) -> str: return """ Keywords: {} diff --git a/agent/tools/yahoofinance.py b/agent/tools/yahoofinance.py index 324dfb64308..06a4a9dad45 100644 --- a/agent/tools/yahoofinance.py +++ b/agent/tools/yahoofinance.py @@ -75,7 +75,7 @@ class YahooFinance(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) def _invoke(self, **kwargs): if self.check_if_canceled("YahooFinance processing"): - return + return None if not kwargs.get("stock_code"): self.set_output("report", "") @@ -84,33 +84,33 @@ def _invoke(self, **kwargs): last_e = "" for _ in range(self._param.max_retries+1): if self.check_if_canceled("YahooFinance processing"): - return + return None - yohoo_res = [] + yahoo_res = [] try: msft = yf.Ticker(kwargs["stock_code"]) if self.check_if_canceled("YahooFinance processing"): - return + return None if self._param.info: - yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n") + yahoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n") if self._param.history: - yohoo_res.append("# History:\n" + msft.history().to_markdown() + "\n") + yahoo_res.append("# History:\n" + msft.history().to_markdown() + "\n") if self._param.financials: - yohoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n") + yahoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n") if self._param.balance_sheet: - yohoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n") - yohoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n") + yahoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n") + yahoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n") if self._param.cash_flow_statement: - yohoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n") - yohoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n") + yahoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n") + yahoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n") if self._param.news: - yohoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n") - self.set_output("report", "\n\n".join(yohoo_res)) + yahoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n") + self.set_output("report", "\n\n".join(yahoo_res)) return self.output("report") except Exception as e: if self.check_if_canceled("YahooFinance processing"): - return + return None last_e = e logging.exception(f"YahooFinance error: {e}") diff --git a/agentic_reasoning/deep_research.py b/agentic_reasoning/deep_research.py index d7121245f0f..20f7017f474 100644 --- a/agentic_reasoning/deep_research.py +++ b/agentic_reasoning/deep_research.py @@ -51,7 +51,7 @@ def _remove_result_tags(text: str) -> str: """Remove Result Tags""" return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT) - def _generate_reasoning(self, msg_history): + async def _generate_reasoning(self, msg_history): """Generate reasoning steps""" query_think = "" if msg_history[-1]["role"] != "user": @@ -59,13 +59,14 @@ def _generate_reasoning(self, msg_history): else: msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n" - for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): + async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) if not ans: continue query_think = ans yield query_think - return query_think + query_think = "" + yield query_think def _extract_search_queries(self, query_think, question, step_index): """Extract search queries from thinking""" @@ -143,10 +144,10 @@ def _update_chunk_info(self, chunk_info, kbinfos): if d["doc_id"] not in dids: chunk_info["doc_aggs"].append(d) - def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos): + async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos): """Extract and summarize relevant information""" summary_think = "" - for ans in self.chat_mdl.chat_streamly( + async for ans in self.chat_mdl.async_chat_streamly( RELEVANT_EXTRACTION_PROMPT.format( prev_reasoning=truncated_prev_reasoning, search_query=search_query, @@ -160,10 +161,11 @@ def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos continue summary_think = ans yield summary_think + summary_think = "" - return summary_think + yield summary_think - def thinking(self, chunk_info: dict, question: str): + async def thinking(self, chunk_info: dict, question: str): executed_search_queries = [] msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}] all_reasoning_steps = [] @@ -180,7 +182,7 @@ def thinking(self, chunk_info: dict, question: str): # Step 1: Generate reasoning query_think = "" - for ans in self._generate_reasoning(msg_history): + async for ans in self._generate_reasoning(msg_history): query_think = ans yield {"answer": think + self._remove_query_tags(query_think) + "", "reference": {}, "audio_binary": None} @@ -223,7 +225,7 @@ def thinking(self, chunk_info: dict, question: str): # Step 6: Extract relevant information think += "\n\n" summary_think = "" - for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos): + async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos): summary_think = ans yield {"answer": think + self._remove_result_tags(summary_think) + "", "reference": {}, "audio_binary": None} diff --git a/api/__init__.py b/api/__init__.py index 643f79713c8..a42cd9a6dd7 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -14,5 +14,5 @@ # limitations under the License. # -from beartype.claw import beartype_this_package -beartype_this_package() +# from beartype.claw import beartype_this_package +# beartype_this_package() diff --git a/api/apps/__init__.py b/api/apps/__init__.py index f2009db2c16..c329679f8fb 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -13,36 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging import os import sys -import logging from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from flask import Blueprint, Flask -from werkzeug.wrappers.request import Request -from flask_cors import CORS +from quart import Blueprint, Quart, request, g, current_app, session from flasgger import Swagger from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer - +from quart_cors import cors from common.constants import StatusEnum -from api.db.db_models import close_connection +from api.db.db_models import close_connection, APIToken from api.db.services import UserService from api.utils.json_encode import CustomJSONEncoder from api.utils import commands -from flask_mail import Mail -from flask_session import Session -from flask_login import LoginManager +from quart_auth import Unauthorized from common import settings from api.utils.api_utils import server_error_response from api.constants import API_VERSION +from common.misc_utils import get_uuid -__all__ = ["app"] +settings.init_settings() -Request.json = property(lambda self: self.get_json(force=True, silent=True)) +__all__ = ["app"] -app = Flask(__name__) -smtp_mail_server = Mail() +app = Quart(__name__) +app = cors(app, allow_origin="*") # Add this at the beginning of your file to configure Swagger UI swagger_config = { @@ -76,32 +73,168 @@ }, ) -CORS(app, supports_credentials=True, max_age=2592000) app.url_map.strict_slashes = False app.json_encoder = CustomJSONEncoder app.errorhandler(Exception)(server_error_response) +# Configure Quart timeouts for slow LLM responses (e.g., local Ollama on CPU) +# Default Quart timeouts are 60 seconds which is too short for many LLM backends +app.config["RESPONSE_TIMEOUT"] = int(os.environ.get("QUART_RESPONSE_TIMEOUT", 600)) +app.config["BODY_TIMEOUT"] = int(os.environ.get("QUART_BODY_TIMEOUT", 600)) + ## convince for dev and debug # app.config["LOGIN_DISABLED"] = True app.config["SESSION_PERMANENT"] = False -app.config["SESSION_TYPE"] = "filesystem" +app.config["SESSION_TYPE"] = "redis" +app.config["SESSION_REDIS"] = settings.decrypt_database_config(name="redis") app.config["MAX_CONTENT_LENGTH"] = int( os.environ.get("MAX_CONTENT_LENGTH", 1024 * 1024 * 1024) ) +app.config['SECRET_KEY'] = settings.SECRET_KEY +app.secret_key = settings.SECRET_KEY +commands.register_commands(app) -Session(app) -login_manager = LoginManager() -login_manager.init_app(app) +from functools import wraps +from typing import ParamSpec, TypeVar +from collections.abc import Awaitable, Callable +from werkzeug.local import LocalProxy -commands.register_commands(app) +T = TypeVar("T") +P = ParamSpec("P") + + +def _load_user(): + jwt = Serializer(secret_key=settings.SECRET_KEY) + authorization = request.headers.get("Authorization") + g.user = None + if not authorization: + return None + + try: + access_token = str(jwt.loads(authorization)) + + if not access_token or not access_token.strip(): + logging.warning("Authentication attempt with empty access token") + return None + + # Access tokens should be UUIDs (32 hex characters) + if len(access_token.strip()) < 32: + logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") + return None + + user = UserService.query( + access_token=access_token, status=StatusEnum.VALID.value + ) + if not user and len(authorization.split()) == 2: + objs = APIToken.query(token=authorization.split()[1]) + if objs: + user = UserService.query(id=objs[0].tenant_id, status=StatusEnum.VALID.value) + if user: + if not user[0].access_token or not user[0].access_token.strip(): + logging.warning(f"User {user[0].email} has empty access_token in database") + return None + g.user = user[0] + return user[0] + except Exception as e: + logging.warning(f"load_user got exception {e}") + + +current_user = LocalProxy(_load_user) + + +def login_required(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + """A decorator to restrict route access to authenticated users. + + This should be used to wrap a route handler (or view function) to + enforce that only authenticated requests can access it. Note that + it is important that this decorator be wrapped by the route + decorator and not vice, versa, as below. + + .. code-block:: python + @app.route('/') + @login_required + async def index(): + ... -def search_pages_path(pages_dir): + If the request is not authenticated a + `quart.exceptions.Unauthorized` exception will be raised. + + """ + + @wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + if not current_user: # or not session.get("_user_id"): + raise Unauthorized() + else: + return await current_app.ensure_async(func)(*args, **kwargs) + + return wrapper + + +def login_user(user, remember=False, duration=None, force=False, fresh=True): + """ + Logs a user in. You should pass the actual user object to this. If the + user's `is_active` property is ``False``, they will not be logged in + unless `force` is ``True``. + + This will return ``True`` if the login attempt succeeds, and ``False`` if + it fails (i.e. because the user is inactive). + + :param user: The user object to log in. + :type user: object + :param remember: Whether to remember the user after their session expires. + Defaults to ``False``. + :type remember: bool + :param duration: The amount of time before the remember cookie expires. If + ``None`` the value set in the settings is used. Defaults to ``None``. + :type duration: :class:`datetime.timedelta` + :param force: If the user is inactive, setting this to ``True`` will log + them in regardless. Defaults to ``False``. + :type force: bool + :param fresh: setting this to ``False`` will log in the user with a session + marked as not "fresh". Defaults to ``True``. + :type fresh: bool + """ + if not force and not user.is_active: + return False + + session["_user_id"] = user.id + session["_fresh"] = fresh + session["_id"] = get_uuid() + return True + + +def logout_user(): + """ + Logs a user out. (You do not need to pass the actual user.) This will + also clean up the remember me cookie if it exists. + """ + if "_user_id" in session: + session.pop("_user_id") + + if "_fresh" in session: + session.pop("_fresh") + + if "_id" in session: + session.pop("_id") + + COOKIE_NAME = "remember_token" + cookie_name = current_app.config.get("REMEMBER_COOKIE_NAME", COOKIE_NAME) + if cookie_name in request.cookies: + session["_remember"] = "clear" + if "_remember_seconds" in session: + session.pop("_remember_seconds") + + return True + + +def search_pages_path(page_path): app_path_list = [ - path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".") + path for path in page_path.glob("*_app.py") if not path.name.startswith(".") ] api_path_list = [ - path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".") + path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".") ] app_path_list.extend(api_path_list) return app_path_list @@ -138,44 +271,22 @@ def register_page(page_path): ] client_urls_prefix = [ - register_page(path) for dir in pages_dir for path in search_pages_path(dir) + register_page(path) for directory in pages_dir for path in search_pages_path(directory) ] -@login_manager.request_loader -def load_user(web_request): - jwt = Serializer(secret_key=settings.SECRET_KEY) - authorization = web_request.headers.get("Authorization") - if authorization: - try: - access_token = str(jwt.loads(authorization)) - - if not access_token or not access_token.strip(): - logging.warning("Authentication attempt with empty access token") - return None - - # Access tokens should be UUIDs (32 hex characters) - if len(access_token.strip()) < 32: - logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") - return None - - user = UserService.query( - access_token=access_token, status=StatusEnum.VALID.value - ) - if user: - if not user[0].access_token or not user[0].access_token.strip(): - logging.warning(f"User {user[0].email} has empty access_token in database") - return None - return user[0] - else: - return None - except Exception as e: - logging.warning(f"load_user got exception {e}") - return None - else: - return None +@app.errorhandler(404) +async def not_found(error): + error_msg: str = f"The requested URL {request.path} was not found" + logging.error(error_msg) + return { + "error": "Not Found", + "message": error_msg, + }, 404 @app.teardown_request -def _db_close(exc): +def _db_close(exception): + if exception: + logging.exception(f"Request failed: {exception}") close_connection() diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 1ab1c462ac8..97d7dc94302 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -13,46 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import json -import os -import re from datetime import datetime, timedelta -from flask import request, Response -from api.db.services.llm_service import LLMBundle -from flask_login import login_required, current_user - -from api.db import VALID_FILE_TYPES, FileType -from api.db.db_models import APIToken, Task, File -from api.db.services import duplicate_name +from quart import request +from api.db.db_models import APIToken from api.db.services.api_service import APITokenService, API4ConversationService -from api.db.services.dialog_service import DialogService, chat -from api.db.services.document_service import DocumentService, doc_upload_and_parse -from api.db.services.file2document_service import File2DocumentService -from api.db.services.file_service import FileService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.db.services.task_service import queue_tasks, TaskService from api.db.services.user_service import UserTenantService -from common.misc_utils import get_uuid -from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource -from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \ - generate_confirmation_token - -from api.utils.file_utils import filename_type, thumbnail -from rag.app.tag import label_question -from rag.prompts.generator import keyword_extraction +from api.utils.api_utils import generate_confirmation_token, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.time_utils import current_timestamp, datetime_format - -from api.db.services.canvas_service import UserCanvasService -from agent.canvas import Canvas -from functools import partial -from pathlib import Path -from common import settings +from api.apps import login_required, current_user @manager.route('/new_token', methods=['POST']) # noqa: F821 @login_required -def new_token(): - req = request.json +async def new_token(): + req = await get_request_json() try: tenants = UserTenantService.query(user_id=current_user.id) if not tenants: @@ -97,8 +71,8 @@ def token_list(): @manager.route('/rm', methods=['POST']) # noqa: F821 @validate_request("tokens", "tenant_id") @login_required -def rm(): - req = request.json +async def rm(): + req = await get_request_json() try: for token in req["tokens"]: APITokenService.filter_delete( @@ -126,770 +100,18 @@ def stats(): "to_date", datetime.now().strftime("%Y-%m-%d %H:%M:%S")), "agent" if "canvas_id" in request.args else None) - res = { - "pv": [(o["dt"], o["pv"]) for o in objs], - "uv": [(o["dt"], o["uv"]) for o in objs], - "speed": [(o["dt"], float(o["tokens"]) / (float(o["duration"] + 0.1))) for o in objs], - "tokens": [(o["dt"], float(o["tokens"]) / 1000.) for o in objs], - "round": [(o["dt"], o["round"]) for o in objs], - "thumb_up": [(o["dt"], o["thumb_up"]) for o in objs] - } - return get_json_result(data=res) - except Exception as e: - return server_error_response(e) - - -@manager.route('/new_conversation', methods=['GET']) # noqa: F821 -def set_conversation(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - try: - if objs[0].source == "agent": - e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id) - if not e: - return server_error_response("canvas not found.") - if not isinstance(cvs.dsl, str): - cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - canvas = Canvas(cvs.dsl, objs[0].tenant_id) - conv = { - "id": get_uuid(), - "dialog_id": cvs.id, - "user_id": request.args.get("user_id", ""), - "message": [{"role": "assistant", "content": canvas.get_prologue()}], - "source": "agent" - } - API4ConversationService.save(**conv) - return get_json_result(data=conv) - else: - e, dia = DialogService.get_by_id(objs[0].dialog_id) - if not e: - return get_data_error_result(message="Dialog not found") - conv = { - "id": get_uuid(), - "dialog_id": dia.id, - "user_id": request.args.get("user_id", ""), - "message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}] - } - API4ConversationService.save(**conv) - return get_json_result(data=conv) - except Exception as e: - return server_error_response(e) - - -@manager.route('/completion', methods=['POST']) # noqa: F821 -@validate_request("conversation_id", "messages") -def completion(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - req = request.json - e, conv = API4ConversationService.get_by_id(req["conversation_id"]) - if not e: - return get_data_error_result(message="Conversation not found!") - if "quote" not in req: - req["quote"] = False - - msg = [] - for m in req["messages"]: - if m["role"] == "system": - continue - if m["role"] == "assistant" and not msg: - continue - msg.append(m) - if not msg[-1].get("id"): - msg[-1]["id"] = get_uuid() - message_id = msg[-1]["id"] - - def fillin_conv(ans): - nonlocal conv, message_id - if not conv.reference: - conv.reference.append(ans["reference"]) - else: - conv.reference[-1] = ans["reference"] - conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} - ans["id"] = message_id - - def rename_field(ans): - reference = ans['reference'] - if not isinstance(reference, dict): - return - for chunk_i in reference.get('chunks', []): - if 'docnm_kwd' in chunk_i: - chunk_i['doc_name'] = chunk_i['docnm_kwd'] - chunk_i.pop('docnm_kwd') - - try: - if conv.source == "agent": - stream = req.get("stream", True) - conv.message.append(msg[-1]) - e, cvs = UserCanvasService.get_by_id(conv.dialog_id) - if not e: - return server_error_response("canvas not found.") - del req["conversation_id"] - del req["messages"] - - if not isinstance(cvs.dsl, str): - cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - final_ans = {"reference": [], "content": ""} - canvas = Canvas(cvs.dsl, objs[0].tenant_id) - - canvas.messages.append(msg[-1]) - canvas.add_user_input(msg[-1]["content"]) - answer = canvas.run(stream=stream) - - assert answer is not None, "Nothing. Is it over?" - - if stream: - assert isinstance(answer, partial), "Nothing. Is it over?" - - def sse(): - nonlocal answer, cvs, conv - try: - for ans in answer(): - for k in ans.keys(): - final_ans[k] = ans[k] - ans = {"answer": ans["content"], "reference": ans.get("reference", [])} - fillin_conv(ans) - rename_field(ans) - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, - ensure_ascii=False) + "\n\n" - - canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) - canvas.history.append(("assistant", final_ans["content"])) - if final_ans.get("reference"): - canvas.reference.append(final_ans["reference"]) - cvs.dsl = json.loads(str(canvas)) - API4ConversationService.append_message(conv.id, conv.to_dict()) - except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - resp = Response(sse(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" - canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) - if final_ans.get("reference"): - canvas.reference.append(final_ans["reference"]) - cvs.dsl = json.loads(str(canvas)) - - result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])} - fillin_conv(result) - API4ConversationService.append_message(conv.id, conv.to_dict()) - rename_field(result) - return get_json_result(data=result) - - # ******************For dialog****************** - conv.message.append(msg[-1]) - e, dia = DialogService.get_by_id(conv.dialog_id) - if not e: - return get_data_error_result(message="Dialog not found!") - del req["conversation_id"] - del req["messages"] - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - def stream(): - nonlocal dia, msg, req, conv - try: - for ans in chat(dia, msg, True, **req): - fillin_conv(ans) - rename_field(ans) - yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, - ensure_ascii=False) + "\n\n" - API4ConversationService.append_message(conv.id, conv.to_dict()) - except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" - yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" - - if req.get("stream", True): - resp = Response(stream(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp - - answer = None - for ans in chat(dia, msg, **req): - answer = ans - fillin_conv(ans) - API4ConversationService.append_message(conv.id, conv.to_dict()) - break - rename_field(answer) - return get_json_result(data=answer) - - except Exception as e: - return server_error_response(e) - - -@manager.route('/conversation/', methods=['GET']) # noqa: F821 -# @login_required -def get_conversation(conversation_id): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - try: - e, conv = API4ConversationService.get_by_id(conversation_id) - if not e: - return get_data_error_result(message="Conversation not found!") - - conv = conv.to_dict() - if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token: - return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"', - code=RetCode.AUTHENTICATION_ERROR) - - for referenct_i in conv['reference']: - if referenct_i is None or len(referenct_i) == 0: - continue - for chunk_i in referenct_i['chunks']: - if 'docnm_kwd' in chunk_i.keys(): - chunk_i['doc_name'] = chunk_i['docnm_kwd'] - chunk_i.pop('docnm_kwd') - return get_json_result(data=conv) - except Exception as e: - return server_error_response(e) - - -@manager.route('/document/upload', methods=['POST']) # noqa: F821 -@validate_request("kb_name") -def upload(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - kb_name = request.form.get("kb_name").strip() - tenant_id = objs[0].tenant_id - - try: - e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id) - if not e: - return get_data_error_result( - message="Can't find this knowledgebase!") - kb_id = kb.id - except Exception as e: - return server_error_response(e) - - if 'file' not in request.files: - return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) - - file = request.files['file'] - if file.filename == '': - return get_json_result( - data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) - - root_folder = FileService.get_root_folder(tenant_id) - pf_id = root_folder["id"] - FileService.init_knowledgebase_docs(pf_id, tenant_id) - kb_root_folder = FileService.get_kb_folder(tenant_id) - kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) - - try: - if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)): - return get_data_error_result( - message="Exceed the maximum file number of a free user!") - - filename = duplicate_name( - DocumentService.query, - name=file.filename, - kb_id=kb_id) - filetype = filename_type(filename) - if not filetype: - return get_data_error_result( - message="This type of file has not been supported yet!") - - location = filename - while settings.STORAGE_IMPL.obj_exist(kb_id, location): - location += "_" - blob = request.files['file'].read() - settings.STORAGE_IMPL.put(kb_id, location, blob) - doc = { - "id": get_uuid(), - "kb_id": kb.id, - "parser_id": kb.parser_id, - "parser_config": kb.parser_config, - "created_by": kb.tenant_id, - "type": filetype, - "name": filename, - "location": location, - "size": len(blob), - "thumbnail": thumbnail(filename, blob), - "suffix": Path(filename).suffix.lstrip("."), - } - - form_data = request.form - if "parser_id" in form_data.keys(): - if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]: - doc["parser_id"] = request.form.get("parser_id").strip() - if doc["type"] == FileType.VISUAL: - doc["parser_id"] = ParserType.PICTURE.value - if doc["type"] == FileType.AURAL: - doc["parser_id"] = ParserType.AUDIO.value - if re.search(r"\.(ppt|pptx|pages)$", filename): - doc["parser_id"] = ParserType.PRESENTATION.value - if re.search(r"\.(eml)$", filename): - doc["parser_id"] = ParserType.EMAIL.value - - doc_result = DocumentService.insert(doc) - FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id) - except Exception as e: - return server_error_response(e) - - if "run" in form_data.keys(): - if request.form.get("run").strip() == "1": - try: - info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0} - DocumentService.update_by_id(doc["id"], info) - # if str(req["run"]) == TaskStatus.CANCEL.value: - tenant_id = DocumentService.get_tenant_id(doc["id"]) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - - # e, doc = DocumentService.get_by_id(doc["id"]) - TaskService.filter_delete([Task.doc_id == doc["id"]]) - e, doc = DocumentService.get_by_id(doc["id"]) - doc = doc.to_dict() - doc["tenant_id"] = tenant_id - bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) - queue_tasks(doc, bucket, name, 0) - except Exception as e: - return server_error_response(e) - - return get_json_result(data=doc_result.to_json()) - - -@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821 -@validate_request("conversation_id") -def upload_parse(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - if 'file' not in request.files: - return get_json_result( - data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) - - file_objs = request.files.getlist('file') - for file_obj in file_objs: - if file_obj.filename == '': - return get_json_result( - data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR) - - doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) - return get_json_result(data=doc_ids) - - -@manager.route('/list_chunks', methods=['POST']) # noqa: F821 -# @login_required -def list_chunks(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - req = request.json - try: - if "doc_name" in req.keys(): - tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name']) - doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name']) - - elif "doc_id" in req.keys(): - tenant_id = DocumentService.get_tenant_id(req['doc_id']) - doc_id = req['doc_id'] - else: - return get_json_result( - data=False, message="Can't find doc_name or doc_id" - ) - kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - - res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids) - res = [ - { - "content": res_item["content_with_weight"], - "doc_name": res_item["docnm_kwd"], - "image_id": res_item["img_id"] - } for res_item in res - ] - - except Exception as e: - return server_error_response(e) - - return get_json_result(data=res) - -@manager.route('/get_chunk/', methods=['GET']) # noqa: F821 -# @login_required -def get_chunk(chunk_id): - from rag.nlp import search - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - try: - tenant_id = objs[0].tenant_id - kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) - chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids) - if chunk is None: - return server_error_response(Exception("Chunk not found")) - k = [] - for n in chunk.keys(): - if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): - k.append(n) - for n in k: - del chunk[n] - - return get_json_result(data=chunk) - except Exception as e: - return server_error_response(e) - -@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821 -# @login_required -def list_kb_docs(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - req = request.json - tenant_id = objs[0].tenant_id - kb_name = req.get("kb_name", "").strip() - - try: - e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id) - if not e: - return get_data_error_result( - message="Can't find this knowledgebase!") - kb_id = kb.id - - except Exception as e: - return server_error_response(e) - - page_number = int(req.get("page", 1)) - items_per_page = int(req.get("page_size", 15)) - orderby = req.get("orderby", "create_time") - desc = req.get("desc", True) - keywords = req.get("keywords", "") - status = req.get("status", []) - if status: - invalid_status = {s for s in status if s not in VALID_TASK_STATUS} - if invalid_status: - return get_data_error_result( - message=f"Invalid filter status conditions: {', '.join(invalid_status)}" - ) - types = req.get("types", []) - if types: - invalid_types = {t for t in types if t not in VALID_FILE_TYPES} - if invalid_types: - return get_data_error_result( - message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}" - ) - try: - docs, tol = DocumentService.get_by_kb_id( - kb_id, page_number, items_per_page, orderby, desc, keywords, status, types) - docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs] - - return get_json_result(data={"total": tol, "docs": docs}) - - except Exception as e: - return server_error_response(e) - - -@manager.route('/document/infos', methods=['POST']) # noqa: F821 -@validate_request("doc_ids") -def docinfos(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - req = request.json - doc_ids = req["doc_ids"] - docs = DocumentService.get_by_ids(doc_ids) - return get_json_result(data=list(docs.dicts())) - - -@manager.route('/document', methods=['DELETE']) # noqa: F821 -# @login_required -def document_rm(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - tenant_id = objs[0].tenant_id - req = request.json - try: - doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", [])) - for doc_id in req.get("doc_ids", []): - if doc_id not in doc_ids: - doc_ids.append(doc_id) - - if not doc_ids: - return get_json_result( - data=False, message="Can't find doc_names or doc_ids" - ) - - except Exception as e: - return server_error_response(e) - - root_folder = FileService.get_root_folder(tenant_id) - pf_id = root_folder["id"] - FileService.init_knowledgebase_docs(pf_id, tenant_id) - - errors = "" - docs = DocumentService.get_by_ids(doc_ids) - doc_dic = {} - for doc in docs: - doc_dic[doc.id] = doc + res = {"pv": [], "uv": [], "speed": [], "tokens": [], "round": [], "thumb_up": []} - for doc_id in doc_ids: - try: - if doc_id not in doc_dic: - return get_data_error_result(message="Document not found!") - doc = doc_dic[doc_id] - tenant_id = DocumentService.get_tenant_id(doc_id) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") + for obj in objs: + dt = obj["dt"] + res["pv"].append((dt, obj["pv"])) + res["uv"].append((dt, obj["uv"])) + res["speed"].append((dt, float(obj["tokens"]) / (float(obj["duration"]) + 0.1))) # +0.1 to avoid division by zero + res["tokens"].append((dt, float(obj["tokens"]) / 1000.0)) # convert to thousands + res["round"].append((dt, obj["round"])) + res["thumb_up"].append((dt, obj["thumb_up"])) - b, n = File2DocumentService.get_storage_address(doc_id=doc_id) - - if not DocumentService.remove_document(doc, tenant_id): - return get_data_error_result( - message="Database error (Document removal)!") - - f2d = File2DocumentService.get_by_document_id(doc_id) - FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) - File2DocumentService.delete_by_document_id(doc_id) - - settings.STORAGE_IMPL.rm(b, n) - except Exception as e: - errors += str(e) - - if errors: - return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) - - return get_json_result(data=True) - - -@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821 -@validate_request("Authorization", "conversation_id", "word") -def completion_faq(): - import base64 - req = request.json - - token = req["Authorization"] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - e, conv = API4ConversationService.get_by_id(req["conversation_id"]) - if not e: - return get_data_error_result(message="Conversation not found!") - if "quote" not in req: - req["quote"] = True - - msg = [{"role": "user", "content": req["word"]}] - if not msg[-1].get("id"): - msg[-1]["id"] = get_uuid() - message_id = msg[-1]["id"] - - def fillin_conv(ans): - nonlocal conv, message_id - if not conv.reference: - conv.reference.append(ans["reference"]) - else: - conv.reference[-1] = ans["reference"] - conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id} - ans["id"] = message_id - - try: - if conv.source == "agent": - conv.message.append(msg[-1]) - e, cvs = UserCanvasService.get_by_id(conv.dialog_id) - if not e: - return server_error_response("canvas not found.") - - if not isinstance(cvs.dsl, str): - cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - final_ans = {"reference": [], "doc_aggs": []} - canvas = Canvas(cvs.dsl, objs[0].tenant_id) - - canvas.messages.append(msg[-1]) - canvas.add_user_input(msg[-1]["content"]) - answer = canvas.run(stream=False) - - assert answer is not None, "Nothing. Is it over?" - - data_type_picture = { - "type": 3, - "url": "base64 content" - } - data = [ - { - "type": 1, - "content": "" - } - ] - final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else "" - canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id}) - if final_ans.get("reference"): - canvas.reference.append(final_ans["reference"]) - cvs.dsl = json.loads(str(canvas)) - - ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])} - data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"]) - fillin_conv(ans) - API4ConversationService.append_message(conv.id, conv.to_dict()) - - chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])] - for chunk_idx in chunk_idxs[:1]: - if ans["reference"]["chunks"][chunk_idx]["img_id"]: - try: - bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") - response = settings.STORAGE_IMPL.get(bkt, nm) - data_type_picture["url"] = base64.b64encode(response).decode('utf-8') - data.append(data_type_picture) - break - except Exception as e: - return server_error_response(e) - - response = {"code": 200, "msg": "success", "data": data} - return response - - # ******************For dialog****************** - conv.message.append(msg[-1]) - e, dia = DialogService.get_by_id(conv.dialog_id) - if not e: - return get_data_error_result(message="Dialog not found!") - del req["conversation_id"] - - if not conv.reference: - conv.reference = [] - conv.message.append({"role": "assistant", "content": "", "id": message_id}) - conv.reference.append({"chunks": [], "doc_aggs": []}) - - data_type_picture = { - "type": 3, - "url": "base64 content" - } - data = [ - { - "type": 1, - "content": "" - } - ] - ans = "" - for a in chat(dia, msg, stream=False, **req): - ans = a - break - data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"]) - fillin_conv(ans) - API4ConversationService.append_message(conv.id, conv.to_dict()) - - chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])] - for chunk_idx in chunk_idxs[:1]: - if ans["reference"]["chunks"][chunk_idx]["img_id"]: - try: - bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-") - response = settings.STORAGE_IMPL.get(bkt, nm) - data_type_picture["url"] = base64.b64encode(response).decode('utf-8') - data.append(data_type_picture) - break - except Exception as e: - return server_error_response(e) - - response = {"code": 200, "msg": "success", "data": data} - return response - - except Exception as e: - return server_error_response(e) - - -@manager.route('/retrieval', methods=['POST']) # noqa: F821 -@validate_request("kb_id", "question") -def retrieval(): - token = request.headers.get('Authorization').split()[1] - objs = APIToken.query(token=token) - if not objs: - return get_json_result( - data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR) - - req = request.json - kb_ids = req.get("kb_id", []) - doc_ids = req.get("doc_ids", []) - question = req.get("question") - page = int(req.get("page", 1)) - size = int(req.get("page_size", 30)) - similarity_threshold = float(req.get("similarity_threshold", 0.2)) - vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) - top = int(req.get("top_k", 1024)) - highlight = bool(req.get("highlight", False)) - - try: - kbs = KnowledgebaseService.get_by_ids(kb_ids) - embd_nms = list(set([kb.embd_id for kb in kbs])) - if len(embd_nms) != 1: - return get_json_result( - data=False, message='Knowledge bases use different embedding models or does not exist."', - code=RetCode.AUTHENTICATION_ERROR) - - embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id) - rerank_mdl = None - if req.get("rerank_id"): - rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) - if req.get("keyword", False): - chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) - ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, - similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl, highlight= highlight, - rank_feature=label_question(question, kbs)) - for c in ranks["chunks"]: - c.pop("vector", None) - return get_json_result(data=ranks) + return get_json_result(data=res) except Exception as e: - if str(e).find("not_found") > 0: - return get_json_result(data=False, message='No chunk found! Check the chunk status please!', - code=RetCode.DATA_ERROR) return server_error_response(e) diff --git a/api/apps/auth/github.py b/api/apps/auth/github.py index f48d4a5fc27..918ff60db8c 100644 --- a/api/apps/auth/github.py +++ b/api/apps/auth/github.py @@ -14,7 +14,7 @@ # limitations under the License. # -import requests +from common.http_client import async_request, sync_request from .oauth import OAuthClient, UserInfo @@ -34,24 +34,49 @@ def __init__(self, config): def fetch_user_info(self, access_token, **kwargs): """ - Fetch GitHub user info. + Fetch GitHub user info (synchronous). """ user_info = {} try: headers = {"Authorization": f"Bearer {access_token}"} - # user info - response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout) + response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout) response.raise_for_status() user_info.update(response.json()) - # email info - response = requests.get(self.userinfo_url+"/emails", headers=headers, timeout=self.http_request_timeout) + email_response = sync_request( + "GET", self.userinfo_url + "/emails", headers=headers, timeout=self.http_request_timeout + ) + email_response.raise_for_status() + email_info = email_response.json() + user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] + return self.normalize_user_info(user_info) + except Exception as e: + raise ValueError(f"Failed to fetch github user info: {e}") + + async def async_fetch_user_info(self, access_token, **kwargs): + """Async variant of fetch_user_info using httpx.""" + user_info = {} + headers = {"Authorization": f"Bearer {access_token}"} + try: + response = await async_request( + "GET", + self.userinfo_url, + headers=headers, + timeout=self.http_request_timeout, + ) response.raise_for_status() - email_info = response.json() - user_info["email"] = next( - (email for email in email_info if email["primary"]), None - )["email"] + user_info.update(response.json()) + + email_response = await async_request( + "GET", + self.userinfo_url + "/emails", + headers=headers, + timeout=self.http_request_timeout, + ) + email_response.raise_for_status() + email_info = email_response.json() + user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] return self.normalize_user_info(user_info) - except requests.exceptions.RequestException as e: + except Exception as e: raise ValueError(f"Failed to fetch github user info: {e}") diff --git a/api/apps/auth/oauth.py b/api/apps/auth/oauth.py index 6f7e0e5b54a..5b2afcea1d0 100644 --- a/api/apps/auth/oauth.py +++ b/api/apps/auth/oauth.py @@ -14,8 +14,8 @@ # limitations under the License. # -import requests import urllib.parse +from common.http_client import async_request, sync_request class UserInfo: @@ -74,15 +74,40 @@ def exchange_code_for_token(self, code): "redirect_uri": self.redirect_uri, "grant_type": "authorization_code" } - response = requests.post( + response = sync_request( + "POST", self.token_url, data=payload, headers={"Accept": "application/json"}, - timeout=self.http_request_timeout + timeout=self.http_request_timeout, ) response.raise_for_status() return response.json() - except requests.exceptions.RequestException as e: + except Exception as e: + raise ValueError(f"Failed to exchange authorization code for token: {e}") + + async def async_exchange_code_for_token(self, code): + """ + Async variant of exchange_code_for_token using httpx. + """ + payload = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": self.redirect_uri, + "grant_type": "authorization_code", + } + try: + response = await async_request( + "POST", + self.token_url, + data=payload, + headers={"Accept": "application/json"}, + timeout=self.http_request_timeout, + ) + response.raise_for_status() + return response.json() + except Exception as e: raise ValueError(f"Failed to exchange authorization code for token: {e}") @@ -92,11 +117,27 @@ def fetch_user_info(self, access_token, **kwargs): """ try: headers = {"Authorization": f"Bearer {access_token}"} - response = requests.get(self.userinfo_url, headers=headers, timeout=self.http_request_timeout) + response = sync_request("GET", self.userinfo_url, headers=headers, timeout=self.http_request_timeout) + response.raise_for_status() + user_info = response.json() + return self.normalize_user_info(user_info) + except Exception as e: + raise ValueError(f"Failed to fetch user info: {e}") + + async def async_fetch_user_info(self, access_token, **kwargs): + """Async variant of fetch_user_info using httpx.""" + headers = {"Authorization": f"Bearer {access_token}"} + try: + response = await async_request( + "GET", + self.userinfo_url, + headers=headers, + timeout=self.http_request_timeout, + ) response.raise_for_status() user_info = response.json() return self.normalize_user_info(user_info) - except requests.exceptions.RequestException as e: + except Exception as e: raise ValueError(f"Failed to fetch user info: {e}") diff --git a/api/apps/auth/oidc.py b/api/apps/auth/oidc.py index cafcaadfdfd..80ac79399f2 100644 --- a/api/apps/auth/oidc.py +++ b/api/apps/auth/oidc.py @@ -15,7 +15,7 @@ # import jwt -import requests +from common.http_client import sync_request from .oauth import OAuthClient @@ -50,10 +50,10 @@ def _load_oidc_metadata(issuer): """ try: metadata_url = f"{issuer}/.well-known/openid-configuration" - response = requests.get(metadata_url, timeout=7) + response = sync_request("GET", metadata_url, timeout=7) response.raise_for_status() return response.json() - except requests.exceptions.RequestException as e: + except Exception as e: raise ValueError(f"Failed to fetch OIDC metadata: {e}") @@ -95,6 +95,13 @@ def fetch_user_info(self, access_token, id_token=None, **kwargs): user_info.update(super().fetch_user_info(access_token).to_dict()) return self.normalize_user_info(user_info) + async def async_fetch_user_info(self, access_token, id_token=None, **kwargs): + user_info = {} + if id_token: + user_info = self.parse_id_token(id_token) + user_info.update((await super().async_fetch_user_info(access_token)).to_dict()) + return self.normalize_user_info(user_info) + def normalize_user_info(self, user_info): return super().normalize_user_info(user_info) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 0ac2951ae5f..21bd237894f 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -13,19 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +import inspect import json import logging -import re -import sys from functools import partial - -import flask -import trio -from flask import request, Response -from flask_login import login_required, current_user - +from quart import request, Response, make_response from agent.component import LLM -from api.db import CanvasCategory, FileType +from api.db import CanvasCategory from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.document_service import DocumentService from api.db.services.file_service import FileService @@ -35,17 +30,18 @@ from api.db.services.user_canvas_version import UserCanvasVersionService from common.constants import RetCode from common.misc_utils import get_uuid -from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result +from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \ + get_request_json from agent.canvas import Canvas from peewee import MySQLDatabase, PostgresqlDatabase from api.db.db_models import APIToken, Task import time -from api.utils.file_utils import filename_type, read_potential_broken_pdf from rag.flow.pipeline import Pipeline from rag.nlp import search from rag.utils.redis_conn import REDIS_CONN from common import settings +from api.apps import login_required, current_user @manager.route('/templates', methods=['GET']) # noqa: F821 @@ -57,8 +53,9 @@ def templates(): @manager.route('/rm', methods=['POST']) # noqa: F821 @validate_request("canvas_ids") @login_required -def rm(): - for i in request.json["canvas_ids"]: +async def rm(): + req = await get_request_json() + for i in req["canvas_ids"]: if not UserCanvasService.accessible(i, current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', @@ -70,8 +67,8 @@ def rm(): @manager.route('/set', methods=['POST']) # noqa: F821 @validate_request("dsl", "title") @login_required -def save(): - req = request.json +async def save(): + req = await get_request_json() if not isinstance(req["dsl"], str): req["dsl"] = json.dumps(req["dsl"], ensure_ascii=False) req["dsl"] = json.loads(req["dsl"]) @@ -129,18 +126,18 @@ def getsse(canvas_id): @manager.route('/completion', methods=['POST']) # noqa: F821 @validate_request("id") @login_required -def run(): - req = request.json +async def run(): + req = await get_request_json() query = req.get("query", "") files = req.get("files", []) inputs = req.get("inputs", {}) user_id = req.get("user_id", current_user.id) - if not UserCanvasService.accessible(req["id"], current_user.id): + if not await asyncio.to_thread(UserCanvasService.accessible, req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) - e, cvs = UserCanvasService.get_by_id(req["id"]) + e, cvs = await asyncio.to_thread(UserCanvasService.get_by_id, req["id"]) if not e: return get_data_error_result(message="canvas not found.") @@ -150,20 +147,20 @@ def run(): if cvs.canvas_category == CanvasCategory.DataFlow: task_id = get_uuid() Pipeline(cvs.dsl, tenant_id=current_user.id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"]) - ok, error_message = queue_dataflow(tenant_id=user_id, flow_id=req["id"], task_id=task_id, file=files[0], priority=0) + ok, error_message = await asyncio.to_thread(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0) if not ok: return get_data_error_result(message=error_message) return get_json_result(data={"message_id": task_id}) try: - canvas = Canvas(cvs.dsl, current_user.id) + canvas = Canvas(cvs.dsl, current_user.id, canvas_id=cvs.id) except Exception as e: return server_error_response(e) - def sse(): + async def sse(): nonlocal canvas, user_id try: - for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): + async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" cvs.dsl = json.loads(str(canvas)) @@ -179,15 +176,15 @@ def sse(): resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - resp.call_on_close(lambda: canvas.cancel_task()) + #resp.call_on_close(lambda: canvas.cancel_task()) return resp @manager.route('/rerun', methods=['POST']) # noqa: F821 @validate_request("id", "dsl", "component_id") @login_required -def rerun(): - req = request.json +async def rerun(): + req = await get_request_json() doc = PipelineOperationLogService.get_documents_info(req["id"]) if not doc: return get_data_error_result(message="Document not found.") @@ -195,7 +192,7 @@ def rerun(): if 0 < doc["progress"] < 1: return get_data_error_result(message=f"`{doc['name']}` is processing...") - if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]): + if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]): settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"]) doc["progress_msg"] = "" doc["chunk_num"] = 0 @@ -224,8 +221,8 @@ def cancel(task_id): @manager.route('/reset', methods=['POST']) # noqa: F821 @validate_request("id") @login_required -def reset(): - req = request.json +async def reset(): + req = await get_request_json() if not UserCanvasService.accessible(req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', @@ -235,7 +232,7 @@ def reset(): if not e: return get_data_error_result(message="canvas not found.") - canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) + canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id) canvas.reset() req["dsl"] = json.loads(str(canvas)) UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]}) @@ -245,76 +242,16 @@ def reset(): @manager.route("/upload/", methods=["POST"]) # noqa: F821 -def upload(canvas_id): +async def upload(canvas_id): e, cvs = UserCanvasService.get_by_canvas_id(canvas_id) if not e: return get_data_error_result(message="canvas not found.") user_id = cvs["user_id"] - def structured(filename, filetype, blob, content_type): - nonlocal user_id - if filetype == FileType.PDF.value: - blob = read_potential_broken_pdf(blob) - - location = get_uuid() - FileService.put_blob(user_id, location, blob) - - return { - "id": location, - "name": filename, - "size": sys.getsizeof(blob), - "extension": filename.split(".")[-1].lower(), - "mime_type": content_type, - "created_by": user_id, - "created_at": time.time(), - "preview_url": None - } - - if request.args.get("url"): - from crawl4ai import ( - AsyncWebCrawler, - BrowserConfig, - CrawlerRunConfig, - DefaultMarkdownGenerator, - PruningContentFilter, - CrawlResult - ) - try: - url = request.args.get("url") - filename = re.sub(r"\?.*", "", url.split("/")[-1]) - async def adownload(): - browser_config = BrowserConfig( - headless=True, - verbose=False, - ) - async with AsyncWebCrawler(config=browser_config) as crawler: - crawler_config = CrawlerRunConfig( - markdown_generator=DefaultMarkdownGenerator( - content_filter=PruningContentFilter() - ), - pdf=True, - screenshot=False - ) - result: CrawlResult = await crawler.arun( - url=url, - config=crawler_config - ) - return result - page = trio.run(adownload()) - if page.pdf: - if filename.split(".")[-1].lower() != "pdf": - filename += ".pdf" - return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"])) - - return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id)) - - except Exception as e: - return server_error_response(e) - - file = request.files['file'] + files = await request.files + file = files['file'] if files and files.get("file") else None try: - DocumentService.check_doc_health(user_id, file.filename) - return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type)) + return get_json_result(data=FileService.upload_info(user_id, file, request.args.get("url"))) except Exception as e: return server_error_response(e) @@ -333,7 +270,7 @@ def input_form(): data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) - canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) + canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id) return get_json_result(data=canvas.get_component_input_form(cpn_id)) except Exception as e: return server_error_response(e) @@ -342,15 +279,15 @@ def input_form(): @manager.route('/debug', methods=['POST']) # noqa: F821 @validate_request("id", "component_id", "params") @login_required -def debug(): - req = request.json +async def debug(): + req = await get_request_json() if not UserCanvasService.accessible(req["id"], current_user.id): return get_json_result( data=False, message='Only owner of canvas authorized for this operation.', code=RetCode.OPERATING_ERROR) try: e, user_canvas = UserCanvasService.get_by_id(req["id"]) - canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id) + canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id) canvas.reset() canvas.message_id = get_uuid() component = canvas.get_component(req["component_id"])["obj"] @@ -363,8 +300,13 @@ def debug(): for k in outputs.keys(): if isinstance(outputs[k], partial): txt = "" - for c in outputs[k](): - txt += c + iter_obj = outputs[k]() + if inspect.isasyncgen(iter_obj): + async for c in iter_obj: + txt += c + else: + for c in iter_obj: + txt += c outputs[k] = txt return get_json_result(data=outputs) except Exception as e: @@ -374,8 +316,8 @@ def debug(): @manager.route('/test_db_connect', methods=['POST']) # noqa: F821 @validate_request("db_type", "database", "username", "host", "port", "password") @login_required -def test_db_connect(): - req = request.json +async def test_db_connect(): + req = await get_request_json() try: if req["db_type"] in ["mysql", "mariadb"]: db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"], @@ -406,7 +348,15 @@ def test_db_connect(): f"UID={req['username']};" f"PWD={req['password']};" ) - logging.info(conn_str) + redacted_conn_str = ( + f"DATABASE={req['database']};" + f"HOSTNAME={req['host']};" + f"PORT={req['port']};" + f"PROTOCOL=TCPIP;" + f"UID={req['username']};" + f"PWD=****;" + ) + logging.info(redacted_conn_str) conn = ibm_db.connect(conn_str, "", "") stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1") ibm_db.fetch_assoc(stmt) @@ -426,7 +376,6 @@ def _parse_catalog_schema(db_name: str): try: import trino import os - from trino.auth import BasicAuthentication except Exception as e: return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}") @@ -438,7 +387,7 @@ def _parse_catalog_schema(db_name: str): auth = None if http_scheme == "https" and req.get("password"): - auth = BasicAuthentication(req.get("username") or "ragflow", req["password"]) + auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"]) conn = trino.dbapi.connect( host=req["host"], @@ -471,8 +420,8 @@ def _parse_catalog_schema(db_name: str): @login_required def getlistversion(canvas_id): try: - list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1) - return get_json_result(data=list) + versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1) + return get_json_result(data=versions) except Exception as e: return get_data_error_result(message=f"Error getting history files: {e}") @@ -520,8 +469,8 @@ def list_canvas(): @manager.route('/setting', methods=['POST']) # noqa: F821 @validate_request("id", "title", "permission") @login_required -def setting(): - req = request.json +async def setting(): + req = await get_request_json() req["user_id"] = current_user.id if not UserCanvasService.accessible(req["id"], current_user.id): @@ -602,8 +551,8 @@ def prompts(): @manager.route('/download', methods=['GET']) # noqa: F821 -def download(): +async def download(): id = request.args.get("id") created_by = request.args.get("created_by") blob = FileService.get_blob(created_by, id) - return flask.make_response(blob) + return await make_response(blob) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 78a614ddf7c..f5b248fd5ef 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -13,35 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import datetime import json import re - +import base64 import xxhash -from flask import request -from flask_login import current_user, login_required +from quart import request -from api.db.services.dialog_service import meta_filter from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from common.metadata_utils import apply_meta_data_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ + get_request_json from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search -from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction +from rag.prompts.generator import cross_languages, keyword_extraction from common.string_utils import remove_redundant_spaces from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD from common import settings +from api.apps import login_required, current_user @manager.route('/list', methods=['POST']) # noqa: F821 @login_required @validate_request("doc_id") -def list_chunk(): - req = request.json +async def list_chunk(): + req = await get_request_json() doc_id = req["doc_id"] page = int(req.get("page", 1)) size = int(req.get("size", 30)) @@ -74,6 +76,7 @@ def list_chunk(): "image_id": sres.field[id].get("img_id", ""), "available_int": int(sres.field[id].get("available_int", 1)), "positions": sres.field[id].get("position_int", []), + "doc_type_kwd": sres.field[id].get("doc_type_kwd") } assert isinstance(d["positions"], list) assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) @@ -121,8 +124,8 @@ def get(): @manager.route('/set', methods=['POST']) # noqa: F821 @login_required @validate_request("doc_id", "chunk_id", "content_with_weight") -def set(): - req = request.json +async def set(): + req = await get_request_json() d = { "id": req["chunk_id"], "content_with_weight": req["content_with_weight"]} @@ -146,31 +149,42 @@ def set(): d["available_int"] = req["available_int"] try: - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - - embd_id = DocumentService.get_embd_id(req["doc_id"]) - embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) - - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - - if doc.parser_id == ParserType.QA: - arr = [ - t for t in re.split( - r"[\n\t]", - req["content_with_weight"]) if len(t) > 1] - q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) - d = beAdoc(d, q, a, not any( - [rag_tokenizer.is_chinese(t) for t in q + a])) - - v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) - v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] - d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) - return get_json_result(data=True) + def _set_sync(): + tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + if not tenant_id: + return get_data_error_result(message="Tenant not found!") + + embd_id = DocumentService.get_embd_id(req["doc_id"]) + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) + + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + + _d = d + if doc.parser_id == ParserType.QA: + arr = [ + t for t in re.split( + r"[\n\t]", + req["content_with_weight"]) if len(t) > 1] + q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) + _d = beAdoc(d, q, a, not any( + [rag_tokenizer.is_chinese(t) for t in q + a])) + + v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not _d.get("question_kwd") else "\n".join(_d["question_kwd"])]) + v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] + _d["q_%d_vec" % len(v)] = v.tolist() + settings.docStoreConn.update({"id": req["chunk_id"]}, _d, search.index_name(tenant_id), doc.kb_id) + + # update image + image_base64 = req.get("image_base64", None) + if image_base64: + bkt, name = req.get("img_id", "-").split("-") + image_binary = base64.b64decode(image_base64) + settings.STORAGE_IMPL.put(bkt, name, image_binary) + return get_json_result(data=True) + + return await asyncio.to_thread(_set_sync) except Exception as e: return server_error_response(e) @@ -178,19 +192,22 @@ def set(): @manager.route('/switch', methods=['POST']) # noqa: F821 @login_required @validate_request("chunk_ids", "available_int", "doc_id") -def switch(): - req = request.json +async def switch(): + req = await get_request_json() try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - for cid in req["chunk_ids"]: - if not settings.docStoreConn.update({"id": cid}, - {"available_int": int(req["available_int"])}, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), - doc.kb_id): - return get_data_error_result(message="Index updating failure") - return get_json_result(data=True) + def _switch_sync(): + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + for cid in req["chunk_ids"]: + if not settings.docStoreConn.update({"id": cid}, + {"available_int": int(req["available_int"])}, + search.index_name(DocumentService.get_tenant_id(req["doc_id"])), + doc.kb_id): + return get_data_error_result(message="Index updating failure") + return get_json_result(data=True) + + return await asyncio.to_thread(_switch_sync) except Exception as e: return server_error_response(e) @@ -198,23 +215,26 @@ def switch(): @manager.route('/rm', methods=['POST']) # noqa: F821 @login_required @validate_request("chunk_ids", "doc_id") -def rm(): - req = request.json +async def rm(): + req = await get_request_json() try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, - search.index_name(DocumentService.get_tenant_id(req["doc_id"])), - doc.kb_id): - return get_data_error_result(message="Chunk deleting failure") - deleted_chunk_ids = req["chunk_ids"] - chunk_number = len(deleted_chunk_ids) - DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) - for cid in deleted_chunk_ids: - if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): - settings.STORAGE_IMPL.rm(doc.kb_id, cid) - return get_json_result(data=True) + def _rm_sync(): + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, + search.index_name(DocumentService.get_tenant_id(req["doc_id"])), + doc.kb_id): + return get_data_error_result(message="Chunk deleting failure") + deleted_chunk_ids = req["chunk_ids"] + chunk_number = len(deleted_chunk_ids) + DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) + for cid in deleted_chunk_ids: + if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): + settings.STORAGE_IMPL.rm(doc.kb_id, cid) + return get_json_result(data=True) + + return await asyncio.to_thread(_rm_sync) except Exception as e: return server_error_response(e) @@ -222,8 +242,8 @@ def rm(): @manager.route('/create', methods=['POST']) # noqa: F821 @login_required @validate_request("doc_id", "content_with_weight") -def create(): - req = request.json +async def create(): + req = await get_request_json() chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]} @@ -244,35 +264,38 @@ def create(): d["tag_feas"] = req["tag_feas"] try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - d["kb_id"] = [doc.kb_id] - d["docnm_kwd"] = doc.name - d["title_tks"] = rag_tokenizer.tokenize(doc.name) - d["doc_id"] = doc.id - - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - - e, kb = KnowledgebaseService.get_by_id(doc.kb_id) - if not e: - return get_data_error_result(message="Knowledgebase not found!") - if kb.pagerank: - d[PAGERANK_FLD] = kb.pagerank - - embd_id = DocumentService.get_embd_id(req["doc_id"]) - embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) - - v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) - v = 0.1 * v[0] + 0.9 * v[1] - d["q_%d_vec" % len(v)] = v.tolist() - settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) - - DocumentService.increment_chunk_num( - doc.id, doc.kb_id, c, 1, 0) - return get_json_result(data={"chunk_id": chunck_id}) + def _create_sync(): + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + d["kb_id"] = [doc.kb_id] + d["docnm_kwd"] = doc.name + d["title_tks"] = rag_tokenizer.tokenize(doc.name) + d["doc_id"] = doc.id + + tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + if not tenant_id: + return get_data_error_result(message="Tenant not found!") + + e, kb = KnowledgebaseService.get_by_id(doc.kb_id) + if not e: + return get_data_error_result(message="Knowledgebase not found!") + if kb.pagerank: + d[PAGERANK_FLD] = kb.pagerank + + embd_id = DocumentService.get_embd_id(req["doc_id"]) + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) + + v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) + v = 0.1 * v[0] + 0.9 * v[1] + d["q_%d_vec" % len(v)] = v.tolist() + settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) + + DocumentService.increment_chunk_num( + doc.id, doc.kb_id, c, 1, 0) + return get_json_result(data={"chunk_id": chunck_id}) + + return await asyncio.to_thread(_create_sync) except Exception as e: return server_error_response(e) @@ -280,8 +303,8 @@ def create(): @manager.route('/retrieval_test', methods=['POST']) # noqa: F821 @login_required @validate_request("kb_id", "question") -def retrieval_test(): - req = request.json +async def retrieval_test(): + req = await get_request_json() page = int(req.get("page", 1)) size = int(req.get("size", 30)) question = req["question"] @@ -296,25 +319,29 @@ def retrieval_test(): use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) langs = req.get("cross_languages", []) - tenant_ids = [] - - if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters = gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters)) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"])) - if not doc_ids: - doc_ids = None - - try: - tenants = UserTenantService.query(user_id=current_user.id) + user_id = current_user.id + + async def _retrieval(): + local_doc_ids = list(doc_ids) if doc_ids else [] + tenant_ids = [] + + meta_data_filter = {} + chat_mdl = None + if req.get("search_id", ""): + search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + else: + meta_data_filter = req.get("meta_data_filter") or {} + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_mdl = LLMBundle(user_id, LLMType.CHAT) + + if meta_data_filter: + metas = DocumentService.get_meta_by_kbs(kb_ids) + local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids) + + tenants = UserTenantService.query(user_id=user_id) for kb_id in kb_ids: for tenant in tenants: if KnowledgebaseService.query( @@ -323,15 +350,16 @@ def retrieval_test(): break else: return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', + data=False, message='Only owner of dataset authorized for this operation.', code=RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: return get_data_error_result(message="Knowledgebase not found!") + _question = question if langs: - question = cross_languages(kb.tenant_id, None, question, langs) + _question = await cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -341,31 +369,35 @@ def retrieval_test(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) + _question += await keyword_extraction(chat_mdl, _question) - labels = label_question(question, [kb]) - ranks = settings.retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, + labels = label_question(_question, [kb]) + ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size, float(req.get("similarity_threshold", 0.0)), float(req.get("vector_similarity_weight", 0.3)), top, - doc_ids, rerank_mdl=rerank_mdl, + local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight", False), rank_feature=labels ) if use_kg: - ck = settings.kg_retriever.retrieval(question, + ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) + ranks["chunks"] = settings.retriever.retrieval_by_children(ranks["chunks"], tenant_ids) for c in ranks["chunks"]: c.pop("vector", None) ranks["labels"] = labels return get_json_result(data=ranks) + + try: + return await _retrieval() except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found! Check the chunk status please!', diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index 23965e617a2..fb074419bb5 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import time @@ -20,24 +21,25 @@ from html import escape from typing import Any -from flask import make_response, request -from flask_login import current_user, login_required +from quart import request, make_response from google_auth_oauthlib.flow import Flow from api.db import InputType from api.db.services.connector_service import ConnectorService, SyncLogsService -from api.utils.api_utils import get_data_error_result, get_json_result, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, validate_request from common.constants import RetCode, TaskStatus -from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource -from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES +from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, GMAIL_WEB_OAUTH_REDIRECT_URI, BOX_WEB_OAUTH_REDIRECT_URI, DocumentSource +from common.data_source.google_util.constant import WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES from common.misc_utils import get_uuid from rag.utils.redis_conn import REDIS_CONN +from api.apps import login_required, current_user +from box_sdk_gen import BoxOAuth, OAuthConfig, GetAuthorizeUrlOptions @manager.route("/set", methods=["POST"]) # noqa: F821 @login_required -def set_connector(): - req = request.json +async def set_connector(): + req = await get_request_json() if req.get("id"): conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} ConnectorService.update_by_id(req["id"], conn) @@ -55,10 +57,9 @@ def set_connector(): "timeout_secs": int(req.get("timeout_secs", 60 * 29)), "status": TaskStatus.SCHEDULE, } - conn["status"] = TaskStatus.SCHEDULE ConnectorService.save(**conn) - time.sleep(1) + await asyncio.sleep(1) e, conn = ConnectorService.get_by_id(req["id"]) return get_json_result(data=conn.to_dict()) @@ -89,8 +90,8 @@ def list_logs(connector_id): @manager.route("//resume", methods=["PUT"]) # noqa: F821 @login_required -def resume(connector_id): - req = request.json +async def resume(connector_id): + req = await get_request_json() if req.get("resume"): ConnectorService.resume(connector_id, TaskStatus.SCHEDULE) else: @@ -101,8 +102,8 @@ def resume(connector_id): @manager.route("//rebuild", methods=["PUT"]) # noqa: F821 @login_required @validate_request("kb_id") -def rebuild(connector_id): - req = request.json +async def rebuild(connector_id): + req = await get_request_json() err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id) if err: return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR) @@ -117,17 +118,27 @@ def rm_connector(connector_id): return get_json_result(data=True) -GOOGLE_WEB_FLOW_STATE_PREFIX = "google_drive_web_flow_state" -GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result" WEB_FLOW_TTL_SECS = 15 * 60 -def _web_state_cache_key(flow_id: str) -> str: - return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}" +def _web_state_cache_key(flow_id: str, source_type: str | None = None) -> str: + """Return Redis key for web OAuth state. + The default prefix keeps backward compatibility for Google Drive. + When source_type == "gmail", a different prefix is used so that + Drive/Gmail flows don't clash in Redis. + """ + prefix = f"{source_type}_web_flow_state" + return f"{prefix}:{flow_id}" -def _web_result_cache_key(flow_id: str) -> str: - return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}" + +def _web_result_cache_key(flow_id: str, source_type: str | None = None) -> str: + """Return Redis key for web OAuth result. + + Mirrors _web_state_cache_key logic for result storage. + """ + prefix = f"{source_type}_web_flow_result" + return f"{prefix}:{flow_id}" def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]: @@ -146,43 +157,61 @@ def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]: return {"web": web_section} -def _render_web_oauth_popup(flow_id: str, success: bool, message: str): +async def _render_web_oauth_popup(flow_id: str, success: bool, message: str, source="drive"): status = "success" if success else "error" auto_close = "window.close();" if success else "" escaped_message = escape(message) + # Drive: ragflow-google-drive-oauth + # Gmail: ragflow-gmail-oauth + payload_type = f"ragflow-{source}-oauth" payload_json = json.dumps( { - "type": "ragflow-google-drive-oauth", + "type": payload_type, "status": status, "flowId": flow_id or "", "message": message, } ) - html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format( + # TODO(google-oauth): title/heading/message may need to reflect drive/gmail based on cached type + html = WEB_OAUTH_POPUP_TEMPLATE.format( + title=f"Google {source.capitalize()} Authorization", heading="Authorization complete" if success else "Authorization failed", message=escaped_message, payload_json=payload_json, auto_close=auto_close, ) - response = make_response(html, 200) + response = await make_response(html, 200) response.headers["Content-Type"] = "text/html; charset=utf-8" return response -@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821 +@manager.route("/google/oauth/web/start", methods=["POST"]) # noqa: F821 @login_required @validate_request("credentials") -def start_google_drive_web_oauth(): - if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI: +async def start_google_web_oauth(): + source = request.args.get("type", "google-drive") + if source not in ("google-drive", "gmail"): + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.") + + if source == "gmail": + redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI + scopes = GOOGLE_SCOPES[DocumentSource.GMAIL] + else: + redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + scopes = GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE] + + if not redirect_uri: return get_json_result( code=RetCode.SERVER_ERROR, - message="Google Drive OAuth redirect URI is not configured on the server.", + message="Google OAuth redirect URI is not configured on the server.", ) - req = request.json or {} + req = await get_request_json() raw_credentials = req.get("credentials", "") + try: credentials = _load_credentials(raw_credentials) + print(credentials) except ValueError as exc: return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc)) @@ -199,8 +228,8 @@ def start_google_drive_web_oauth(): flow_id = str(uuid.uuid4()) try: - flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) - flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + flow = Flow.from_client_config(client_config, scopes=scopes) + flow.redirect_uri = redirect_uri authorization_url, _ = flow.authorization_url( access_type="offline", include_granted_scopes="true", @@ -219,7 +248,7 @@ def start_google_drive_web_oauth(): "client_config": client_config, "created_at": int(time.time()), } - REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.set_obj(_web_state_cache_key(flow_id, source), cache_payload, WEB_FLOW_TTL_SECS) return get_json_result( data={ @@ -230,60 +259,115 @@ def start_google_drive_web_oauth(): ) +@manager.route("/gmail/oauth/web/callback", methods=["GET"]) # noqa: F821 +async def google_gmail_web_oauth_callback(): + state_id = request.args.get("state") + error = request.args.get("error") + source = "gmail" + + error_description = request.args.get("error_description") or error + + if not state_id: + return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source) + + state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source)) + if not state_cache: + return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source) + + state_obj = json.loads(state_cache) + client_config = state_obj.get("client_config") + if not client_config: + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) + + if error: + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source) + + code = request.args.get("code") + if not code: + return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) + + try: + # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) + flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GMAIL]) + flow.redirect_uri = GMAIL_WEB_OAUTH_REDIRECT_URI + flow.fetch_token(code=code) + except Exception as exc: # pragma: no cover - defensive + logging.exception("Failed to exchange Google OAuth code: %s", exc) + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source) + + creds_json = flow.credentials.to_json() + result_payload = { + "user_id": state_obj.get("user_id"), + "credentials": creds_json, + } + REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + + return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) + + @manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 -def google_drive_web_oauth_callback(): +async def google_drive_web_oauth_callback(): state_id = request.args.get("state") error = request.args.get("error") + source = "google-drive" + error_description = request.args.get("error_description") or error if not state_id: - return _render_web_oauth_popup("", False, "Missing OAuth state parameter.") + return await _render_web_oauth_popup("", False, "Missing OAuth state parameter.", source) - state_cache = REDIS_CONN.get(_web_state_cache_key(state_id)) + state_cache = REDIS_CONN.get(_web_state_cache_key(state_id, source)) if not state_cache: - return _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.") + return await _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.", source) state_obj = json.loads(state_cache) client_config = state_obj.get("client_config") if not client_config: - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.", source) if error: - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.", source) code = request.args.get("code") if not code: - return _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.") + return await _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.", source) try: + # TODO(google-oauth): branch scopes/redirect_uri based on source_type (drive vs gmail) flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI flow.fetch_token(code=code) except Exception as exc: # pragma: no cover - defensive logging.exception("Failed to exchange Google OAuth code: %s", exc) - REDIS_CONN.delete(_web_state_cache_key(state_id)) - return _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.") + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.", source) creds_json = flow.credentials.to_json() result_payload = { "user_id": state_obj.get("user_id"), "credentials": creds_json, } - REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS) - REDIS_CONN.delete(_web_state_cache_key(state_id)) - - return _render_web_oauth_popup(state_id, True, "Authorization completed successfully.") + REDIS_CONN.set_obj(_web_result_cache_key(state_id, source), result_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.delete(_web_state_cache_key(state_id, source)) + return await _render_web_oauth_popup(state_id, True, "Authorization completed successfully.", source) -@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821 +@manager.route("/google/oauth/web/result", methods=["POST"]) # noqa: F821 @login_required @validate_request("flow_id") -def poll_google_drive_web_result(): - req = request.json or {} +async def poll_google_web_result(): + req = await request.json or {} + source = request.args.get("type") + if source not in ("google-drive", "gmail"): + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Invalid Google OAuth type.") flow_id = req.get("flow_id") - cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id)) + cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id, source)) if not cache_raw: return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.") @@ -291,5 +375,109 @@ def poll_google_drive_web_result(): if result.get("user_id") != current_user.id: return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.") - REDIS_CONN.delete(_web_result_cache_key(flow_id)) + REDIS_CONN.delete(_web_result_cache_key(flow_id, source)) return get_json_result(data={"credentials": result.get("credentials")}) + +@manager.route("/box/oauth/web/start", methods=["POST"]) # noqa: F821 +@login_required +async def start_box_web_oauth(): + req = await get_request_json() + + client_id = req.get("client_id") + client_secret = req.get("client_secret") + redirect_uri = req.get("redirect_uri", BOX_WEB_OAUTH_REDIRECT_URI) + + if not client_id or not client_secret: + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Box client_id and client_secret are required.") + + flow_id = str(uuid.uuid4()) + + box_auth = BoxOAuth( + OAuthConfig( + client_id=client_id, + client_secret=client_secret, + ) + ) + + auth_url = box_auth.get_authorize_url( + options=GetAuthorizeUrlOptions( + redirect_uri=redirect_uri, + state=flow_id, + ) + ) + + cache_payload = { + "user_id": current_user.id, + "auth_url": auth_url, + "client_id": client_id, + "client_secret": client_secret, + "created_at": int(time.time()), + } + REDIS_CONN.set_obj(_web_state_cache_key(flow_id, "box"), cache_payload, WEB_FLOW_TTL_SECS) + return get_json_result( + data = { + "flow_id": flow_id, + "authorization_url": auth_url, + "expires_in": WEB_FLOW_TTL_SECS,} + ) + +@manager.route("/box/oauth/web/callback", methods=["GET"]) # noqa: F821 +async def box_web_oauth_callback(): + flow_id = request.args.get("state") + if not flow_id: + return await _render_web_oauth_popup("", False, "Missing OAuth parameters.", "box") + + code = request.args.get("code") + if not code: + return await _render_web_oauth_popup(flow_id, False, "Missing authorization code from Box.", "box") + + cache_payload = json.loads(REDIS_CONN.get(_web_state_cache_key(flow_id, "box"))) + if not cache_payload: + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="Box OAuth session expired or invalid.") + + error = request.args.get("error") + error_description = request.args.get("error_description") or error + if error: + REDIS_CONN.delete(_web_state_cache_key(flow_id, "box")) + return await _render_web_oauth_popup(flow_id, False, error_description or "Authorization failed.", "box") + + auth = BoxOAuth( + OAuthConfig( + client_id=cache_payload.get("client_id"), + client_secret=cache_payload.get("client_secret"), + ) + ) + + auth.get_tokens_authorization_code_grant(code) + token = auth.retrieve_token() + result_payload = { + "user_id": cache_payload.get("user_id"), + "client_id": cache_payload.get("client_id"), + "client_secret": cache_payload.get("client_secret"), + "access_token": token.access_token, + "refresh_token": token.refresh_token, + } + + REDIS_CONN.set_obj(_web_result_cache_key(flow_id, "box"), result_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.delete(_web_state_cache_key(flow_id, "box")) + + return await _render_web_oauth_popup(flow_id, True, "Authorization completed successfully.", "box") + +@manager.route("/box/oauth/web/result", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("flow_id") +async def poll_box_web_result(): + req = await get_request_json() + flow_id = req.get("flow_id") + + cache_blob = REDIS_CONN.get(_web_result_cache_key(flow_id, "box")) + if not cache_blob: + return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.") + + cache_raw = json.loads(cache_blob) + if cache_raw.get("user_id") != current_user.id: + return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.") + + REDIS_CONN.delete(_web_result_cache_key(flow_id, "box")) + + return get_json_result(data={"credentials": cache_raw}) \ No newline at end of file diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 984e57caccd..b85921115c2 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -14,19 +14,21 @@ # limitations under the License. # import json +import os import re import logging from copy import deepcopy -from flask import Response, request -from flask_login import current_user, login_required +import tempfile +from quart import Response, request +from api.apps import current_user, login_required from api.db.db_models import APIToken from api.db.services.conversation_service import ConversationService, structure_answer -from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap +from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from rag.prompts.template import load_prompt from rag.prompts.generator import chunks_format from common.constants import RetCode, LLMType @@ -34,8 +36,8 @@ @manager.route("/set", methods=["POST"]) # noqa: F821 @login_required -def set_conversation(): - req = request.json +async def set_conversation(): + req = await get_request_json() conv_id = req.get("conversation_id") is_new = req.get("is_new") name = req.get("name", "New conversation") @@ -78,14 +80,13 @@ def set_conversation(): @manager.route("/get", methods=["GET"]) # noqa: F821 @login_required -def get(): +async def get(): conv_id = request.args["conversation_id"] try: e, conv = ConversationService.get_by_id(conv_id) if not e: return get_data_error_result(message="Conversation not found!") tenants = UserTenantService.query(user_id=current_user.id) - avatar = None for tenant in tenants: dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id) if dialog and len(dialog) > 0: @@ -129,8 +130,9 @@ def getsse(dialog_id): @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required -def rm(): - conv_ids = request.json["conversation_ids"] +async def rm(): + req = await get_request_json() + conv_ids = req["conversation_ids"] try: for cid in conv_ids: exist, conv = ConversationService.get_by_id(cid) @@ -150,7 +152,7 @@ def rm(): @manager.route("/list", methods=["GET"]) # noqa: F821 @login_required -def list_conversation(): +async def list_conversation(): dialog_id = request.args["dialog_id"] try: if not DialogService.query(tenant_id=current_user.id, id=dialog_id): @@ -166,8 +168,8 @@ def list_conversation(): @manager.route("/completion", methods=["POST"]) # noqa: F821 @login_required @validate_request("conversation_id", "messages") -def completion(): - req = request.json +async def completion(): + req = await get_request_json() msg = [] for m in req["messages"]: if m["role"] == "system": @@ -216,10 +218,10 @@ def completion(): dia.llm_setting = chat_model_config is_embedded = bool(chat_model_id) - def stream(): + async def stream(): nonlocal dia, msg, req, conv try: - for ans in chat(dia, msg, True, **req): + async for ans in async_chat(dia, msg, True, **req): ans = structure_answer(conv, ans, message_id, conv.id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" if not is_embedded: @@ -239,7 +241,7 @@ def stream(): else: answer = None - for ans in chat(dia, msg, **req): + async for ans in async_chat(dia, msg, **req): answer = structure_answer(conv, ans, message_id, conv.id) if not is_embedded: ConversationService.update_by_id(conv.id, conv.to_dict()) @@ -248,11 +250,69 @@ def stream(): except Exception as e: return server_error_response(e) +@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821 +@login_required +async def sequence2txt(): + req = await request.form + stream_mode = req.get("stream", "false").lower() == "true" + files = await request.files + if "file" not in files: + return get_data_error_result(message="Missing 'file' in multipart form-data") + + uploaded = files["file"] + + ALLOWED_EXTS = { + ".wav", ".mp3", ".m4a", ".aac", + ".flac", ".ogg", ".webm", + ".opus", ".wma" + } + + filename = uploaded.filename or "" + suffix = os.path.splitext(filename)[-1].lower() + if suffix not in ALLOWED_EXTS: + return get_data_error_result(message= + f"Unsupported audio format: {suffix}. " + f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}" + ) + fd, temp_audio_path = tempfile.mkstemp(suffix=suffix) + os.close(fd) + await uploaded.save(temp_audio_path) + + tenants = TenantService.get_info_by(current_user.id) + if not tenants: + return get_data_error_result(message="Tenant not found!") + + asr_id = tenants[0]["asr_id"] + if not asr_id: + return get_data_error_result(message="No default ASR model is set") + + asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id) + if not stream_mode: + text = asr_mdl.transcription(temp_audio_path) + try: + os.remove(temp_audio_path) + except Exception as e: + logging.error(f"Failed to remove temp audio file: {str(e)}") + return get_json_result(data={"text": text}) + async def event_stream(): + try: + for evt in asr_mdl.stream_transcription(temp_audio_path): + yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n" + except Exception as e: + err = {"event": "error", "text": str(e)} + yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n" + finally: + try: + os.remove(temp_audio_path) + except Exception as e: + logging.error(f"Failed to remove temp audio file: {str(e)}") + + return Response(event_stream(), content_type="text/event-stream") @manager.route("/tts", methods=["POST"]) # noqa: F821 @login_required -def tts(): - req = request.json +async def tts(): + req = await get_request_json() text = req["text"] tenants = TenantService.get_info_by(current_user.id) @@ -284,8 +344,8 @@ def stream_audio(): @manager.route("/delete_msg", methods=["POST"]) # noqa: F821 @login_required @validate_request("conversation_id", "message_id") -def delete_msg(): - req = request.json +async def delete_msg(): + req = await get_request_json() e, conv = ConversationService.get_by_id(req["conversation_id"]) if not e: return get_data_error_result(message="Conversation not found!") @@ -307,8 +367,8 @@ def delete_msg(): @manager.route("/thumbup", methods=["POST"]) # noqa: F821 @login_required @validate_request("conversation_id", "message_id") -def thumbup(): - req = request.json +async def thumbup(): + req = await get_request_json() e, conv = ConversationService.get_by_id(req["conversation_id"]) if not e: return get_data_error_result(message="Conversation not found!") @@ -334,8 +394,8 @@ def thumbup(): @manager.route("/ask", methods=["POST"]) # noqa: F821 @login_required @validate_request("question", "kb_ids") -def ask_about(): - req = request.json +async def ask_about(): + req = await get_request_json() uid = current_user.id search_id = req.get("search_id", "") @@ -346,10 +406,10 @@ def ask_about(): if search_app: search_config = search_app.get("search_config", {}) - def stream(): + async def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): + async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" @@ -366,8 +426,8 @@ def stream(): @manager.route("/mindmap", methods=["POST"]) # noqa: F821 @login_required @validate_request("question", "kb_ids") -def mindmap(): - req = request.json +async def mindmap(): + req = await get_request_json() search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} search_config = search_app.get("search_config", {}) if search_app else {} @@ -375,7 +435,7 @@ def mindmap(): kb_ids.extend(req["kb_ids"]) kb_ids = list(set(kb_ids)) - mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config) + mind_map = await gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config) if "error" in mind_map: return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) @@ -384,8 +444,8 @@ def mindmap(): @manager.route("/related_questions", methods=["POST"]) # noqa: F821 @login_required @validate_request("question") -def related_questions(): - req = request.json +async def related_questions(): + req = await get_request_json() search_id = req.get("search_id", "") search_config = {} @@ -402,7 +462,7 @@ def related_questions(): if "parameter" in gen_conf: del gen_conf["parameter"] prompt = load_prompt("related_question") - ans = chat_mdl.chat( + ans = await chat_mdl.async_chat( prompt, [ { diff --git a/api/apps/dialog_app.py b/api/apps/dialog_app.py index 99f70056891..d2aad88ee1a 100644 --- a/api/apps/dialog_app.py +++ b/api/apps/dialog_app.py @@ -14,25 +14,24 @@ # limitations under the License. # -from flask import request -from flask_login import login_required, current_user +from quart import request from api.db.services import duplicate_name from api.db.services.dialog_service import DialogService from common.constants import StatusEnum from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.misc_utils import get_uuid from common.constants import RetCode -from api.utils.api_utils import get_json_result +from api.apps import login_required, current_user @manager.route('/set', methods=['POST']) # noqa: F821 @validate_request("prompt_config") @login_required -def set_dialog(): - req = request.json +async def set_dialog(): + req = await get_request_json() dialog_id = req.get("dialog_id", "") is_create = not dialog_id name = req.get("name", "New Dialog") @@ -66,7 +65,7 @@ def set_dialog(): if not is_create: if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']: - return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no knowledge base / Tavily used here.") + return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") for p in prompt_config["parameters"]: if p["optional"]: @@ -154,33 +153,34 @@ def get_kb_names(kb_ids): @login_required def list_dialogs(): try: - diags = DialogService.query( + conversations = DialogService.query( tenant_id=current_user.id, status=StatusEnum.VALID.value, reverse=True, order_by=DialogService.model.create_time) - diags = [d.to_dict() for d in diags] - for d in diags: - d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) - return get_json_result(data=diags) + conversations = [d.to_dict() for d in conversations] + for conversation in conversations: + conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"]) + return get_json_result(data=conversations) except Exception as e: return server_error_response(e) @manager.route('/next', methods=['POST']) # noqa: F821 @login_required -def list_dialogs_next(): - keywords = request.args.get("keywords", "") - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - parser_id = request.args.get("parser_id") - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "true").lower() == "false": +async def list_dialogs_next(): + args = request.args + keywords = args.get("keywords", "") + page_number = int(args.get("page", 0)) + items_per_page = int(args.get("page_size", 0)) + parser_id = args.get("parser_id") + orderby = args.get("orderby", "create_time") + if args.get("desc", "true").lower() == "false": desc = False else: desc = True - req = request.get_json() + req = await get_request_json() owner_ids = req.get("owner_ids", []) try: if not owner_ids: @@ -207,8 +207,8 @@ def list_dialogs_next(): @manager.route('/rm', methods=['POST']) # noqa: F821 @login_required @validate_request("dialog_ids") -def rm(): - req = request.json +async def rm(): + req = await get_request_json() dialog_list=[] tenants = UserTenantService.query(user_id=current_user.id) try: diff --git a/api/apps/document_app.py b/api/apps/document_app.py index c2e37598e92..4fcc07e65c8 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -13,22 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License # +import asyncio import json import os.path import pathlib import re from pathlib import Path - -import flask -from flask import request -from flask_login import current_user, login_required - +from quart import request, make_response +from api.apps import current_user, login_required from api.common.check_team_permission import check_kb_team_permission from api.constants import FILE_NAME_LEN_LIMIT, IMG_BASE64_PREFIX from api.db import VALID_FILE_TYPES, FileType from api.db.db_models import Task from api.db.services import duplicate_name from api.db.services.document_service import DocumentService, doc_upload_and_parse +from common.metadata_utils import meta_filter, convert_conditions from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService @@ -39,7 +38,7 @@ get_data_error_result, get_json_result, server_error_response, - validate_request, + validate_request, get_request_json, ) from api.utils.file_utils import filename_type, thumbnail from common.file_utils import get_project_base_directory @@ -53,14 +52,16 @@ @manager.route("/upload", methods=["POST"]) # noqa: F821 @login_required @validate_request("kb_id") -def upload(): - kb_id = request.form.get("kb_id") +async def upload(): + form = await request.form + kb_id = form.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - if "file" not in request.files: + files = await request.files + if "file" not in files: return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist("file") + file_objs = files.getlist("file") for file_obj in file_objs: if file_obj.filename == "": return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR) @@ -69,11 +70,11 @@ def upload(): e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: - raise LookupError("Can't find this knowledgebase!") + raise LookupError("Can't find this dataset!") if not check_kb_team_permission(kb, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - err, files = FileService.upload_document(kb, file_objs, current_user.id) + err, files = await asyncio.to_thread(FileService.upload_document, kb, file_objs, current_user.id) if err: return get_json_result(data=files, message="\n".join(err), code=RetCode.SERVER_ERROR) @@ -87,17 +88,18 @@ def upload(): @manager.route("/web_crawl", methods=["POST"]) # noqa: F821 @login_required @validate_request("kb_id", "name", "url") -def web_crawl(): - kb_id = request.form.get("kb_id") +async def web_crawl(): + form = await request.form + kb_id = form.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - name = request.form.get("name") - url = request.form.get("url") + name = form.get("name") + url = form.get("url") if not is_valid_url(url): return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: - raise LookupError("Can't find this knowledgebase!") + raise LookupError("Can't find this dataset!") if check_kb_team_permission(kb, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) @@ -152,8 +154,8 @@ def web_crawl(): @manager.route("/create", methods=["POST"]) # noqa: F821 @login_required @validate_request("name", "kb_id") -def create(): - req = request.json +async def create(): + req = await get_request_json() kb_id = req["kb_id"] if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) @@ -167,10 +169,10 @@ def create(): try: e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: - return get_data_error_result(message="Can't find this knowledgebase!") + return get_data_error_result(message="Can't find this dataset!") if DocumentService.query(name=req["name"], kb_id=kb_id): - return get_data_error_result(message="Duplicated document name in the same knowledgebase.") + return get_data_error_result(message="Duplicated document name in the same dataset.") kb_root_folder = FileService.get_kb_folder(kb.tenant_id) if not kb_root_folder: @@ -208,7 +210,7 @@ def create(): @manager.route("/list", methods=["POST"]) # noqa: F821 @login_required -def list_docs(): +async def list_docs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) @@ -217,7 +219,7 @@ def list_docs(): if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): break else: - return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) keywords = request.args.get("keywords", "") page_number = int(request.args.get("page", 0)) @@ -230,7 +232,11 @@ def list_docs(): create_time_from = int(request.args.get("create_time_from", 0)) create_time_to = int(request.args.get("create_time_to", 0)) - req = request.get_json() + req = await get_request_json() + + return_empty_metadata = req.get("return_empty_metadata", False) + if isinstance(return_empty_metadata, str): + return_empty_metadata = return_empty_metadata.lower() == "true" run_status = req.get("run_status", []) if run_status: @@ -245,9 +251,74 @@ def list_docs(): return get_data_error_result(message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}") suffix = req.get("suffix", []) + metadata_condition = req.get("metadata_condition", {}) or {} + metadata = req.get("metadata", {}) or {} + if isinstance(metadata, dict) and metadata.get("empty_metadata"): + return_empty_metadata = True + metadata = {k: v for k, v in metadata.items() if k != "empty_metadata"} + if return_empty_metadata: + metadata_condition = {} + metadata = {} + else: + if metadata_condition and not isinstance(metadata_condition, dict): + return get_data_error_result(message="metadata_condition must be an object.") + if metadata and not isinstance(metadata, dict): + return get_data_error_result(message="metadata must be an object.") + + doc_ids_filter = None + metas = None + if metadata_condition or metadata: + metas = DocumentService.get_flatted_meta_by_kbs([kb_id]) + + if metadata_condition: + doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) + if metadata_condition.get("conditions") and not doc_ids_filter: + return get_json_result(data={"total": 0, "docs": []}) + + if metadata: + metadata_doc_ids = None + for key, values in metadata.items(): + if not values: + continue + if not isinstance(values, list): + values = [values] + values = [str(v) for v in values if v is not None and str(v).strip()] + if not values: + continue + key_doc_ids = set() + for value in values: + key_doc_ids.update(metas.get(key, {}).get(value, [])) + if metadata_doc_ids is None: + metadata_doc_ids = key_doc_ids + else: + metadata_doc_ids &= key_doc_ids + if not metadata_doc_ids: + return get_json_result(data={"total": 0, "docs": []}) + if metadata_doc_ids is not None: + if doc_ids_filter is None: + doc_ids_filter = metadata_doc_ids + else: + doc_ids_filter &= metadata_doc_ids + if not doc_ids_filter: + return get_json_result(data={"total": 0, "docs": []}) + + if doc_ids_filter is not None: + doc_ids_filter = list(doc_ids_filter) try: - docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix) + docs, tol = DocumentService.get_by_kb_id( + kb_id, + page_number, + items_per_page, + orderby, + desc, + keywords, + run_status, + types, + suffix, + doc_ids_filter, + return_empty_metadata=return_empty_metadata, + ) if create_time_from or create_time_to: filtered_docs = [] @@ -270,8 +341,8 @@ def list_docs(): @manager.route("/filter", methods=["POST"]) # noqa: F821 @login_required -def get_filter(): - req = request.get_json() +async def get_filter(): + req = await get_request_json() kb_id = req.get("kb_id") if not kb_id: @@ -281,7 +352,7 @@ def get_filter(): if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): break else: - return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) keywords = req.get("keywords", "") @@ -308,8 +379,8 @@ def get_filter(): @manager.route("/infos", methods=["POST"]) # noqa: F821 @login_required -def docinfos(): - req = request.json +async def doc_infos(): + req = await get_request_json() doc_ids = req["doc_ids"] for doc_id in doc_ids: if not DocumentService.accessible(doc_id, current_user.id): @@ -318,6 +389,107 @@ def docinfos(): return get_json_result(data=list(docs.dicts())) +@manager.route("/metadata/summary", methods=["POST"]) # noqa: F821 +@login_required +async def metadata_summary(): + req = await get_request_json() + kb_id = req.get("kb_id") + if not kb_id: + return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) + + tenants = UserTenantService.query(user_id=current_user.id) + for tenant in tenants: + if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): + break + else: + return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) + + try: + summary = DocumentService.get_metadata_summary(kb_id) + return get_json_result(data={"summary": summary}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/metadata/update", methods=["POST"]) # noqa: F821 +@login_required +async def metadata_update(): + req = await get_request_json() + kb_id = req.get("kb_id") + if not kb_id: + return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) + + tenants = UserTenantService.query(user_id=current_user.id) + for tenant in tenants: + if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): + break + else: + return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) + + selector = req.get("selector", {}) or {} + updates = req.get("updates", []) or [] + deletes = req.get("deletes", []) or [] + + if not isinstance(selector, dict): + return get_json_result(data=False, message="selector must be an object.", code=RetCode.ARGUMENT_ERROR) + if not isinstance(updates, list) or not isinstance(deletes, list): + return get_json_result(data=False, message="updates and deletes must be lists.", code=RetCode.ARGUMENT_ERROR) + + metadata_condition = selector.get("metadata_condition", {}) or {} + if metadata_condition and not isinstance(metadata_condition, dict): + return get_json_result(data=False, message="metadata_condition must be an object.", code=RetCode.ARGUMENT_ERROR) + + document_ids = selector.get("document_ids", []) or [] + if document_ids and not isinstance(document_ids, list): + return get_json_result(data=False, message="document_ids must be a list.", code=RetCode.ARGUMENT_ERROR) + + for upd in updates: + if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd: + return get_json_result(data=False, message="Each update requires key and value.", code=RetCode.ARGUMENT_ERROR) + for d in deletes: + if not isinstance(d, dict) or not d.get("key"): + return get_json_result(data=False, message="Each delete requires key.", code=RetCode.ARGUMENT_ERROR) + + kb_doc_ids = KnowledgebaseService.list_documents_by_ids([kb_id]) + target_doc_ids = set(kb_doc_ids) + if document_ids: + invalid_ids = set(document_ids) - set(kb_doc_ids) + if invalid_ids: + return get_json_result(data=False, message=f"These documents do not belong to dataset {kb_id}: {', '.join(invalid_ids)}", code=RetCode.ARGUMENT_ERROR) + target_doc_ids = set(document_ids) + + if metadata_condition: + metas = DocumentService.get_flatted_meta_by_kbs([kb_id]) + filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) + target_doc_ids = target_doc_ids & filtered_ids + if metadata_condition.get("conditions") and not target_doc_ids: + return get_json_result(data={"updated": 0, "matched_docs": 0}) + + target_doc_ids = list(target_doc_ids) + updated = DocumentService.batch_update_metadata(kb_id, target_doc_ids, updates, deletes) + return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)}) + + +@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("doc_id", "metadata") +async def update_metadata_setting(): + req = await get_request_json() + if not DocumentService.accessible(req["doc_id"], current_user.id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + + DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]}) + e, doc = DocumentService.get_by_id(doc.id) + if not e: + return get_data_error_result(message="Document not found!") + + return get_json_result(data=doc.to_dict()) + + @manager.route("/thumbnails", methods=["GET"]) # noqa: F821 # @login_required def thumbnails(): @@ -340,8 +512,8 @@ def thumbnails(): @manager.route("/change_status", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_ids", "status") -def change_status(): - req = request.get_json() +async def change_status(): + req = await get_request_json() doc_ids = req.get("doc_ids", []) status = str(req.get("status", "")) @@ -361,7 +533,7 @@ def change_status(): continue e, kb = KnowledgebaseService.get_by_id(doc.kb_id) if not e: - result[doc_id] = {"error": "Can't find this knowledgebase!"} + result[doc_id] = {"error": "Can't find this dataset!"} continue if not DocumentService.update_by_id(doc_id, {"status": str(status)}): result[doc_id] = {"error": "Database error (Document update)!"} @@ -380,8 +552,8 @@ def change_status(): @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id") -def rm(): - req = request.json +async def rm(): + req = await get_request_json() doc_ids = req["doc_id"] if isinstance(doc_ids, str): doc_ids = [doc_ids] @@ -390,7 +562,7 @@ def rm(): if not DocumentService.accessible4deletion(doc_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - errors = FileService.delete_docs(doc_ids, current_user.id) + errors = await asyncio.to_thread(FileService.delete_docs, doc_ids, current_user.id) if errors: return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) @@ -401,46 +573,57 @@ def rm(): @manager.route("/run", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_ids", "run") -def run(): - req = request.json - for doc_id in req["doc_ids"]: - if not DocumentService.accessible(doc_id, current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) +async def run(): + req = await get_request_json() try: - kb_table_num_map = {} - for id in req["doc_ids"]: - info = {"run": str(req["run"]), "progress": 0} - if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False): - info["progress_msg"] = "" - info["chunk_num"] = 0 - info["token_num"] = 0 - - tenant_id = DocumentService.get_tenant_id(id) - if not tenant_id: - return get_data_error_result(message="Tenant not found!") - e, doc = DocumentService.get_by_id(id) - if not e: - return get_data_error_result(message="Document not found!") - - if str(req["run"]) == TaskStatus.CANCEL.value: - if str(doc.run) == TaskStatus.RUNNING.value: - cancel_all_task_of(id) - else: - return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status") - if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]): - DocumentService.clear_chunk_num_when_rerun(doc.id) + def _run_sync(): + for doc_id in req["doc_ids"]: + if not DocumentService.accessible(doc_id, current_user.id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + + kb_table_num_map = {} + for id in req["doc_ids"]: + info = {"run": str(req["run"]), "progress": 0} + if str(req["run"]) == TaskStatus.RUNNING.value and req.get("delete", False): + info["progress_msg"] = "" + info["chunk_num"] = 0 + info["token_num"] = 0 + + tenant_id = DocumentService.get_tenant_id(id) + if not tenant_id: + return get_data_error_result(message="Tenant not found!") + e, doc = DocumentService.get_by_id(id) + if not e: + return get_data_error_result(message="Document not found!") + + if str(req["run"]) == TaskStatus.CANCEL.value: + if str(doc.run) == TaskStatus.RUNNING.value: + cancel_all_task_of(id) + else: + return get_data_error_result(message="Cannot cancel a task that is not in RUNNING status") + if all([("delete" not in req or req["delete"]), str(req["run"]) == TaskStatus.RUNNING.value, str(doc.run) == TaskStatus.DONE.value]): + DocumentService.clear_chunk_num_when_rerun(doc.id) + + DocumentService.update_by_id(id, info) + if req.get("delete", False): + TaskService.filter_delete([Task.doc_id == id]) + if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) + + if str(req["run"]) == TaskStatus.RUNNING.value: + if req.get("apply_kb"): + e, kb = KnowledgebaseService.get_by_id(doc.kb_id) + if not e: + raise LookupError("Can't find this dataset!") + doc.parser_config["enable_metadata"] = kb.parser_config.get("enable_metadata", False) + doc.parser_config["metadata"] = kb.parser_config.get("metadata", {}) + DocumentService.update_parser_config(doc.id, doc.parser_config) + doc_dict = doc.to_dict() + DocumentService.run(tenant_id, doc_dict, kb_table_num_map) - DocumentService.update_by_id(id, info) - if req.get("delete", False): - TaskService.filter_delete([Task.doc_id == id]) - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id) - - if str(req["run"]) == TaskStatus.RUNNING.value: - doc = doc.to_dict() - DocumentService.run(tenant_id, doc, kb_table_num_map) + return get_json_result(data=True) - return get_json_result(data=True) + return await asyncio.to_thread(_run_sync) except Exception as e: return server_error_response(e) @@ -448,66 +631,72 @@ def run(): @manager.route("/rename", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id", "name") -def rename(): - req = request.json - if not DocumentService.accessible(req["doc_id"], current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) +async def rename(): + req = await get_request_json() try: - e, doc = DocumentService.get_by_id(req["doc_id"]) - if not e: - return get_data_error_result(message="Document not found!") - if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: - return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR) - if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT: - return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR) + def _rename_sync(): + if not DocumentService.accessible(req["doc_id"], current_user.id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + + e, doc = DocumentService.get_by_id(req["doc_id"]) + if not e: + return get_data_error_result(message="Document not found!") + if pathlib.Path(req["name"].lower()).suffix != pathlib.Path(doc.name.lower()).suffix: + return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.ARGUMENT_ERROR) + if len(req["name"].encode("utf-8")) > FILE_NAME_LEN_LIMIT: + return get_json_result(data=False, message=f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.", code=RetCode.ARGUMENT_ERROR) - for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): - if d.name == req["name"]: - return get_data_error_result(message="Duplicated document name in the same knowledgebase.") + for d in DocumentService.query(name=req["name"], kb_id=doc.kb_id): + if d.name == req["name"]: + return get_data_error_result(message="Duplicated document name in the same dataset.") - if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}): - return get_data_error_result(message="Database error (Document rename)!") + if not DocumentService.update_by_id(req["doc_id"], {"name": req["name"]}): + return get_data_error_result(message="Database error (Document rename)!") - informs = File2DocumentService.get_by_document_id(req["doc_id"]) - if informs: - e, file = FileService.get_by_id(informs[0].file_id) - FileService.update_by_id(file.id, {"name": req["name"]}) + informs = File2DocumentService.get_by_document_id(req["doc_id"]) + if informs: + e, file = FileService.get_by_id(informs[0].file_id) + FileService.update_by_id(file.id, {"name": req["name"]}) - tenant_id = DocumentService.get_tenant_id(req["doc_id"]) - title_tks = rag_tokenizer.tokenize(req["name"]) - es_body = { - "docnm_kwd": req["name"], - "title_tks": title_tks, - "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), - } - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): - settings.docStoreConn.update( - {"doc_id": req["doc_id"]}, - es_body, - search.index_name(tenant_id), - doc.kb_id, - ) + tenant_id = DocumentService.get_tenant_id(req["doc_id"]) + title_tks = rag_tokenizer.tokenize(req["name"]) + es_body = { + "docnm_kwd": req["name"], + "title_tks": title_tks, + "title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks), + } + if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): + settings.docStoreConn.update( + {"doc_id": req["doc_id"]}, + es_body, + search.index_name(tenant_id), + doc.kb_id, + ) + return get_json_result(data=True) + + return await asyncio.to_thread(_rename_sync) - return get_json_result(data=True) except Exception as e: return server_error_response(e) @manager.route("/get/", methods=["GET"]) # noqa: F821 # @login_required -def get(doc_id): +async def get(doc_id): try: e, doc = DocumentService.get_by_id(doc_id) if not e: return get_data_error_result(message="Document not found!") b, n = File2DocumentService.get_storage_address(doc_id=doc_id) - response = flask.make_response(settings.STORAGE_IMPL.get(b, n)) + data = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n) + response = await make_response(data) ext = re.search(r"\.([^.]+)$", doc.name.lower()) ext = ext.group(1) if ext else None if ext: if doc.type == FileType.VISUAL.value: + content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}") else: content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}") @@ -517,12 +706,27 @@ def get(doc_id): return server_error_response(e) +@manager.route("/download/", methods=["GET"]) # noqa: F821 +@login_required +async def download_attachment(attachment_id): + try: + ext = request.args.get("ext", "markdown") + data = await asyncio.to_thread(settings.STORAGE_IMPL.get, current_user.id, attachment_id) + response = await make_response(data) + response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) + + return response + + except Exception as e: + return server_error_response(e) + + @manager.route("/change_parser", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id") -def change_parser(): +async def change_parser(): - req = request.json + req = await get_request_json() if not DocumentService.accessible(req["doc_id"], current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) @@ -542,8 +746,10 @@ def reset_doc(): tenant_id = DocumentService.get_tenant_id(req["doc_id"]) if not tenant_id: return get_data_error_result(message="Tenant not found!") - if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id): + DocumentService.delete_chunk_images(doc, tenant_id) + if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id): settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) + return None try: if "pipeline_id" in req and req["pipeline_id"] != "": @@ -572,13 +778,14 @@ def reset_doc(): @manager.route("/image/", methods=["GET"]) # noqa: F821 # @login_required -def get_image(image_id): +async def get_image(image_id): try: arr = image_id.split("-") if len(arr) != 2: return get_data_error_result(message="Image not found.") bkt, nm = image_id.split("-") - response = flask.make_response(settings.STORAGE_IMPL.get(bkt, nm)) + data = await asyncio.to_thread(settings.STORAGE_IMPL.get, bkt, nm) + response = await make_response(data) response.headers.set("Content-Type", "image/JPEG") return response except Exception as e: @@ -588,24 +795,26 @@ def get_image(image_id): @manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821 @login_required @validate_request("conversation_id") -def upload_and_parse(): - if "file" not in request.files: +async def upload_and_parse(): + files = await request.files + if "file" not in files: return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist("file") + file_objs = files.getlist("file") for file_obj in file_objs: if file_obj.filename == "": return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR) - doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) - + form = await request.form + doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id) return get_json_result(data=doc_ids) @manager.route("/parse", methods=["POST"]) # noqa: F821 @login_required -def parse(): - url = request.json.get("url") if request.json else "" +async def parse(): + req = await get_request_json() + url = req.get("url", "") if url: if not is_valid_url(url): return get_json_result(data=False, message="The URL format is invalid", code=RetCode.ARGUMENT_ERROR) @@ -646,10 +855,11 @@ def read(self): txt = FileService.parse_docs([f], current_user.id) return get_json_result(data=txt) - if "file" not in request.files: + files = await request.files + if "file" not in files: return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist("file") + file_objs = files.getlist("file") txt = FileService.parse_docs(file_objs, current_user.id) return get_json_result(data=txt) @@ -658,8 +868,8 @@ def read(self): @manager.route("/set_meta", methods=["POST"]) # noqa: F821 @login_required @validate_request("doc_id", "meta") -def set_meta(): - req = request.json +async def set_meta(): + req = await get_request_json() if not DocumentService.accessible(req["doc_id"], current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: @@ -667,7 +877,10 @@ def set_meta(): if not isinstance(meta, dict): return get_json_result(data=False, message="Only dictionary type supported.", code=RetCode.ARGUMENT_ERROR) for k, v in meta.items(): - if not isinstance(v, str) and not isinstance(v, int) and not isinstance(v, float): + if isinstance(v, list): + if not all(isinstance(i, (str, int, float)) for i in v): + return get_json_result(data=False, message=f"The type is not supported in list: {v}", code=RetCode.ARGUMENT_ERROR) + elif not isinstance(v, (str, int, float)): return get_json_result(data=False, message=f"The type is not supported: {v}", code=RetCode.ARGUMENT_ERROR) except Exception as e: return get_json_result(data=False, message=f"Json syntax error: {e}", code=RetCode.ARGUMENT_ERROR) @@ -685,3 +898,13 @@ def set_meta(): return get_json_result(data=True) except Exception as e: return server_error_response(e) + + +@manager.route("/upload_info", methods=["POST"]) # noqa: F821 +async def upload_info(): + files = await request.files + file = files['file'] if files and files.get("file") else None + try: + return get_json_result(data=FileService.upload_info(current_user.id, file, request.args.get("url"))) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/evaluation_app.py b/api/apps/evaluation_app.py new file mode 100644 index 00000000000..b33db26da17 --- /dev/null +++ b/api/apps/evaluation_app.py @@ -0,0 +1,479 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +RAG Evaluation API Endpoints + +Provides REST API for RAG evaluation functionality including: +- Dataset management +- Test case management +- Evaluation execution +- Results retrieval +- Configuration recommendations +""" + +from quart import request +from api.apps import login_required, current_user +from api.db.services.evaluation_service import EvaluationService +from api.utils.api_utils import ( + get_data_error_result, + get_json_result, + get_request_json, + server_error_response, + validate_request +) +from common.constants import RetCode + + +# ==================== Dataset Management ==================== + +@manager.route('/dataset/create', methods=['POST']) # noqa: F821 +@login_required +@validate_request("name", "kb_ids") +async def create_dataset(): + """ + Create a new evaluation dataset. + + Request body: + { + "name": "Dataset name", + "description": "Optional description", + "kb_ids": ["kb_id1", "kb_id2"] + } + """ + try: + req = await get_request_json() + name = req.get("name", "").strip() + description = req.get("description", "") + kb_ids = req.get("kb_ids", []) + + if not name: + return get_data_error_result(message="Dataset name cannot be empty") + + if not kb_ids or not isinstance(kb_ids, list): + return get_data_error_result(message="kb_ids must be a non-empty list") + + success, result = EvaluationService.create_dataset( + name=name, + description=description, + kb_ids=kb_ids, + tenant_id=current_user.id, + user_id=current_user.id + ) + + if not success: + return get_data_error_result(message=result) + + return get_json_result(data={"dataset_id": result}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/dataset/list', methods=['GET']) # noqa: F821 +@login_required +async def list_datasets(): + """ + List evaluation datasets for current tenant. + + Query params: + - page: Page number (default: 1) + - page_size: Items per page (default: 20) + """ + try: + page = int(request.args.get("page", 1)) + page_size = int(request.args.get("page_size", 20)) + + result = EvaluationService.list_datasets( + tenant_id=current_user.id, + user_id=current_user.id, + page=page, + page_size=page_size + ) + + return get_json_result(data=result) + except Exception as e: + return server_error_response(e) + + +@manager.route('/dataset/', methods=['GET']) # noqa: F821 +@login_required +async def get_dataset(dataset_id): + """Get dataset details by ID""" + try: + dataset = EvaluationService.get_dataset(dataset_id) + if not dataset: + return get_data_error_result( + message="Dataset not found", + code=RetCode.DATA_ERROR + ) + + return get_json_result(data=dataset) + except Exception as e: + return server_error_response(e) + + +@manager.route('/dataset/', methods=['PUT']) # noqa: F821 +@login_required +async def update_dataset(dataset_id): + """ + Update dataset. + + Request body: + { + "name": "New name", + "description": "New description", + "kb_ids": ["kb_id1", "kb_id2"] + } + """ + try: + req = await get_request_json() + + # Remove fields that shouldn't be updated + req.pop("id", None) + req.pop("tenant_id", None) + req.pop("created_by", None) + req.pop("create_time", None) + + success = EvaluationService.update_dataset(dataset_id, **req) + + if not success: + return get_data_error_result(message="Failed to update dataset") + + return get_json_result(data={"dataset_id": dataset_id}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/dataset/', methods=['DELETE']) # noqa: F821 +@login_required +async def delete_dataset(dataset_id): + """Delete dataset (soft delete)""" + try: + success = EvaluationService.delete_dataset(dataset_id) + + if not success: + return get_data_error_result(message="Failed to delete dataset") + + return get_json_result(data={"dataset_id": dataset_id}) + except Exception as e: + return server_error_response(e) + + +# ==================== Test Case Management ==================== + +@manager.route('/dataset//case/add', methods=['POST']) # noqa: F821 +@login_required +@validate_request("question") +async def add_test_case(dataset_id): + """ + Add a test case to a dataset. + + Request body: + { + "question": "Test question", + "reference_answer": "Optional ground truth answer", + "relevant_doc_ids": ["doc_id1", "doc_id2"], + "relevant_chunk_ids": ["chunk_id1", "chunk_id2"], + "metadata": {"key": "value"} + } + """ + try: + req = await get_request_json() + question = req.get("question", "").strip() + + if not question: + return get_data_error_result(message="Question cannot be empty") + + success, result = EvaluationService.add_test_case( + dataset_id=dataset_id, + question=question, + reference_answer=req.get("reference_answer"), + relevant_doc_ids=req.get("relevant_doc_ids"), + relevant_chunk_ids=req.get("relevant_chunk_ids"), + metadata=req.get("metadata") + ) + + if not success: + return get_data_error_result(message=result) + + return get_json_result(data={"case_id": result}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/dataset//case/import', methods=['POST']) # noqa: F821 +@login_required +@validate_request("cases") +async def import_test_cases(dataset_id): + """ + Bulk import test cases. + + Request body: + { + "cases": [ + { + "question": "Question 1", + "reference_answer": "Answer 1", + ... + }, + { + "question": "Question 2", + ... + } + ] + } + """ + try: + req = await get_request_json() + cases = req.get("cases", []) + + if not cases or not isinstance(cases, list): + return get_data_error_result(message="cases must be a non-empty list") + + success_count, failure_count = EvaluationService.import_test_cases( + dataset_id=dataset_id, + cases=cases + ) + + return get_json_result(data={ + "success_count": success_count, + "failure_count": failure_count, + "total": len(cases) + }) + except Exception as e: + return server_error_response(e) + + +@manager.route('/dataset//cases', methods=['GET']) # noqa: F821 +@login_required +async def get_test_cases(dataset_id): + """Get all test cases for a dataset""" + try: + cases = EvaluationService.get_test_cases(dataset_id) + return get_json_result(data={"cases": cases, "total": len(cases)}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/case/', methods=['DELETE']) # noqa: F821 +@login_required +async def delete_test_case(case_id): + """Delete a test case""" + try: + success = EvaluationService.delete_test_case(case_id) + + if not success: + return get_data_error_result(message="Failed to delete test case") + + return get_json_result(data={"case_id": case_id}) + except Exception as e: + return server_error_response(e) + + +# ==================== Evaluation Execution ==================== + +@manager.route('/run/start', methods=['POST']) # noqa: F821 +@login_required +@validate_request("dataset_id", "dialog_id") +async def start_evaluation(): + """ + Start an evaluation run. + + Request body: + { + "dataset_id": "dataset_id", + "dialog_id": "dialog_id", + "name": "Optional run name" + } + """ + try: + req = await get_request_json() + dataset_id = req.get("dataset_id") + dialog_id = req.get("dialog_id") + name = req.get("name") + + success, result = EvaluationService.start_evaluation( + dataset_id=dataset_id, + dialog_id=dialog_id, + user_id=current_user.id, + name=name + ) + + if not success: + return get_data_error_result(message=result) + + return get_json_result(data={"run_id": result}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/run/', methods=['GET']) # noqa: F821 +@login_required +async def get_evaluation_run(run_id): + """Get evaluation run details""" + try: + result = EvaluationService.get_run_results(run_id) + + if not result: + return get_data_error_result( + message="Evaluation run not found", + code=RetCode.DATA_ERROR + ) + + return get_json_result(data=result) + except Exception as e: + return server_error_response(e) + + +@manager.route('/run//results', methods=['GET']) # noqa: F821 +@login_required +async def get_run_results(run_id): + """Get detailed results for an evaluation run""" + try: + result = EvaluationService.get_run_results(run_id) + + if not result: + return get_data_error_result( + message="Evaluation run not found", + code=RetCode.DATA_ERROR + ) + + return get_json_result(data=result) + except Exception as e: + return server_error_response(e) + + +@manager.route('/run/list', methods=['GET']) # noqa: F821 +@login_required +async def list_evaluation_runs(): + """ + List evaluation runs. + + Query params: + - dataset_id: Filter by dataset (optional) + - dialog_id: Filter by dialog (optional) + - page: Page number (default: 1) + - page_size: Items per page (default: 20) + """ + try: + # TODO: Implement list_runs in EvaluationService + return get_json_result(data={"runs": [], "total": 0}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/run/', methods=['DELETE']) # noqa: F821 +@login_required +async def delete_evaluation_run(run_id): + """Delete an evaluation run""" + try: + # TODO: Implement delete_run in EvaluationService + return get_json_result(data={"run_id": run_id}) + except Exception as e: + return server_error_response(e) + + +# ==================== Analysis & Recommendations ==================== + +@manager.route('/run//recommendations', methods=['GET']) # noqa: F821 +@login_required +async def get_recommendations(run_id): + """Get configuration recommendations based on evaluation results""" + try: + recommendations = EvaluationService.get_recommendations(run_id) + return get_json_result(data={"recommendations": recommendations}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/compare', methods=['POST']) # noqa: F821 +@login_required +@validate_request("run_ids") +async def compare_runs(): + """ + Compare multiple evaluation runs. + + Request body: + { + "run_ids": ["run_id1", "run_id2", "run_id3"] + } + """ + try: + req = await get_request_json() + run_ids = req.get("run_ids", []) + + if not run_ids or not isinstance(run_ids, list) or len(run_ids) < 2: + return get_data_error_result( + message="run_ids must be a list with at least 2 run IDs" + ) + + # TODO: Implement compare_runs in EvaluationService + return get_json_result(data={"comparison": {}}) + except Exception as e: + return server_error_response(e) + + +@manager.route('/run//export', methods=['GET']) # noqa: F821 +@login_required +async def export_results(run_id): + """Export evaluation results as JSON/CSV""" + try: + # format_type = request.args.get("format", "json") # TODO: Use for CSV export + + result = EvaluationService.get_run_results(run_id) + + if not result: + return get_data_error_result( + message="Evaluation run not found", + code=RetCode.DATA_ERROR + ) + + # TODO: Implement CSV export + return get_json_result(data=result) + except Exception as e: + return server_error_response(e) + + +# ==================== Real-time Evaluation ==================== + +@manager.route('/evaluate_single', methods=['POST']) # noqa: F821 +@login_required +@validate_request("question", "dialog_id") +async def evaluate_single(): + """ + Evaluate a single question-answer pair in real-time. + + Request body: + { + "question": "Test question", + "dialog_id": "dialog_id", + "reference_answer": "Optional ground truth", + "relevant_chunk_ids": ["chunk_id1", "chunk_id2"] + } + """ + try: + # req = await get_request_json() # TODO: Use for single evaluation implementation + + # TODO: Implement single evaluation + # This would execute the RAG pipeline and return metrics immediately + + return get_json_result(data={ + "answer": "", + "metrics": {}, + "retrieved_chunks": [] + }) + except Exception as e: + return server_error_response(e) diff --git a/api/apps/file2document_app.py b/api/apps/file2document_app.py index ca1e6b096d5..f410e8a1767 100644 --- a/api/apps/file2document_app.py +++ b/api/apps/file2document_app.py @@ -19,22 +19,20 @@ from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService -from flask import request -from flask_login import login_required, current_user +from api.apps import login_required, current_user from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.misc_utils import get_uuid from common.constants import RetCode from api.db import FileType from api.db.services.document_service import DocumentService -from api.utils.api_utils import get_json_result @manager.route('/convert', methods=['POST']) # noqa: F821 @login_required @validate_request("file_ids", "kb_ids") -def convert(): - req = request.json +async def convert(): + req = await get_request_json() kb_ids = req["kb_ids"] file_ids = req["file_ids"] file2documents = [] @@ -70,7 +68,7 @@ def convert(): e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: return get_data_error_result( - message="Can't find this knowledgebase!") + message="Can't find this dataset!") e, file = FileService.get_by_id(id) if not e: return get_data_error_result( @@ -79,7 +77,8 @@ def convert(): doc = DocumentService.insert({ "id": get_uuid(), "kb_id": kb.id, - "parser_id": FileService.get_parser(file.type, file.name, kb.parser_id), + "parser_id": kb.parser_id, + "pipeline_id": kb.pipeline_id, "parser_config": kb.parser_config, "created_by": current_user.id, "type": file.type, @@ -103,8 +102,8 @@ def convert(): @manager.route('/rm', methods=['POST']) # noqa: F821 @login_required @validate_request("file_ids") -def rm(): - req = request.json +async def rm(): + req = await get_request_json() file_ids = req["file_ids"] if not file_ids: return get_json_result( diff --git a/api/apps/file_app.py b/api/apps/file_app.py index 279e32525bb..1ce5d4caed9 100644 --- a/api/apps/file_app.py +++ b/api/apps/file_app.py @@ -14,13 +14,12 @@ # limitations under the License # import logging +import asyncio import os import pathlib import re - -import flask -from flask import request -from flask_login import login_required, current_user +from quart import request, make_response +from api.apps import login_required, current_user from api.common.check_team_permission import check_file_team_permission from api.db.services.document_service import DocumentService @@ -31,7 +30,7 @@ from api.db import FileType from api.db.services import duplicate_name from api.db.services.file_service import FileService -from api.utils.api_utils import get_json_result +from api.utils.api_utils import get_json_result, get_request_json from api.utils.file_utils import filename_type from api.utils.web_utils import CONTENT_TYPE_MAP from common import settings @@ -40,17 +39,19 @@ @manager.route('/upload', methods=['POST']) # noqa: F821 @login_required # @validate_request("parent_id") -def upload(): - pf_id = request.form.get("parent_id") +async def upload(): + form = await request.form + pf_id = form.get("parent_id") if not pf_id: root_folder = FileService.get_root_folder(current_user.id) pf_id = root_folder["id"] - if 'file' not in request.files: + files = await request.files + if 'file' not in files: return get_json_result( data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist('file') + file_objs = files.getlist('file') for file_obj in file_objs: if file_obj.filename == '': @@ -61,9 +62,10 @@ def upload(): e, pf_folder = FileService.get_by_id(pf_id) if not e: return get_data_error_result( message="Can't find this folder!") - for file_obj in file_objs: + + async def _handle_single_file(file_obj): MAX_FILE_NUM_PER_USER: int = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0)) - if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(current_user.id): + if 0 < MAX_FILE_NUM_PER_USER <= await asyncio.to_thread(DocumentService.get_doc_count, current_user.id): return get_data_error_result( message="Exceed the maximum file number of a free user!") # split file name path @@ -75,35 +77,36 @@ def upload(): file_len = len(file_obj_names) # get folder - file_id_list = FileService.get_id_list_by_id(pf_id, file_obj_names, 1, [pf_id]) + file_id_list = await asyncio.to_thread(FileService.get_id_list_by_id, pf_id, file_obj_names, 1, [pf_id]) len_id_list = len(file_id_list) # create folder if file_len != len_id_list: - e, file = FileService.get_by_id(file_id_list[len_id_list - 1]) + e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 1]) if not e: return get_data_error_result(message="Folder not found!") - last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, + last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 1], file_obj_names, len_id_list) else: - e, file = FileService.get_by_id(file_id_list[len_id_list - 2]) + e, file = await asyncio.to_thread(FileService.get_by_id, file_id_list[len_id_list - 2]) if not e: return get_data_error_result(message="Folder not found!") - last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, + last_folder = await asyncio.to_thread(FileService.create_folder, file, file_id_list[len_id_list - 2], file_obj_names, len_id_list) # file type filetype = filename_type(file_obj_names[file_len - 1]) location = file_obj_names[file_len - 1] - while settings.STORAGE_IMPL.obj_exist(last_folder.id, location): + while await asyncio.to_thread(settings.STORAGE_IMPL.obj_exist, last_folder.id, location): location += "_" - blob = file_obj.read() - filename = duplicate_name( + blob = await asyncio.to_thread(file_obj.read) + filename = await asyncio.to_thread( + duplicate_name, FileService.query, name=file_obj_names[file_len - 1], parent_id=last_folder.id) - settings.STORAGE_IMPL.put(last_folder.id, location, blob) - file = { + await asyncio.to_thread(settings.STORAGE_IMPL.put, last_folder.id, location, blob) + file_data = { "id": get_uuid(), "parent_id": last_folder.id, "tenant_id": current_user.id, @@ -113,8 +116,13 @@ def upload(): "location": location, "size": len(blob), } - file = FileService.insert(file) - file_res.append(file.to_json()) + inserted = await asyncio.to_thread(FileService.insert, file_data) + return inserted.to_json() + + for file_obj in file_objs: + res = await _handle_single_file(file_obj) + file_res.append(res) + return get_json_result(data=file_res) except Exception as e: return server_error_response(e) @@ -123,10 +131,10 @@ def upload(): @manager.route('/create', methods=['POST']) # noqa: F821 @login_required @validate_request("name") -def create(): - req = request.json - pf_id = request.json.get("parent_id") - input_file_type = request.json.get("type") +async def create(): + req = await get_request_json() + pf_id = req.get("parent_id") + input_file_type = req.get("type") if not pf_id: root_folder = FileService.get_root_folder(current_user.id) pf_id = root_folder["id"] @@ -238,59 +246,62 @@ def get_all_parent_folders(): @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required @validate_request("file_ids") -def rm(): - req = request.json +async def rm(): + req = await get_request_json() file_ids = req["file_ids"] - def _delete_single_file(file): - try: - if file.location: - settings.STORAGE_IMPL.rm(file.parent_id, file.location) - except Exception: - logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}") - - informs = File2DocumentService.get_by_file_id(file.id) - for inform in informs: - doc_id = inform.document_id - e, doc = DocumentService.get_by_id(doc_id) - if e and doc: - tenant_id = DocumentService.get_tenant_id(doc_id) - if tenant_id: - DocumentService.remove_document(doc, tenant_id) - File2DocumentService.delete_by_file_id(file.id) - - FileService.delete(file) - - def _delete_folder_recursive(folder, tenant_id): - sub_files = FileService.list_all_files_by_parent_id(folder.id) - for sub_file in sub_files: - if sub_file.type == FileType.FOLDER.value: - _delete_folder_recursive(sub_file, tenant_id) - else: - _delete_single_file(sub_file) + try: + def _delete_single_file(file): + try: + if file.location: + settings.STORAGE_IMPL.rm(file.parent_id, file.location) + except Exception as e: + logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}") + + informs = File2DocumentService.get_by_file_id(file.id) + for inform in informs: + doc_id = inform.document_id + e, doc = DocumentService.get_by_id(doc_id) + if e and doc: + tenant_id = DocumentService.get_tenant_id(doc_id) + if tenant_id: + DocumentService.remove_document(doc, tenant_id) + File2DocumentService.delete_by_file_id(file.id) + + FileService.delete(file) + + def _delete_folder_recursive(folder, tenant_id): + sub_files = FileService.list_all_files_by_parent_id(folder.id) + for sub_file in sub_files: + if sub_file.type == FileType.FOLDER.value: + _delete_folder_recursive(sub_file, tenant_id) + else: + _delete_single_file(sub_file) - FileService.delete(folder) + FileService.delete(folder) - try: - for file_id in file_ids: - e, file = FileService.get_by_id(file_id) - if not e or not file: - return get_data_error_result(message="File or Folder not found!") - if not file.tenant_id: - return get_data_error_result(message="Tenant not found!") - if not check_file_team_permission(file, current_user.id): - return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + def _rm_sync(): + for file_id in file_ids: + e, file = FileService.get_by_id(file_id) + if not e or not file: + return get_data_error_result(message="File or Folder not found!") + if not file.tenant_id: + return get_data_error_result(message="Tenant not found!") + if not check_file_team_permission(file, current_user.id): + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) - if file.source_type == FileSource.KNOWLEDGEBASE: - continue + if file.source_type == FileSource.KNOWLEDGEBASE: + continue - if file.type == FileType.FOLDER.value: - _delete_folder_recursive(file, current_user.id) - continue + if file.type == FileType.FOLDER.value: + _delete_folder_recursive(file, current_user.id) + continue - _delete_single_file(file) + _delete_single_file(file) - return get_json_result(data=True) + return get_json_result(data=True) + + return await asyncio.to_thread(_rm_sync) except Exception as e: return server_error_response(e) @@ -299,8 +310,8 @@ def _delete_folder_recursive(folder, tenant_id): @manager.route('/rename', methods=['POST']) # noqa: F821 @login_required @validate_request("file_id", "name") -def rename(): - req = request.json +async def rename(): + req = await get_request_json() try: e, file = FileService.get_by_id(req["file_id"]) if not e: @@ -338,7 +349,7 @@ def rename(): @manager.route('/get/', methods=['GET']) # noqa: F821 @login_required -def get(file_id): +async def get(file_id): try: e, file = FileService.get_by_id(file_id) if not e: @@ -346,12 +357,12 @@ def get(file_id): if not check_file_team_permission(file, current_user.id): return get_json_result(data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) - blob = settings.STORAGE_IMPL.get(file.parent_id, file.location) + blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, file.parent_id, file.location) if not blob: b, n = File2DocumentService.get_storage_address(file_id=file_id) - blob = settings.STORAGE_IMPL.get(b, n) + blob = await asyncio.to_thread(settings.STORAGE_IMPL.get, b, n) - response = flask.make_response(blob) + response = await make_response(blob) ext = re.search(r"\.([^.]+)$", file.name.lower()) ext = ext.group(1) if ext else None if ext: @@ -368,8 +379,8 @@ def get(file_id): @manager.route("/mv", methods=["POST"]) # noqa: F821 @login_required @validate_request("src_file_ids", "dest_file_id") -def move(): - req = request.json +async def move(): + req = await get_request_json() try: file_ids = req["src_file_ids"] dest_parent_id = req["dest_file_id"] @@ -444,10 +455,12 @@ def _move_entry_recursive(source_file_entry, dest_folder): }, ) - for file in files: - _move_entry_recursive(file, dest_folder) + def _move_sync(): + for file in files: + _move_entry_recursive(file, dest_folder) + return get_json_result(data=True) - return get_json_result(data=True) + return await asyncio.to_thread(_move_sync) except Exception as e: return server_error_response(e) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 7094c28d705..fff982563f9 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -16,12 +16,12 @@ import json import logging import random +import re +import asyncio -from flask import request -from flask_login import login_required, current_user +from quart import request import numpy as np - from api.db.services.connector_service import Connector2KbService from api.db.services.llm_service import LLMBundle from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks @@ -30,7 +30,8 @@ from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from api.db.services.task_service import TaskService, GRAPH_RAPTOR_FAKE_DOC_ID from api.db.services.user_service import TenantService, UserTenantService -from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters +from api.utils.api_utils import get_error_data_result, server_error_response, get_data_error_result, validate_request, not_allowed_parameters, \ + get_request_json from api.db import VALID_FILE_TYPES from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.db_models import File @@ -38,26 +39,31 @@ from rag.nlp import search from api.constants import DATASET_NAME_LIMIT from rag.utils.redis_conn import REDIS_CONN -from rag.utils.doc_store_conn import OrderByExpr from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD from common import settings +from common.doc_store.doc_store_base import OrderByExpr +from api.apps import login_required, current_user + @manager.route('/create', methods=['post']) # noqa: F821 @login_required @validate_request("name") -def create(): - req = request.json - req = KnowledgebaseService.create_with_name( +async def create(): + req = await get_request_json() + e, res = KnowledgebaseService.create_with_name( name = req.pop("name", None), tenant_id = current_user.id, parser_id = req.pop("parser_id", None), **req ) + if not e: + return res + try: - if not KnowledgebaseService.save(**req): + if not KnowledgebaseService.save(**res): return get_data_error_result() - return get_json_result(data={"kb_id":req["id"]}) + return get_json_result(data={"kb_id":res["id"]}) except Exception as e: return server_error_response(e) @@ -66,8 +72,8 @@ def create(): @login_required @validate_request("kb_id", "name", "description", "parser_id") @not_allowed_parameters("id", "tenant_id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") -def update(): - req = request.json +async def update(): + req = await get_request_json() if not isinstance(req["name"], str): return get_data_error_result(message="Dataset name must be string.") if req["name"].strip() == "": @@ -87,19 +93,32 @@ def update(): if not KnowledgebaseService.query( created_by=current_user.id, id=req["kb_id"]): return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', + data=False, message='Only owner of dataset authorized for this operation.', code=RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) + + # Rename folder in FileService + if e and req["name"].lower() != kb.name.lower(): + FileService.filter_update( + [ + File.tenant_id == kb.tenant_id, + File.source_type == FileSource.KNOWLEDGEBASE, + File.type == "folder", + File.name == kb.name, + ], + {"name": req["name"]}, + ) + if not e: return get_data_error_result( - message="Can't find this knowledgebase!") + message="Can't find this dataset!") if req["name"].lower() != kb.name.lower() \ and len( KnowledgebaseService.query(name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value)) >= 1: return get_data_error_result( - message="Duplicated knowledgebase name.") + message="Duplicated dataset name.") del req["kb_id"] connectors = [] @@ -111,12 +130,22 @@ def update(): if kb.pagerank != req.get("pagerank", 0): if req.get("pagerank", 0) > 0: - settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]}, - search.index_name(kb.tenant_id), kb.id) + await asyncio.to_thread( + settings.docStoreConn.update, + {"kb_id": kb.id}, + {PAGERANK_FLD: req["pagerank"]}, + search.index_name(kb.tenant_id), + kb.id, + ) else: # Elasticsearch requires PAGERANK_FLD be non-zero! - settings.docStoreConn.update({"exists": PAGERANK_FLD}, {"remove": PAGERANK_FLD}, - search.index_name(kb.tenant_id), kb.id) + await asyncio.to_thread( + settings.docStoreConn.update, + {"exists": PAGERANK_FLD}, + {"remove": PAGERANK_FLD}, + search.index_name(kb.tenant_id), + kb.id, + ) e, kb = KnowledgebaseService.get_by_id(kb.id) if not e: @@ -134,6 +163,21 @@ def update(): return server_error_response(e) +@manager.route('/update_metadata_setting', methods=['post']) # noqa: F821 +@login_required +@validate_request("kb_id", "metadata") +async def update_metadata_setting(): + req = await get_request_json() + e, kb = KnowledgebaseService.get_by_id(req["kb_id"]) + if not e: + return get_data_error_result( + message="Database error (Knowledgebase rename)!") + kb = kb.to_dict() + kb["parser_config"]["metadata"] = req["metadata"] + KnowledgebaseService.update_by_id(kb["id"], kb) + return get_json_result(data=kb) + + @manager.route('/detail', methods=['GET']) # noqa: F821 @login_required def detail(): @@ -146,12 +190,12 @@ def detail(): break else: return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', + data=False, message='Only owner of dataset authorized for this operation.', code=RetCode.OPERATING_ERROR) kb = KnowledgebaseService.get_detail(kb_id) if not kb: return get_data_error_result( - message="Can't find this knowledgebase!") + message="Can't find this dataset!") kb["size"] = DocumentService.get_total_size_by_kb_id(kb_id=kb["id"],keywords="", run_status=[], types=[]) kb["connectors"] = Connector2KbService.list_connectors(kb_id) @@ -165,18 +209,19 @@ def detail(): @manager.route('/list', methods=['POST']) # noqa: F821 @login_required -def list_kbs(): - keywords = request.args.get("keywords", "") - page_number = int(request.args.get("page", 0)) - items_per_page = int(request.args.get("page_size", 0)) - parser_id = request.args.get("parser_id") - orderby = request.args.get("orderby", "create_time") - if request.args.get("desc", "true").lower() == "false": +async def list_kbs(): + args = request.args + keywords = args.get("keywords", "") + page_number = int(args.get("page", 0)) + items_per_page = int(args.get("page_size", 0)) + parser_id = args.get("parser_id") + orderby = args.get("orderby", "create_time") + if args.get("desc", "true").lower() == "false": desc = False else: desc = True - req = request.get_json() + req = await get_request_json() owner_ids = req.get("owner_ids", []) try: if not owner_ids: @@ -198,11 +243,12 @@ def list_kbs(): except Exception as e: return server_error_response(e) + @manager.route('/rm', methods=['post']) # noqa: F821 @login_required @validate_request("kb_id") -def rm(): - req = request.json +async def rm(): + req = await get_request_json() if not KnowledgebaseService.accessible4deletion(req["kb_id"], current_user.id): return get_json_result( data=False, @@ -214,28 +260,37 @@ def rm(): created_by=current_user.id, id=req["kb_id"]) if not kbs: return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', + data=False, message='Only owner of dataset authorized for this operation.', code=RetCode.OPERATING_ERROR) - for doc in DocumentService.query(kb_id=req["kb_id"]): - if not DocumentService.remove_document(doc, kbs[0].tenant_id): + def _rm_sync(): + for doc in DocumentService.query(kb_id=req["kb_id"]): + if not DocumentService.remove_document(doc, kbs[0].tenant_id): + return get_data_error_result( + message="Database error (Document removal)!") + f2d = File2DocumentService.get_by_document_id(doc.id) + if f2d: + FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) + File2DocumentService.delete_by_document_id(doc.id) + FileService.filter_delete( + [ + File.tenant_id == kbs[0].tenant_id, + File.source_type == FileSource.KNOWLEDGEBASE, + File.type == "folder", + File.name == kbs[0].name, + ] + ) + if not KnowledgebaseService.delete_by_id(req["kb_id"]): return get_data_error_result( - message="Database error (Document removal)!") - f2d = File2DocumentService.get_by_document_id(doc.id) - if f2d: - FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id]) - File2DocumentService.delete_by_document_id(doc.id) - FileService.filter_delete( - [File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name]) - if not KnowledgebaseService.delete_by_id(req["kb_id"]): - return get_data_error_result( - message="Database error (Knowledgebase removal)!") - for kb in kbs: - settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) - settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id) - if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): - settings.STORAGE_IMPL.remove_bucket(kb.id) - return get_json_result(data=True) + message="Database error (Knowledgebase removal)!") + for kb in kbs: + settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id) + settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id) + if hasattr(settings.STORAGE_IMPL, 'remove_bucket'): + settings.STORAGE_IMPL.remove_bucket(kb.id) + return get_json_result(data=True) + + return await asyncio.to_thread(_rm_sync) except Exception as e: return server_error_response(e) @@ -278,8 +333,8 @@ def list_tags_from_kbs(): @manager.route('//rm_tags', methods=['POST']) # noqa: F821 @login_required -def rm_tags(kb_id): - req = request.json +async def rm_tags(kb_id): + req = await get_request_json() if not KnowledgebaseService.accessible(kb_id, current_user.id): return get_json_result( data=False, @@ -298,8 +353,8 @@ def rm_tags(kb_id): @manager.route('//rename_tag', methods=['POST']) # noqa: F821 @login_required -def rename_tags(kb_id): - req = request.json +async def rename_tags(kb_id): + req = await get_request_json() if not KnowledgebaseService.accessible(kb_id, current_user.id): return get_json_result( data=False, @@ -331,7 +386,7 @@ def knowledge_graph(kb_id): } obj = {"graph": {}, "mind_map": {}} - if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id): + if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id): return get_json_result(data=obj) sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id]) if not len(sres.ids): @@ -402,7 +457,7 @@ def get_basic_info(): @manager.route("/list_pipeline_logs", methods=["POST"]) # noqa: F821 @login_required -def list_pipeline_logs(): +async def list_pipeline_logs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) @@ -421,7 +476,7 @@ def list_pipeline_logs(): if create_date_to > create_date_from: return get_data_error_result(message="Create data filter is abnormal.") - req = request.get_json() + req = await get_request_json() operation_status = req.get("operation_status", []) if operation_status: @@ -446,7 +501,7 @@ def list_pipeline_logs(): @manager.route("/list_pipeline_dataset_logs", methods=["POST"]) # noqa: F821 @login_required -def list_pipeline_dataset_logs(): +async def list_pipeline_dataset_logs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) @@ -463,7 +518,7 @@ def list_pipeline_dataset_logs(): if create_date_to > create_date_from: return get_data_error_result(message="Create data filter is abnormal.") - req = request.get_json() + req = await get_request_json() operation_status = req.get("operation_status", []) if operation_status: @@ -480,12 +535,12 @@ def list_pipeline_dataset_logs(): @manager.route("/delete_pipeline_logs", methods=["POST"]) # noqa: F821 @login_required -def delete_pipeline_logs(): +async def delete_pipeline_logs(): kb_id = request.args.get("kb_id") if not kb_id: return get_json_result(data=False, message='Lack of "KB ID"', code=RetCode.ARGUMENT_ERROR) - req = request.get_json() + req = await get_request_json() log_ids = req.get("log_ids", []) PipelineOperationLogService.delete_by_ids(log_ids) @@ -509,8 +564,8 @@ def pipeline_log_detail(): @manager.route("/run_graphrag", methods=["POST"]) # noqa: F821 @login_required -def run_graphrag(): - req = request.json +async def run_graphrag(): + req = await get_request_json() kb_id = req.get("kb_id", "") if not kb_id: @@ -578,8 +633,8 @@ def trace_graphrag(): @manager.route("/run_raptor", methods=["POST"]) # noqa: F821 @login_required -def run_raptor(): - req = request.json +async def run_raptor(): + req = await get_request_json() kb_id = req.get("kb_id", "") if not kb_id: @@ -647,8 +702,8 @@ def trace_raptor(): @manager.route("/run_mindmap", methods=["POST"]) # noqa: F821 @login_required -def run_mindmap(): - req = request.json +async def run_mindmap(): + req = await get_request_json() kb_id = req.get("kb_id", "") if not kb_id: @@ -731,6 +786,8 @@ def delete_kb_task(): def cancel_task(task_id): REDIS_CONN.set(f"{task_id}-cancel", "x") + kb_task_id_field: str = "" + kb_task_finish_at: str = "" match pipeline_task_type: case PipelineTaskType.GRAPH_RAG: kb_task_id_field = "graphrag_task_id" @@ -761,7 +818,7 @@ def cancel_task(task_id): @manager.route("/check_embedding", methods=["post"]) # noqa: F821 @login_required -def check_embedding(): +async def check_embedding(): def _guess_vec_field(src: dict) -> str | None: for k in src or {}: @@ -801,30 +858,30 @@ def sample_random_chunks_with_vectors( index_nm = search.index_name(tenant_id) res0 = docStoreConn.search( - selectFields=[], highlightFields=[], + select_fields=[], highlight_fields=[], condition={"kb_id": kb_id, "available_int": 1}, - matchExprs=[], orderBy=OrderByExpr(), + match_expressions=[], order_by=OrderByExpr(), offset=0, limit=1, - indexNames=index_nm, knowledgebaseIds=[kb_id] + index_names=index_nm, knowledgebase_ids=[kb_id] ) - total = docStoreConn.getTotal(res0) + total = docStoreConn.get_total(res0) if total <= 0: return [] n = min(n, total) - offsets = sorted(random.sample(range(total), n)) + offsets = sorted(random.sample(range(min(total,1000)), n)) out = [] for off in offsets: res1 = docStoreConn.search( - selectFields=list(base_fields), - highlightFields=[], + select_fields=list(base_fields), + highlight_fields=[], condition={"kb_id": kb_id, "available_int": 1}, - matchExprs=[], orderBy=OrderByExpr(), + match_expressions=[], order_by=OrderByExpr(), offset=off, limit=1, - indexNames=index_nm, knowledgebaseIds=[kb_id] + index_names=index_nm, knowledgebase_ids=[kb_id] ) - ids = docStoreConn.getChunkIds(res1) + ids = docStoreConn.get_doc_ids(res1) if not ids: continue @@ -845,9 +902,14 @@ def sample_random_chunks_with_vectors( "position_int": full_doc.get("position_int"), "top_int": full_doc.get("top_int"), "content_with_weight": full_doc.get("content_with_weight") or "", + "question_kwd": full_doc.get("question_kwd") or [] }) return out - req = request.json + + def _clean(s: str) -> str: + s = re.sub(r"]{0,12})?>", " ", s or "") + return s if s else "None" + req = await get_request_json() kb_id = req.get("kb_id", "") embd_id = req.get("embd_id", "") n = int(req.get("check_num", 5)) @@ -859,8 +921,10 @@ def sample_random_chunks_with_vectors( results, eff_sims = [], [] for ck in samples: - txt = (ck.get("content_with_weight") or "").strip() - if not txt: + title = ck.get("doc_name") or "Title" + txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or "" + txt_in = _clean(txt_in) + if not txt_in: results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"}) continue @@ -869,10 +933,19 @@ def sample_random_chunks_with_vectors( continue try: - qv, _ = emb_mdl.encode_queries(txt) - sim = _cos_sim(qv, ck["vector"]) - except Exception: - return get_error_data_result(message="embedding failure") + v, _ = emb_mdl.encode([title, txt_in]) + assert len(v[1]) == len(ck["vector"]), f"The dimension ({len(v[1])}) of given embedding model is different from the original ({len(ck['vector'])})" + sim_content = _cos_sim(v[1], ck["vector"]) + title_w = 0.1 + qv_mix = title_w * v[0] + (1 - title_w) * v[1] + sim_mix = _cos_sim(qv_mix, ck["vector"]) + sim = sim_content + mode = "content_only" + if sim_mix > sim: + sim = sim_mix + mode = "title+content" + except Exception as e: + return get_error_data_result(message=f"Embedding failure. {e}") eff_sims.append(sim) results.append({ @@ -892,9 +965,8 @@ def sample_random_chunks_with_vectors( "avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6), "min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6), "max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6), + "match_mode": mode, } - if summary["avg_cos_sim"] > 0.99: + if summary["avg_cos_sim"] > 0.9: return get_json_result(data={"summary": summary, "results": results}) - return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results}) - - + return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results}) diff --git a/api/apps/langfuse_app.py b/api/apps/langfuse_app.py index 151c40fcd40..1d7993d365c 100644 --- a/api/apps/langfuse_app.py +++ b/api/apps/langfuse_app.py @@ -15,28 +15,28 @@ # -from flask import request -from flask_login import current_user, login_required +from api.apps import current_user, login_required from langfuse import Langfuse from api.db.db_models import DB from api.db.services.langfuse_service import TenantLangfuseService -from api.utils.api_utils import get_error_data_result, get_json_result, server_error_response, validate_request +from api.utils.api_utils import get_error_data_result, get_json_result, get_request_json, server_error_response, validate_request @manager.route("/api_key", methods=["POST", "PUT"]) # noqa: F821 @login_required @validate_request("secret_key", "public_key", "host") -def set_api_key(): - req = request.get_json() +async def set_api_key(): + req = await get_request_json() secret_key = req.get("secret_key", "") public_key = req.get("public_key", "") host = req.get("host", "") if not all([secret_key, public_key, host]): return get_error_data_result(message="Missing required fields") + current_user_id = current_user.id langfuse_keys = dict( - tenant_id=current_user.id, + tenant_id=current_user_id, secret_key=secret_key, public_key=public_key, host=host, @@ -46,23 +46,24 @@ def set_api_key(): if not langfuse.auth_check(): return get_error_data_result(message="Invalid Langfuse keys") - langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) + langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id) with DB.atomic(): try: if not langfuse_entry: TenantLangfuseService.save(**langfuse_keys) else: - TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys) + TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys) return get_json_result(data=langfuse_keys) except Exception as e: - server_error_response(e) + return server_error_response(e) @manager.route("/api_key", methods=["GET"]) # noqa: F821 @login_required @validate_request() def get_api_key(): - langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id) + current_user_id = current_user.id + langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id) if not langfuse_entry: return get_json_result(message="Have not record any Langfuse keys.") @@ -73,7 +74,7 @@ def get_api_key(): except langfuse.api.core.api_error.ApiError as api_err: return get_json_result(message=f"Error from Langfuse: {api_err}") except Exception as e: - server_error_response(e) + return server_error_response(e) langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"] langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"] @@ -85,7 +86,8 @@ def get_api_key(): @login_required @validate_request() def delete_api_key(): - langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id) + current_user_id = current_user.id + langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id) if not langfuse_entry: return get_json_result(message="Have not record any Langfuse keys.") @@ -94,4 +96,4 @@ def delete_api_key(): TenantLangfuseService.delete_model(langfuse_entry) return get_json_result(data=True) except Exception as e: - server_error_response(e) + return server_error_response(e) diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index c34d71cc06a..9a68e825606 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -16,16 +16,16 @@ import logging import json import os -from flask import request -from flask_login import login_required, current_user +from quart import request + +from api.apps import login_required, current_user from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.llm_service import LLMService -from api.utils.api_utils import server_error_response, get_data_error_result, validate_request +from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.constants import StatusEnum, LLMType from api.db.db_models import TenantLLM -from api.utils.api_utils import get_json_result, get_allowed_llm_factories from rag.utils.base64_image import test_image -from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel +from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel @manager.route("/factories", methods=["GET"]) # noqa: F821 @@ -43,7 +43,13 @@ def factories(): mdl_types[m.fid] = set([]) mdl_types[m.fid].add(m.model_type) for f in fac: - f["model_types"] = list(mdl_types.get(f["name"], [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS])) + f["model_types"] = list( + mdl_types.get( + f["name"], + [LLMType.CHAT, LLMType.EMBEDDING, LLMType.RERANK, LLMType.IMAGE2TEXT, LLMType.SPEECH2TEXT, LLMType.TTS, LLMType.OCR], + ) + ) + return get_json_result(data=fac) except Exception as e: return server_error_response(e) @@ -52,8 +58,8 @@ def factories(): @manager.route("/set_api_key", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "api_key") -def set_api_key(): - req = request.json +async def set_api_key(): + req = await get_request_json() # test if api key works chat_passed, embd_passed, rerank_passed = False, False, False factory = req["llm_factory"] @@ -74,7 +80,7 @@ def set_api_key(): assert factory in ChatModel, f"Chat model from {factory} is not supported yet." mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) + m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) if m.find("**ERROR**") >= 0: raise Exception(m) chat_passed = True @@ -122,8 +128,8 @@ def set_api_key(): @manager.route("/add_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory") -def add_llm(): - req = request.json +async def add_llm(): + req = await get_request_json() factory = req["llm_factory"] api_key = req.get("api_key", "x") llm_name = req.get("llm_name") @@ -142,16 +148,16 @@ def apikey_json(keys): elif factory == "Tencent Hunyuan": req["api_key"] = apikey_json(["hunyuan_sid", "hunyuan_sk"]) - return set_api_key() + return await set_api_key() elif factory == "Tencent Cloud": req["api_key"] = apikey_json(["tencent_cloud_sid", "tencent_cloud_sk"]) - return set_api_key() + return await set_api_key() elif factory == "Bedrock": # For Bedrock, due to its special authentication method # Assemble bedrock_ak, bedrock_sk, bedrock_region - api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"]) + api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"]) elif factory == "LocalAI": llm_name += "___LocalAI" @@ -186,6 +192,9 @@ def apikey_json(keys): elif factory == "OpenRouter": api_key = apikey_json(["api_key", "provider_order"]) + elif factory == "MinerU": + api_key = apikey_json(["api_key", "provider_order"]) + llm = { "tenant_id": current_user.id, "llm_factory": factory, @@ -199,61 +208,83 @@ def apikey_json(keys): msg = "" mdl_nm = llm["llm_name"].split("___")[0] extra = {"provider": factory} - if llm["model_type"] == LLMType.EMBEDDING.value: - assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." - mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) - try: - arr, tc = mdl.encode(["Test if the api key is available"]) - if len(arr[0]) == 0: - raise Exception("Fail") - except Exception as e: - msg += f"\nFail to access embedding model({mdl_nm})." + str(e) - elif llm["model_type"] == LLMType.CHAT.value: - assert factory in ChatModel, f"Chat model from {factory} is not supported yet." - mdl = ChatModel[factory]( - key=llm["api_key"], - model_name=mdl_nm, - base_url=llm["api_base"], - **extra, - ) - try: - m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) - if not tc and m.find("**ERROR**:") >= 0: - raise Exception(m) - except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - elif llm["model_type"] == LLMType.RERANK: - assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." - try: - mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) - arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) - if len(arr) == 0: - raise Exception("Not known.") - except KeyError: - msg += f"{factory} dose not support this model({factory}/{mdl_nm})" - except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - elif llm["model_type"] == LLMType.IMAGE2TEXT.value: - assert factory in CvModel, f"Image to text model from {factory} is not supported yet." - mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) - try: - image_data = test_image - m, tc = mdl.describe(image_data) - if not tc and m.find("**ERROR**:") >= 0: - raise Exception(m) - except Exception as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - elif llm["model_type"] == LLMType.TTS: - assert factory in TTSModel, f"TTS model from {factory} is not supported yet." - mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"]) - try: - for resp in mdl.tts("Hello~ RAGFlower!"): - pass - except RuntimeError as e: - msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - else: - # TODO: check other type of models - pass + model_type = llm["model_type"] + model_api_key = llm["api_key"] + model_base_url = llm.get("api_base", "") + match model_type: + case LLMType.EMBEDDING.value: + assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." + mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + try: + arr, tc = mdl.encode(["Test if the api key is available"]) + if len(arr[0]) == 0: + raise Exception("Fail") + except Exception as e: + msg += f"\nFail to access embedding model({mdl_nm})." + str(e) + case LLMType.CHAT.value: + assert factory in ChatModel, f"Chat model from {factory} is not supported yet." + mdl = ChatModel[factory]( + key=model_api_key, + model_name=mdl_nm, + base_url=model_base_url, + **extra, + ) + try: + m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], + {"temperature": 0.9}) + if not tc and m.find("**ERROR**:") >= 0: + raise Exception(m) + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + + case LLMType.RERANK.value: + assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." + try: + mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) + if len(arr) == 0: + raise Exception("Not known.") + except KeyError: + msg += f"{factory} dose not support this model({factory}/{mdl_nm})" + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + + case LLMType.IMAGE2TEXT.value: + assert factory in CvModel, f"Image to text model from {factory} is not supported yet." + mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + try: + image_data = test_image + m, tc = mdl.describe(image_data) + if not tc and m.find("**ERROR**:") >= 0: + raise Exception(m) + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + case LLMType.TTS.value: + assert factory in TTSModel, f"TTS model from {factory} is not supported yet." + mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + try: + for resp in mdl.tts("Hello~ RAGFlower!"): + pass + except RuntimeError as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + case LLMType.OCR.value: + assert factory in OcrModel, f"OCR model from {factory} is not supported yet." + try: + mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + ok, reason = mdl.check_available() + if not ok: + raise RuntimeError(reason or "Model not available") + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + case LLMType.SPEECH2TEXT: + assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet." + try: + mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) + # TODO: check the availability + except Exception as e: + msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) + case _: + raise RuntimeError(f"Unknown model type: {model_type}") if msg: return get_data_error_result(message=msg) @@ -267,8 +298,8 @@ def apikey_json(keys): @manager.route("/delete_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "llm_name") -def delete_llm(): - req = request.json +async def delete_llm(): + req = await get_request_json() TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]]) return get_json_result(data=True) @@ -276,8 +307,8 @@ def delete_llm(): @manager.route("/enable_llm", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory", "llm_name") -def enable_llm(): - req = request.json +async def enable_llm(): + req = await get_request_json() TenantLLMService.filter_update( [TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"], TenantLLM.llm_name == req["llm_name"]], {"status": str(req.get("status", "1"))} ) @@ -287,8 +318,8 @@ def enable_llm(): @manager.route("/delete_factory", methods=["POST"]) # noqa: F821 @login_required @validate_request("llm_factory") -def delete_factory(): - req = request.json +async def delete_factory(): + req = await get_request_json() TenantLLMService.filter_delete([TenantLLM.tenant_id == current_user.id, TenantLLM.llm_factory == req["llm_factory"]]) return get_json_result(data=True) @@ -297,6 +328,7 @@ def delete_factory(): @login_required def my_llms(): try: + TenantLLMService.ensure_mineru_from_env(current_user.id) include_details = request.args.get("include_details", "false").lower() == "true" if include_details: @@ -344,6 +376,7 @@ def list_app(): weighted = [] model_type = request.args.get("model_type") try: + TenantLLMService.ensure_mineru_from_env(current_user.id) objs = TenantLLMService.query(tenant_id=current_user.id) facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key and o.status == StatusEnum.VALID.value]) status = {(o.llm_name + "@" + o.llm_factory) for o in objs if o.status == StatusEnum.VALID.value} diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index 66d4474915e..62ae2e3c06b 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from flask import Response, request -from flask_login import current_user, login_required +import asyncio + +from quart import Response, request +from api.apps import current_user, login_required from api.db.db_models import MCPServer from api.db.services.mcp_server_service import MCPServerService @@ -22,15 +24,14 @@ from common.constants import RetCode, VALID_MCP_SERVER_TYPES from common.misc_utils import get_uuid -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ - get_mcp_tools +from api.utils.api_utils import get_data_error_result, get_json_result, get_mcp_tools, get_request_json, server_error_response, validate_request from api.utils.web_utils import get_float, safe_json_parse -from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions @manager.route("/list", methods=["POST"]) # noqa: F821 @login_required -def list_mcp() -> Response: +async def list_mcp() -> Response: keywords = request.args.get("keywords", "") page_number = int(request.args.get("page", 0)) items_per_page = int(request.args.get("page_size", 0)) @@ -40,7 +41,7 @@ def list_mcp() -> Response: else: desc = True - req = request.get_json() + req = await get_request_json() mcp_ids = req.get("mcp_ids", []) try: servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or [] @@ -72,8 +73,8 @@ def detail() -> Response: @manager.route("/create", methods=["POST"]) # noqa: F821 @login_required @validate_request("name", "url", "server_type") -def create() -> Response: - req = request.get_json() +async def create() -> Response: + req = await get_request_json() server_type = req.get("server_type", "") if server_type not in VALID_MCP_SERVER_TYPES: @@ -107,7 +108,7 @@ def create() -> Response: return get_data_error_result(message="Tenant not found.") mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) - server_tools, err_message = get_mcp_tools([mcp_server], timeout) + server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout) if err_message: return get_data_error_result(err_message) @@ -127,8 +128,8 @@ def create() -> Response: @manager.route("/update", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_id") -def update() -> Response: - req = request.get_json() +async def update() -> Response: + req = await get_request_json() mcp_id = req.get("mcp_id", "") e, mcp_server = MCPServerService.get_by_id(mcp_id) @@ -159,7 +160,7 @@ def update() -> Response: req["id"] = mcp_id mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) - server_tools, err_message = get_mcp_tools([mcp_server], timeout) + server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout) if err_message: return get_data_error_result(err_message) @@ -183,8 +184,8 @@ def update() -> Response: @manager.route("/rm", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_ids") -def rm() -> Response: - req = request.get_json() +async def rm() -> Response: + req = await get_request_json() mcp_ids = req.get("mcp_ids", []) try: @@ -201,8 +202,8 @@ def rm() -> Response: @manager.route("/import", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcpServers") -def import_multiple() -> Response: - req = request.get_json() +async def import_multiple() -> Response: + req = await get_request_json() servers = req.get("mcpServers", {}) if not servers: return get_data_error_result(message="No MCP servers provided.") @@ -243,7 +244,7 @@ def import_multiple() -> Response: headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {} variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}} mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers) - server_tools, err_message = get_mcp_tools([mcp_server], timeout) + server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout) if err_message: results.append({"server": base_name, "success": False, "message": err_message}) continue @@ -268,8 +269,8 @@ def import_multiple() -> Response: @manager.route("/export", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_ids") -def export_multiple() -> Response: - req = request.get_json() +async def export_multiple() -> Response: + req = await get_request_json() mcp_ids = req.get("mcp_ids", []) if not mcp_ids: @@ -300,8 +301,8 @@ def export_multiple() -> Response: @manager.route("/list_tools", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_ids") -def list_tools() -> Response: - req = request.get_json() +async def list_tools() -> Response: + req = await get_request_json() mcp_ids = req.get("mcp_ids", []) if not mcp_ids: return get_data_error_result(message="No MCP server IDs provided.") @@ -323,9 +324,8 @@ def list_tools() -> Response: tool_call_sessions.append(tool_call_session) try: - tools = tool_call_session.get_tools(timeout) + tools = await asyncio.to_thread(tool_call_session.get_tools, timeout) except Exception as e: - tools = [] return get_data_error_result(message=f"MCP list tools error: {e}") results[server_key] = [] @@ -341,14 +341,14 @@ def list_tools() -> Response: return server_error_response(e) finally: # PERF: blocking call to close sessions — consider moving to background thread or task queue - close_multiple_mcp_toolcall_sessions(tool_call_sessions) + await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions) @manager.route("/test_tool", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_id", "tool_name", "arguments") -def test_tool() -> Response: - req = request.get_json() +async def test_tool() -> Response: + req = await get_request_json() mcp_id = req.get("mcp_id", "") if not mcp_id: return get_data_error_result(message="No MCP server ID provided.") @@ -368,10 +368,10 @@ def test_tool() -> Response: tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) tool_call_sessions.append(tool_call_session) - result = tool_call_session.tool_call(tool_name, arguments, timeout) + result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout) # PERF: blocking call to close sessions — consider moving to background thread or task queue - close_multiple_mcp_toolcall_sessions(tool_call_sessions) + await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions) return get_json_result(data=result) except Exception as e: return server_error_response(e) @@ -380,8 +380,8 @@ def test_tool() -> Response: @manager.route("/cache_tools", methods=["POST"]) # noqa: F821 @login_required @validate_request("mcp_id", "tools") -def cache_tool() -> Response: - req = request.get_json() +async def cache_tool() -> Response: + req = await get_request_json() mcp_id = req.get("mcp_id", "") if not mcp_id: return get_data_error_result(message="No MCP server ID provided.") @@ -403,8 +403,8 @@ def cache_tool() -> Response: @manager.route("/test_mcp", methods=["POST"]) # noqa: F821 @validate_request("url", "server_type") -def test_mcp() -> Response: - req = request.get_json() +async def test_mcp() -> Response: + req = await get_request_json() url = req.get("url", "") if not url: @@ -425,13 +425,12 @@ def test_mcp() -> Response: tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables) try: - tools = tool_call_session.get_tools(timeout) + tools = await asyncio.to_thread(tool_call_session.get_tools, timeout) except Exception as e: - tools = [] return get_data_error_result(message=f"Test MCP error: {e}") finally: # PERF: blocking call to close sessions — consider moving to background thread or task queue - close_multiple_mcp_toolcall_sessions([tool_call_session]) + await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session]) for tool in tools: tool_dict = tool.model_dump() diff --git a/api/apps/memories_app.py b/api/apps/memories_app.py new file mode 100644 index 00000000000..66fcabb4c99 --- /dev/null +++ b/api/apps/memories_app.py @@ -0,0 +1,228 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +from quart import request +from api.apps import login_required, current_user +from api.db import TenantPermission +from api.db.services.memory_service import MemoryService +from api.db.services.user_service import UserTenantService +from api.db.services.canvas_service import UserCanvasService +from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default +from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result +from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human +from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT +from memory.services.messages import MessageService +from memory.utils.prompt_util import PromptAssembler +from common.constants import MemoryType, RetCode, ForgettingPolicy + + +@manager.route("", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("name", "memory_type", "embd_id", "llm_id") +async def create_memory(): + req = await get_request_json() + # check name length + name = req["name"] + memory_name = name.strip() + if len(memory_name) == 0: + return get_error_argument_result("Memory name cannot be empty or whitespace.") + if len(memory_name) > MEMORY_NAME_LIMIT: + return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") + # check memory_type valid + memory_type = set(req["memory_type"]) + invalid_type = memory_type - {e.name.lower() for e in MemoryType} + if invalid_type: + return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") + memory_type = list(memory_type) + + try: + res, memory = MemoryService.create_memory( + tenant_id=current_user.id, + name=memory_name, + memory_type=memory_type, + embd_id=req["embd_id"], + llm_id=req["llm_id"] + ) + + if res: + return get_json_result(message=True, data=format_ret_data_from_memory(memory)) + else: + return get_json_result(message=memory, code=RetCode.SERVER_ERROR) + + except Exception as e: + return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) + + +@manager.route("/", methods=["PUT"]) # noqa: F821 +@login_required +async def update_memory(memory_id): + req = await get_request_json() + update_dict = {} + # check name length + if "name" in req: + name = req["name"] + memory_name = name.strip() + if len(memory_name) == 0: + return get_error_argument_result("Memory name cannot be empty or whitespace.") + if len(memory_name) > MEMORY_NAME_LIMIT: + return get_error_argument_result(f"Memory name '{memory_name}' exceeds limit of {MEMORY_NAME_LIMIT}.") + update_dict["name"] = memory_name + # check permissions valid + if req.get("permissions"): + if req["permissions"] not in [e.value for e in TenantPermission]: + return get_error_argument_result(f"Unknown permission '{req['permissions']}'.") + update_dict["permissions"] = req["permissions"] + if req.get("llm_id"): + update_dict["llm_id"] = req["llm_id"] + if req.get("embd_id"): + update_dict["embd_id"] = req["embd_id"] + if req.get("memory_type"): + memory_type = set(req["memory_type"]) + invalid_type = memory_type - {e.name.lower() for e in MemoryType} + if invalid_type: + return get_error_argument_result(f"Memory type '{invalid_type}' is not supported.") + update_dict["memory_type"] = list(memory_type) + # check memory_size valid + if req.get("memory_size"): + if not 0 < int(req["memory_size"]) <= MEMORY_SIZE_LIMIT: + return get_error_argument_result(f"Memory size should be in range (0, {MEMORY_SIZE_LIMIT}] Bytes.") + update_dict["memory_size"] = req["memory_size"] + # check forgetting_policy valid + if req.get("forgetting_policy"): + if req["forgetting_policy"] not in [e.value for e in ForgettingPolicy]: + return get_error_argument_result(f"Forgetting policy '{req['forgetting_policy']}' is not supported.") + update_dict["forgetting_policy"] = req["forgetting_policy"] + # check temperature valid + if "temperature" in req: + temperature = float(req["temperature"]) + if not 0 <= temperature <= 1: + return get_error_argument_result("Temperature should be in range [0, 1].") + update_dict["temperature"] = temperature + # allow update to empty fields + for field in ["avatar", "description", "system_prompt", "user_prompt"]: + if field in req: + update_dict[field] = req[field] + current_memory = MemoryService.get_by_memory_id(memory_id) + if not current_memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + + memory_dict = current_memory.to_dict() + memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) + to_update = {} + for k, v in update_dict.items(): + if isinstance(v, list) and set(memory_dict[k]) != set(v): + to_update[k] = v + elif memory_dict[k] != v: + to_update[k] = v + + if not to_update: + return get_json_result(message=True, data=memory_dict) + # check memory empty when update embd_id, memory_type + memory_size = get_memory_size_cache(memory_id, current_memory.tenant_id) + not_allowed_update = [f for f in ["embd_id", "memory_type"] if f in to_update and memory_size > 0] + if not_allowed_update: + return get_error_argument_result(f"Can't update {not_allowed_update} when memory isn't empty.") + if "memory_type" in to_update: + if "system_prompt" not in to_update and judge_system_prompt_is_default(current_memory.system_prompt, current_memory.memory_type): + # update old default prompt, assemble a new one + to_update["system_prompt"] = PromptAssembler.assemble_system_prompt({"memory_type": to_update["memory_type"]}) + + try: + MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update) + updated_memory = MemoryService.get_by_memory_id(memory_id) + return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory)) + + except Exception as e: + logging.error(e) + return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) + + +@manager.route("/", methods=["DELETE"]) # noqa: F821 +@login_required +async def delete_memory(memory_id): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(message=True, code=RetCode.NOT_FOUND) + try: + MemoryService.delete_memory(memory_id) + if MessageService.has_index(memory.tenant_id, memory_id): + MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id) + return get_json_result(message=True) + except Exception as e: + logging.error(e) + return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) + + +@manager.route("", methods=["GET"]) # noqa: F821 +@login_required +async def list_memory(): + args = request.args + try: + tenant_ids = args.getlist("tenant_id") + memory_types = args.getlist("memory_type") + storage_type = args.get("storage_type") + keywords = args.get("keywords", "") + page = int(args.get("page", 1)) + page_size = int(args.get("page_size", 50)) + # make filter dict + filter_dict = {"memory_type": memory_types, "storage_type": storage_type} + if not tenant_ids: + # restrict to current user's tenants + user_tenants = UserTenantService.get_user_tenant_relation_by_user_id(current_user.id) + filter_dict["tenant_id"] = [tenant["tenant_id"] for tenant in user_tenants] + else: + filter_dict["tenant_id"] = tenant_ids + + memory_list, count = MemoryService.get_by_filter(filter_dict, keywords, page, page_size) + [memory.update({"memory_type": get_memory_type_human(memory["memory_type"])}) for memory in memory_list] + return get_json_result(message=True, data={"memory_list": memory_list, "total_count": count}) + + except Exception as e: + logging.error(e) + return get_json_result(message=str(e), code=RetCode.SERVER_ERROR) + + +@manager.route("//config", methods=["GET"]) # noqa: F821 +@login_required +async def get_memory_config(memory_id): + memory = MemoryService.get_with_owner_name_by_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + return get_json_result(message=True, data=format_ret_data_from_memory(memory)) + + +@manager.route("/", methods=["GET"]) # noqa: F821 +@login_required +async def get_memory_detail(memory_id): + args = request.args + agent_ids = args.getlist("agent_id") + keywords = args.get("keywords", "") + keywords = keywords.strip() + page = int(args.get("page", 1)) + page_size = int(args.get("page_size", 50)) + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + messages = MessageService.list_message( + memory.tenant_id, memory_id, agent_ids, keywords, page, page_size) + agent_name_mapping = {} + if messages["message_list"]: + agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]]) + agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list} + for message in messages["message_list"]: + message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown") + return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True) diff --git a/api/apps/messages_app.py b/api/apps/messages_app.py new file mode 100644 index 00000000000..2963baefa4a --- /dev/null +++ b/api/apps/messages_app.py @@ -0,0 +1,168 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from quart import request +from api.apps import login_required +from api.db.services.memory_service import MemoryService +from common.time_utils import current_timestamp, timestamp_to_date + +from memory.services.messages import MessageService +from api.db.joint_services import memory_message_service +from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result +from common.constants import RetCode + + +@manager.route("", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response") +async def add_message(): + + req = await get_request_json() + memory_ids = req["memory_id"] + agent_id = req["agent_id"] + session_id = req["session_id"] + user_id = req["user_id"] if req.get("user_id") else "" + user_input = req["user_input"] + agent_response = req["agent_response"] + + res = [] + for memory_id in memory_ids: + success, msg = await memory_message_service.save_to_memory( + memory_id, + { + "user_id": user_id, + "agent_id": agent_id, + "session_id": session_id, + "user_input": user_input, + "agent_response": agent_response + } + ) + res.append({ + "memory_id": memory_id, + "success": success, + "message": msg + }) + + if all([r["success"] for r in res]): + return get_json_result(message="Successfully added to memories.") + + return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res) + + +@manager.route("/:", methods=["DELETE"]) # noqa: F821 +@login_required +async def forget_message(memory_id: str, message_id: int): + + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + + forget_time = timestamp_to_date(current_timestamp()) + update_succeed = MessageService.update_message( + {"memory_id": memory_id, "message_id": int(message_id)}, + {"forget_at": forget_time}, + memory.tenant_id, memory_id) + if update_succeed: + return get_json_result(message=update_succeed) + else: + return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.") + + +@manager.route("/:", methods=["PUT"]) # noqa: F821 +@login_required +@validate_request("status") +async def update_message(memory_id: str, message_id: int): + req = await get_request_json() + status = req["status"] + if not isinstance(status, bool): + return get_error_argument_result("Status must be a boolean.") + + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + + update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id) + if update_succeed: + return get_json_result(message=update_succeed) + else: + return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.") + + +@manager.route("/search", methods=["GET"]) # noqa: F821 +@login_required +async def search_message(): + args = request.args + print(args, flush=True) + empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)] + if empty_fields: + return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.") + + memory_ids = args.getlist("memory_id") + query = args.get("query") + similarity_threshold = float(args.get("similarity_threshold", 0.2)) + keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7)) + top_n = int(args.get("top_n", 5)) + agent_id = args.get("agent_id", "") + session_id = args.get("session_id", "") + + filter_dict = { + "memory_id": memory_ids, + "agent_id": agent_id, + "session_id": session_id + } + params = { + "query": query, + "similarity_threshold": similarity_threshold, + "keywords_similarity_weight": keywords_similarity_weight, + "top_n": top_n + } + res = memory_message_service.query_message(filter_dict, params) + return get_json_result(message=True, data=res) + + +@manager.route("", methods=["GET"]) # noqa: F821 +@login_required +async def get_messages(): + args = request.args + memory_ids = args.getlist("memory_id") + agent_id = args.get("agent_id", "") + session_id = args.get("session_id", "") + limit = int(args.get("limit", 10)) + if not memory_ids: + return get_error_argument_result("memory_ids is required.") + memory_list = MemoryService.get_by_ids(memory_ids) + uids = [memory.tenant_id for memory in memory_list] + res = MessageService.get_recent_messages( + uids, + memory_ids, + agent_id, + session_id, + limit + ) + return get_json_result(message=True, data=res) + + +@manager.route("/:/content", methods=["GET"]) # noqa: F821 +@login_required +async def get_message_content(memory_id:str, message_id: int): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.") + + res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id) + if res: + return get_json_result(message=True, data=res) + else: + return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.") diff --git a/api/apps/plugin_app.py b/api/apps/plugin_app.py index 9ca04416db0..6e7a8769018 100644 --- a/api/apps/plugin_app.py +++ b/api/apps/plugin_app.py @@ -15,8 +15,8 @@ # -from flask import Response -from flask_login import login_required +from quart import Response +from api.apps import login_required from api.utils.api_utils import get_json_result from plugin import GlobalPluginManager diff --git a/api/apps/sdk/agents.py b/api/apps/sdk/agents.py index 208b7a1bef7..e6a68786992 100644 --- a/api/apps/sdk/agents.py +++ b/api/apps/sdk/agents.py @@ -14,20 +14,29 @@ # limitations under the License. # +import asyncio +import base64 +import hashlib +import hmac +import ipaddress import json import logging import time from typing import Any, cast +import jwt + from agent.canvas import Canvas from api.db import CanvasCategory from api.db.services.canvas_service import UserCanvasService +from api.db.services.file_service import FileService from api.db.services.user_canvas_version import UserCanvasVersionService from common.constants import RetCode from common.misc_utils import get_uuid -from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required +from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required from api.utils.api_utils import get_result -from flask import request, Response +from quart import request, Response +from rag.utils.redis_conn import REDIS_CONN @manager.route('/agents', methods=['GET']) # noqa: F821 @@ -41,19 +50,19 @@ def list_agents(tenant_id): return get_error_data_result("The agent doesn't exist.") page_number = int(request.args.get("page", 1)) items_per_page = int(request.args.get("page_size", 30)) - orderby = request.args.get("orderby", "update_time") + order_by = request.args.get("orderby", "update_time") if request.args.get("desc") == "False" or request.args.get("desc") == "false": desc = False else: desc = True - canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title) + canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title) return get_result(data=canvas) @manager.route("/agents", methods=["POST"]) # noqa: F821 @token_required -def create_agent(tenant_id: str): - req: dict[str, Any] = cast(dict[str, Any], request.json) +async def create_agent(tenant_id: str): + req: dict[str, Any] = cast(dict[str, Any], await get_request_json()) req["user_id"] = tenant_id if req.get("dsl") is not None: @@ -89,8 +98,8 @@ def create_agent(tenant_id: str): @manager.route("/agents/", methods=["PUT"]) # noqa: F821 @token_required -def update_agent(tenant_id: str, agent_id: str): - req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], request.json).items() if v is not None} +async def update_agent(tenant_id: str, agent_id: str): + req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None} req["user_id"] = tenant_id if req.get("dsl") is not None: @@ -132,48 +141,785 @@ def delete_agent(tenant_id: str, agent_id: str): UserCanvasService.delete_by_id(agent_id) return get_json_result(data=True) +@manager.route("/webhook/", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821 +@manager.route("/webhook_test/",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821 +async def webhook(agent_id: str): + is_test = request.path.startswith("/api/v1/webhook_test") + start_ts = time.time() -@manager.route('/webhook/', methods=['POST']) # noqa: F821 -@token_required -def webhook(tenant_id: str, agent_id: str): - req = request.json - if not UserCanvasService.accessible(req["id"], tenant_id): - return get_json_result( - data=False, message='Only owner of canvas authorized for this operation.', - code=RetCode.OPERATING_ERROR) + # 1. Fetch canvas by agent_id + exists, cvs = UserCanvasService.get_by_id(agent_id) + if not exists: + return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST + + # 2. Check canvas category + if cvs.canvas_category == CanvasCategory.DataFlow: + return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST + + # 3. Load DSL from canvas + dsl = getattr(cvs, "dsl", None) + if not isinstance(dsl, dict): + return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST + + # 4. Check webhook configuration in DSL + components = dsl.get("components", {}) + for k, _ in components.items(): + cpn_obj = components[k]["obj"] + if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook": + webhook_cfg = cpn_obj["params"] + + if not webhook_cfg: + return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST + + # 5. Validate request method against webhook_cfg.methods + allowed_methods = webhook_cfg.get("methods", []) + request_method = request.method.upper() + if allowed_methods and request_method not in allowed_methods: + return get_data_error_result( + code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook." + ),RetCode.BAD_REQUEST + + # 6. Validate webhook security + async def validate_webhook_security(security_cfg: dict): + """Validate webhook security rules based on security configuration.""" + + if not security_cfg: + return # No security config → allowed by default + + # 1. Validate max body size + await _validate_max_body_size(security_cfg) + + # 2. Validate IP whitelist + _validate_ip_whitelist(security_cfg) + + # # 3. Validate rate limiting + _validate_rate_limit(security_cfg) + + # 4. Validate authentication + auth_type = security_cfg.get("auth_type", "none") + + if auth_type == "none": + return + + if auth_type == "token": + _validate_token_auth(security_cfg) + + elif auth_type == "basic": + _validate_basic_auth(security_cfg) + + elif auth_type == "jwt": + _validate_jwt_auth(security_cfg) + + else: + raise Exception(f"Unsupported auth_type: {auth_type}") + + async def _validate_max_body_size(security_cfg): + """Check request size does not exceed max_body_size.""" + max_size = security_cfg.get("max_body_size") + if not max_size: + return + + # Convert "10MB" → bytes + units = {"kb": 1024, "mb": 1024**2} + size_str = max_size.lower() + + for suffix, factor in units.items(): + if size_str.endswith(suffix): + limit = int(size_str.replace(suffix, "")) * factor + break + else: + raise Exception("Invalid max_body_size format") + MAX_LIMIT = 10 * 1024 * 1024 # 10MB + if limit > MAX_LIMIT: + raise Exception("max_body_size exceeds maximum allowed size (10MB)") + + content_length = request.content_length or 0 + if content_length > limit: + raise Exception(f"Request body too large: {content_length} > {limit}") + + def _validate_ip_whitelist(security_cfg): + """Allow only IPs listed in ip_whitelist.""" + whitelist = security_cfg.get("ip_whitelist", []) + if not whitelist: + return + + client_ip = request.remote_addr + + + for rule in whitelist: + if "/" in rule: + # CIDR notation + if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False): + return + else: + # Single IP + if client_ip == rule: + return + + raise Exception(f"IP {client_ip} is not allowed by whitelist") + + def _validate_rate_limit(security_cfg): + """Simple in-memory rate limiting.""" + rl = security_cfg.get("rate_limit") + if not rl: + return + + limit = int(rl.get("limit", 60)) + if limit <= 0: + raise Exception("rate_limit.limit must be > 0") + per = rl.get("per", "minute") + + window = { + "second": 1, + "minute": 60, + "hour": 3600, + "day": 86400, + }.get(per) + + if not window: + raise Exception(f"Invalid rate_limit.per: {per}") + + capacity = limit + rate = limit / window + cost = 1 + + key = f"rl:tb:{agent_id}" + now = time.time() + + try: + res = REDIS_CONN.lua_token_bucket( + keys=[key], + args=[capacity, rate, now, cost], + client=REDIS_CONN.REDIS, + ) + + allowed = int(res[0]) + if allowed != 1: + raise Exception("Too many requests (rate limit exceeded)") + + except Exception as e: + raise Exception(f"Rate limit error: {e}") + + def _validate_token_auth(security_cfg): + """Validate header-based token authentication.""" + token_cfg = security_cfg.get("token",{}) + header = token_cfg.get("token_header") + token_value = token_cfg.get("token_value") + + provided = request.headers.get(header) + if provided != token_value: + raise Exception("Invalid token authentication") + + def _validate_basic_auth(security_cfg): + """Validate HTTP Basic Auth credentials.""" + auth_cfg = security_cfg.get("basic_auth", {}) + username = auth_cfg.get("username") + password = auth_cfg.get("password") + + auth = request.authorization + if not auth or auth.username != username or auth.password != password: + raise Exception("Invalid Basic Auth credentials") + + def _validate_jwt_auth(security_cfg): + """Validate JWT token in Authorization header.""" + jwt_cfg = security_cfg.get("jwt", {}) + secret = jwt_cfg.get("secret") + if not secret: + raise Exception("JWT secret not configured") + + auth_header = request.headers.get("Authorization", "") + if not auth_header.startswith("Bearer "): + raise Exception("Missing Bearer token") + + token = auth_header[len("Bearer "):].strip() + if not token: + raise Exception("Empty Bearer token") + + alg = (jwt_cfg.get("algorithm") or "HS256").upper() + + decode_kwargs = { + "key": secret, + "algorithms": [alg], + } + options = {} + if jwt_cfg.get("audience"): + decode_kwargs["audience"] = jwt_cfg["audience"] + options["verify_aud"] = True + else: + options["verify_aud"] = False + + if jwt_cfg.get("issuer"): + decode_kwargs["issuer"] = jwt_cfg["issuer"] + options["verify_iss"] = True + else: + options["verify_iss"] = False + try: + decoded = jwt.decode( + token, + options=options, + **decode_kwargs, + ) + except Exception as e: + raise Exception(f"Invalid JWT: {str(e)}") + + raw_required_claims = jwt_cfg.get("required_claims", []) + if isinstance(raw_required_claims, str): + required_claims = [raw_required_claims] + elif isinstance(raw_required_claims, (list, tuple, set)): + required_claims = list(raw_required_claims) + else: + required_claims = [] + + required_claims = [ + c for c in required_claims + if isinstance(c, str) and c.strip() + ] - e, cvs = UserCanvasService.get_by_id(req["id"]) - if not e: - return get_data_error_result(message="canvas not found.") + RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"} + for claim in required_claims: + if claim in RESERVED_CLAIMS: + raise Exception(f"Reserved JWT claim cannot be required: {claim}") + for claim in required_claims: + if claim not in decoded: + raise Exception(f"Missing JWT claim: {claim}") + + return decoded + + try: + security_config=webhook_cfg.get("security", {}) + await validate_webhook_security(security_config) + except Exception as e: + return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST if not isinstance(cvs.dsl, str): - cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) + dsl = json.dumps(cvs.dsl, ensure_ascii=False) + try: + canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id) + except Exception as e: + resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)) + resp.status_code = RetCode.BAD_REQUEST + return resp - if cvs.canvas_category == CanvasCategory.DataFlow: - return get_data_error_result(message="Dataflow can not be triggered by webhook.") + # 7. Parse request body + async def parse_webhook_request(content_type): + """Parse request based on content-type and return structured data.""" + + # 1. Query + query_data = {k: v for k, v in request.args.items()} + + # 2. Headers + header_data = {k: v for k, v in request.headers.items()} + # 3. Body + ctype = request.headers.get("Content-Type", "").split(";")[0].strip() + if ctype and ctype != content_type: + raise ValueError( + f"Invalid Content-Type: expect '{content_type}', got '{ctype}'" + ) + + body_data: dict = {} + + try: + if ctype == "application/json": + body_data = await request.get_json() or {} + + elif ctype == "multipart/form-data": + nonlocal canvas + form = await request.form + files = await request.files + + body_data = {} + + for key, value in form.items(): + body_data[key] = value + + if len(files) > 10: + raise Exception("Too many uploaded files") + for key, file in files.items(): + desc = FileService.upload_info( + cvs.user_id, # user + file, # FileStorage + None # url (None for webhook) + ) + file_parsed= await canvas.get_files_async([desc]) + body_data[key] = file_parsed + + elif ctype == "application/x-www-form-urlencoded": + form = await request.form + body_data = dict(form) + + else: + # text/plain / octet-stream / empty / unknown + raw = await request.get_data() + if raw: + try: + body_data = json.loads(raw.decode("utf-8")) + except Exception: + body_data = {} + else: + body_data = {} + + except Exception: + body_data = {} + + return { + "query": query_data, + "headers": header_data, + "body": body_data, + "content_type": ctype, + } + + def extract_by_schema(data, schema, name="section"): + """ + Extract only fields defined in schema. + Required fields must exist. + Optional fields default to type-based default values. + Type validation included. + """ + props = schema.get("properties", {}) + required = schema.get("required", []) + + extracted = {} + + for field, field_schema in props.items(): + field_type = field_schema.get("type") + + # 1. Required field missing + if field in required and field not in data: + raise Exception(f"{name} missing required field: {field}") + + # 2. Optional → default value + if field not in data: + extracted[field] = default_for_type(field_type) + continue + + raw_value = data[field] + + # 3. Auto convert value + try: + value = auto_cast_value(raw_value, field_type) + except Exception as e: + raise Exception(f"{name}.{field} auto-cast failed: {str(e)}") + + # 4. Type validation + if not validate_type(value, field_type): + raise Exception( + f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}" + ) + + extracted[field] = value + + return extracted + + + def default_for_type(t): + """Return default value for the given schema type.""" + if t == "file": + return [] + if t == "object": + return {} + if t == "boolean": + return False + if t == "number": + return 0 + if t == "string": + return "" + if t and t.startswith("array"): + return [] + if t == "null": + return None + return None + + def auto_cast_value(value, expected_type): + """Convert string values into schema type when possible.""" + + # Non-string values already good + if not isinstance(value, str): + return value + + v = value.strip() + + # Boolean + if expected_type == "boolean": + if v.lower() in ["true", "1"]: + return True + if v.lower() in ["false", "0"]: + return False + raise Exception(f"Cannot convert '{value}' to boolean") + + # Number + if expected_type == "number": + # integer + if v.isdigit() or (v.startswith("-") and v[1:].isdigit()): + return int(v) + + # float + try: + return float(v) + except Exception: + raise Exception(f"Cannot convert '{value}' to number") + + # Object + if expected_type == "object": + try: + parsed = json.loads(v) + if isinstance(parsed, dict): + return parsed + else: + raise Exception("JSON is not an object") + except Exception: + raise Exception(f"Cannot convert '{value}' to object") + + # Array + if expected_type.startswith("array"): + try: + parsed = json.loads(v) + if isinstance(parsed, list): + return parsed + else: + raise Exception("JSON is not an array") + except Exception: + raise Exception(f"Cannot convert '{value}' to array") + + # String (accept original) + if expected_type == "string": + return value + + # File + if expected_type == "file": + return value + # Default: do nothing + return value + + + def validate_type(value, t): + """Validate value type against schema type t.""" + if t == "file": + return isinstance(value, list) + + if t == "string": + return isinstance(value, str) + + if t == "number": + return isinstance(value, (int, float)) + + if t == "boolean": + return isinstance(value, bool) + + if t == "object": + return isinstance(value, dict) + + # array / array / array + if t.startswith("array"): + if not isinstance(value, list): + return False + + if "<" in t and ">" in t: + inner = t[t.find("<") + 1 : t.find(">")] + + # Check each element type + for item in value: + if not validate_type(item, inner): + return False + + return True + + return True + parsed = await parse_webhook_request(webhook_cfg.get("content_types")) + SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}}) + + # Extract strictly by schema try: - canvas = Canvas(cvs.dsl, tenant_id, agent_id) + query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query") + header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers") + body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body") except Exception as e: - return get_json_result( - data=False, message=str(e), - code=RetCode.EXCEPTION_ERROR) + return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST + + clean_request = { + "query": query_clean, + "headers": header_clean, + "body": body_clean, + "input": parsed + } + + execution_mode = webhook_cfg.get("execution_mode", "Immediately") + response_cfg = webhook_cfg.get("response", {}) + + def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600): + key = f"webhook-trace-{agent_id}-logs" + + raw = REDIS_CONN.get(key) + obj = json.loads(raw) if raw else {"webhooks": {}} + + ws = obj["webhooks"].setdefault( + str(start_ts), + {"start_ts": start_ts, "events": []} + ) + + ws["events"].append({ + "ts": time.time(), + **event + }) + + REDIS_CONN.set_obj(key, obj, ttl) - def sse(): - nonlocal canvas + if execution_mode == "Immediately": + status = response_cfg.get("status", 200) try: - for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req): - yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" + status = int(status) + except (TypeError, ValueError): + return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST - cvs.dsl = json.loads(str(canvas)) - UserCanvasService.update_by_id(req["id"], cvs.to_dict()) - except Exception as e: - logging.exception(e) - yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n" - - resp = Response(sse(), mimetype="text/event-stream") - resp.headers.add_header("Cache-control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") - return resp + if not (200 <= status <= 399): + return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST + + body_tpl = response_cfg.get("body_template", "") + + def parse_body(body: str): + if not body: + return None, "application/json" + + try: + parsed = json.loads(body) + return parsed, "application/json" + except (json.JSONDecodeError, TypeError): + return body, "text/plain" + + + body, content_type = parse_body(body_tpl) + resp = Response( + json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body, + status=status, + content_type=content_type, + ) + + async def background_run(): + try: + async for ans in canvas.run( + query="", + user_id=cvs.user_id, + webhook_payload=clean_request + ): + if is_test: + append_webhook_trace(agent_id, start_ts, ans) + + if is_test: + append_webhook_trace( + agent_id, + start_ts, + { + "event": "finished", + "elapsed_time": time.time() - start_ts, + "success": True, + } + ) + + cvs.dsl = json.loads(str(canvas)) + UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict()) + + except Exception as e: + logging.exception("Webhook background run failed") + if is_test: + try: + append_webhook_trace( + agent_id, + start_ts, + { + "event": "error", + "message": str(e), + "error_type": type(e).__name__, + } + ) + append_webhook_trace( + agent_id, + start_ts, + { + "event": "finished", + "elapsed_time": time.time() - start_ts, + "success": False, + } + ) + except Exception: + logging.exception("Failed to append webhook trace") + + asyncio.create_task(background_run()) + return resp + else: + async def sse(): + nonlocal canvas + contents: list[str] = [] + status = 200 + try: + async for ans in canvas.run( + query="", + user_id=cvs.user_id, + webhook_payload=clean_request, + ): + if ans["event"] == "message": + content = ans["data"]["content"] + if ans["data"].get("start_to_think", False): + content = "" + elif ans["data"].get("end_to_think", False): + content = "" + if content: + contents.append(content) + if ans["event"] == "message_end": + status = int(ans["data"].get("status", status)) + if is_test: + append_webhook_trace( + agent_id, + start_ts, + ans + ) + if is_test: + append_webhook_trace( + agent_id, + start_ts, + { + "event": "finished", + "elapsed_time": time.time() - start_ts, + "success": True, + } + ) + final_content = "".join(contents) + return { + "message": final_content, + "success": True, + "code": status, + } + + except Exception as e: + if is_test: + append_webhook_trace( + agent_id, + start_ts, + { + "event": "error", + "message": str(e), + "error_type": type(e).__name__, + } + ) + append_webhook_trace( + agent_id, + start_ts, + { + "event": "finished", + "elapsed_time": time.time() - start_ts, + "success": False, + } + ) + return {"code": 400, "message": str(e),"success":False} + + result = await sse() + return Response( + json.dumps(result), + status=result["code"], + mimetype="application/json", + ) + + +@manager.route("/webhook_trace/", methods=["GET"]) # noqa: F821 +async def webhook_trace(agent_id: str): + def encode_webhook_id(start_ts: str) -> str: + WEBHOOK_ID_SECRET = "webhook_id_secret" + sig = hmac.new( + WEBHOOK_ID_SECRET.encode("utf-8"), + start_ts.encode("utf-8"), + hashlib.sha256, + ).digest() + return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=") + + def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None: + for ts in webhooks.keys(): + if encode_webhook_id(ts) == enc_id: + return ts + return None + since_ts = request.args.get("since_ts", type=float) + webhook_id = request.args.get("webhook_id") + + key = f"webhook-trace-{agent_id}-logs" + raw = REDIS_CONN.get(key) + + if since_ts is None: + now = time.time() + return get_json_result( + data={ + "webhook_id": None, + "events": [], + "next_since_ts": now, + "finished": False, + } + ) + + if not raw: + return get_json_result( + data={ + "webhook_id": None, + "events": [], + "next_since_ts": since_ts, + "finished": False, + } + ) + + obj = json.loads(raw) + webhooks = obj.get("webhooks", {}) + + if webhook_id is None: + candidates = [ + float(k) for k in webhooks.keys() if float(k) > since_ts + ] + + if not candidates: + return get_json_result( + data={ + "webhook_id": None, + "events": [], + "next_since_ts": since_ts, + "finished": False, + } + ) + + start_ts = min(candidates) + real_id = str(start_ts) + webhook_id = encode_webhook_id(real_id) + + return get_json_result( + data={ + "webhook_id": webhook_id, + "events": [], + "next_since_ts": start_ts, + "finished": False, + } + ) + + real_id = decode_webhook_id(webhook_id, webhooks) + + if not real_id: + return get_json_result( + data={ + "webhook_id": webhook_id, + "events": [], + "next_since_ts": since_ts, + "finished": True, + } + ) + + ws = webhooks.get(str(real_id)) + events = ws.get("events", []) + new_events = [e for e in events if e.get("ts", 0) > since_ts] + + next_ts = since_ts + for e in new_events: + next_ts = max(next_ts, e["ts"]) + + finished = any(e.get("event") == "finished" for e in new_events) + + return get_json_result( + data={ + "webhook_id": webhook_id, + "events": new_events, + "next_since_ts": next_ts, + "finished": finished, + } + ) diff --git a/api/apps/sdk/chat.py b/api/apps/sdk/chat.py index a3f03b4484f..1efb628f1bc 100644 --- a/api/apps/sdk/chat.py +++ b/api/apps/sdk/chat.py @@ -14,22 +14,20 @@ # limitations under the License. # import logging - -from flask import request - +from quart import request from api.db.services.dialog_service import DialogService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService from common.misc_utils import get_uuid from common.constants import RetCode, StatusEnum -from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required +from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required, get_request_json @manager.route("/chats", methods=["POST"]) # noqa: F821 @token_required -def create(tenant_id): - req = request.json +async def create(tenant_id): + req = await get_request_json() ids = [i for i in req.get("dataset_ids", []) if i] for kb_id in ids: kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id) @@ -94,7 +92,7 @@ def create(tenant_id): req["tenant_id"] = tenant_id # prompt more parameter default_prompt = { - "system": """You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the knowledge base!" Answers need to consider chat history. + "system": """You are an intelligent assistant. Please summarize the content of the dataset to answer the question. Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" Answers need to consider chat history. Here is the knowledge base: {knowledge} The above is the knowledge base.""", @@ -145,10 +143,10 @@ def create(tenant_id): @manager.route("/chats/", methods=["PUT"]) # noqa: F821 @token_required -def update(tenant_id, chat_id): +async def update(tenant_id, chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(message="You do not own the chat") - req = request.json + req = await get_request_json() ids = req.get("dataset_ids", []) if "show_quotation" in req: req["do_refer"] = req.pop("show_quotation") @@ -176,7 +174,9 @@ def update(tenant_id, chat_id): req["llm_id"] = llm.pop("model_name") if req.get("llm_id") is not None: llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(req["llm_id"]) - if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type="chat"): + model_type = llm.pop("model_type") + model_type = model_type if model_type in ["chat", "image2text"] else "chat" + if not TenantLLMService.query(tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, model_type=model_type): return get_error_data_result(f"`model_name` {req.get('llm_id')} doesn't exist") req["llm_setting"] = req.pop("llm") e, tenant = TenantService.get_by_id(tenant_id) @@ -228,10 +228,10 @@ def update(tenant_id, chat_id): @manager.route("/chats", methods=["DELETE"]) # noqa: F821 @token_required -def delete(tenant_id): +async def delete_chats(tenant_id): errors = [] success_count = 0 - req = request.json + req = await get_request_json() if not req: ids = None else: @@ -251,8 +251,7 @@ def delete(tenant_id): errors.append(f"Assistant({id}) not found.") continue temp_dict = {"status": StatusEnum.INVALID.value} - DialogService.update_by_id(id, temp_dict) - success_count += 1 + success_count += DialogService.update_by_id(id, temp_dict) if errors: if success_count > 0: @@ -288,7 +287,7 @@ def list_chat(tenant_id): chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name) if not chats: return get_result(data=[]) - list_assts = [] + list_assistants = [] key_mapping = { "parameters": "variables", "prologue": "opener", @@ -322,5 +321,5 @@ def list_chat(tenant_id): del res["kb_ids"] res["datasets"] = kb_list res["avatar"] = res.pop("icon") - list_assts.append(res) - return get_result(data=list_assts) + list_assistants.append(res) + return get_result(data=list_assistants) diff --git a/api/apps/sdk/dataset.py b/api/apps/sdk/dataset.py index 8a315ce69d1..7d52c3fec50 100644 --- a/api/apps/sdk/dataset.py +++ b/api/apps/sdk/dataset.py @@ -18,13 +18,14 @@ import logging import os import json -from flask import request +from quart import request from peewee import OperationalError from api.db.db_models import File -from api.db.services.document_service import DocumentService +from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService from api.db.services.user_service import TenantService from common.constants import RetCode, FileSource, StatusEnum from api.utils.api_utils import ( @@ -53,7 +54,7 @@ @manager.route("/datasets", methods=["POST"]) # noqa: F821 @token_required -def create(tenant_id): +async def create(tenant_id): """ Create a new dataset. --- @@ -115,17 +116,19 @@ def create(tenant_id): # | embedding_model| embd_id | # | chunk_method | parser_id | - req, err = validate_and_parse_json_request(request, CreateDatasetReq) + req, err = await validate_and_parse_json_request(request, CreateDatasetReq) if err is not None: return get_error_argument_result(err) - - req = KnowledgebaseService.create_with_name( + e, req = KnowledgebaseService.create_with_name( name = req.pop("name", None), tenant_id = tenant_id, parser_id = req.pop("parser_id", None), **req ) + if not e: + return req + # Insert embedding model(embd id) ok, t = TenantService.get_by_id(tenant_id) if not ok: @@ -144,7 +147,6 @@ def create(tenant_id): ok, k = KnowledgebaseService.get_by_id(req["id"]) if not ok: return get_error_data_result(message="Dataset created failed") - response_data = remap_dictionary_keys(k.to_dict()) return get_result(data=response_data) except Exception as e: @@ -153,7 +155,7 @@ def create(tenant_id): @manager.route("/datasets", methods=["DELETE"]) # noqa: F821 @token_required -def delete(tenant_id): +async def delete(tenant_id): """ Delete datasets. --- @@ -191,7 +193,7 @@ def delete(tenant_id): schema: type: object """ - req, err = validate_and_parse_json_request(request, DeleteDatasetReq) + req, err = await validate_and_parse_json_request(request, DeleteDatasetReq) if err is not None: return get_error_argument_result(err) @@ -251,7 +253,7 @@ def delete(tenant_id): @manager.route("/datasets/", methods=["PUT"]) # noqa: F821 @token_required -def update(tenant_id, dataset_id): +async def update(tenant_id, dataset_id): """ Update a dataset. --- @@ -317,7 +319,7 @@ def update(tenant_id, dataset_id): # | embedding_model| embd_id | # | chunk_method | parser_id | extras = {"dataset_id": dataset_id} - req, err = validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) + req, err = await validate_and_parse_json_request(request, UpdateDatasetReq, extras=extras, exclude_unset=True) if err is not None: return get_error_argument_result(err) @@ -493,7 +495,7 @@ def knowledge_graph(tenant_id, dataset_id): } obj = {"graph": {}, "mind_map": {}} - if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id): + if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id): return get_result(data=obj) sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id]) if not len(sres.ids): @@ -532,3 +534,157 @@ def delete_knowledge_graph(tenant_id, dataset_id): search.index_name(kb.tenant_id), dataset_id) return get_result(data=True) + + +@manager.route("/datasets//run_graphrag", methods=["POST"]) # noqa: F821 +@token_required +def run_graphrag(tenant_id,dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result( + data=False, + message='No authorization.', + code=RetCode.AUTHENTICATION_ERROR + ) + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.graphrag_task_id + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}") + + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.") + + documents, _ = DocumentService.get_by_kb_id( + kb_id=dataset_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return get_error_data_result(message=f"No documents in Dataset {dataset_id}") + + sample_document = documents[0] + document_ids = [document["id"] for document in documents] + + task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}): + logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}") + + return get_result(data={"graphrag_task_id": task_id}) + + +@manager.route("/datasets//trace_graphrag", methods=["GET"]) # noqa: F821 +@token_required +def trace_graphrag(tenant_id,dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result( + data=False, + message='No authorization.', + code=RetCode.AUTHENTICATION_ERROR + ) + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.graphrag_task_id + if not task_id: + return get_result(data={}) + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return get_result(data={}) + + return get_result(data=task.to_dict()) + + +@manager.route("/datasets//run_raptor", methods=["POST"]) # noqa: F821 +@token_required +def run_raptor(tenant_id,dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result( + data=False, + message='No authorization.', + code=RetCode.AUTHENTICATION_ERROR + ) + + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.raptor_task_id + if task_id: + ok, task = TaskService.get_by_id(task_id) + if not ok: + logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}") + + if task and task.progress not in [-1, 1]: + return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.") + + documents, _ = DocumentService.get_by_kb_id( + kb_id=dataset_id, + page_number=0, + items_per_page=0, + orderby="create_time", + desc=False, + keywords="", + run_status=[], + types=[], + suffix=[], + ) + if not documents: + return get_error_data_result(message=f"No documents in Dataset {dataset_id}") + + sample_document = documents[0] + document_ids = [document["id"] for document in documents] + + task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids)) + + if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}): + logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}") + + return get_result(data={"raptor_task_id": task_id}) + + +@manager.route("/datasets//trace_raptor", methods=["GET"]) # noqa: F821 +@token_required +def trace_raptor(tenant_id,dataset_id): + if not dataset_id: + return get_error_data_result(message='Lack of "Dataset ID"') + + if not KnowledgebaseService.accessible(dataset_id, tenant_id): + return get_result( + data=False, + message='No authorization.', + code=RetCode.AUTHENTICATION_ERROR + ) + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return get_error_data_result(message="Invalid Dataset ID") + + task_id = kb.raptor_task_id + if not task_id: + return get_result(data={}) + + ok, task = TaskService.get_by_id(task_id) + if not ok: + return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred") + + return get_result(data=task.to_dict()) \ No newline at end of file diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index d2c3485a940..7a11688ddcb 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -15,21 +15,21 @@ # import logging -from flask import request, jsonify +from quart import jsonify from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle -from api.utils.api_utils import validate_request, build_error_result, apikey_required +from common.metadata_utils import meta_filter, convert_conditions +from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request from rag.app.tag import label_question -from api.db.services.dialog_service import meta_filter, convert_conditions from common.constants import RetCode, LLMType from common import settings @manager.route('/dify/retrieval', methods=['POST']) # noqa: F821 @apikey_required @validate_request("knowledge_id", "query") -def retrieval(tenant_id): +async def retrieval(tenant_id): """ Dify-compatible retrieval API --- @@ -113,14 +113,14 @@ def retrieval(tenant_id): 404: description: Knowledge base or document not found """ - req = request.json + req = await get_request_json() question = req["query"] kb_id = req["knowledge_id"] use_kg = req.get("use_kg", False) retrieval_setting = req.get("retrieval_setting", {}) similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0)) top = int(retrieval_setting.get("top_k", 1024)) - metadata_condition = req.get("metadata_condition", {}) + metadata_condition = req.get("metadata_condition", {}) or {} metas = DocumentService.get_meta_by_kbs([kb_id]) doc_ids = [] @@ -131,12 +131,10 @@ def retrieval(tenant_id): return build_error_result(message="Knowledgebase not found!", code=RetCode.NOT_FOUND) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) - print(metadata_condition) - # print("after", convert_conditions(metadata_condition)) - doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition))) - # print("doc_ids", doc_ids) - if not doc_ids and metadata_condition is not None: - doc_ids = ['-999'] + if metadata_condition: + doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) + if not doc_ids and metadata_condition: + doc_ids = ["-999"] ranks = settings.retriever.retrieval( question, embd_mdl, diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 4caf2cc8dfb..bef03d38ec4 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -14,13 +14,14 @@ # limitations under the License. # import datetime +import json import logging import pathlib import re from io import BytesIO import xxhash -from flask import request, send_file +from quart import request, send_file from peewee import OperationalError from pydantic import BaseModel, Field, validator @@ -33,9 +34,10 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService -from api.db.services.task_service import TaskService, queue_tasks -from api.db.services.dialog_service import meta_filter, convert_conditions -from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required +from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of +from common.metadata_utils import meta_filter, convert_conditions +from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \ + get_request_json from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search @@ -69,7 +71,7 @@ def validate_positions(cls, value): @manager.route("/datasets//documents", methods=["POST"]) # noqa: F821 @token_required -def upload(dataset_id, tenant_id): +async def upload(dataset_id, tenant_id): """ Upload documents to a dataset. --- @@ -93,6 +95,10 @@ def upload(dataset_id, tenant_id): type: file required: true description: Document files to upload. + - in: formData + name: parent_path + type: string + description: Optional nested path under the parent folder. Uses '/' separators. responses: 200: description: Successfully uploaded documents. @@ -126,9 +132,11 @@ def upload(dataset_id, tenant_id): type: string description: Processing status. """ - if "file" not in request.files: + form = await request.form + files = await request.files + if "file" not in files: return get_error_data_result(message="No file part!", code=RetCode.ARGUMENT_ERROR) - file_objs = request.files.getlist("file") + file_objs = files.getlist("file") for file_obj in file_objs: if file_obj.filename == "": return get_result(message="No file selected!", code=RetCode.ARGUMENT_ERROR) @@ -151,7 +159,7 @@ def upload(dataset_id, tenant_id): e, kb = KnowledgebaseService.get_by_id(dataset_id) if not e: raise LookupError(f"Can't find the dataset with ID {dataset_id}!") - err, files = FileService.upload_document(kb, file_objs, tenant_id) + err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=form.get("parent_path")) if err: return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR) # rename key's name @@ -175,7 +183,7 @@ def upload(dataset_id, tenant_id): @manager.route("/datasets//documents/", methods=["PUT"]) # noqa: F821 @token_required -def update_doc(tenant_id, dataset_id, document_id): +async def update_doc(tenant_id, dataset_id, document_id): """ Update a document within a dataset. --- @@ -224,12 +232,12 @@ def update_doc(tenant_id, dataset_id, document_id): schema: type: object """ - req = request.json + req = await get_request_json() if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): return get_error_data_result(message="You don't own the dataset.") e, kb = KnowledgebaseService.get_by_id(dataset_id) if not e: - return get_error_data_result(message="Can't find this knowledgebase!") + return get_error_data_result(message="Can't find this dataset!") doc = DocumentService.query(kb_id=dataset_id, id=document_id) if not doc: return get_error_data_result(message="The dataset doesn't own the document.") @@ -314,9 +322,7 @@ def update_doc(tenant_id, dataset_id, document_id): try: if not DocumentService.update_by_id(doc.id, {"status": str(status)}): return get_error_data_result(message="Database error (Document update)!") - settings.docStoreConn.update({"doc_id": doc.id}, {"available_int": status}, search.index_name(kb.tenant_id), doc.kb_id) - return get_result(data=True) except Exception as e: return server_error_response(e) @@ -343,19 +349,17 @@ def update_doc(tenant_id, dataset_id, document_id): } renamed_doc = {} for key, value in doc.to_dict().items(): - if key == "run": - renamed_doc["run"] = run_mapping.get(str(value)) new_key = key_mapping.get(key, key) renamed_doc[new_key] = value if key == "run": - renamed_doc["run"] = run_mapping.get(value) + renamed_doc["run"] = run_mapping.get(str(value)) return get_result(data=renamed_doc) @manager.route("/datasets//documents/", methods=["GET"]) # noqa: F821 @token_required -def download(tenant_id, dataset_id, document_id): +async def download(tenant_id, dataset_id, document_id): """ Download a document from a dataset. --- @@ -405,10 +409,10 @@ def download(tenant_id, dataset_id, document_id): return construct_json_result(message="This file is empty.", code=RetCode.DATA_ERROR) file = BytesIO(file_stream) # Use send_file with a proper filename and MIME type - return send_file( + return await send_file( file, as_attachment=True, - download_name=doc[0].name, + attachment_filename=doc[0].name, mimetype="application/octet-stream", # Set a default MIME type ) @@ -529,7 +533,7 @@ def list_docs(dataset_id, tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") q = request.args - document_id = q.get("id") + document_id = q.get("id") name = q.get("name") if document_id and not DocumentService.query(id=document_id, kb_id=dataset_id): @@ -538,23 +542,39 @@ def list_docs(dataset_id, tenant_id): return get_error_data_result(message=f"You don't own the document {name}.") page = int(q.get("page", 1)) - page_size = int(q.get("page_size", 30)) + page_size = int(q.get("page_size", 30)) orderby = q.get("orderby", "create_time") desc = str(q.get("desc", "true")).strip().lower() != "false" keywords = q.get("keywords", "") # filters - align with OpenAPI parameter names - suffix = q.getlist("suffix") - run_status = q.getlist("run") - create_time_from = int(q.get("create_time_from", 0)) - create_time_to = int(q.get("create_time_to", 0)) + suffix = q.getlist("suffix") + run_status = q.getlist("run") + create_time_from = int(q.get("create_time_from", 0)) + create_time_to = int(q.get("create_time_to", 0)) + metadata_condition_raw = q.get("metadata_condition") + metadata_condition = {} + if metadata_condition_raw: + try: + metadata_condition = json.loads(metadata_condition_raw) + except Exception: + return get_error_data_result(message="metadata_condition must be valid JSON.") + if metadata_condition and not isinstance(metadata_condition, dict): + return get_error_data_result(message="metadata_condition must be an object.") - # map run status (accept text or numeric) - align with API parameter + # map run status (text or numeric) - align with API parameter run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"} run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status] + doc_ids_filter = None + if metadata_condition: + metas = DocumentService.get_flatted_meta_by_kbs([dataset_id]) + doc_ids_filter = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")) + if metadata_condition.get("conditions") and not doc_ids_filter: + return get_result(data={"total": 0, "docs": []}) + docs, total = DocumentService.get_list( - dataset_id, page, page_size, orderby, desc, keywords, document_id, name, suffix, run_status_converted + dataset_id, page, page_size, orderby, desc, keywords, document_id, name, suffix, run_status_converted, doc_ids_filter ) # time range filter (0 means no bound) @@ -568,7 +588,7 @@ def list_docs(dataset_id, tenant_id): # rename keys + map run status back to text for output key_mapping = { "chunk_num": "chunk_count", - "kb_id": "dataset_id", + "kb_id": "dataset_id", "token_num": "token_count", "parser_id": "chunk_method", } @@ -583,9 +603,73 @@ def list_docs(dataset_id, tenant_id): return get_result(data={"total": total, "docs": output_docs}) + +@manager.route("/datasets//metadata/summary", methods=["GET"]) # noqa: F821 +@token_required +def metadata_summary(dataset_id, tenant_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") + + try: + summary = DocumentService.get_metadata_summary(dataset_id) + return get_result(data={"summary": summary}) + except Exception as e: + return server_error_response(e) + + +@manager.route("/datasets//metadata/update", methods=["POST"]) # noqa: F821 +@token_required +async def metadata_batch_update(dataset_id, tenant_id): + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") + + req = await get_request_json() + selector = req.get("selector", {}) or {} + updates = req.get("updates", []) or [] + deletes = req.get("deletes", []) or [] + + if not isinstance(selector, dict): + return get_error_data_result(message="selector must be an object.") + if not isinstance(updates, list) or not isinstance(deletes, list): + return get_error_data_result(message="updates and deletes must be lists.") + + metadata_condition = selector.get("metadata_condition", {}) or {} + if metadata_condition and not isinstance(metadata_condition, dict): + return get_error_data_result(message="metadata_condition must be an object.") + + document_ids = selector.get("document_ids", []) or [] + if document_ids and not isinstance(document_ids, list): + return get_error_data_result(message="document_ids must be a list.") + + for upd in updates: + if not isinstance(upd, dict) or not upd.get("key") or "value" not in upd: + return get_error_data_result(message="Each update requires key and value.") + for d in deletes: + if not isinstance(d, dict) or not d.get("key"): + return get_error_data_result(message="Each delete requires key.") + + kb_doc_ids = KnowledgebaseService.list_documents_by_ids([dataset_id]) + target_doc_ids = set(kb_doc_ids) + if document_ids: + invalid_ids = set(document_ids) - set(kb_doc_ids) + if invalid_ids: + return get_error_data_result(message=f"These documents do not belong to dataset {dataset_id}: {', '.join(invalid_ids)}") + target_doc_ids = set(document_ids) + + if metadata_condition: + metas = DocumentService.get_flatted_meta_by_kbs([dataset_id]) + filtered_ids = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))) + target_doc_ids = target_doc_ids & filtered_ids + if metadata_condition.get("conditions") and not target_doc_ids: + return get_result(data={"updated": 0, "matched_docs": 0}) + + target_doc_ids = list(target_doc_ids) + updated = DocumentService.batch_update_metadata(dataset_id, target_doc_ids, updates, deletes) + return get_result(data={"updated": updated, "matched_docs": len(target_doc_ids)}) + @manager.route("/datasets//documents", methods=["DELETE"]) # noqa: F821 @token_required -def delete(tenant_id, dataset_id): +async def delete(tenant_id, dataset_id): """ Delete documents from a dataset. --- @@ -624,7 +708,7 @@ def delete(tenant_id, dataset_id): """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}. ") - req = request.json + req = await get_request_json() if not req: doc_ids = None else: @@ -695,7 +779,7 @@ def delete(tenant_id, dataset_id): @manager.route("/datasets//chunks", methods=["POST"]) # noqa: F821 @token_required -def parse(tenant_id, dataset_id): +async def parse(tenant_id, dataset_id): """ Start parsing documents into chunks. --- @@ -734,7 +818,7 @@ def parse(tenant_id, dataset_id): """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - req = request.json + req = await get_request_json() if not req.get("document_ids"): return get_error_data_result("`document_ids` is required") doc_list = req.get("document_ids") @@ -778,7 +862,7 @@ def parse(tenant_id, dataset_id): @manager.route("/datasets//chunks", methods=["DELETE"]) # noqa: F821 @token_required -def stop_parsing(tenant_id, dataset_id): +async def stop_parsing(tenant_id, dataset_id): """ Stop parsing documents into chunks. --- @@ -817,7 +901,7 @@ def stop_parsing(tenant_id, dataset_id): """ if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") - req = request.json + req = await get_request_json() if not req.get("document_ids"): return get_error_data_result("`document_ids` is required") @@ -832,6 +916,8 @@ def stop_parsing(tenant_id, dataset_id): return get_error_data_result(message=f"You don't own the document {id}.") if int(doc[0].progress) == 1 or doc[0].progress == 0: return get_error_data_result("Can't stop parsing document with progress at 0 or 1") + # Send cancellation signal via Redis to stop background task + cancel_all_task_of(id) info = {"run": "2", "progress": 0, "chunk_num": 0} DocumentService.update_by_id(id, info) settings.docStoreConn.delete({"doc_id": doc[0].id}, search.index_name(tenant_id), dataset_id) @@ -885,7 +971,7 @@ def list_chunks(tenant_id, dataset_id, document_id): type: string required: false default: "" - description: Chunk Id. + description: Chunk id. - in: header name: Authorization type: string @@ -994,7 +1080,7 @@ def list_chunks(tenant_id, dataset_id, document_id): res["chunks"].append(final_chunk) _ = Chunk(**final_chunk) - elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id): + elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id): sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True) res["total"] = sres.total for id in sres.ids: @@ -1019,7 +1105,7 @@ def list_chunks(tenant_id, dataset_id, document_id): "/datasets//documents//chunks", methods=["POST"] ) @token_required -def add_chunk(tenant_id, dataset_id, document_id): +async def add_chunk(tenant_id, dataset_id, document_id): """ Add a chunk to a document. --- @@ -1089,7 +1175,7 @@ def add_chunk(tenant_id, dataset_id, document_id): if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] - req = request.json + req = await get_request_json() if not str(req.get("content", "")).strip(): return get_error_data_result(message="`content` is required") if "important_keywords" in req: @@ -1148,7 +1234,7 @@ def add_chunk(tenant_id, dataset_id, document_id): "datasets//documents//chunks", methods=["DELETE"] ) @token_required -def rm_chunk(tenant_id, dataset_id, document_id): +async def rm_chunk(tenant_id, dataset_id, document_id): """ Remove chunks from a document. --- @@ -1195,11 +1281,14 @@ def rm_chunk(tenant_id, dataset_id, document_id): docs = DocumentService.get_by_ids([document_id]) if not docs: raise LookupError(f"Can't find the document with ID {document_id}!") - req = request.json + req = await get_request_json() condition = {"doc_id": document_id} if "chunk_ids" in req: unique_chunk_ids, duplicate_messages = check_duplicate_ids(req["chunk_ids"], "chunk") condition["id"] = unique_chunk_ids + else: + unique_chunk_ids = [] + duplicate_messages = [] chunk_number = settings.docStoreConn.delete(condition, search.index_name(tenant_id), dataset_id) if chunk_number != 0: DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) @@ -1219,7 +1308,7 @@ def rm_chunk(tenant_id, dataset_id, document_id): "/datasets//documents//chunks/", methods=["PUT"] ) @token_required -def update_chunk(tenant_id, dataset_id, document_id, chunk_id): +async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): """ Update a chunk within a document. --- @@ -1281,8 +1370,8 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] - req = request.json - if "content" in req: + req = await get_request_json() + if "content" in req and req["content"] is not None: content = req["content"] else: content = chunk.get("content_with_weight", "") @@ -1323,7 +1412,7 @@ def update_chunk(tenant_id, dataset_id, document_id, chunk_id): @manager.route("/retrieval", methods=["POST"]) # noqa: F821 @token_required -def retrieval_test(tenant_id): +async def retrieval_test(tenant_id): """ Retrieve chunks based on a query. --- @@ -1404,7 +1493,7 @@ def retrieval_test(tenant_id): format: float description: Similarity score. """ - req = request.json + req = await get_request_json() if not req.get("dataset_ids"): return get_error_data_result("`dataset_ids` is required.") kb_ids = req["dataset_ids"] @@ -1427,6 +1516,7 @@ def retrieval_test(tenant_id): question = req["question"] doc_ids = req.get("document_ids", []) use_kg = req.get("use_kg", False) + toc_enhance = req.get("toc_enhance", False) langs = req.get("cross_languages", []) if not isinstance(doc_ids, list): return get_error_data_result("`documents` should be a list") @@ -1435,9 +1525,14 @@ def retrieval_test(tenant_id): if doc_id not in doc_ids_list: return get_error_data_result(f"The datasets don't own the document {doc_id}") if not doc_ids: - metadata_condition = req.get("metadata_condition", {}) + metadata_condition = req.get("metadata_condition", {}) or {} metas = DocumentService.get_meta_by_kbs(kb_ids) - doc_ids = meta_filter(metas, convert_conditions(metadata_condition)) + doc_ids = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")) + # If metadata_condition has conditions but no docs match, return empty result + if not doc_ids and metadata_condition.get("conditions"): + return get_result(data={"total": 0, "chunks": [], "doc_aggs": {}}) + if metadata_condition and not doc_ids: + doc_ids = ["-999"] similarity_threshold = float(req.get("similarity_threshold", 0.2)) vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) top = int(req.get("top_k", 1024)) @@ -1457,11 +1552,11 @@ def retrieval_test(tenant_id): rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) if langs: - question = cross_languages(kb.tenant_id, None, question, langs) + question = await cross_languages(kb.tenant_id, None, question, langs) if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) + question += await keyword_extraction(chat_mdl, question) ranks = settings.retriever.retrieval( question, @@ -1478,6 +1573,11 @@ def retrieval_test(tenant_id): highlight=highlight, rank_feature=label_question(question, kbs), ) + if toc_enhance: + chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) + cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size) + if cks: + ranks["chunks"] = cks if use_kg: ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index 733c894c3e0..a618777884e 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -14,35 +14,34 @@ # limitations under the License. # - +import asyncio import pathlib import re - -import flask -from flask import request +from quart import request, make_response from pathlib import Path from api.db.services.document_service import DocumentService from api.db.services.file2document_service import File2DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils.api_utils import server_error_response, token_required +from api.utils.api_utils import get_json_result, get_request_json, server_error_response, token_required from common.misc_utils import get_uuid from api.db import FileType from api.db.services import duplicate_name from api.db.services.file_service import FileService -from api.utils.api_utils import get_json_result from api.utils.file_utils import filename_type +from api.utils.web_utils import CONTENT_TYPE_MAP from common import settings +from common.constants import RetCode @manager.route('/file/upload', methods=['POST']) # noqa: F821 @token_required -def upload(tenant_id): +async def upload(tenant_id): """ Upload a file to the system. --- tags: - - File Management + - File security: - ApiKeyAuth: [] parameters: @@ -79,26 +78,28 @@ def upload(tenant_id): type: string description: File type (e.g., document, folder) """ - pf_id = request.form.get("parent_id") + form = await request.form + files = await request.files + pf_id = form.get("parent_id") if not pf_id: root_folder = FileService.get_root_folder(tenant_id) pf_id = root_folder["id"] - if 'file' not in request.files: - return get_json_result(data=False, message='No file part!', code=400) - file_objs = request.files.getlist('file') + if 'file' not in files: + return get_json_result(data=False, message='No file part!', code=RetCode.BAD_REQUEST) + file_objs = files.getlist('file') for file_obj in file_objs: if file_obj.filename == '': - return get_json_result(data=False, message='No selected file!', code=400) + return get_json_result(data=False, message='No selected file!', code=RetCode.BAD_REQUEST) file_res = [] try: e, pf_folder = FileService.get_by_id(pf_id) if not e: - return get_json_result(data=False, message="Can't find this folder!", code=404) + return get_json_result(data=False, message="Can't find this folder!", code=RetCode.NOT_FOUND) for file_obj in file_objs: # Handle file path @@ -114,13 +115,13 @@ def upload(tenant_id): if file_len != len_id_list: e, file = FileService.get_by_id(file_id_list[len_id_list - 1]) if not e: - return get_json_result(data=False, message="Folder not found!", code=404) + return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND) last_folder = FileService.create_folder(file, file_id_list[len_id_list - 1], file_obj_names, len_id_list) else: e, file = FileService.get_by_id(file_id_list[len_id_list - 2]) if not e: - return get_json_result(data=False, message="Folder not found!", code=404) + return get_json_result(data=False, message="Folder not found!", code=RetCode.NOT_FOUND) last_folder = FileService.create_folder(file, file_id_list[len_id_list - 2], file_obj_names, len_id_list) @@ -151,12 +152,12 @@ def upload(tenant_id): @manager.route('/file/create', methods=['POST']) # noqa: F821 @token_required -def create(tenant_id): +async def create(tenant_id): """ Create a new file or folder. --- tags: - - File Management + - File security: - ApiKeyAuth: [] parameters: @@ -193,18 +194,19 @@ def create(tenant_id): type: type: string """ - req = request.json - pf_id = request.json.get("parent_id") - input_file_type = request.json.get("type") + req = await get_request_json() + pf_id = req.get("parent_id") + input_file_type = req.get("type") if not pf_id: root_folder = FileService.get_root_folder(tenant_id) pf_id = root_folder["id"] try: if not FileService.is_parent_folder_exist(pf_id): - return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=400) + return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST) if FileService.query(name=req["name"], parent_id=pf_id): - return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409) + return get_json_result(data=False, message="Duplicated folder name in the same folder.", + code=RetCode.CONFLICT) if input_file_type == FileType.FOLDER.value: file_type = FileType.FOLDER.value @@ -229,12 +231,12 @@ def create(tenant_id): @manager.route('/file/list', methods=['GET']) # noqa: F821 @token_required -def list_files(tenant_id): +async def list_files(tenant_id): """ List files under a specific folder. --- tags: - - File Management + - File security: - ApiKeyAuth: [] parameters: @@ -306,13 +308,13 @@ def list_files(tenant_id): try: e, file = FileService.get_by_id(pf_id) if not e: - return get_json_result(message="Folder not found!", code=404) + return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND) files, total = FileService.get_by_pf_id(tenant_id, pf_id, page_number, items_per_page, orderby, desc, keywords) parent_folder = FileService.get_parent_folder(pf_id) if not parent_folder: - return get_json_result(message="File not found!", code=404) + return get_json_result(message="File not found!", code=RetCode.NOT_FOUND) return get_json_result(data={"total": total, "files": files, "parent_folder": parent_folder.to_json()}) except Exception as e: @@ -321,12 +323,12 @@ def list_files(tenant_id): @manager.route('/file/root_folder', methods=['GET']) # noqa: F821 @token_required -def get_root_folder(tenant_id): +async def get_root_folder(tenant_id): """ Get user's root folder. --- tags: - - File Management + - File security: - ApiKeyAuth: [] responses: @@ -357,12 +359,12 @@ def get_root_folder(tenant_id): @manager.route('/file/parent_folder', methods=['GET']) # noqa: F821 @token_required -def get_parent_folder(): +async def get_parent_folder(): """ Get parent folder info of a file. --- tags: - - File Management + - File security: - ApiKeyAuth: [] parameters: @@ -392,7 +394,7 @@ def get_parent_folder(): try: e, file = FileService.get_by_id(file_id) if not e: - return get_json_result(message="Folder not found!", code=404) + return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND) parent_folder = FileService.get_parent_folder(file_id) return get_json_result(data={"parent_folder": parent_folder.to_json()}) @@ -402,12 +404,12 @@ def get_parent_folder(): @manager.route('/file/all_parent_folder', methods=['GET']) # noqa: F821 @token_required -def get_all_parent_folders(tenant_id): +async def get_all_parent_folders(tenant_id): """ Get all parent folders of a file. --- tags: - - File Management + - File security: - ApiKeyAuth: [] parameters: @@ -439,7 +441,7 @@ def get_all_parent_folders(tenant_id): try: e, file = FileService.get_by_id(file_id) if not e: - return get_json_result(message="Folder not found!", code=404) + return get_json_result(message="Folder not found!", code=RetCode.NOT_FOUND) parent_folders = FileService.get_all_parent_folders(file_id) parent_folders_res = [folder.to_json() for folder in parent_folders] @@ -450,12 +452,12 @@ def get_all_parent_folders(tenant_id): @manager.route('/file/rm', methods=['POST']) # noqa: F821 @token_required -def rm(tenant_id): +async def rm(tenant_id): """ Delete one or multiple files/folders. --- tags: - - File Management + - File security: - ApiKeyAuth: [] parameters: @@ -481,40 +483,40 @@ def rm(tenant_id): type: boolean example: true """ - req = request.json + req = await get_request_json() file_ids = req["file_ids"] try: for file_id in file_ids: e, file = FileService.get_by_id(file_id) if not e: - return get_json_result(message="File or Folder not found!", code=404) + return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND) if not file.tenant_id: - return get_json_result(message="Tenant not found!", code=404) + return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND) if file.type == FileType.FOLDER.value: file_id_list = FileService.get_all_innermost_file_ids(file_id, []) for inner_file_id in file_id_list: e, file = FileService.get_by_id(inner_file_id) if not e: - return get_json_result(message="File not found!", code=404) + return get_json_result(message="File not found!", code=RetCode.NOT_FOUND) settings.STORAGE_IMPL.rm(file.parent_id, file.location) FileService.delete_folder_by_pf_id(tenant_id, file_id) else: settings.STORAGE_IMPL.rm(file.parent_id, file.location) if not FileService.delete(file): - return get_json_result(message="Database error (File removal)!", code=500) + return get_json_result(message="Database error (File removal)!", code=RetCode.SERVER_ERROR) informs = File2DocumentService.get_by_file_id(file_id) for inform in informs: doc_id = inform.document_id e, doc = DocumentService.get_by_id(doc_id) if not e: - return get_json_result(message="Document not found!", code=404) + return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND) tenant_id = DocumentService.get_tenant_id(doc_id) if not tenant_id: - return get_json_result(message="Tenant not found!", code=404) + return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND) if not DocumentService.remove_document(doc, tenant_id): - return get_json_result(message="Database error (Document removal)!", code=500) + return get_json_result(message="Database error (Document removal)!", code=RetCode.SERVER_ERROR) File2DocumentService.delete_by_file_id(file_id) return get_json_result(data=True) @@ -524,12 +526,12 @@ def rm(tenant_id): @manager.route('/file/rename', methods=['POST']) # noqa: F821 @token_required -def rename(tenant_id): +async def rename(tenant_id): """ Rename a file. --- tags: - - File Management + - File security: - ApiKeyAuth: [] parameters: @@ -556,27 +558,29 @@ def rename(tenant_id): type: boolean example: true """ - req = request.json + req = await get_request_json() try: e, file = FileService.get_by_id(req["file_id"]) if not e: - return get_json_result(message="File not found!", code=404) + return get_json_result(message="File not found!", code=RetCode.NOT_FOUND) if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path( file.name.lower()).suffix: - return get_json_result(data=False, message="The extension of file can't be changed", code=400) + return get_json_result(data=False, message="The extension of file can't be changed", + code=RetCode.BAD_REQUEST) for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id): if existing_file.name == req["name"]: - return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409) + return get_json_result(data=False, message="Duplicated file name in the same folder.", + code=RetCode.CONFLICT) if not FileService.update_by_id(req["file_id"], {"name": req["name"]}): - return get_json_result(message="Database error (File rename)!", code=500) + return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR) informs = File2DocumentService.get_by_file_id(req["file_id"]) if informs: if not DocumentService.update_by_id(informs[0].document_id, {"name": req["name"]}): - return get_json_result(message="Database error (Document rename)!", code=500) + return get_json_result(message="Database error (Document rename)!", code=RetCode.SERVER_ERROR) return get_json_result(data=True) except Exception as e: @@ -585,12 +589,12 @@ def rename(tenant_id): @manager.route('/file/get/', methods=['GET']) # noqa: F821 @token_required -def get(tenant_id, file_id): +async def get(tenant_id, file_id): """ Download a file. --- tags: - - File Management + - File security: - ApiKeyAuth: [] produces: @@ -606,20 +610,20 @@ def get(tenant_id, file_id): description: File stream schema: type: file - 404: + RetCode.NOT_FOUND: description: File not found """ try: e, file = FileService.get_by_id(file_id) if not e: - return get_json_result(message="Document not found!", code=404) + return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND) blob = settings.STORAGE_IMPL.get(file.parent_id, file.location) if not blob: b, n = File2DocumentService.get_storage_address(file_id=file_id) blob = settings.STORAGE_IMPL.get(b, n) - response = flask.make_response(blob) + response = await make_response(blob) ext = re.search(r"\.([^.]+)$", file.name) if ext: if file.type == FileType.VISUAL.value: @@ -631,14 +635,29 @@ def get(tenant_id, file_id): return server_error_response(e) +@manager.route("/file/download/", methods=["GET"]) # noqa: F821 +@token_required +async def download_attachment(tenant_id, attachment_id): + try: + ext = request.args.get("ext", "markdown") + data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id) + response = await make_response(data) + response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}")) + + return response + + except Exception as e: + return server_error_response(e) + + @manager.route('/file/mv', methods=['POST']) # noqa: F821 @token_required -def move(tenant_id): +async def move(tenant_id): """ Move one or multiple files to another folder. --- tags: - - File Management + - File security: - ApiKeyAuth: [] parameters: @@ -667,7 +686,7 @@ def move(tenant_id): type: boolean example: true """ - req = request.json + req = await get_request_json() try: file_ids = req["src_file_ids"] parent_id = req["dest_file_id"] @@ -677,13 +696,13 @@ def move(tenant_id): for file_id in file_ids: file = files_dict[file_id] if not file: - return get_json_result(message="File or Folder not found!", code=404) + return get_json_result(message="File or Folder not found!", code=RetCode.NOT_FOUND) if not file.tenant_id: - return get_json_result(message="Tenant not found!", code=404) + return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND) fe, _ = FileService.get_by_id(parent_id) if not fe: - return get_json_result(message="Parent Folder not found!", code=404) + return get_json_result(message="Parent Folder not found!", code=RetCode.NOT_FOUND) FileService.move_file(file_ids, parent_id) return get_json_result(data=True) @@ -693,8 +712,8 @@ def move(tenant_id): @manager.route('/file/convert', methods=['POST']) # noqa: F821 @token_required -def convert(tenant_id): - req = request.json +async def convert(tenant_id): + req = await get_request_json() kb_ids = req["kb_ids"] file_ids = req["file_ids"] file2documents = [] @@ -705,7 +724,7 @@ def convert(tenant_id): for file_id in file_ids: file = files_set[file_id] if not file: - return get_json_result(message="File not found!", code=404) + return get_json_result(message="File not found!", code=RetCode.NOT_FOUND) file_ids_list = [file_id] if file.type == FileType.FOLDER.value: file_ids_list = FileService.get_all_innermost_file_ids(file_id, []) @@ -716,13 +735,13 @@ def convert(tenant_id): doc_id = inform.document_id e, doc = DocumentService.get_by_id(doc_id) if not e: - return get_json_result(message="Document not found!", code=404) + return get_json_result(message="Document not found!", code=RetCode.NOT_FOUND) tenant_id = DocumentService.get_tenant_id(doc_id) if not tenant_id: - return get_json_result(message="Tenant not found!", code=404) + return get_json_result(message="Tenant not found!", code=RetCode.NOT_FOUND) if not DocumentService.remove_document(doc, tenant_id): return get_json_result( - message="Database error (Document removal)!", code=404) + message="Database error (Document removal)!", code=RetCode.NOT_FOUND) File2DocumentService.delete_by_file_id(id) # insert @@ -730,11 +749,11 @@ def convert(tenant_id): e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: return get_json_result( - message="Can't find this knowledgebase!", code=404) + message="Can't find this dataset!", code=RetCode.NOT_FOUND) e, file = FileService.get_by_id(id) if not e: return get_json_result( - message="Can't find this file!", code=404) + message="Can't find this file!", code=RetCode.NOT_FOUND) doc = DocumentService.insert({ "id": get_uuid(), diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 4edb2bb6b91..f9615e36ba1 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -14,38 +14,42 @@ # limitations under the License. # import json +import copy import re import time import tiktoken -from flask import Response, jsonify, request +from quart import Response, jsonify, request from agent.canvas import Canvas from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService, completion_openai from api.db.services.canvas_service import completion as agent_completion -from api.db.services.conversation_service import ConversationService, iframe_completion -from api.db.services.conversation_service import completion as rag_completion -from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter +from api.db.services.conversation_service import ConversationService +from api.db.services.conversation_service import async_iframe_completion as iframe_completion +from api.db.services.conversation_service import async_completion as rag_completion +from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from common.misc_utils import get_uuid from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \ - get_result, server_error_response, token_required, validate_request + get_result, get_request_json, server_error_response, token_required, validate_request from rag.app.tag import label_question from rag.prompts.template import load_prompt -from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format +from rag.prompts.generator import cross_languages, keyword_extraction, chunks_format from common.constants import RetCode, LLMType, StatusEnum from common import settings + @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @token_required -def create(tenant_id, chat_id): - req = request.json +async def create(tenant_id, chat_id): + req = await get_request_json() req["dialog_id"] = chat_id dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) if not dia: @@ -56,7 +60,7 @@ def create(tenant_id, chat_id): "name": req.get("name", "New session"), "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}], "user_id": req.get("user_id", ""), - "reference": [{}], + "reference": [], } if not conv.get("name"): return get_error_data_result(message="`name` can not be empty.") @@ -73,7 +77,7 @@ def create(tenant_id, chat_id): @manager.route("/agents//sessions", methods=["POST"]) # noqa: F821 @token_required -def create_agent_session(tenant_id, agent_id): +async def create_agent_session(tenant_id, agent_id): user_id = request.args.get("user_id", tenant_id) e, cvs = UserCanvasService.get_by_id(agent_id) if not e: @@ -84,7 +88,7 @@ def create_agent_session(tenant_id, agent_id): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) session_id = get_uuid() - canvas = Canvas(cvs.dsl, tenant_id, agent_id) + canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id) canvas.reset() cvs.dsl = json.loads(str(canvas)) @@ -97,8 +101,8 @@ def create_agent_session(tenant_id, agent_id): @manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821 @token_required -def update(tenant_id, chat_id, session_id): - req = request.json +async def update(tenant_id, chat_id, session_id): + req = await get_request_json() req["dialog_id"] = chat_id conv_id = session_id conv = ConversationService.query(id=conv_id, dialog_id=chat_id) @@ -119,17 +123,39 @@ def update(tenant_id, chat_id, session_id): @manager.route("/chats//completions", methods=["POST"]) # noqa: F821 @token_required -def chat_completion(tenant_id, chat_id): - req = request.json +async def chat_completion(tenant_id, chat_id): + req = await get_request_json() if not req: req = {"question": ""} if not req.get("session_id"): req["question"] = "" - if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): + dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value) + if not dia: return get_error_data_result(f"You don't own the chat {chat_id}") + dia = dia[0] if req.get("session_id"): if not ConversationService.query(id=req["session_id"], dialog_id=chat_id): return get_error_data_result(f"You don't own the session {req['session_id']}") + + metadata_condition = req.get("metadata_condition") or {} + if metadata_condition and not isinstance(metadata_condition, dict): + return get_error_data_result(message="metadata_condition must be an object.") + + if metadata_condition and req.get("question"): + metas = DocumentService.get_meta_by_kbs(dia.kb_ids or []) + filtered_doc_ids = meta_filter( + metas, + convert_conditions(metadata_condition), + metadata_condition.get("logic", "and"), + ) + if metadata_condition.get("conditions") and not filtered_doc_ids: + filtered_doc_ids = ["-999"] + + if filtered_doc_ids: + req["doc_ids"] = ",".join(filtered_doc_ids) + else: + req.pop("doc_ids", None) + if req.get("stream", True): resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") @@ -140,7 +166,7 @@ def chat_completion(tenant_id, chat_id): return resp else: answer = None - for ans in rag_completion(tenant_id, chat_id, **req): + async for ans in rag_completion(tenant_id, chat_id, **req): answer = ans break return get_result(data=answer) @@ -149,7 +175,7 @@ def chat_completion(tenant_id, chat_id): @manager.route("/chats_openai//chat/completions", methods=["POST"]) # noqa: F821 @validate_request("model", "messages") # noqa: F821 @token_required -def chat_completion_openai_like(tenant_id, chat_id): +async def chat_completion_openai_like(tenant_id, chat_id): """ OpenAI-like chat completion API that simulates the behavior of OpenAI's completions endpoint. @@ -192,7 +218,19 @@ def chat_completion_openai_like(tenant_id, chat_id): {"role": "user", "content": "Can you tell me how to install neovim"}, ], stream=stream, - extra_body={"reference": reference} + extra_body={ + "reference": reference, + "metadata_condition": { + "logic": "and", + "conditions": [ + { + "name": "author", + "comparison_operator": "is", + "value": "bob" + } + ] + } + } ) if stream: @@ -206,9 +244,13 @@ def chat_completion_openai_like(tenant_id, chat_id): if reference: print(completion.choices[0].message.reference) """ - req = request.get_json() + req = await get_request_json() - need_reference = bool(req.get("reference", False)) + extra_body = req.get("extra_body") or {} + if extra_body and not isinstance(extra_body, dict): + return get_error_data_result("extra_body must be an object.") + + need_reference = bool(extra_body.get("reference", False)) messages = req.get("messages", []) # To prevent empty [] input @@ -226,6 +268,22 @@ def chat_completion_openai_like(tenant_id, chat_id): return get_error_data_result(f"You don't own the chat {chat_id}") dia = dia[0] + metadata_condition = extra_body.get("metadata_condition") or {} + if metadata_condition and not isinstance(metadata_condition, dict): + return get_error_data_result(message="metadata_condition must be an object.") + + doc_ids_str = None + if metadata_condition: + metas = DocumentService.get_meta_by_kbs(dia.kb_ids or []) + filtered_doc_ids = meta_filter( + metas, + convert_conditions(metadata_condition), + metadata_condition.get("logic", "and"), + ) + if metadata_condition.get("conditions") and not filtered_doc_ids: + filtered_doc_ids = ["-999"] + doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None + # Filter system and non-sense assistant messages msg = [] for m in messages: @@ -244,7 +302,7 @@ def chat_completion_openai_like(tenant_id, chat_id): # The value for the usage field on all chunks except for the last one will be null. # The usage field on the last chunk contains token usage statistics for the entire request. # The choices field on the last chunk will always be an empty array []. - def streamed_response_generator(chat_id, dia, msg): + async def streamed_response_generator(chat_id, dia, msg): token_used = 0 answer_cache = "" reasoning_cache = "" @@ -273,14 +331,17 @@ def streamed_response_generator(chat_id, dia, msg): } try: - for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference): + chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} + if doc_ids_str: + chat_kwargs["doc_ids"] = doc_ids_str + async for ans in async_chat(dia, msg, True, **chat_kwargs): last_ans = ans answer = ans["answer"] reasoning_match = re.search(r"(.*?)", answer, flags=re.DOTALL) if reasoning_match: reasoning_part = reasoning_match.group(1) - content_part = answer[reasoning_match.end():] + content_part = answer[reasoning_match.end() :] else: reasoning_part = "" content_part = answer @@ -325,8 +386,7 @@ def streamed_response_generator(chat_id, dia, msg): response["choices"][0]["delta"]["content"] = None response["choices"][0]["delta"]["reasoning_content"] = None response["choices"][0]["finish_reason"] = "stop" - response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, - "total_tokens": len(prompt) + token_used} + response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used} if need_reference: response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", [])) response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "") @@ -341,7 +401,10 @@ def streamed_response_generator(chat_id, dia, msg): return resp else: answer = None - for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference): + chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} + if doc_ids_str: + chat_kwargs["doc_ids"] = doc_ids_str + async for ans in async_chat(dia, msg, False, **chat_kwargs): # focus answer content only answer = ans break @@ -383,9 +446,9 @@ def streamed_response_generator(chat_id, dia, msg): @manager.route("/agents_openai//chat/completions", methods=["POST"]) # noqa: F821 @validate_request("model", "messages") # noqa: F821 @token_required -def agents_completion_openai_compatibility(tenant_id, agent_id): - req = request.json - tiktokenenc = tiktoken.get_encoding("cl100k_base") +async def agents_completion_openai_compatibility(tenant_id, agent_id): + req = await get_request_json() + tiktoken_encode = tiktoken.get_encoding("cl100k_base") messages = req.get("messages", []) if not messages: return get_error_data_result("You must provide at least one message.") @@ -393,7 +456,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id): return get_error_data_result(f"You don't own the agent {agent_id}") filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]] - prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages) + prompt_tokens = sum(len(tiktoken_encode.encode(m["content"])) for m in filtered_messages) if not filtered_messages: return jsonify( get_data_openai( @@ -401,7 +464,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id): content="No valid messages found (user or assistant).", finish_reason="stop", model=req.get("model", ""), - completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")), + completion_tokens=len(tiktoken_encode.encode("No valid messages found (user or assistant).")), prompt_tokens=prompt_tokens, ) ) @@ -428,35 +491,51 @@ def agents_completion_openai_compatibility(tenant_id, agent_id): return resp else: # For non-streaming, just return the response directly - response = next( - completion_openai( + async for response in completion_openai( tenant_id, agent_id, question, session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""), stream=False, **req, - ) - ) - return jsonify(response) + ): + return jsonify(response) + + return None @manager.route("/agents//completions", methods=["POST"]) # noqa: F821 @token_required -def agent_completions(tenant_id, agent_id): - req = request.json +async def agent_completions(tenant_id, agent_id): + req = await get_request_json() + return_trace = bool(req.get("return_trace", False)) if req.get("stream", True): - def generate(): - for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): + async def generate(): + trace_items = [] + async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): if isinstance(answer, str): try: ans = json.loads(answer[5:]) # remove "data:" except Exception: continue - if ans.get("event") not in ["message", "message_end"]: + event = ans.get("event") + if event == "node_finished": + if return_trace: + data = ans.get("data", {}) + trace_items.append( + { + "component_id": data.get("component_id"), + "trace": [copy.deepcopy(data)], + } + ) + ans.setdefault("data", {})["trace"] = trace_items + answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" + yield answer + + if event not in ["message", "message_end"]: continue yield answer @@ -473,7 +552,8 @@ def generate(): full_content = "" reference = {} final_ans = "" - for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): + trace_items = [] + async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): try: ans = json.loads(answer[5:]) @@ -483,17 +563,28 @@ def generate(): if ans.get("data", {}).get("reference", None): reference.update(ans["data"]["reference"]) + if return_trace and ans.get("event") == "node_finished": + data = ans.get("data", {}) + trace_items.append( + { + "component_id": data.get("component_id"), + "trace": [copy.deepcopy(data)], + } + ) + final_ans = ans except Exception as e: return get_result(data=f"**ERROR**: {str(e)}") final_ans["data"]["content"] = full_content final_ans["data"]["reference"] = reference + if return_trace and final_ans: + final_ans["data"]["trace"] = trace_items return get_result(data=final_ans) @manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 @token_required -def list_session(tenant_id, chat_id): +async def list_session(tenant_id, chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(message=f"You don't own the assistant {chat_id}.") id = request.args.get("id") @@ -547,7 +638,7 @@ def list_session(tenant_id, chat_id): @manager.route("/agents//sessions", methods=["GET"]) # noqa: F821 @token_required -def list_agent_session(tenant_id, agent_id): +async def list_agent_session(tenant_id, agent_id): if not UserCanvasService.query(user_id=tenant_id, id=agent_id): return get_error_data_result(message=f"You don't own the agent {agent_id}.") id = request.args.get("id") @@ -610,13 +701,13 @@ def list_agent_session(tenant_id, agent_id): @manager.route("/chats//sessions", methods=["DELETE"]) # noqa: F821 @token_required -def delete(tenant_id, chat_id): +async def delete(tenant_id, chat_id): if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): return get_error_data_result(message="You don't own the chat") errors = [] success_count = 0 - req = request.json + req = await get_request_json() convs = ConversationService.query(dialog_id=chat_id) if not req: ids = None @@ -661,10 +752,10 @@ def delete(tenant_id, chat_id): @manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821 @token_required -def delete_agent_session(tenant_id, agent_id): +async def delete_agent_session(tenant_id, agent_id): errors = [] success_count = 0 - req = request.json + req = await get_request_json() cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) if not cvs: return get_error_data_result(f"You don't own the agent {agent_id}") @@ -716,8 +807,8 @@ def delete_agent_session(tenant_id, agent_id): @manager.route("/sessions/ask", methods=["POST"]) # noqa: F821 @token_required -def ask_about(tenant_id): - req = request.json +async def ask_about(tenant_id): + req = await get_request_json() if not req.get("question"): return get_error_data_result("`question` is required.") if not req.get("dataset_ids"): @@ -734,10 +825,10 @@ def ask_about(tenant_id): return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file") uid = tenant_id - def stream(): + async def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid): + async for ans in async_ask(req["question"], req["kb_ids"], uid): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps( @@ -755,8 +846,8 @@ def stream(): @manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821 @token_required -def related_questions(tenant_id): - req = request.json +async def related_questions(tenant_id): + req = await get_request_json() if not req.get("question"): return get_error_data_result("`question` is required.") question = req["question"] @@ -789,7 +880,7 @@ def related_questions(tenant_id): - At the same time, related terms can also help search engines better understand user needs and return more accurate search results. """ - ans = chat_mdl.chat( + ans = await chat_mdl.async_chat( prompt, [ { @@ -806,8 +897,8 @@ def related_questions(tenant_id): @manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 -def chatbot_completions(dialog_id): - req = request.json +async def chatbot_completions(dialog_id): + req = await get_request_json() token = request.headers.get("Authorization").split() if len(token) != 2: @@ -828,12 +919,13 @@ def chatbot_completions(dialog_id): resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - for answer in iframe_completion(dialog_id, **req): + async for answer in iframe_completion(dialog_id, **req): return get_result(data=answer) + return None @manager.route("/chatbots//info", methods=["GET"]) # noqa: F821 -def chatbots_inputs(dialog_id): +async def chatbots_inputs(dialog_id): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -856,8 +948,8 @@ def chatbots_inputs(dialog_id): @manager.route("/agentbots//completions", methods=["POST"]) # noqa: F821 -def agent_bot_completions(agent_id): - req = request.json +async def agent_bot_completions(agent_id): + req = await get_request_json() token = request.headers.get("Authorization").split() if len(token) != 2: @@ -875,12 +967,13 @@ def agent_bot_completions(agent_id): resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - for answer in agent_completion(objs[0].tenant_id, agent_id, **req): + async for answer in agent_completion(objs[0].tenant_id, agent_id, **req): return get_result(data=answer) + return None @manager.route("/agentbots//inputs", methods=["GET"]) # noqa: F821 -def begin_inputs(agent_id): +async def begin_inputs(agent_id): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -893,7 +986,7 @@ def begin_inputs(agent_id): if not e: return get_error_data_result(f"Can't find agent by ID: {agent_id}") - canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id) + canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id, canvas_id=cvs.id) return get_result( data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()}) @@ -901,7 +994,7 @@ def begin_inputs(agent_id): @manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821 @validate_request("question", "kb_ids") -def ask_about_embedded(): +async def ask_about_embedded(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -910,7 +1003,7 @@ def ask_about_embedded(): if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - req = request.json + req = await get_request_json() uid = objs[0].tenant_id search_id = req.get("search_id", "") @@ -919,10 +1012,10 @@ def ask_about_embedded(): if search_app := SearchService.get_detail(search_id): search_config = search_app.get("search_config", {}) - def stream(): + async def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): + async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps( @@ -940,7 +1033,7 @@ def stream(): @manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821 @validate_request("kb_id", "question") -def retrieval_test_embedded(): +async def retrieval_test_embedded(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -949,7 +1042,7 @@ def retrieval_test_embedded(): if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - req = request.json + req = await get_request_json() page = int(req.get("page", 1)) size = int(req.get("size", 30)) question = req["question"] @@ -965,28 +1058,31 @@ def retrieval_test_embedded(): use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) langs = req.get("cross_languages", []) - tenant_ids = [] - tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") - if req.get("search_id", ""): - search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) - meta_data_filter = search_config.get("meta_data_filter", {}) - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters = gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters)) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"])) - if not doc_ids: - doc_ids = None + async def _retrieval(): + local_doc_ids = list(doc_ids) if doc_ids else [] + tenant_ids = [] + _question = question + + meta_data_filter = {} + chat_mdl = None + if req.get("search_id", ""): + search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + else: + meta_data_filter = req.get("meta_data_filter") or {} + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + + if meta_data_filter: + metas = DocumentService.get_meta_by_kbs(kb_ids) + local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids) - try: tenants = UserTenantService.query(user_id=tenant_id) for kb_id in kb_ids: for tenant in tenants: @@ -994,7 +1090,7 @@ def retrieval_test_embedded(): tenant_ids.append(tenant.tenant_id) break else: - return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", + return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) @@ -1002,7 +1098,7 @@ def retrieval_test_embedded(): return get_error_data_result(message="Knowledgebase not found!") if langs: - question = cross_languages(kb.tenant_id, None, question, langs) + _question = await cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -1012,15 +1108,15 @@ def retrieval_test_embedded(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) + _question += await keyword_extraction(chat_mdl, _question) - labels = label_question(question, [kb]) + labels = label_question(_question, [kb]) ranks = settings.retriever.retrieval( - question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels + _question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, + local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) if use_kg: - ck = settings.kg_retriever.retrieval(question, tenant_ids, kb_ids, embd_mdl, + ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -1030,6 +1126,9 @@ def retrieval_test_embedded(): ranks["labels"] = labels return get_json_result(data=ranks) + + try: + return await _retrieval() except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message="No chunk found! Check the chunk status please!", @@ -1039,7 +1138,7 @@ def retrieval_test_embedded(): @manager.route("/searchbots/related_questions", methods=["POST"]) # noqa: F821 @validate_request("question") -def related_questions_embedded(): +async def related_questions_embedded(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -1048,7 +1147,7 @@ def related_questions_embedded(): if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - req = request.json + req = await get_request_json() tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") @@ -1066,7 +1165,7 @@ def related_questions_embedded(): gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) prompt = load_prompt("related_question") - ans = chat_mdl.chat( + ans = await chat_mdl.async_chat( prompt, [ { @@ -1083,7 +1182,7 @@ def related_questions_embedded(): @manager.route("/searchbots/detail", methods=["GET"]) # noqa: F821 -def detail_share_embedded(): +async def detail_share_embedded(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -1115,7 +1214,7 @@ def detail_share_embedded(): @manager.route("/searchbots/mindmap", methods=["POST"]) # noqa: F821 @validate_request("question", "kb_ids") -def mindmap(): +async def mindmap(): token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') @@ -1125,12 +1224,12 @@ def mindmap(): return get_error_data_result(message='Authentication error: API key is invalid!"') tenant_id = objs[0].tenant_id - req = request.json + req = await get_request_json() search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} - mind_map = gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) + mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) if "error" in mind_map: return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) diff --git a/api/apps/search_app.py b/api/apps/search_app.py index 79922337138..d82c3b27d65 100644 --- a/api/apps/search_app.py +++ b/api/apps/search_app.py @@ -14,8 +14,8 @@ # limitations under the License. # -from flask import request -from flask_login import current_user, login_required +from quart import request +from api.apps import current_user, login_required from api.constants import DATASET_NAME_LIMIT from api.db.db_models import DB @@ -24,14 +24,14 @@ from api.db.services.user_service import TenantService, UserTenantService from common.misc_utils import get_uuid from common.constants import RetCode, StatusEnum -from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, server_error_response, validate_request +from api.utils.api_utils import get_data_error_result, get_json_result, not_allowed_parameters, get_request_json, server_error_response, validate_request @manager.route("/create", methods=["post"]) # noqa: F821 @login_required @validate_request("name") -def create(): - req = request.get_json() +async def create(): + req = await get_request_json() search_name = req["name"] description = req.get("description", "") if not isinstance(search_name, str): @@ -65,8 +65,8 @@ def create(): @login_required @validate_request("search_id", "name", "search_config", "tenant_id") @not_allowed_parameters("id", "created_by", "create_time", "update_time", "create_date", "update_date", "created_by") -def update(): - req = request.get_json() +async def update(): + req = await get_request_json() if not isinstance(req["name"], str): return get_data_error_result(message="Search name must be string.") if req["name"].strip() == "": @@ -140,7 +140,7 @@ def detail(): @manager.route("/list", methods=["POST"]) # noqa: F821 @login_required -def list_search_app(): +async def list_search_app(): keywords = request.args.get("keywords", "") page_number = int(request.args.get("page", 0)) items_per_page = int(request.args.get("page_size", 0)) @@ -150,7 +150,7 @@ def list_search_app(): else: desc = True - req = request.get_json() + req = await get_request_json() owner_ids = req.get("owner_ids", []) try: if not owner_ids: @@ -173,8 +173,8 @@ def list_search_app(): @manager.route("/rm", methods=["post"]) # noqa: F821 @login_required @validate_request("search_id") -def rm(): - req = request.get_json() +async def rm(): + req = await get_request_json() search_id = req["search_id"] if not SearchService.accessible4deletion(search_id, current_user.id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) diff --git a/api/apps/system_app.py b/api/apps/system_app.py index b63f80a6a7e..379b597de9d 100644 --- a/api/apps/system_app.py +++ b/api/apps/system_app.py @@ -17,7 +17,7 @@ from datetime import datetime import json -from flask_login import login_required, current_user +from api.apps import login_required, current_user from api.db.db_models import APIToken from api.db.services.api_service import APITokenService @@ -34,7 +34,7 @@ from timeit import default_timer as timer from rag.utils.redis_conn import REDIS_CONN -from flask import jsonify +from quart import jsonify from api.utils.health_utils import run_health_checks from common import settings @@ -177,7 +177,7 @@ def healthz(): return jsonify(result), (200 if all_ok else 500) -@manager.route("/ping", methods=["GET"]) # noqa: F821 +@manager.route("/ping", methods=["GET"]) # noqa: F821 def ping(): return "pong", 200 @@ -213,7 +213,7 @@ def new_token(): if not tenants: return get_data_error_result(message="Tenant not found!") - tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id + tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id obj = { "tenant_id": tenant_id, "token": generate_confirmation_token(), @@ -268,13 +268,12 @@ def token_list(): if not tenants: return get_data_error_result(message="Tenant not found!") - tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id + tenant_id = [tenant for tenant in tenants if tenant.role == "owner"][0].tenant_id objs = APITokenService.query(tenant_id=tenant_id) objs = [o.to_dict() for o in objs] for o in objs: if not o["beta"]: - o["beta"] = generate_confirmation_token().replace( - "ragflow-", "")[:32] + o["beta"] = generate_confirmation_token().replace("ragflow-", "")[:32] APITokenService.filter_update([APIToken.tenant_id == tenant_id, APIToken.token == o["token"]], o) return get_json_result(data=objs) except Exception as e: @@ -307,13 +306,19 @@ def rm(token): type: boolean description: Deletion status. """ - APITokenService.filter_delete( - [APIToken.tenant_id == current_user.id, APIToken.token == token] - ) - return get_json_result(data=True) + try: + tenants = UserTenantService.query(user_id=current_user.id) + if not tenants: + return get_data_error_result(message="Tenant not found!") + + tenant_id = tenants[0].tenant_id + APITokenService.filter_delete([APIToken.tenant_id == tenant_id, APIToken.token == token]) + return get_json_result(data=True) + except Exception as e: + return server_error_response(e) -@manager.route('/config', methods=['GET']) # noqa: F821 +@manager.route("/config", methods=["GET"]) # noqa: F821 def get_config(): """ Get system configuration. @@ -330,6 +335,4 @@ def get_config(): type: integer 0 means disabled, 1 means enabled description: Whether user registration is enabled """ - return get_json_result(data={ - "registerEnabled": settings.REGISTER_ENABLED - }) + return get_json_result(data={"registerEnabled": settings.REGISTER_ENABLED}) diff --git a/api/apps/tenant_app.py b/api/apps/tenant_app.py index abb096faaa5..be6305e8911 100644 --- a/api/apps/tenant_app.py +++ b/api/apps/tenant_app.py @@ -13,11 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from flask import request -from flask_login import login_required, current_user - -from api.apps import smtp_mail_server +import logging +import asyncio from api.db import UserTenantRole from api.db.db_models import UserTenant from api.db.services.user_service import UserTenantService, UserService @@ -25,9 +22,10 @@ from common.constants import RetCode, StatusEnum from common.misc_utils import get_uuid from common.time_utils import delta_seconds -from api.utils.api_utils import get_json_result, validate_request, server_error_response, get_data_error_result +from api.utils.api_utils import get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from api.utils.web_utils import send_invite_email from common import settings +from api.apps import login_required, current_user @manager.route("//user/list", methods=["GET"]) # noqa: F821 @@ -51,14 +49,14 @@ def user_list(tenant_id): @manager.route('//user', methods=['POST']) # noqa: F821 @login_required @validate_request("email") -def create(tenant_id): +async def create(tenant_id): if current_user.id != tenant_id: return get_json_result( data=False, message='No authorization.', code=RetCode.AUTHENTICATION_ERROR) - req = request.json + req = await get_request_json() invite_user_email = req["email"] invite_users = UserService.query(email=invite_user_email) if not invite_users: @@ -83,20 +81,24 @@ def create(tenant_id): role=UserTenantRole.INVITE, status=StatusEnum.VALID.value) - if smtp_mail_server and settings.SMTP_CONF: - from threading import Thread + try: user_name = "" _, user = UserService.get_by_id(current_user.id) if user: user_name = user.nickname - Thread( - target=send_invite_email, - args=(invite_user_email, settings.MAIL_FRONTEND_URL, tenant_id, user_name or current_user.email), - daemon=True - ).start() - + asyncio.create_task( + send_invite_email( + to_email=invite_user_email, + invite_url=settings.MAIL_FRONTEND_URL, + tenant_id=tenant_id, + inviter=user_name or current_user.email + ) + ) + except Exception as e: + logging.exception(f"Failed to send invite email to {invite_user_email}: {e}") + return get_json_result(data=False, message="Failed to send invite email.", code=RetCode.SERVER_ERROR) usr = invite_users[0].to_dict() usr = {k: v for k, v in usr.items() if k in ["id", "avatar", "email", "nickname"]} diff --git a/api/apps/user_app.py b/api/apps/user_app.py index 06130cce7fe..e1ad157bc72 100644 --- a/api/apps/user_app.py +++ b/api/apps/user_app.py @@ -21,9 +21,9 @@ import secrets import time from datetime import datetime +import base64 -from flask import redirect, request, session, make_response -from flask_login import current_user, login_required, login_user, logout_user +from quart import make_response, redirect, request, session from werkzeug.security import check_password_hash, generate_password_hash from api.apps.auth import get_auth_client @@ -40,12 +40,13 @@ from api.utils.api_utils import ( get_data_error_result, get_json_result, + get_request_json, server_error_response, validate_request, ) from api.utils.crypt import decrypt from rag.utils.redis_conn import REDIS_CONN -from api.apps import smtp_mail_server +from api.apps import login_required, current_user, login_user, logout_user from api.utils.web_utils import ( send_email_html, OTP_LENGTH, @@ -58,10 +59,11 @@ captcha_key, ) from common import settings +from common.http_client import async_request @manager.route("/login", methods=["POST", "GET"]) # noqa: F821 -def login(): +async def login(): """ User login endpoint. --- @@ -91,10 +93,14 @@ def login(): schema: type: object """ - if not request.json: + json_body = await get_request_json() + if not json_body: return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Unauthorized!") - email = request.json.get("email", "") + email = json_body.get("email", "") + if email == "admin@ragflow.io": + return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="Default admin account cannot be used to login normal services!") + users = UserService.query(email=email) if not users: return get_json_result( @@ -103,7 +109,7 @@ def login(): message=f"Email: {email} is not registered!", ) - password = request.json.get("password") + password = json_body.get("password") try: password = decrypt(password) except BaseException: @@ -121,11 +127,12 @@ def login(): response_data = user.to_json() user.access_token = get_uuid() login_user(user) - user.update_time = (current_timestamp(),) - user.update_date = (datetime_format(datetime.now()),) + user.update_time = current_timestamp() + user.update_date = datetime_format(datetime.now()) user.save() msg = "Welcome back!" - return construct_response(data=response_data, auth=user.get_id(), message=msg) + + return await construct_response(data=response_data, auth=user.get_id(), message=msg) else: return get_json_result( data=False, @@ -135,7 +142,7 @@ def login(): @manager.route("/login/channels", methods=["GET"]) # noqa: F821 -def get_login_channels(): +async def get_login_channels(): """ Get all supported authentication channels. """ @@ -156,7 +163,7 @@ def get_login_channels(): @manager.route("/login/", methods=["GET"]) # noqa: F821 -def oauth_login(channel): +async def oauth_login(channel): channel_config = settings.OAUTH_CONFIG.get(channel) if not channel_config: raise ValueError(f"Invalid channel name: {channel}") @@ -169,7 +176,7 @@ def oauth_login(channel): @manager.route("/oauth/callback/", methods=["GET"]) # noqa: F821 -def oauth_callback(channel): +async def oauth_callback(channel): """ Handle the OAuth/OIDC callback for various channels dynamically. """ @@ -191,7 +198,10 @@ def oauth_callback(channel): return redirect("/?error=missing_code") # Exchange authorization code for access token - token_info = auth_cli.exchange_code_for_token(code) + if hasattr(auth_cli, "async_exchange_code_for_token"): + token_info = await auth_cli.async_exchange_code_for_token(code) + else: + token_info = auth_cli.exchange_code_for_token(code) access_token = token_info.get("access_token") if not access_token: return redirect("/?error=token_failed") @@ -199,7 +209,10 @@ def oauth_callback(channel): id_token = token_info.get("id_token") # Fetch user info - user_info = auth_cli.fetch_user_info(access_token, id_token=id_token) + if hasattr(auth_cli, "async_fetch_user_info"): + user_info = await auth_cli.async_fetch_user_info(access_token, id_token=id_token) + else: + user_info = auth_cli.fetch_user_info(access_token, id_token=id_token) if not user_info.email: return redirect("/?error=email_missing") @@ -258,7 +271,7 @@ def oauth_callback(channel): @manager.route("/github_callback", methods=["GET"]) # noqa: F821 -def github_callback(): +async def github_callback(): """ **Deprecated**, Use `/oauth/callback/` instead. @@ -278,9 +291,8 @@ def github_callback(): schema: type: object """ - import requests - - res = requests.post( + res = await async_request( + "POST", settings.GITHUB_OAUTH.get("url"), data={ "client_id": settings.GITHUB_OAUTH.get("client_id"), @@ -298,7 +310,7 @@ def github_callback(): session["access_token"] = res["access_token"] session["access_token_from"] = "github" - user_info = user_info_from_github(session["access_token"]) + user_info = await user_info_from_github(session["access_token"]) email_address = user_info["email"] users = UserService.query(email=email_address) user_id = get_uuid() @@ -347,7 +359,7 @@ def github_callback(): @manager.route("/feishu_callback", methods=["GET"]) # noqa: F821 -def feishu_callback(): +async def feishu_callback(): """ Feishu OAuth callback endpoint. --- @@ -365,9 +377,8 @@ def feishu_callback(): schema: type: object """ - import requests - - app_access_token_res = requests.post( + app_access_token_res = await async_request( + "POST", settings.FEISHU_OAUTH.get("app_access_token_url"), data=json.dumps( { @@ -381,7 +392,8 @@ def feishu_callback(): if app_access_token_res["code"] != 0: return redirect("/?error=%s" % app_access_token_res) - res = requests.post( + res = await async_request( + "POST", settings.FEISHU_OAUTH.get("user_access_token_url"), data=json.dumps( { @@ -402,7 +414,7 @@ def feishu_callback(): return redirect("/?error=contact:user.email:readonly not in scope") session["access_token"] = res["data"]["access_token"] session["access_token_from"] = "feishu" - user_info = user_info_from_feishu(session["access_token"]) + user_info = await user_info_from_feishu(session["access_token"]) email_address = user_info["email"] users = UserService.query(email=email_address) user_id = get_uuid() @@ -450,36 +462,34 @@ def feishu_callback(): return redirect("/?auth=%s" % user.get_id()) -def user_info_from_feishu(access_token): - import requests - +async def user_info_from_feishu(access_token): headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {access_token}", } - res = requests.get("https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) + res = await async_request("GET", "https://open.feishu.cn/open-apis/authen/v1/user_info", headers=headers) user_info = res.json()["data"] user_info["email"] = None if user_info.get("email") == "" else user_info["email"] return user_info -def user_info_from_github(access_token): - import requests - +async def user_info_from_github(access_token): headers = {"Accept": "application/json", "Authorization": f"token {access_token}"} - res = requests.get(f"https://api.github.com/user?access_token={access_token}", headers=headers) + res = await async_request("GET", f"https://api.github.com/user?access_token={access_token}", headers=headers) user_info = res.json() - email_info = requests.get( + email_info_response = await async_request( + "GET", f"https://api.github.com/user/emails?access_token={access_token}", headers=headers, - ).json() + ) + email_info = email_info_response.json() user_info["email"] = next((email for email in email_info if email["primary"]), None)["email"] return user_info @manager.route("/logout", methods=["GET"]) # noqa: F821 @login_required -def log_out(): +async def log_out(): """ User logout endpoint. --- @@ -501,7 +511,7 @@ def log_out(): @manager.route("/setting", methods=["POST"]) # noqa: F821 @login_required -def setting_user(): +async def setting_user(): """ Update user settings. --- @@ -530,7 +540,7 @@ def setting_user(): type: object """ update_dict = {} - request_data = request.json + request_data = await get_request_json() if request_data.get("password"): new_password = request_data.get("new_password") if not check_password_hash(current_user.password, decrypt(request_data["password"])): @@ -569,7 +579,7 @@ def setting_user(): @manager.route("/info", methods=["GET"]) # noqa: F821 @login_required -def user_profile(): +async def user_profile(): """ Get user profile information. --- @@ -650,7 +660,7 @@ def user_register(user_id, user): tenant_llm = get_init_tenant_llm(user_id) if not UserService.save(**user): - return + return None TenantService.insert(**tenant) UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) @@ -660,7 +670,7 @@ def user_register(user_id, user): @manager.route("/register", methods=["POST"]) # noqa: F821 @validate_request("nickname", "email", "password") -def user_add(): +async def user_add(): """ Register a new user. --- @@ -697,7 +707,7 @@ def user_add(): code=RetCode.OPERATING_ERROR, ) - req = request.json + req = await get_request_json() email_address = req["email"] # Validate the email address @@ -737,7 +747,7 @@ def user_add(): raise Exception(f"Same email: {email_address} exists!") user = users[0] login_user(user) - return construct_response( + return await construct_response( data=user.to_json(), auth=user.get_id(), message=f"{nickname}, welcome aboard!", @@ -754,7 +764,7 @@ def user_add(): @manager.route("/tenant_info", methods=["GET"]) # noqa: F821 @login_required -def tenant_info(): +async def tenant_info(): """ Get tenant information. --- @@ -793,7 +803,7 @@ def tenant_info(): @manager.route("/set_tenant_info", methods=["POST"]) # noqa: F821 @login_required @validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id") -def set_tenant_info(): +async def set_tenant_info(): """ Update tenant information. --- @@ -830,17 +840,17 @@ def set_tenant_info(): schema: type: object """ - req = request.json + req = await get_request_json() try: tid = req.pop("tenant_id") TenantService.update_by_id(tid, req) return get_json_result(data=True) except Exception as e: return server_error_response(e) - + @manager.route("/forget/captcha", methods=["GET"]) # noqa: F821 -def forget_get_captcha(): +async def forget_get_captcha(): """ GET /forget/captcha?email= - Generate an image captcha and cache it in Redis under key captcha:{email} with TTL = OTP_TTL_SECONDS. @@ -862,19 +872,19 @@ def forget_get_captcha(): from captcha.image import ImageCaptcha image = ImageCaptcha(width=300, height=120, font_sizes=[50, 60, 70]) img_bytes = image.generate(captcha_text).read() - response = make_response(img_bytes) + response = await make_response(img_bytes) response.headers.set("Content-Type", "image/JPEG") return response @manager.route("/forget/otp", methods=["POST"]) # noqa: F821 -def forget_send_otp(): +async def forget_send_otp(): """ POST /forget/otp - Verify the image captcha stored at captcha:{email} (case-insensitive). - On success, generate an email OTP (A–Z with length = OTP_LENGTH), store hash + salt (and timestamp) in Redis with TTL, reset attempts and cooldown, and send the OTP via email. """ - req = request.get_json() + req = await get_request_json() email = req.get("email") or "" captcha = (req.get("captcha") or "").strip() @@ -917,47 +927,45 @@ def forget_send_otp(): ttl_min = OTP_TTL_SECONDS // 60 - if not smtp_mail_server: - logging.warning("SMTP mail server not initialized; skip sending email.") - else: - try: - send_email_html( - subject="Your Password Reset Code", - to_email=email, - template_key="reset_code", - code=otp, - ttl_min=ttl_min, - ) - except Exception: - return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email") - + try: + await send_email_html( + subject="Your Password Reset Code", + to_email=email, + template_key="reset_code", + code=otp, + ttl_min=ttl_min, + ) + + except Exception as e: + logging.exception(e) + return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to send email") + return get_json_result(data=True, code=RetCode.SUCCESS, message="verification passed, email sent") -@manager.route("/forget", methods=["POST"]) # noqa: F821 -def forget(): +def _verified_key(email: str) -> str: + return f"otp:verified:{email}" + + +@manager.route("/forget/verify-otp", methods=["POST"]) # noqa: F821 +async def forget_verify_otp(): """ - POST: Verify email + OTP and reset password, then log the user in. - Request JSON: { email, otp, new_password, confirm_new_password } + Verify email + OTP only. On success: + - consume the OTP and attempt counters + - set a short-lived verified flag in Redis for the email + Request JSON: { email, otp } """ - req = request.get_json() + req = await get_request_json() email = req.get("email") or "" otp = (req.get("otp") or "").strip() - new_pwd = req.get("new_password") - new_pwd2 = req.get("confirm_new_password") - - if not all([email, otp, new_pwd, new_pwd2]): - return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email, otp and passwords are required") - # For reset, passwords are provided as-is (no decrypt needed) - if new_pwd != new_pwd2: - return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="passwords do not match") + if not all([email, otp]): + return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and otp are required") users = UserService.query(email=email) if not users: return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email") - user = users[0] # Verify OTP from Redis k_code, k_attempts, k_last, k_lock = otp_keys(email) if REDIS_CONN.get(k_lock): @@ -973,7 +981,6 @@ def forget(): except Exception: return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="otp storage corrupted") - # Case-insensitive verification: OTP generated uppercase calc = hash_code(otp.upper(), salt) if calc != stored_hash: # bump attempts @@ -986,23 +993,70 @@ def forget(): REDIS_CONN.set(k_lock, int(time.time()), ATTEMPT_LOCK_SECONDS) return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="expired otp") - # Success: consume OTP and reset password + # Success: consume OTP and attempts; mark verified REDIS_CONN.delete(k_code) REDIS_CONN.delete(k_attempts) REDIS_CONN.delete(k_last) REDIS_CONN.delete(k_lock) + # set verified flag with limited TTL, reuse OTP_TTL_SECONDS or smaller window + try: + REDIS_CONN.set(_verified_key(email), "1", OTP_TTL_SECONDS) + except Exception: + return get_json_result(data=False, code=RetCode.SERVER_ERROR, message="failed to set verification state") + + return get_json_result(data=True, code=RetCode.SUCCESS, message="otp verified") + + +@manager.route("/forget/reset-password", methods=["POST"]) # noqa: F821 +async def forget_reset_password(): + """ + Reset password after successful OTP verification. + Requires: { email, new_password, confirm_new_password } + Steps: + - check verified flag in Redis + - update user password + - auto login + - clear verified flag + """ + + req = await get_request_json() + email = req.get("email") or "" + new_pwd = req.get("new_password") + new_pwd2 = req.get("confirm_new_password") + + new_pwd_base64 = decrypt(new_pwd) + new_pwd_string = base64.b64decode(new_pwd_base64).decode('utf-8') + new_pwd2_string = base64.b64decode(decrypt(new_pwd2)).decode('utf-8') + + REDIS_CONN.get(_verified_key(email)) + if not REDIS_CONN.get(_verified_key(email)): + return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="email not verified") + + if not all([email, new_pwd, new_pwd2]): + return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and passwords are required") + + if new_pwd_string != new_pwd2_string: + return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="passwords do not match") + + users = UserService.query_user_by_email(email=email) + if not users: + return get_json_result(data=False, code=RetCode.DATA_ERROR, message="invalid email") + + user = users[0] try: - UserService.update_user_password(user.id, new_pwd) + UserService.update_user_password(user.id, new_pwd_base64) except Exception as e: logging.exception(e) return get_json_result(data=False, code=RetCode.EXCEPTION_ERROR, message="failed to reset password") - # Auto login (reuse login flow) - user.access_token = get_uuid() - login_user(user) - user.update_time = (current_timestamp(),) - user.update_date = (datetime_format(datetime.now()),) - user.save() + # clear verified flag + try: + REDIS_CONN.delete(_verified_key(email)) + except Exception: + pass + msg = "Password reset successful. Logged in." - return construct_response(data=user.to_json(), auth=user.get_id(), message=msg) + return await construct_response(data=user.to_json(), auth=user.get_id(), message=msg) + + diff --git a/api/constants.py b/api/constants.py index 464b7d8e669..9edaa844c0f 100644 --- a/api/constants.py +++ b/api/constants.py @@ -24,3 +24,5 @@ DATASET_NAME_LIMIT = 128 FILE_NAME_LEN_LIMIT = 255 +MEMORY_NAME_LIMIT = 128 +MEMORY_SIZE_LIMIT = 10*1024*1024 # Byte diff --git a/api/db/db_models.py b/api/db/db_models.py index 68bf37ce4c6..738e26a06ac 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -25,7 +25,7 @@ from enum import Enum from functools import wraps -from flask_login import UserMixin +from quart_auth import AuthUser from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from peewee import InterfaceError, OperationalError, BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate @@ -305,6 +305,7 @@ def begin(self): time.sleep(self.retry_delay * (2 ** attempt)) else: raise + return None class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase): @@ -594,7 +595,7 @@ def fill_db_model_object(model_object, human_model_dict): return model_object -class User(DataBaseModel, UserMixin): +class User(DataBaseModel, AuthUser): id = CharField(max_length=32, primary_key=True) access_token = CharField(max_length=255, null=True, index=True) nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True) @@ -748,7 +749,7 @@ class Knowledgebase(DataBaseModel): parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True) pipeline_id = CharField(max_length=32, null=True, help_text="Pipeline ID", index=True) - parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) + parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0}) pagerank = IntegerField(default=0, index=False) graphrag_task_id = CharField(max_length=32, null=True, help_text="Graph RAG task ID", index=True) @@ -772,8 +773,8 @@ class Document(DataBaseModel): thumbnail = TextField(null=True, help_text="thumbnail base64 string") kb_id = CharField(max_length=256, null=False, index=True) parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True) - pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True) - parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]}) + pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True) + parser_config = JSONField(null=False, default={"pages": [[1, 1000000]], "table_context_size": 0, "image_context_size": 0}) source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True) type = CharField(max_length=32, null=False, help_text="file extension", index=True) created_by = CharField(max_length=32, null=False, help_text="who created it", index=True) @@ -876,7 +877,7 @@ class Meta: class Conversation(DataBaseModel): id = CharField(max_length=32, primary_key=True) dialog_id = CharField(max_length=32, null=False, index=True) - name = CharField(max_length=255, null=True, help_text="converastion name", index=True) + name = CharField(max_length=255, null=True, help_text="conversation name", index=True) message = JSONField(null=True) reference = JSONField(null=True, default=[]) user_id = CharField(max_length=255, null=True, help_text="user_id", index=True) @@ -1112,6 +1113,91 @@ class Meta: db_table = "sync_logs" +class EvaluationDataset(DataBaseModel): + """Ground truth dataset for RAG evaluation""" + id = CharField(max_length=32, primary_key=True) + tenant_id = CharField(max_length=32, null=False, index=True, help_text="tenant ID") + name = CharField(max_length=255, null=False, index=True, help_text="dataset name") + description = TextField(null=True, help_text="dataset description") + kb_ids = JSONField(null=False, help_text="knowledge base IDs to evaluate against") + created_by = CharField(max_length=32, null=False, index=True, help_text="creator user ID") + create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp") + update_time = BigIntegerField(null=False, help_text="last update timestamp") + status = IntegerField(null=False, default=1, help_text="1=valid, 0=invalid") + + class Meta: + db_table = "evaluation_datasets" + + +class EvaluationCase(DataBaseModel): + """Individual test case in an evaluation dataset""" + id = CharField(max_length=32, primary_key=True) + dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets") + question = TextField(null=False, help_text="test question") + reference_answer = TextField(null=True, help_text="optional ground truth answer") + relevant_doc_ids = JSONField(null=True, help_text="expected relevant document IDs") + relevant_chunk_ids = JSONField(null=True, help_text="expected relevant chunk IDs") + metadata = JSONField(null=True, help_text="additional context/tags") + create_time = BigIntegerField(null=False, help_text="creation timestamp") + + class Meta: + db_table = "evaluation_cases" + + +class EvaluationRun(DataBaseModel): + """A single evaluation run""" + id = CharField(max_length=32, primary_key=True) + dataset_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_datasets") + dialog_id = CharField(max_length=32, null=False, index=True, help_text="dialog configuration being evaluated") + name = CharField(max_length=255, null=False, help_text="run name") + config_snapshot = JSONField(null=False, help_text="dialog config at time of evaluation") + metrics_summary = JSONField(null=True, help_text="aggregated metrics") + status = CharField(max_length=32, null=False, default="PENDING", help_text="PENDING/RUNNING/COMPLETED/FAILED") + created_by = CharField(max_length=32, null=False, index=True, help_text="user who started the run") + create_time = BigIntegerField(null=False, index=True, help_text="creation timestamp") + complete_time = BigIntegerField(null=True, help_text="completion timestamp") + + class Meta: + db_table = "evaluation_runs" + + +class EvaluationResult(DataBaseModel): + """Result for a single test case in an evaluation run""" + id = CharField(max_length=32, primary_key=True) + run_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_runs") + case_id = CharField(max_length=32, null=False, index=True, help_text="FK to evaluation_cases") + generated_answer = TextField(null=False, help_text="generated answer") + retrieved_chunks = JSONField(null=False, help_text="chunks that were retrieved") + metrics = JSONField(null=False, help_text="all computed metrics") + execution_time = FloatField(null=False, help_text="response time in seconds") + token_usage = JSONField(null=True, help_text="prompt/completion tokens") + create_time = BigIntegerField(null=False, help_text="creation timestamp") + + class Meta: + db_table = "evaluation_results" + + +class Memory(DataBaseModel): + id = CharField(max_length=32, primary_key=True) + name = CharField(max_length=128, null=False, index=False, help_text="Memory name") + avatar = TextField(null=True, help_text="avatar base64 string") + tenant_id = CharField(max_length=32, null=False, index=True) + memory_type = IntegerField(null=False, default=1, index=True, help_text="Bit flags (LSB->MSB): 1=raw, 2=semantic, 4=episodic, 8=procedural. E.g., 5 enables raw + episodic.") + storage_type = CharField(max_length=32, default='table', null=False, index=True, help_text="table|graph") + embd_id = CharField(max_length=128, null=False, index=False, help_text="embedding model ID") + llm_id = CharField(max_length=128, null=False, index=False, help_text="chat model ID") + permissions = CharField(max_length=16, null=False, index=True, help_text="me|team", default="me") + description = TextField(null=True, help_text="description") + memory_size = IntegerField(default=5242880, null=False, index=False) + forgetting_policy = CharField(max_length=32, null=False, default="FIFO", index=False, help_text="LRU|FIFO") + temperature = FloatField(default=0.5, index=False) + system_prompt = TextField(null=True, help_text="system prompt", index=False) + user_prompt = TextField(null=True, help_text="user prompt", index=False) + + class Meta: + db_table = "memory" + + def migrate_db(): logging.disable(logging.ERROR) migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) @@ -1292,4 +1378,43 @@ def migrate_db(): migrate(migrator.add_column("llm_factories", "rank", IntegerField(default=0, index=False))) except Exception: pass + + # RAG Evaluation tables + try: + migrate(migrator.add_column("evaluation_datasets", "id", CharField(max_length=32, primary_key=True))) + except Exception: + pass + try: + migrate(migrator.add_column("evaluation_datasets", "tenant_id", CharField(max_length=32, null=False, index=True))) + except Exception: + pass + try: + migrate(migrator.add_column("evaluation_datasets", "name", CharField(max_length=255, null=False, index=True))) + except Exception: + pass + try: + migrate(migrator.add_column("evaluation_datasets", "description", TextField(null=True))) + except Exception: + pass + try: + migrate(migrator.add_column("evaluation_datasets", "kb_ids", JSONField(null=False))) + except Exception: + pass + try: + migrate(migrator.add_column("evaluation_datasets", "created_by", CharField(max_length=32, null=False, index=True))) + except Exception: + pass + try: + migrate(migrator.add_column("evaluation_datasets", "create_time", BigIntegerField(null=False, index=True))) + except Exception: + pass + try: + migrate(migrator.add_column("evaluation_datasets", "update_time", BigIntegerField(null=False))) + except Exception: + pass + try: + migrate(migrator.add_column("evaluation_datasets", "status", IntegerField(null=False, default=1))) + except Exception: + pass + logging.disable(logging.NOTSET) diff --git a/api/db/init_data.py b/api/db/init_data.py index 4a9ad067afd..77f676f0962 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import json import os @@ -29,19 +30,23 @@ from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm from api.db.services.user_service import TenantService, UserTenantService +from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache from common.constants import LLMType from common.file_utils import get_project_base_directory from common import settings from api.common.base64 import encode_to_base64 +DEFAULT_SUPERUSER_NICKNAME = os.getenv("DEFAULT_SUPERUSER_NICKNAME", "admin") +DEFAULT_SUPERUSER_EMAIL = os.getenv("DEFAULT_SUPERUSER_EMAIL", "admin@ragflow.io") +DEFAULT_SUPERUSER_PASSWORD = os.getenv("DEFAULT_SUPERUSER_PASSWORD", "admin") -def init_superuser(): +def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_EMAIL, password=DEFAULT_SUPERUSER_PASSWORD, role=UserTenantRole.OWNER): user_info = { "id": uuid.uuid1().hex, - "password": encode_to_base64("admin"), - "nickname": "admin", + "password": encode_to_base64(password), + "nickname": nickname, "is_superuser": True, - "email": "admin@ragflow.io", + "email": email, "creator": "system", "status": "1", } @@ -58,7 +63,7 @@ def init_superuser(): "tenant_id": user_info["id"], "user_id": user_info["id"], "invited_by": user_info["id"], - "role": UserTenantRole.OWNER + "role": role } tenant_llm = get_init_tenant_llm(user_info["id"]) @@ -70,11 +75,10 @@ def init_superuser(): UserTenantService.insert(**usr_tenant) TenantLLMService.insert_many(tenant_llm) logging.info( - "Super user initialized. email: admin@ragflow.io, password: admin. Changing the password after login is strongly recommended.") + f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.") chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) - msg = chat_mdl.chat(system="", history=[ - {"role": "user", "content": "Hello!"}], gen_conf={}) + msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})) if msg.find("ERROR: ") == 0: logging.error( "'{}' doesn't work. {}".format( @@ -166,6 +170,8 @@ def init_web_data(): # init_superuser() add_graph_templates() + init_message_id_sequence() + init_memory_size_cache() logging.info("init web data success:{}".format(time.time() - start_time)) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py new file mode 100644 index 00000000000..79848cad5c3 --- /dev/null +++ b/api/db/joint_services/memory_message_service.py @@ -0,0 +1,389 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from typing import List + +from api.db.services.task_service import TaskService +from common import settings +from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms +from common.constants import MemoryType, LLMType +from common.doc_store.doc_store_base import FusionExpr +from common.misc_utils import get_uuid +from api.db.db_utils import bulk_insert_into_db +from api.db.db_models import Task +from api.db.services.memory_service import MemoryService +from api.db.services.tenant_llm_service import TenantLLMService +from api.db.services.llm_service import LLMBundle +from api.utils.memory_utils import get_memory_type_human +from memory.services.messages import MessageService +from memory.services.query import MsgTextQuery, get_vector +from memory.utils.prompt_util import PromptAssembler +from memory.utils.msg_util import get_json_result_from_llm_response +from rag.utils.redis_conn import REDIS_CONN + + +async def save_to_memory(memory_id: str, message_dict: dict): + """ + :param memory_id: + :param message_dict: { + "user_id": str, + "agent_id": str, + "session_id": str, + "user_input": str, + "agent_response": str + } + """ + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return False, f"Memory '{memory_id}' not found." + + tenant_id = memory.tenant_id + extracted_content = await extract_by_llm( + tenant_id, + memory.llm_id, + {"temperature": memory.temperature}, + get_memory_type_human(memory.memory_type), + message_dict.get("user_input", ""), + message_dict.get("agent_response", "") + ) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract + raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory") + message_list = [{ + "message_id": raw_message_id, + "message_type": MemoryType.RAW.name.lower(), + "source_id": 0, + "memory_id": memory_id, + "user_id": "", + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}", + "valid_at": timestamp_to_date(current_timestamp()), + "invalid_at": None, + "forget_at": None, + "status": True + }, *[{ + "message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"), + "message_type": content["message_type"], + "source_id": raw_message_id, + "memory_id": memory_id, + "user_id": "", + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": content["content"], + "valid_at": content["valid_at"], + "invalid_at": content["invalid_at"] if content["invalid_at"] else None, + "forget_at": None, + "status": True + } for content in extracted_content]] + return await embed_and_save(memory, message_list) + + +async def save_extracted_to_memory_only(memory_id: str, message_dict, source_message_id: int): + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + return False, f"Memory '{memory_id}' not found." + + if memory.memory_type == MemoryType.RAW.value: + return True, f"Memory '{memory_id}' don't need to extract." + + tenant_id = memory.tenant_id + extracted_content = await extract_by_llm( + tenant_id, + memory.llm_id, + {"temperature": memory.temperature}, + get_memory_type_human(memory.memory_type), + message_dict.get("user_input", ""), + message_dict.get("agent_response", "") + ) + message_list = [{ + "message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"), + "message_type": content["message_type"], + "source_id": source_message_id, + "memory_id": memory_id, + "user_id": "", + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": content["content"], + "valid_at": content["valid_at"], + "invalid_at": content["invalid_at"] if content["invalid_at"] else None, + "forget_at": None, + "status": True + } for content in extracted_content] + if not message_list: + return True, "No memory extracted from raw message." + + return await embed_and_save(memory, message_list) + + +async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str, + agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]: + llm_type = TenantLLMService.llm_id2llm_type(llm_id) + if not llm_type: + raise RuntimeError(f"Unknown type of LLM '{llm_id}'") + if not system_prompt: + system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type}) + conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}" + conversation_time = timestamp_to_date(current_timestamp()) + user_prompts = [] + if user_prompt: + user_prompts.append({"role": "user", "content": user_prompt}) + user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"}) + else: + user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)}) + llm = LLMBundle(tenant_id, llm_type, llm_id) + res = await llm.async_chat(system_prompt, user_prompts, extract_conf) + res_json = get_json_result_from_llm_response(res) + return [{ + "content": extracted_content["content"], + "valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]), + "invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "", + "message_type": message_type + } for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list] + + +async def embed_and_save(memory, message_list: list[dict]): + embedding_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id) + vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list]) + for idx, msg in enumerate(message_list): + msg["content_embed"] = vector_list[idx] + vector_dimension = len(vector_list[0]) + if not MessageService.has_index(memory.tenant_id, memory.id): + created = MessageService.create_index(memory.tenant_id, memory.id, vector_size=vector_dimension) + if not created: + return False, "Failed to create message index." + + new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list]) + current_memory_size = get_memory_size_cache(memory.tenant_id, memory.id) + if new_msg_size + current_memory_size > memory.memory_size: + size_to_delete = current_memory_size + new_msg_size - memory.memory_size + if memory.forgetting_policy == "FIFO": + message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory.id, memory.tenant_id, + size_to_delete) + MessageService.delete_message({"message_id": message_ids_to_delete}, memory.tenant_id, memory.id) + decrease_memory_size_cache(memory.id, delete_size) + else: + return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete." + fail_cases = MessageService.insert_message(message_list, memory.tenant_id, memory.id) + if fail_cases: + return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases) + + increase_memory_size_cache(memory.id, new_msg_size) + return True, "Message saved successfully." + + +def query_message(filter_dict: dict, params: dict): + """ + :param filter_dict: { + "memory_id": List[str], + "agent_id": optional + "session_id": optional + } + :param params: { + "query": question str, + "similarity_threshold": float, + "keywords_similarity_weight": float, + "top_n": int + } + """ + memory_ids = filter_dict["memory_id"] + memory_list = MemoryService.get_by_ids(memory_ids) + if not memory_list: + return [] + + condition_dict = {k: v for k, v in filter_dict.items() if v} + uids = [memory.tenant_id for memory in memory_list] + + question = params["query"] + question = question.strip() + memory = memory_list[0] + embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id) + match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"]) + match_text, _ = MsgTextQuery().question(question, min_match=params["similarity_threshold"]) + keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7) + fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(1 - keywords_similarity_weight), str(keywords_similarity_weight)])}) + + return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"]) + + +def init_message_id_sequence(): + message_id_redis_key = "id_generator:memory" + if REDIS_CONN.exist(message_id_redis_key): + current_max_id = REDIS_CONN.get(message_id_redis_key) + logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.") + else: + max_id = 1 + exist_memory_list = MemoryService.get_all_memory() + if not exist_memory_list: + REDIS_CONN.set(message_id_redis_key, max_id) + else: + max_id = MessageService.get_max_message_id( + uid_list=[m.tenant_id for m in exist_memory_list], + memory_ids=[m.id for m in exist_memory_list] + ) + REDIS_CONN.set(message_id_redis_key, max_id) + logging.info(f"Init message_id sequence done, current max id is {max_id}.") + + +def get_memory_size_cache(memory_id: str, uid: str): + redis_key = f"memory_{memory_id}" + if REDIS_CONN.exist(redis_key): + return int(REDIS_CONN.get(redis_key)) + else: + memory_size_map = MessageService.calculate_memory_size( + [memory_id], + [uid] + ) + memory_size = memory_size_map.get(memory_id, 0) + set_memory_size_cache(memory_id, memory_size) + return memory_size + + +def set_memory_size_cache(memory_id: str, size: int): + redis_key = f"memory_{memory_id}" + return REDIS_CONN.set(redis_key, size) + + +def increase_memory_size_cache(memory_id: str, size: int): + redis_key = f"memory_{memory_id}" + return REDIS_CONN.incrby(redis_key, size) + + +def decrease_memory_size_cache(memory_id: str, size: int): + redis_key = f"memory_{memory_id}" + return REDIS_CONN.decrby(redis_key, size) + + +def init_memory_size_cache(): + memory_list = MemoryService.get_all_memory() + if not memory_list: + logging.info("No memory found, no need to init memory size.") + else: + for m in memory_list: + get_memory_size_cache(m.id, m.tenant_id) + logging.info("Memory size cache init done.") + + +def judge_system_prompt_is_default(system_prompt: str, memory_type: int|list[str]): + memory_type_list = memory_type if isinstance(memory_type, list) else get_memory_type_human(memory_type) + return system_prompt == PromptAssembler.assemble_system_prompt({"memory_type": memory_type_list}) + + +async def queue_save_to_memory_task(memory_ids: list[str], message_dict: dict): + """ + :param memory_ids: + :param message_dict: { + "user_id": str, + "agent_id": str, + "session_id": str, + "user_input": str, + "agent_response": str + } + """ + def new_task(_memory_id: str, _source_id: int): + return { + "id": get_uuid(), + "doc_id": _memory_id, + "task_type": "memory", + "progress": 0.0, + "digest": str(_source_id) + } + + not_found_memory = [] + failed_memory = [] + for memory_id in memory_ids: + memory = MemoryService.get_by_memory_id(memory_id) + if not memory: + not_found_memory.append(memory_id) + continue + + raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory") + raw_message = { + "message_id": raw_message_id, + "message_type": MemoryType.RAW.name.lower(), + "source_id": 0, + "memory_id": memory_id, + "user_id": "", + "agent_id": message_dict["agent_id"], + "session_id": message_dict["session_id"], + "content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}", + "valid_at": timestamp_to_date(current_timestamp()), + "invalid_at": None, + "forget_at": None, + "status": True + } + res, msg = await embed_and_save(memory, [raw_message]) + if not res: + failed_memory.append({"memory_id": memory_id, "fail_msg": msg}) + continue + + task = new_task(memory_id, raw_message_id) + bulk_insert_into_db(Task, [task], replace_on_conflict=True) + task_message = { + "id": task["id"], + "task_id": task["id"], + "task_type": task["task_type"], + "memory_id": memory_id, + "source_id": raw_message_id, + "message_dict": message_dict + } + if not REDIS_CONN.queue_product(settings.get_svr_queue_name(priority=0), message=task_message): + failed_memory.append({"memory_id": memory_id, "fail_msg": "Can't access Redis."}) + + error_msg = "" + if not_found_memory: + error_msg = f"Memory {not_found_memory} not found." + if failed_memory: + error_msg += "".join([f"Memory {fm['memory_id']} failed. Detail: {fm['fail_msg']}" for fm in failed_memory]) + + if error_msg: + return False, error_msg + + return True, "All add to task." + + +async def handle_save_to_memory_task(task_param: dict): + """ + :param task_param: { + "id": task_id + "memory_id": id + "source_id": id + "message_dict": { + "user_id": str, + "agent_id": str, + "session_id": str, + "user_input": str, + "agent_response": str + } + } + """ + _, task = TaskService.get_by_id(task_param["id"]) + if not task: + return False, f"Task {task_param['id']} is not found." + if task.progress == -1: + return False, f"Task {task_param['id']} is already failed." + now_time = current_timestamp() + TaskService.update_by_id(task_param["id"], {"begin_at": timestamp_to_date(now_time)}) + + memory_id = task_param["memory_id"] + source_id = task_param["source_id"] + message_dict = task_param["message_dict"] + success, msg = await save_extracted_to_memory_only(memory_id, message_dict, source_id) + if success: + TaskService.update_progress(task.id, {"progress": 1.0, "progress_msg": msg}) + return True, msg + + logging.error(msg) + TaskService.update_progress(task.id, {"progress": -1, "progress_msg": None}) + return False, msg diff --git a/api/db/joint_services/user_account_service.py b/api/db/joint_services/user_account_service.py index 34ceee64818..2e4dfeaab23 100644 --- a/api/db/joint_services/user_account_service.py +++ b/api/db/joint_services/user_account_service.py @@ -34,6 +34,8 @@ from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_canvas_version import UserCanvasVersionService from api.db.services.user_service import TenantService, UserService, UserTenantService +from api.db.services.memory_service import MemoryService +from memory.services.messages import MessageService from rag.nlp import search from common.constants import ActiveEnum from common import settings @@ -153,7 +155,7 @@ def delete_user_data(user_id: str) -> dict: done_msg += "Start to delete owned tenant.\n" tenant_id = owned_tenant[0]["tenant_id"] kb_ids = KnowledgebaseService.get_kb_ids(usr.id) - # step1.1 delete knowledgebase related file and info + # step1.1 delete dataset related file and info if kb_ids: # step1.1.1 delete files in storage, remove bucket for kb_id in kb_ids: @@ -182,7 +184,7 @@ def delete_user_data(user_id: str) -> dict: search.index_name(tenant_id), kb_ids) done_msg += f"- Deleted {r} chunk records.\n" kb_delete_res = KnowledgebaseService.delete_by_ids(kb_ids) - done_msg += f"- Deleted {kb_delete_res} knowledgebase records.\n" + done_msg += f"- Deleted {kb_delete_res} dataset records.\n" # step1.1.4 delete agents agent_delete_res = delete_user_agents(usr.id) done_msg += f"- Deleted {agent_delete_res['agents_deleted_count']} agent, {agent_delete_res['version_deleted_count']} versions records.\n" @@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict: done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n" langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id) done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n" - # step1.3 delete own tenant + # step1.3 delete memory and messages + user_memory = MemoryService.get_by_tenant_id(tenant_id) + if user_memory: + for memory in user_memory: + if MessageService.has_index(tenant_id, memory.id): + MessageService.delete_index(tenant_id, memory.id) + done_msg += " Deleted memory index." + memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory]) + done_msg += f"Deleted {memory_delete_res} memory datasets." + # step1.4 delete own tenant tenant_delete_res = TenantService.delete_by_id(tenant_id) done_msg += f"- Deleted {tenant_delete_res} tenant.\n" # step2 delete user-tenant relation @@ -258,7 +269,7 @@ def delete_user_data(user_id: str) -> dict: # step2.1.5 delete document record doc_delete_res = DocumentService.delete_by_ids([d['id'] for d in created_documents]) done_msg += f"- Deleted {doc_delete_res} documents.\n" - # step2.1.6 update knowledge base doc&chunk&token cnt + # step2.1.6 update dataset doc&chunk&token cnt for kb_id, doc_num in kb_doc_info.items(): KnowledgebaseService.decrease_document_num_in_delete(kb_id, doc_num) @@ -273,7 +284,7 @@ def delete_user_data(user_id: str) -> dict: except Exception as e: logging.exception(e) - return {"success": False, "message": f"Error: {str(e)}. Already done:\n{done_msg}"} + return {"success": False, "message": "An internal error occurred during user deletion. Some operations may have completed.","details": done_msg} def delete_user_agents(user_id: str) -> dict: diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 5a0f82c2b00..763e9c4601e 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -123,6 +123,19 @@ def get_by_canvas_id(cls, pid): logging.exception(e) return False, None + @classmethod + @DB.connection_context() + def get_basic_info_by_canvas_ids(cls, canvas_id): + fields = [ + cls.model.id, + cls.model.avatar, + cls.model.user_id, + cls.model.title, + cls.model.permission, + cls.model.canvas_category + ] + return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts() + @classmethod @DB.connection_context() def get_by_tenant_ids(cls, joined_tenant_ids, user_id, @@ -177,7 +190,7 @@ def accessible(cls, canvas_id, tenant_id): return True -def completion(tenant_id, agent_id, session_id=None, **kwargs): +async def completion(tenant_id, agent_id, session_id=None, **kwargs): query = kwargs.get("query", "") or kwargs.get("question", "") files = kwargs.get("files", []) inputs = kwargs.get("inputs", {}) @@ -198,7 +211,7 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs): if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) session_id=get_uuid() - canvas = Canvas(cvs.dsl, tenant_id, agent_id) + canvas = Canvas(cvs.dsl, tenant_id, agent_id, canvas_id=cvs.id) canvas.reset() conv = { "id": session_id, @@ -219,10 +232,14 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs): "id": message_id }) txt = "" - for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): + async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): ans["session_id"] = session_id if ans["event"] == "message": txt += ans["data"]["content"] + if ans["data"].get("start_to_think", False): + txt += "" + elif ans["data"].get("end_to_think", False): + txt += "" yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id}) @@ -233,7 +250,7 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs): API4ConversationService.append_message(conv["id"], conv) -def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): +async def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): tiktoken_encoder = tiktoken.get_encoding("cl100k_base") prompt_tokens = len(tiktoken_encoder.encode(str(question))) user_id = kwargs.get("user_id", "") @@ -241,7 +258,7 @@ def completion_openai(tenant_id, agent_id, question, session_id=None, stream=Tru if stream: completion_tokens = 0 try: - for ans in completion( + async for ans in completion( tenant_id=tenant_id, agent_id=agent_id, session_id=session_id, @@ -300,7 +317,7 @@ def completion_openai(tenant_id, agent_id, question, session_id=None, stream=Tru try: all_content = "" reference = {} - for ans in completion( + async for ans in completion( tenant_id=tenant_id, agent_id=agent_id, session_id=session_id, diff --git a/api/db/services/common_service.py b/api/db/services/common_service.py index 5b906b5a861..60db241cc8e 100644 --- a/api/db/services/common_service.py +++ b/api/db/services/common_service.py @@ -169,10 +169,12 @@ def insert(cls, **kwargs): """ if "id" not in kwargs: kwargs["id"] = get_uuid() - kwargs["create_time"] = current_timestamp() - kwargs["create_date"] = datetime_format(datetime.now()) - kwargs["update_time"] = current_timestamp() - kwargs["update_date"] = datetime_format(datetime.now()) + timestamp = current_timestamp() + cur_datetime = datetime_format(datetime.now()) + kwargs["create_time"] = timestamp + kwargs["create_date"] = cur_datetime + kwargs["update_time"] = timestamp + kwargs["update_date"] = cur_datetime sample_obj = cls.model(**kwargs).save(force_insert=True) return sample_obj @@ -207,10 +209,14 @@ def update_many_by_id(cls, data_list): data_list (list): List of dictionaries containing record data to update. Each dictionary must include an 'id' field. """ + + timestamp = current_timestamp() + cur_datetime = datetime_format(datetime.now()) + for data in data_list: + data["update_time"] = timestamp + data["update_date"] = cur_datetime with DB.atomic(): for data in data_list: - data["update_time"] = current_timestamp() - data["update_date"] = datetime_format(datetime.now()) cls.model.update(data).where(cls.model.id == data["id"]).execute() @classmethod diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 2ff16669d3a..660530c824b 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -15,6 +15,7 @@ # import logging from datetime import datetime +import os from typing import Tuple, List from anthropic import BaseModel @@ -24,7 +25,6 @@ from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService -from api.db.services.file_service import FileService from common.misc_utils import get_uuid from common.constants import TaskStatus from common.time_utils import current_timestamp, timestamp_to_date @@ -68,9 +68,10 @@ def list(cls, tenant_id): @classmethod def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str): + from api.db.services.file_service import FileService e, conn = cls.get_by_id(connector_id) if not e: - return + return None SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id]) docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id) err = FileService.delete_docs([d.id for d in docs], tenant_id) @@ -103,7 +104,8 @@ def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) Knowledgebase.avatar.alias("kb_avatar"), Connector2Kb.auto_parse, cls.model.from_beginning.alias("reindex"), - cls.model.status + cls.model.status, + cls.model.update_time ] if not connector_id: fields.append(Connector.config) @@ -116,7 +118,11 @@ def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) if connector_id: query = query.where(cls.model.connector_id == connector_id) else: - interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE") + database_type = os.getenv("DB_TYPE", "mysql") + if "postgres" in database_type.lower(): + interval_expr = SQL("make_interval(mins => t2.refresh_freq)") + else: + interval_expr = SQL("INTERVAL `t2`.`refresh_freq` MINUTE") query = query.where( Connector.input_type == InputType.POLL, Connector.status == TaskStatus.SCHEDULE, @@ -125,11 +131,11 @@ def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) ) query = query.distinct().order_by(cls.model.update_time.desc()) - totbal = query.count() + total = query.count() if page_number: query = query.paginate(page_number, items_per_page) - return list(query.dicts()), totbal + return list(query.dicts()), total @classmethod def start(cls, id, connector_id): @@ -191,6 +197,7 @@ def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_co @classmethod def duplicate_and_parse(cls, kb, docs, tenant_id, src, auto_parse=True): + from api.db.services.file_service import FileService if not docs: return None @@ -207,9 +214,21 @@ def read(self) -> bytes: err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src) errs.extend(err) + # Create a mapping from filename to metadata for later use + metadata_map = {} + for d in docs: + if d.get("metadata"): + filename = d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else "") + metadata_map[filename] = d["metadata"] + kb_table_num_map = {} for doc, _ in doc_blob_pairs: doc_ids.append(doc["id"]) + + # Set metadata if available for this document + if doc["name"] in metadata_map: + DocumentService.update_by_id(doc["id"], {"meta_fields": metadata_map[doc["name"]]}) + if not auto_parse or auto_parse == "0": continue DocumentService.run(tenant_id, doc, kb_table_num_map) @@ -242,7 +261,7 @@ def link_connectors(cls, kb_id:str, connectors: list[dict], tenant_id:str): "id": get_uuid(), "connector_id": conn_id, "kb_id": kb_id, - "auto_parse": conn.get("auto_parse", "1") + "auto_parse": conn.get("auto_parse", "1") }) SyncLogsService.schedule(conn_id, kb_id, reindex=True) diff --git a/api/db/services/conversation_service.py b/api/db/services/conversation_service.py index 60f8e55b1cd..2a5b06601dc 100644 --- a/api/db/services/conversation_service.py +++ b/api/db/services/conversation_service.py @@ -19,7 +19,7 @@ from api.db.db_models import Conversation, DB from api.db.services.api_service import API4ConversationService from api.db.services.common_service import CommonService -from api.db.services.dialog_service import DialogService, chat +from api.db.services.dialog_service import DialogService, async_chat from common.misc_utils import get_uuid import json @@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id): conv.reference[-1] = reference return ans - -def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs): +async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs): assert name, "`name` can not be empty." dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) assert dia, "You do not own the chat." @@ -112,11 +111,21 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None "reference": {}, "audio_binary": None, "id": None, - "session_id": session_id + "session_id": session_id }}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" return + else: + answer = { + "answer": conv["message"][0]["content"], + "reference": {}, + "audio_binary": None, + "id": None, + "session_id": session_id + } + yield answer + return conv = ConversationService.query(id=session_id, dialog_id=chat_id) if not conv: @@ -148,7 +157,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None if stream: try: - for ans in chat(dia, msg, True, **kwargs): + async for ans in async_chat(dia, msg, True, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n" ConversationService.update_by_id(conv.id, conv.to_dict()) @@ -160,14 +169,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None else: answer = None - for ans in chat(dia, msg, False, **kwargs): + async for ans in async_chat(dia, msg, False, **kwargs): answer = structure_answer(conv, ans, message_id, session_id) ConversationService.update_by_id(conv.id, conv.to_dict()) break yield answer - -def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs): +async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs): e, dia = DialogService.get_by_id(dialog_id) assert e, "Dialog not found" if not session_id: @@ -222,7 +230,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg if stream: try: - for ans in chat(dia, msg, True, **kwargs): + async for ans in async_chat(dia, msg, True, **kwargs): ans = structure_answer(conv, ans, message_id, session_id) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" @@ -235,7 +243,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg else: answer = None - for ans in chat(dia, msg, False, **kwargs): + async for ans in async_chat(dia, msg, False, **kwargs): answer = structure_answer(conv, ans, message_id, session_id) API4ConversationService.append_message(conv.id, conv.to_dict()) break diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index f54ebf70980..4bc24210b20 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -21,10 +21,10 @@ from datetime import datetime from functools import partial from timeit import default_timer as timer -import trio from langfuse import Langfuse from peewee import fn from agentic_reasoning import DeepResearcher +from api.db.services.file_service import FileService from common.constants import LLMType, ParserType, StatusEnum from api.db.db_models import DB, Dialog from api.db.services.common_service import CommonService @@ -32,6 +32,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.llm_service import LLMBundle +from common.metadata_utils import apply_meta_data_filter from api.db.services.tenant_llm_service import TenantLLMService from common.time_utils import current_timestamp, datetime_format from graphrag.general.mind_map_extractor import MindMapExtractor @@ -39,7 +40,7 @@ from rag.app.tag import label_question from rag.nlp.search import index_name from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ - gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY + PROMPT_JINJA_ENV, ASK_SUMMARY from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces @@ -177,7 +178,11 @@ def get_all_dialogs_by_tenant_id(cls, tenant_id): offset += limit return res -def chat_solo(dialog, messages, stream=True): + +async def async_chat_solo(dialog, messages, stream=True): + attachments = "" + if "files" in messages[-1]: + attachments = "\n\n".join(FileService.get_files(messages[-1]["files"])) if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: @@ -188,10 +193,13 @@ def chat_solo(dialog, messages, stream=True): if prompt_config.get("tts"): tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"] + if attachments and msg: + msg[-1]["content"] += attachments if stream: last_ans = "" delta_ans = "" - for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): + answer = "" + async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): answer = ans delta_ans = ans[len(last_ans):] if num_tokens_from_string(delta_ans) < 16: @@ -202,7 +210,7 @@ def chat_solo(dialog, messages, stream=True): if delta_ans: yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()} else: - answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting) + answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} @@ -270,77 +278,10 @@ def replacement(match): return answer, idx -def convert_conditions(metadata_condition): - if metadata_condition is None: - metadata_condition = {} - op_mapping = { - "is": "=", - "not is": "≠" - } - return [ - { - "op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]), - "key": cond["name"], - "value": cond["value"] - } - for cond in metadata_condition.get("conditions", []) - ] - - -def meta_filter(metas: dict, filters: list[dict]): - doc_ids = set([]) - - def filter_out(v2docs, operator, value): - ids = [] - for input, docids in v2docs.items(): - if operator in ["=", "≠", ">", "<", "≥", "≤"]: - try: - input = float(input) - value = float(value) - except Exception: - input = str(input) - value = str(value) - - for conds in [ - (operator == "contains", str(value).lower() in str(input).lower()), - (operator == "not contains", str(value).lower() not in str(input).lower()), - (operator == "start with", str(input).lower().startswith(str(value).lower())), - (operator == "end with", str(input).lower().endswith(str(value).lower())), - (operator == "empty", not input), - (operator == "not empty", input), - (operator == "=", input == value), - (operator == "≠", input != value), - (operator == ">", input > value), - (operator == "<", input < value), - (operator == "≥", input >= value), - (operator == "≤", input <= value), - ]: - try: - if all(conds): - ids.extend(docids) - break - except Exception: - pass - return ids - - for k, v2docs in metas.items(): - for f in filters: - if k != f["key"]: - continue - ids = filter_out(v2docs, f["op"], f["value"]) - if not doc_ids: - doc_ids = set(ids) - else: - doc_ids = doc_ids & set(ids) - if not doc_ids: - return [] - return list(doc_ids) - - -def chat(dialog, messages, stream=True, **kwargs): +async def async_chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): - for ans in chat_solo(dialog, messages, stream): + async for ans in async_chat_solo(dialog, messages, stream): yield ans return @@ -375,15 +316,18 @@ def chat(dialog, messages, stream=True, **kwargs): retriever = settings.retriever questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] + attachments_= "" if "doc_ids" in messages[-1]: attachments = messages[-1]["doc_ids"] + if "files" in messages[-1]: + attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"])) prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) - ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) + ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) if ans: yield ans return @@ -397,27 +341,25 @@ def chat(dialog, messages, stream=True, **kwargs): prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") if len(questions) > 1 and prompt_config.get("refine_multiturn"): - questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] + questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)] else: questions = questions[-1:] if prompt_config.get("cross_languages"): - questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] + questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] if dialog.meta_data_filter: metas = DocumentService.get_meta_by_kbs(dialog.kb_ids) - if dialog.meta_data_filter.get("method") == "auto": - filters = gen_meta_filter(chat_mdl, metas, questions[-1]) - attachments.extend(meta_filter(metas, filters)) - if not attachments: - attachments = None - elif dialog.meta_data_filter.get("method") == "manual": - attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"])) - if not attachments: - attachments = None + attachments = await apply_meta_data_filter( + dialog.meta_data_filter, + metas, + questions[-1], + chat_mdl, + attachments, + ) if prompt_config.get("keyword", False): - questions[-1] += keyword_extraction(chat_mdl, questions[-1]) + questions[-1] += await keyword_extraction(chat_mdl, questions[-1]) refine_question_ts = timer() @@ -445,7 +387,7 @@ def chat(dialog, messages, stream=True, **kwargs): ), ) - for think in reasoner.thinking(kbinfos, " ".join(questions)): + async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)): if isinstance(think, str): thought = think knowledges = [t for t in think.split("\n") if t] @@ -464,7 +406,7 @@ def chat(dialog, messages, stream=True, **kwargs): dialog.vector_similarity_weight, doc_ids=attachments, top=dialog.top_k, - aggs=False, + aggs=True, rerank_mdl=rerank_mdl, rank_feature=label_question(" ".join(questions), kbs), ) @@ -472,6 +414,7 @@ def chat(dialog, messages, stream=True, **kwargs): cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) if cks: kbinfos["chunks"] = cks + kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids) if prompt_config.get("tavily_api_key"): tav = Tavily(prompt_config["tavily_api_key"]) tav_res = tav.retrieve_chunks(" ".join(questions)) @@ -492,12 +435,13 @@ def chat(dialog, messages, stream=True, **kwargs): empty_res = prompt_config["empty_response"] yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)} - return {"answer": prompt_config["empty_response"], "reference": kbinfos} + yield {"answer": prompt_config["empty_response"], "reference": kbinfos} + return kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting - msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] + msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}] prompt4citation = "" if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): prompt4citation = citation_prompt() @@ -596,7 +540,7 @@ def decorate_answer(answer): if stream: last_ans = "" answer = "" - for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): + async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): if thought: ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) answer = ans @@ -610,17 +554,19 @@ def decorate_answer(answer): yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield decorate_answer(thought + answer) else: - answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf) + answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) res = decorate_answer(answer) res["audio_binary"] = tts(tts_mdl, answer) yield res + return + -def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): +async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): sys_prompt = """ -You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. +You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. Ensure that: 1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it. 2. Write only the SQL, no explanations or additional text. @@ -636,9 +582,9 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question) tried_times = 0 - def get_table(): + async def get_table(): nonlocal sys_prompt, user_prompt, question, tried_times - sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) + sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) sql = re.sub(r"^.*", "", sql, flags=re.DOTALL) logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") sql = re.sub(r"[\r\n]+", " ", sql.lower()) @@ -664,7 +610,11 @@ def get_table(): if kb_ids: kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" if "where" not in sql.lower(): - sql += f" WHERE {kb_filter}" + o = sql.lower().split("order by") + if len(o) > 1: + sql = o[0] + f" WHERE {kb_filter} order by " + o[1] + else: + sql += f" WHERE {kb_filter}" else: sql += f" AND {kb_filter}" @@ -672,10 +622,9 @@ def get_table(): tried_times += 1 return settings.retriever.sql_retrieval(sql, format="json"), sql - tbl, sql = get_table() - if tbl is None: - return None - if tbl.get("error") and tried_times <= 2: + try: + tbl, sql = await get_table() + except Exception as e: user_prompt = """ Table name: {}; Table of database fields are as follows: @@ -689,16 +638,14 @@ def get_table(): The SQL error you provided last time is as follows: {} - Error issued by database as follows: - {} - Please correct the error and write SQL again, only SQL, without any other explanations or text. - """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"]) - tbl, sql = get_table() - logging.debug("TRY it again: {}".format(sql)) + """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e) + try: + tbl, sql = await get_table() + except Exception: + return - logging.debug("GET table: {}".format(tbl)) - if tbl.get("error") or len(tbl["rows"]) == 0: + if len(tbl["rows"]) == 0: return None docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) @@ -742,17 +689,51 @@ def get_table(): "prompt": sys_prompt, } +def clean_tts_text(text: str) -> str: + if not text: + return "" + + text = text.encode("utf-8", "ignore").decode("utf-8", "ignore") + + text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) + + emoji_pattern = re.compile( + "[\U0001F600-\U0001F64F" + "\U0001F300-\U0001F5FF" + "\U0001F680-\U0001F6FF" + "\U0001F1E0-\U0001F1FF" + "\U00002700-\U000027BF" + "\U0001F900-\U0001F9FF" + "\U0001FA70-\U0001FAFF" + "\U0001FAD0-\U0001FAFF]+", + flags=re.UNICODE + ) + text = emoji_pattern.sub("", text) + + text = re.sub(r"\s+", " ", text).strip() + + MAX_LEN = 500 + if len(text) > MAX_LEN: + text = text[:MAX_LEN] + + return text def tts(tts_mdl, text): if not tts_mdl or not text: - return + return None + text = clean_tts_text(text) + if not text: + return None bin = b"" - for chunk in tts_mdl.tts(text): - bin += chunk + try: + for chunk in tts_mdl.tts(text): + bin += chunk + except Exception as e: + logging.error(f"TTS failed: {e}, text={text!r}") + return None return binascii.hexlify(bin).decode("utf-8") - -def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): +async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None kb_ids = search_config.get("kb_ids", kb_ids) @@ -775,15 +756,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - filters = gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters)) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"])) - if not doc_ids: - doc_ids = None + doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) kbinfos = retriever.retrieval( question=question, @@ -796,7 +769,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3), top=search_config.get("top_k", 1024), doc_ids=doc_ids, - aggs=False, + aggs=True, rerank_mdl=rerank_mdl, rank_feature=label_question(question, kbs) ) @@ -826,13 +799,13 @@ def decorate_answer(answer): return {"answer": answer, "reference": refs} answer = "" - for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}): + async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}): answer = ans yield {"answer": answer, "reference": {}} yield decorate_answer(answer) -def gen_mindmap(question, kb_ids, tenant_id, search_config={}): +async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): meta_data_filter = search_config.get("meta_data_filter", {}) doc_ids = search_config.get("doc_ids", []) rerank_id = search_config.get("rerank_id", "") @@ -850,15 +823,7 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}): if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - filters = gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters)) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"])) - if not doc_ids: - doc_ids = None + doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) ranks = settings.retriever.retrieval( question=question, @@ -876,5 +841,5 @@ def gen_mindmap(question, kb_ids, tenant_id, search_config={}): rank_feature=label_question(question, kbs), ) mindmap = MindMapExtractor(chat_mdl) - mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) + mind_map = await mindmap([c["content_with_weight"] for c in ranks["chunks"]]) return mind_map.output diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index a64ae16deb7..a05d1783d9e 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import random @@ -22,7 +23,6 @@ from datetime import datetime from io import BytesIO -import trio import xxhash from peewee import fn, Case, JOIN @@ -33,14 +33,16 @@ from api.db.db_utils import bulk_insert_into_db from api.db.services.common_service import CommonService from api.db.services.knowledgebase_service import KnowledgebaseService +from common.metadata_utils import dedupe_list from common.misc_utils import get_uuid from common.time_utils import current_timestamp, get_format_time from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME from rag.nlp import rag_tokenizer, search from rag.utils.redis_conn import REDIS_CONN -from rag.utils.doc_store_conn import OrderByExpr +from common.doc_store.doc_store_base import OrderByExpr from common import settings + class DocumentService(CommonService): model = Document @@ -78,7 +80,7 @@ def get_cls_model_fields(cls): @classmethod @DB.connection_context() def get_list(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords, id, name, suffix=None, run = None): + orderby, desc, keywords, id, name, suffix=None, run = None, doc_ids=None): fields = cls.get_cls_model_fields() docs = cls.model.select(*[*fields, UserCanvas.title]).join(File2Document, on = (File2Document.document_id == cls.model.id))\ .join(File, on = (File.id == File2Document.file_id))\ @@ -95,6 +97,8 @@ def get_list(cls, kb_id, page_number, items_per_page, docs = docs.where( fn.LOWER(cls.model.name).contains(keywords.lower()) ) + if doc_ids: + docs = docs.where(cls.model.id.in_(doc_ids)) if suffix: docs = docs.where(cls.model.suffix.in_(suffix)) if run: @@ -113,7 +117,7 @@ def get_list(cls, kb_id, page_number, items_per_page, def check_doc_health(cls, tenant_id: str, filename): import os MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0)) - if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER: + if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id): raise RuntimeError("Exceed the maximum file number of a free user!") if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT: raise RuntimeError("Exceed the maximum length of file name!") @@ -121,33 +125,37 @@ def check_doc_health(cls, tenant_id: str, filename): @classmethod @DB.connection_context() - def get_by_kb_id(cls, kb_id, page_number, items_per_page, - orderby, desc, keywords, run_status, types, suffix): + def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_ids=None, return_empty_metadata=False): fields = cls.get_cls_model_fields() if keywords: - docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\ - .join(File2Document, on=(File2Document.document_id == cls.model.id))\ - .join(File, on=(File.id == File2Document.file_id))\ - .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\ - .join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\ - .where( - (cls.model.kb_id == kb_id), - (fn.LOWER(cls.model.name).contains(keywords.lower())) - ) + docs = ( + cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname]) + .join(File2Document, on=(File2Document.document_id == cls.model.id)) + .join(File, on=(File.id == File2Document.file_id)) + .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER) + .join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER) + .where((cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower()))) + ) else: - docs = cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname])\ - .join(File2Document, on=(File2Document.document_id == cls.model.id))\ - .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER)\ - .join(File, on=(File.id == File2Document.file_id))\ - .join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER)\ + docs = ( + cls.model.select(*[*fields, UserCanvas.title.alias("pipeline_name"), User.nickname]) + .join(File2Document, on=(File2Document.document_id == cls.model.id)) + .join(UserCanvas, on=(cls.model.pipeline_id == UserCanvas.id), join_type=JOIN.LEFT_OUTER) + .join(File, on=(File.id == File2Document.file_id)) + .join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER) .where(cls.model.kb_id == kb_id) + ) + if doc_ids: + docs = docs.where(cls.model.id.in_(doc_ids)) if run_status: docs = docs.where(cls.model.run.in_(run_status)) if types: docs = docs.where(cls.model.type.in_(types)) if suffix: docs = docs.where(cls.model.suffix.in_(suffix)) + if return_empty_metadata: + docs = docs.where(fn.COALESCE(fn.JSON_LENGTH(cls.model.meta_fields), 0) == 0) count = docs.count() if desc: @@ -155,7 +163,6 @@ def get_by_kb_id(cls, kb_id, page_number, items_per_page, else: docs = docs.order_by(cls.model.getter_by(orderby).asc()) - if page_number and items_per_page: docs = docs.paginate(page_number, items_per_page) @@ -175,6 +182,16 @@ def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix): "1": 2, "2": 2 } + "metadata": { + "key1": { + "key1_value1": 1, + "key1_value2": 2, + }, + "key2": { + "key2_value1": 2, + "key2_value2": 1, + }, + } }, total where "1" => RUNNING, "2" => CANCEL """ @@ -195,19 +212,42 @@ def get_filter_by_kb_id(cls, kb_id, keywords, run_status, types, suffix): if suffix: query = query.where(cls.model.suffix.in_(suffix)) - rows = query.select(cls.model.run, cls.model.suffix) + rows = query.select(cls.model.run, cls.model.suffix, cls.model.meta_fields) total = rows.count() suffix_counter = {} run_status_counter = {} + metadata_counter = {} + empty_metadata_count = 0 for row in rows: suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1 run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1 - + meta_fields = row.meta_fields or {} + if not meta_fields: + empty_metadata_count += 1 + continue + has_valid_meta = False + for key, value in meta_fields.items(): + values = value if isinstance(value, list) else [value] + for vv in values: + if vv is None: + continue + if isinstance(vv, str) and not vv.strip(): + continue + sv = str(vv) + if key not in metadata_counter: + metadata_counter[key] = {} + metadata_counter[key][sv] = metadata_counter[key].get(sv, 0) + 1 + has_valid_meta = True + if not has_valid_meta: + empty_metadata_count += 1 + + metadata_counter["empty_metadata"] = {"true": empty_metadata_count} return { "suffix": suffix_counter, - "run_status": run_status_counter + "run_status": run_status_counter, + "metadata": metadata_counter, }, total @classmethod @@ -302,27 +342,13 @@ def remove_document(cls, doc, tenant_id): cls.clear_chunk_num(doc.id) try: TaskService.filter_delete([Task.doc_id == doc.id]) - page = 0 - page_size = 1000 - all_chunk_ids = [] - while True: - chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), - page * page_size, page_size, search.index_name(tenant_id), - [doc.kb_id]) - chunk_ids = settings.docStoreConn.getChunkIds(chunks) - if not chunk_ids: - break - all_chunk_ids.extend(chunk_ids) - page += 1 - for cid in all_chunk_ids: - if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): - settings.STORAGE_IMPL.rm(doc.kb_id, cid) + cls.delete_chunk_images(doc, tenant_id) if doc.thumbnail and not doc.thumbnail.startswith(IMG_BASE64_PREFIX): if settings.STORAGE_IMPL.obj_exist(doc.kb_id, doc.thumbnail): settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail) settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) - graph_source = settings.docStoreConn.getFields( + graph_source = settings.docStoreConn.get_fields( settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"] ) if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]: @@ -338,6 +364,23 @@ def remove_document(cls, doc, tenant_id): pass return cls.delete_by_id(doc.id) + @classmethod + @DB.connection_context() + def delete_chunk_images(cls, doc, tenant_id): + page = 0 + page_size = 1000 + while True: + chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(), + page * page_size, page_size, search.index_name(tenant_id), + [doc.kb_id]) + chunk_ids = settings.docStoreConn.get_doc_ids(chunks) + if not chunk_ids: + break + for cid in chunk_ids: + if settings.STORAGE_IMPL.obj_exist(doc.kb_id, cid): + settings.STORAGE_IMPL.rm(doc.kb_id, cid) + page += 1 + @classmethod @DB.connection_context() def get_newly_uploaded(cls): @@ -464,7 +507,7 @@ def get_tenant_id(cls, doc_id): cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: - return + return None return docs[0]["tenant_id"] @classmethod @@ -473,7 +516,7 @@ def get_knowledgebase_id(cls, doc_id): docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id) docs = docs.dicts() if not docs: - return + return None return docs[0]["kb_id"] @classmethod @@ -486,7 +529,7 @@ def get_tenant_id_by_name(cls, name): cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: - return + return None return docs[0]["tenant_id"] @classmethod @@ -533,7 +576,7 @@ def get_embd_id(cls, doc_id): cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value) docs = docs.dicts() if not docs: - return + return None return docs[0]["embd_id"] @classmethod @@ -569,7 +612,7 @@ def get_doc_id_by_doc_name(cls, doc_name): .where(cls.model.name == doc_name) doc_id = doc_id.dicts() if not doc_id: - return + return None return doc_id[0]["id"] @classmethod @@ -643,6 +686,13 @@ def update_meta_fields(cls, doc_id, meta_fields): @classmethod @DB.connection_context() def get_meta_by_kbs(cls, kb_ids): + """ + Legacy metadata aggregator (backward-compatible). + - Does NOT expand list values and a list is kept as one string key. + Example: {"tags": ["foo","bar"]} -> meta["tags"]["['foo', 'bar']"] = [doc_id] + - Expects meta_fields is a dict. + Use when existing callers rely on the old list-as-string semantics. + """ fields = [ cls.model.id, cls.model.meta_fields, @@ -653,12 +703,184 @@ def get_meta_by_kbs(cls, kb_ids): for k,v in r.meta_fields.items(): if k not in meta: meta[k] = {} - v = str(v) - if v not in meta[k]: - meta[k][v] = [] - meta[k][v].append(doc_id) + if not isinstance(v, list): + v = [v] + for vv in v: + if vv not in meta[k]: + if isinstance(vv, list) or isinstance(vv, dict): + continue + meta[k][vv] = [] + meta[k][vv].append(doc_id) + return meta + + @classmethod + @DB.connection_context() + def get_flatted_meta_by_kbs(cls, kb_ids): + """ + - Parses stringified JSON meta_fields when possible and skips non-dict or unparsable values. + - Expands list values into individual entries. + Example: {"tags": ["foo","bar"], "author": "alice"} -> + meta["tags"]["foo"] = [doc_id], meta["tags"]["bar"] = [doc_id], meta["author"]["alice"] = [doc_id] + Prefer for metadata_condition filtering and scenarios that must respect list semantics. + """ + fields = [ + cls.model.id, + cls.model.meta_fields, + ] + meta = {} + for r in cls.model.select(*fields).where(cls.model.kb_id.in_(kb_ids)): + doc_id = r.id + meta_fields = r.meta_fields or {} + if isinstance(meta_fields, str): + try: + meta_fields = json.loads(meta_fields) + except Exception: + continue + if not isinstance(meta_fields, dict): + continue + for k, v in meta_fields.items(): + if k not in meta: + meta[k] = {} + values = v if isinstance(v, list) else [v] + for vv in values: + if vv is None: + continue + sv = str(vv) + if sv not in meta[k]: + meta[k][sv] = [] + meta[k][sv].append(doc_id) return meta + @classmethod + @DB.connection_context() + def get_metadata_summary(cls, kb_id): + fields = [cls.model.id, cls.model.meta_fields] + summary = {} + for r in cls.model.select(*fields).where(cls.model.kb_id == kb_id): + meta_fields = r.meta_fields or {} + if isinstance(meta_fields, str): + try: + meta_fields = json.loads(meta_fields) + except Exception: + continue + if not isinstance(meta_fields, dict): + continue + for k, v in meta_fields.items(): + values = v if isinstance(v, list) else [v] + for vv in values: + if not vv: + continue + sv = str(vv) + if k not in summary: + summary[k] = {} + summary[k][sv] = summary[k].get(sv, 0) + 1 + return {k: sorted([(val, cnt) for val, cnt in v.items()], key=lambda x: x[1], reverse=True) for k, v in summary.items()} + + @classmethod + @DB.connection_context() + def batch_update_metadata(cls, kb_id, doc_ids, updates=None, deletes=None): + updates = updates or [] + deletes = deletes or [] + if not doc_ids: + return 0 + + def _normalize_meta(meta): + if isinstance(meta, str): + try: + meta = json.loads(meta) + except Exception: + return {} + if not isinstance(meta, dict): + return {} + return deepcopy(meta) + + def _str_equal(a, b): + return str(a) == str(b) + + def _apply_updates(meta): + changed = False + for upd in updates: + key = upd.get("key") + if not key or key not in meta: + continue + + new_value = upd.get("value") + match_provided = "match" in upd + if isinstance(meta[key], list): + if not match_provided: + if isinstance(new_value, list): + meta[key] = dedupe_list(new_value) + else: + meta[key] = new_value + changed = True + else: + match_value = upd.get("match") + replaced = False + new_list = [] + for item in meta[key]: + if _str_equal(item, match_value): + new_list.append(new_value) + replaced = True + else: + new_list.append(item) + if replaced: + meta[key] = dedupe_list(new_list) + changed = True + else: + if not match_provided: + meta[key] = new_value + changed = True + else: + match_value = upd.get("match") + if _str_equal(meta[key], match_value): + meta[key] = new_value + changed = True + return changed + + def _apply_deletes(meta): + changed = False + for d in deletes: + key = d.get("key") + if not key or key not in meta: + continue + value = d.get("value", None) + if isinstance(meta[key], list): + if value is None: + del meta[key] + changed = True + continue + new_list = [item for item in meta[key] if not _str_equal(item, value)] + if len(new_list) != len(meta[key]): + if new_list: + meta[key] = new_list + else: + del meta[key] + changed = True + else: + if value is None or _str_equal(meta[key], value): + del meta[key] + changed = True + return changed + + updated_docs = 0 + with DB.atomic(): + rows = cls.model.select(cls.model.id, cls.model.meta_fields).where( + (cls.model.id.in_(doc_ids)) & (cls.model.kb_id == kb_id) + ) + for r in rows: + meta = _normalize_meta(r.meta_fields or {}) + original_meta = deepcopy(meta) + changed = _apply_updates(meta) + changed = _apply_deletes(meta) or changed + if changed and meta != original_meta: + cls.model.update( + meta_fields=meta, + update_time=current_timestamp(), + update_date=get_format_time() + ).where(cls.model.id == r.id).execute() + updated_docs += 1 + return updated_docs + @classmethod @DB.connection_context() def update_progress(cls): @@ -715,13 +937,17 @@ def _sync_progress(cls, docs:list[dict]): prg = 1 status = TaskStatus.DONE.value - # only for special task and parsed docs and unfinised + # only for special task and parsed docs and unfinished freeze_progress = special_task_running and doc_progress >= 1 and not finished msg = "\n".join(sorted(msg)) + begin_at = d.get("process_begin_at") + if not begin_at: + begin_at = datetime.now() + # fallback + cls.update_by_id(d["id"], {"process_begin_at": begin_at}) + info = { - "process_duration": datetime.timestamp( - datetime.now()) - - d["process_begin_at"].timestamp(), + "process_duration": max(datetime.timestamp(datetime.now()) - begin_at.timestamp(), 0), "run": status} if prg != 0 and not freeze_progress: info["progress"] = prg @@ -901,12 +1127,12 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id): e, dia = DialogService.get_by_id(conv.dialog_id) if not dia.kb_ids: - raise LookupError("No knowledge base associated with this conversation. " - "Please add a knowledge base before uploading documents") + raise LookupError("No dataset associated with this conversation. " + "Please add a dataset before uploading documents") kb_id = dia.kb_ids[0] e, kb = KnowledgebaseService.get_by_id(kb_id) if not e: - raise LookupError("Can't find this knowledgebase!") + raise LookupError("Can't find this dataset!") embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language) @@ -922,7 +1148,7 @@ def dummy(prog=None, msg=""): ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email } - parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"} + parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0} exe = ThreadPoolExecutor(max_workers=12) threads = [] doc_nm = {} @@ -974,13 +1200,13 @@ def dummy(prog=None, msg=""): def embedding(doc_id, cnts, batch_size=16): nonlocal embd_mdl, chunk_counts, token_counts - vects = [] + vectors = [] for i in range(0, len(cnts), batch_size): vts, c = embd_mdl.encode(cnts[i: i + batch_size]) - vects.extend(vts.tolist()) + vectors.extend(vts.tolist()) chunk_counts[doc_id] += len(cnts[i:i + batch_size]) token_counts[doc_id] += c - return vects + return vectors idxnm = search.index_name(kb.tenant_id) try_create_idx = True @@ -994,7 +1220,7 @@ def embedding(doc_id, cnts, batch_size=16): from graphrag.general.mind_map_extractor import MindMapExtractor mindmap = MindMapExtractor(llm_bdl) try: - mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]) + mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])) mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2) if len(mind_map) < 32: raise Exception("Few content: " + mind_map) @@ -1011,15 +1237,15 @@ def embedding(doc_id, cnts, batch_size=16): except Exception: logging.exception("Mind map generation error") - vects = embedding(doc_id, [c["content_with_weight"] for c in cks]) - assert len(cks) == len(vects) + vectors = embedding(doc_id, [c["content_with_weight"] for c in cks]) + assert len(cks) == len(vectors) for i, d in enumerate(cks): - v = vects[i] + v = vectors[i] d["q_%d_vec" % len(v)] = v for b in range(0, len(cks), es_bulk_size): if try_create_idx: - if not settings.docStoreConn.indexExist(idxnm, kb_id): - settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0])) + if not settings.docStoreConn.index_exist(idxnm, kb_id): + settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0])) try_create_idx = False settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id) diff --git a/api/db/services/evaluation_service.py b/api/db/services/evaluation_service.py new file mode 100644 index 00000000000..3f523b1d8c1 --- /dev/null +++ b/api/db/services/evaluation_service.py @@ -0,0 +1,638 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +RAG Evaluation Service + +Provides functionality for evaluating RAG system performance including: +- Dataset management +- Test case management +- Evaluation execution +- Metrics computation +- Configuration recommendations +""" + +import asyncio +import logging +import queue +import threading +from typing import List, Dict, Any, Optional, Tuple +from datetime import datetime +from timeit import default_timer as timer + +from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult +from api.db.services.common_service import CommonService +from api.db.services.dialog_service import DialogService +from common.misc_utils import get_uuid +from common.time_utils import current_timestamp +from common.constants import StatusEnum + + +class EvaluationService(CommonService): + """Service for managing RAG evaluations""" + + model = EvaluationDataset + + # ==================== Dataset Management ==================== + + @classmethod + def create_dataset(cls, name: str, description: str, kb_ids: List[str], + tenant_id: str, user_id: str) -> Tuple[bool, str]: + """ + Create a new evaluation dataset. + + Args: + name: Dataset name + description: Dataset description + kb_ids: List of knowledge base IDs to evaluate against + tenant_id: Tenant ID + user_id: User ID who creates the dataset + + Returns: + (success, dataset_id or error_message) + """ + try: + timestamp= current_timestamp() + dataset_id = get_uuid() + dataset = { + "id": dataset_id, + "tenant_id": tenant_id, + "name": name, + "description": description, + "kb_ids": kb_ids, + "created_by": user_id, + "create_time": timestamp, + "update_time": timestamp, + "status": StatusEnum.VALID.value + } + + if not EvaluationDataset.create(**dataset): + return False, "Failed to create dataset" + + return True, dataset_id + except Exception as e: + logging.error(f"Error creating evaluation dataset: {e}") + return False, str(e) + + @classmethod + def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]: + """Get dataset by ID""" + try: + dataset = EvaluationDataset.get_by_id(dataset_id) + if dataset: + return dataset.to_dict() + return None + except Exception as e: + logging.error(f"Error getting dataset {dataset_id}: {e}") + return None + + @classmethod + def list_datasets(cls, tenant_id: str, user_id: str, + page: int = 1, page_size: int = 20) -> Dict[str, Any]: + """List datasets for a tenant""" + try: + query = EvaluationDataset.select().where( + (EvaluationDataset.tenant_id == tenant_id) & + (EvaluationDataset.status == StatusEnum.VALID.value) + ).order_by(EvaluationDataset.create_time.desc()) + + total = query.count() + datasets = query.paginate(page, page_size) + + return { + "total": total, + "datasets": [d.to_dict() for d in datasets] + } + except Exception as e: + logging.error(f"Error listing datasets: {e}") + return {"total": 0, "datasets": []} + + @classmethod + def update_dataset(cls, dataset_id: str, **kwargs) -> bool: + """Update dataset""" + try: + kwargs["update_time"] = current_timestamp() + return EvaluationDataset.update(**kwargs).where( + EvaluationDataset.id == dataset_id + ).execute() > 0 + except Exception as e: + logging.error(f"Error updating dataset {dataset_id}: {e}") + return False + + @classmethod + def delete_dataset(cls, dataset_id: str) -> bool: + """Soft delete dataset""" + try: + return EvaluationDataset.update( + status=StatusEnum.INVALID.value, + update_time=current_timestamp() + ).where(EvaluationDataset.id == dataset_id).execute() > 0 + except Exception as e: + logging.error(f"Error deleting dataset {dataset_id}: {e}") + return False + + # ==================== Test Case Management ==================== + + @classmethod + def add_test_case(cls, dataset_id: str, question: str, + reference_answer: Optional[str] = None, + relevant_doc_ids: Optional[List[str]] = None, + relevant_chunk_ids: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]: + """ + Add a test case to a dataset. + + Args: + dataset_id: Dataset ID + question: Test question + reference_answer: Optional ground truth answer + relevant_doc_ids: Optional list of relevant document IDs + relevant_chunk_ids: Optional list of relevant chunk IDs + metadata: Optional additional metadata + + Returns: + (success, case_id or error_message) + """ + try: + case_id = get_uuid() + case = { + "id": case_id, + "dataset_id": dataset_id, + "question": question, + "reference_answer": reference_answer, + "relevant_doc_ids": relevant_doc_ids, + "relevant_chunk_ids": relevant_chunk_ids, + "metadata": metadata, + "create_time": current_timestamp() + } + + if not EvaluationCase.create(**case): + return False, "Failed to create test case" + + return True, case_id + except Exception as e: + logging.error(f"Error adding test case: {e}") + return False, str(e) + + @classmethod + def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]: + """Get all test cases for a dataset""" + try: + cases = EvaluationCase.select().where( + EvaluationCase.dataset_id == dataset_id + ).order_by(EvaluationCase.create_time) + + return [c.to_dict() for c in cases] + except Exception as e: + logging.error(f"Error getting test cases for dataset {dataset_id}: {e}") + return [] + + @classmethod + def delete_test_case(cls, case_id: str) -> bool: + """Delete a test case""" + try: + return EvaluationCase.delete().where( + EvaluationCase.id == case_id + ).execute() > 0 + except Exception as e: + logging.error(f"Error deleting test case {case_id}: {e}") + return False + + @classmethod + def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]: + """ + Bulk import test cases from a list. + + Args: + dataset_id: Dataset ID + cases: List of test case dictionaries + + Returns: + (success_count, failure_count) + """ + success_count = 0 + failure_count = 0 + + for case_data in cases: + success, _ = cls.add_test_case( + dataset_id=dataset_id, + question=case_data.get("question", ""), + reference_answer=case_data.get("reference_answer"), + relevant_doc_ids=case_data.get("relevant_doc_ids"), + relevant_chunk_ids=case_data.get("relevant_chunk_ids"), + metadata=case_data.get("metadata") + ) + + if success: + success_count += 1 + else: + failure_count += 1 + + return success_count, failure_count + + # ==================== Evaluation Execution ==================== + + @classmethod + def start_evaluation(cls, dataset_id: str, dialog_id: str, + user_id: str, name: Optional[str] = None) -> Tuple[bool, str]: + """ + Start an evaluation run. + + Args: + dataset_id: Dataset ID + dialog_id: Dialog configuration to evaluate + user_id: User ID who starts the run + name: Optional run name + + Returns: + (success, run_id or error_message) + """ + try: + # Get dialog configuration + success, dialog = DialogService.get_by_id(dialog_id) + if not success: + return False, "Dialog not found" + + # Create evaluation run + run_id = get_uuid() + if not name: + name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + + run = { + "id": run_id, + "dataset_id": dataset_id, + "dialog_id": dialog_id, + "name": name, + "config_snapshot": dialog.to_dict(), + "metrics_summary": None, + "status": "RUNNING", + "created_by": user_id, + "create_time": current_timestamp(), + "complete_time": None + } + + if not EvaluationRun.create(**run): + return False, "Failed to create evaluation run" + + # Execute evaluation asynchronously (in production, use task queue) + # For now, we'll execute synchronously + cls._execute_evaluation(run_id, dataset_id, dialog) + + return True, run_id + except Exception as e: + logging.error(f"Error starting evaluation: {e}") + return False, str(e) + + @classmethod + def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any): + """ + Execute evaluation for all test cases. + + This method runs the RAG pipeline for each test case and computes metrics. + """ + try: + # Get all test cases + test_cases = cls.get_test_cases(dataset_id) + + if not test_cases: + EvaluationRun.update( + status="FAILED", + complete_time=current_timestamp() + ).where(EvaluationRun.id == run_id).execute() + return + + # Execute each test case + results = [] + for case in test_cases: + result = cls._evaluate_single_case(run_id, case, dialog) + if result: + results.append(result) + + # Compute summary metrics + metrics_summary = cls._compute_summary_metrics(results) + + # Update run status + EvaluationRun.update( + status="COMPLETED", + metrics_summary=metrics_summary, + complete_time=current_timestamp() + ).where(EvaluationRun.id == run_id).execute() + + except Exception as e: + logging.error(f"Error executing evaluation {run_id}: {e}") + EvaluationRun.update( + status="FAILED", + complete_time=current_timestamp() + ).where(EvaluationRun.id == run_id).execute() + + @classmethod + def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any], + dialog: Any) -> Optional[Dict[str, Any]]: + """ + Evaluate a single test case. + + Args: + run_id: Evaluation run ID + case: Test case dictionary + dialog: Dialog configuration + + Returns: + Result dictionary or None if failed + """ + try: + # Prepare messages + messages = [{"role": "user", "content": case["question"]}] + + # Execute RAG pipeline + start_time = timer() + answer = "" + retrieved_chunks = [] + + + def _sync_from_async_gen(async_gen): + result_queue: queue.Queue = queue.Queue() + + def runner(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def consume(): + try: + async for item in async_gen: + result_queue.put(item) + except Exception as e: + result_queue.put(e) + finally: + result_queue.put(StopIteration) + + loop.run_until_complete(consume()) + loop.close() + + threading.Thread(target=runner, daemon=True).start() + + while True: + item = result_queue.get() + if item is StopIteration: + break + if isinstance(item, Exception): + raise item + yield item + + + def chat(dialog, messages, stream=True, **kwargs): + from api.db.services.dialog_service import async_chat + + return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs)) + + for ans in chat(dialog, messages, stream=False): + if isinstance(ans, dict): + answer = ans.get("answer", "") + retrieved_chunks = ans.get("reference", {}).get("chunks", []) + break + + execution_time = timer() - start_time + + # Compute metrics + metrics = cls._compute_metrics( + question=case["question"], + generated_answer=answer, + reference_answer=case.get("reference_answer"), + retrieved_chunks=retrieved_chunks, + relevant_chunk_ids=case.get("relevant_chunk_ids"), + dialog=dialog + ) + + # Save result + result_id = get_uuid() + result = { + "id": result_id, + "run_id": run_id, + "case_id": case["id"], + "generated_answer": answer, + "retrieved_chunks": retrieved_chunks, + "metrics": metrics, + "execution_time": execution_time, + "token_usage": None, # TODO: Track token usage + "create_time": current_timestamp() + } + + EvaluationResult.create(**result) + + return result + except Exception as e: + logging.error(f"Error evaluating case {case.get('id')}: {e}") + return None + + @classmethod + def _compute_metrics(cls, question: str, generated_answer: str, + reference_answer: Optional[str], + retrieved_chunks: List[Dict[str, Any]], + relevant_chunk_ids: Optional[List[str]], + dialog: Any) -> Dict[str, float]: + """ + Compute evaluation metrics for a single test case. + + Returns: + Dictionary of metric names to values + """ + metrics = {} + + # Retrieval metrics (if ground truth chunks provided) + if relevant_chunk_ids: + retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks] + metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids)) + + # Generation metrics + if generated_answer: + # Basic metrics + metrics["answer_length"] = len(generated_answer) + metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0 + + # TODO: Implement advanced metrics using LLM-as-judge + # - Faithfulness (hallucination detection) + # - Answer relevance + # - Context relevance + # - Semantic similarity (if reference answer provided) + + return metrics + + @classmethod + def _compute_retrieval_metrics(cls, retrieved_ids: List[str], + relevant_ids: List[str]) -> Dict[str, float]: + """ + Compute retrieval metrics. + + Args: + retrieved_ids: List of retrieved chunk IDs + relevant_ids: List of relevant chunk IDs (ground truth) + + Returns: + Dictionary of retrieval metrics + """ + if not relevant_ids: + return {} + + retrieved_set = set(retrieved_ids) + relevant_set = set(relevant_ids) + + # Precision: proportion of retrieved that are relevant + precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0 + + # Recall: proportion of relevant that were retrieved + recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0 + + # F1 score + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 + + # Hit rate: whether any relevant chunk was retrieved + hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0 + + # MRR (Mean Reciprocal Rank): position of first relevant chunk + mrr = 0.0 + for i, chunk_id in enumerate(retrieved_ids, 1): + if chunk_id in relevant_set: + mrr = 1.0 / i + break + + return { + "precision": precision, + "recall": recall, + "f1_score": f1, + "hit_rate": hit_rate, + "mrr": mrr + } + + @classmethod + def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Compute summary metrics across all test cases. + + Args: + results: List of result dictionaries + + Returns: + Summary metrics dictionary + """ + if not results: + return {} + + # Aggregate metrics + metric_sums = {} + metric_counts = {} + + for result in results: + metrics = result.get("metrics", {}) + for key, value in metrics.items(): + if isinstance(value, (int, float)): + metric_sums[key] = metric_sums.get(key, 0) + value + metric_counts[key] = metric_counts.get(key, 0) + 1 + + # Compute averages + summary = { + "total_cases": len(results), + "avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results) + } + + for key in metric_sums: + summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key] + + return summary + + # ==================== Results & Analysis ==================== + + @classmethod + def get_run_results(cls, run_id: str) -> Dict[str, Any]: + """Get results for an evaluation run""" + try: + run = EvaluationRun.get_by_id(run_id) + if not run: + return {} + + results = EvaluationResult.select().where( + EvaluationResult.run_id == run_id + ).order_by(EvaluationResult.create_time) + + return { + "run": run.to_dict(), + "results": [r.to_dict() for r in results] + } + except Exception as e: + logging.error(f"Error getting run results {run_id}: {e}") + return {} + + @classmethod + def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]: + """ + Analyze evaluation results and provide configuration recommendations. + + Args: + run_id: Evaluation run ID + + Returns: + List of recommendation dictionaries + """ + try: + run = EvaluationRun.get_by_id(run_id) + if not run or not run.metrics_summary: + return [] + + metrics = run.metrics_summary + recommendations = [] + + # Low precision: retrieving irrelevant chunks + if metrics.get("avg_precision", 1.0) < 0.7: + recommendations.append({ + "issue": "Low Precision", + "severity": "high", + "description": "System is retrieving many irrelevant chunks", + "suggestions": [ + "Increase similarity_threshold to filter out less relevant chunks", + "Enable reranking to improve chunk ordering", + "Reduce top_k to return fewer chunks" + ] + }) + + # Low recall: missing relevant chunks + if metrics.get("avg_recall", 1.0) < 0.7: + recommendations.append({ + "issue": "Low Recall", + "severity": "high", + "description": "System is missing relevant chunks", + "suggestions": [ + "Increase top_k to retrieve more chunks", + "Lower similarity_threshold to be more inclusive", + "Enable hybrid search (keyword + semantic)", + "Check chunk size - may be too large or too small" + ] + }) + + # Slow response time + if metrics.get("avg_execution_time", 0) > 5.0: + recommendations.append({ + "issue": "Slow Response Time", + "severity": "medium", + "description": f"Average response time is {metrics['avg_execution_time']:.2f}s", + "suggestions": [ + "Reduce top_k to retrieve fewer chunks", + "Optimize embedding model selection", + "Consider caching frequently asked questions" + ] + }) + + return recommendations + except Exception as e: + logging.error(f"Error generating recommendations for run {run_id}: {e}") + return [] diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index 5a3632e97d3..d6a157b2d1e 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio +import base64 import logging import re +import sys +import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path +from typing import Union -from flask_login import current_user from peewee import fn from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileType @@ -31,7 +35,7 @@ from common.constants import TaskStatus, FileSource, ParserType from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.task_service import TaskService -from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img +from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path from rag.llm.cv_model import GptV4 from common import settings @@ -90,13 +94,13 @@ def get_by_pf_id(cls, tenant_id, pf_id, page_number, items_per_page, orderby, de @classmethod @DB.connection_context() def get_kb_id_by_file_id(cls, file_id): - # Get knowledge base IDs associated with a file + # Get dataset IDs associated with a file # Args: # file_id: File ID # Returns: - # List of dictionaries containing knowledge base IDs and names + # List of dictionaries containing dataset IDs and names kbs = ( - cls.model.select(*[Knowledgebase.id, Knowledgebase.name]) + cls.model.select(*[Knowledgebase.id, Knowledgebase.name, File2Document.document_id]) .join(File2Document, on=(File2Document.file_id == file_id)) .join(Document, on=(File2Document.document_id == Document.id)) .join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id)) @@ -106,7 +110,7 @@ def get_kb_id_by_file_id(cls, file_id): return [] kbs_info_list = [] for kb in list(kbs.dicts()): - kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]}) + kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"], "document_id": kb["document_id"]}) return kbs_info_list @classmethod @@ -184,6 +188,7 @@ def get_all_file_ids_by_tenant_id(cls, tenant_id): @classmethod @DB.connection_context() def create_folder(cls, file, parent_id, name, count): + from api.apps import current_user # Recursively create folder structure # Args: # file: Current file object @@ -242,7 +247,7 @@ def get_root_folder(cls, tenant_id): @classmethod @DB.connection_context() def get_kb_folder(cls, tenant_id): - # Get knowledge base folder for tenant + # Get dataset folder for tenant # Args: # tenant_id: Tenant ID # Returns: @@ -258,7 +263,7 @@ def get_kb_folder(cls, tenant_id): @classmethod @DB.connection_context() def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value, size=0, location=""): - # Create a new file from knowledge base + # Create a new file from dataset # Args: # tenant_id: Tenant ID # name: File name @@ -287,7 +292,7 @@ def new_a_file_from_kb(cls, tenant_id, name, parent_id, ty=FileType.FOLDER.value @classmethod @DB.connection_context() def init_knowledgebase_docs(cls, root_id, tenant_id): - # Initialize knowledge base documents + # Initialize dataset documents # Args: # root_id: Root folder ID # tenant_id: Tenant ID @@ -329,7 +334,7 @@ def get_all_parent_folders(cls, start_id): current_id = start_id while current_id: e, file = cls.get_by_id(current_id) - if file.parent_id != file.id and e: + if e and file.parent_id != file.id: parent_folders.append(file) current_id = file.parent_id else: @@ -423,13 +428,15 @@ def move_file(cls, file_ids, folder_id): @classmethod @DB.connection_context() - def upload_document(self, kb, file_objs, user_id, src="local"): + def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str | None = None): root_folder = self.get_root_folder(user_id) pf_id = root_folder["id"] self.init_knowledgebase_docs(pf_id, user_id) kb_root_folder = self.get_kb_folder(user_id) kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) + safe_parent_path = sanitize_path(parent_path) + err, files = [], [] for file in file_objs: try: @@ -439,7 +446,7 @@ def upload_document(self, kb, file_objs, user_id, src="local"): if filetype == FileType.OTHER.value: raise RuntimeError("This type of file has not been supported yet!") - location = filename + location = filename if not safe_parent_path else f"{safe_parent_path}/{filename}" while settings.STORAGE_IMPL.obj_exist(kb.id, location): location += "_" @@ -506,6 +513,7 @@ def parse_docs(file_objs, user_id): @staticmethod def parse(filename, blob, img_base64=True, tenant_id=None): from rag.app import audio, email, naive, picture, presentation + from api.apps import current_user def dummy(prog=None, msg=""): pass @@ -517,7 +525,7 @@ def dummy(prog=None, msg=""): if img_base64 and file_type == FileType.VISUAL.value: return GptV4.image2base64(blob) cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs) - return "\n".join([ck["content_with_weight"] for ck in cks]) + return f"\n -----------------\nFile: {filename}\nContent as following: \n" + "\n".join([ck["content_with_weight"] for ck in cks]) @staticmethod def get_parser(doc_type, filename, default): @@ -585,3 +593,80 @@ def delete_docs(cls, doc_ids, tenant_id): errors += str(e) return errors + + @staticmethod + def upload_info(user_id, file, url: str|None=None): + def structured(filename, filetype, blob, content_type): + nonlocal user_id + if filetype == FileType.PDF.value: + blob = read_potential_broken_pdf(blob) + + location = get_uuid() + FileService.put_blob(user_id, location, blob) + + return { + "id": location, + "name": filename, + "size": sys.getsizeof(blob), + "extension": filename.split(".")[-1].lower(), + "mime_type": content_type, + "created_by": user_id, + "created_at": time.time(), + "preview_url": None + } + + if url: + from crawl4ai import ( + AsyncWebCrawler, + BrowserConfig, + CrawlerRunConfig, + DefaultMarkdownGenerator, + PruningContentFilter, + CrawlResult + ) + filename = re.sub(r"\?.*", "", url.split("/")[-1]) + async def adownload(): + browser_config = BrowserConfig( + headless=True, + verbose=False, + ) + async with AsyncWebCrawler(config=browser_config) as crawler: + crawler_config = CrawlerRunConfig( + markdown_generator=DefaultMarkdownGenerator( + content_filter=PruningContentFilter() + ), + pdf=True, + screenshot=False + ) + result: CrawlResult = await crawler.arun( + url=url, + config=crawler_config + ) + return result + page = asyncio.run(adownload()) + if page.pdf: + if filename.split(".")[-1].lower() != "pdf": + filename += ".pdf" + return structured(filename, "pdf", page.pdf, page.response_headers["content-type"]) + + return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id) + + DocumentService.check_doc_health(user_id, file.filename) + return structured(file.filename, filename_type(file.filename), file.read(), file.content_type) + + @staticmethod + def get_files(files: Union[None, list[dict]]) -> list[str]: + if not files: + return [] + def image_to_base64(file): + return "data:{};base64,{}".format(file["mime_type"], + base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) + exe = ThreadPoolExecutor(max_workers=5) + threads = [] + for file in files: + if file["mime_type"].find("image") >=0: + threads.append(exe.submit(image_to_base64, file)) + continue + threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"])) + return [th.result() for th in threads] + diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index 03179da49bf..5f506888c0d 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -28,10 +28,11 @@ from api.constants import DATASET_NAME_LIMIT from api.utils.api_utils import get_parser_config, get_data_error_result + class KnowledgebaseService(CommonService): - """Service class for managing knowledge base operations. + """Service class for managing dataset operations. - This class extends CommonService to provide specialized functionality for knowledge base + This class extends CommonService to provide specialized functionality for dataset management, including document parsing status tracking, access control, and configuration management. It handles operations such as listing, creating, updating, and deleting knowledge bases, as well as managing their associated documents and permissions. @@ -40,7 +41,7 @@ class KnowledgebaseService(CommonService): - Document parsing status verification - Knowledge base access control - Parser configuration management - - Tenant-based knowledge base organization + - Tenant-based dataset organization Attributes: model: The Knowledgebase model class for database operations. @@ -50,18 +51,18 @@ class KnowledgebaseService(CommonService): @classmethod @DB.connection_context() def accessible4deletion(cls, kb_id, user_id): - """Check if a knowledge base can be deleted by a specific user. + """Check if a dataset can be deleted by a specific user. - This method verifies whether a user has permission to delete a knowledge base - by checking if they are the creator of that knowledge base. + This method verifies whether a user has permission to delete a dataset + by checking if they are the creator of that dataset. Args: - kb_id (str): The unique identifier of the knowledge base to check. + kb_id (str): The unique identifier of the dataset to check. user_id (str): The unique identifier of the user attempting the deletion. Returns: - bool: True if the user has permission to delete the knowledge base, - False if the user doesn't have permission or the knowledge base doesn't exist. + bool: True if the user has permission to delete the dataset, + False if the user doesn't have permission or the dataset doesn't exist. Example: >>> KnowledgebaseService.accessible4deletion("kb123", "user456") @@ -70,10 +71,10 @@ def accessible4deletion(cls, kb_id, user_id): Note: - This method only checks creator permissions - A return value of False can mean either: - 1. The knowledge base doesn't exist - 2. The user is not the creator of the knowledge base + 1. The dataset doesn't exist + 2. The user is not the creator of the dataset """ - # Check if a knowledge base can be deleted by a user + # Check if a dataset can be deleted by a user docs = cls.model.select( cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1) docs = docs.dicts() @@ -84,7 +85,7 @@ def accessible4deletion(cls, kb_id, user_id): @classmethod @DB.connection_context() def is_parsed_done(cls, kb_id): - # Check if all documents in the knowledge base have completed parsing + # Check if all documents in the dataset have completed parsing # # Args: # kb_id: Knowledge base ID @@ -95,13 +96,13 @@ def is_parsed_done(cls, kb_id): from common.constants import TaskStatus from api.db.services.document_service import DocumentService - # Get knowledge base information + # Get dataset information kbs = cls.query(id=kb_id) if not kbs: return False, "Knowledge base not found" kb = kbs[0] - # Get all documents in the knowledge base + # Get all documents in the dataset docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "", [], []) # Check parsing status of each document @@ -118,9 +119,9 @@ def is_parsed_done(cls, kb_id): @classmethod @DB.connection_context() def list_documents_by_ids(cls, kb_ids): - # Get document IDs associated with given knowledge base IDs + # Get document IDs associated with given dataset IDs # Args: - # kb_ids: List of knowledge base IDs + # kb_ids: List of dataset IDs # Returns: # List of document IDs doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where( @@ -234,11 +235,11 @@ def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id): @classmethod @DB.connection_context() def get_kb_ids(cls, tenant_id): - # Get all knowledge base IDs for a tenant + # Get all dataset IDs for a tenant # Args: # tenant_id: Tenant ID # Returns: - # List of knowledge base IDs + # List of dataset IDs fields = [ cls.model.id, ] @@ -249,11 +250,11 @@ def get_kb_ids(cls, tenant_id): @classmethod @DB.connection_context() def get_detail(cls, kb_id): - # Get detailed information about a knowledge base + # Get detailed information about a dataset # Args: # kb_id: Knowledge base ID # Returns: - # Dictionary containing knowledge base details + # Dictionary containing dataset details fields = [ cls.model.id, cls.model.embd_id, @@ -293,13 +294,13 @@ def get_detail(cls, kb_id): @classmethod @DB.connection_context() def update_parser_config(cls, id, config): - # Update parser configuration for a knowledge base + # Update parser configuration for a dataset # Args: # id: Knowledge base ID # config: New parser configuration e, m = cls.get_by_id(id) if not e: - raise LookupError(f"knowledgebase({id}) not found.") + raise LookupError(f"dataset({id}) not found.") def dfs_update(old, new): # Deep update of nested configuration @@ -324,7 +325,7 @@ def dfs_update(old, new): def delete_field_map(cls, id): e, m = cls.get_by_id(id) if not e: - raise LookupError(f"knowledgebase({id}) not found.") + raise LookupError(f"dataset({id}) not found.") m.parser_config.pop("field_map", None) cls.update_by_id(id, {"parser_config": m.parser_config}) @@ -334,7 +335,7 @@ def delete_field_map(cls, id): def get_field_map(cls, ids): # Get field mappings for knowledge bases # Args: - # ids: List of knowledge base IDs + # ids: List of dataset IDs # Returns: # Dictionary of field mappings conf = {} @@ -346,7 +347,7 @@ def get_field_map(cls, ids): @classmethod @DB.connection_context() def get_by_name(cls, kb_name, tenant_id): - # Get knowledge base by name and tenant ID + # Get dataset by name and tenant ID # Args: # kb_name: Knowledge base name # tenant_id: Tenant ID @@ -364,9 +365,9 @@ def get_by_name(cls, kb_name, tenant_id): @classmethod @DB.connection_context() def get_all_ids(cls): - # Get all knowledge base IDs + # Get all dataset IDs # Returns: - # List of all knowledge base IDs + # List of all dataset IDs return [m["id"] for m in cls.model.select(cls.model.id).dicts()] @@ -391,12 +392,12 @@ def create_with_name( """ # Validate name if not isinstance(name, str): - return get_data_error_result(message="Dataset name must be string.") + return False, get_data_error_result(message="Dataset name must be string.") dataset_name = name.strip() if dataset_name == "": - return get_data_error_result(message="Dataset name can't be empty.") + return False, get_data_error_result(message="Dataset name can't be empty.") if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT: - return get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}") + return False, get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}") # Deduplicate name within tenant dataset_name = duplicate_name( @@ -409,7 +410,7 @@ def create_with_name( # Verify tenant exists ok, _t = TenantService.get_by_id(tenant_id) if not ok: - return False, "Tenant not found." + return False, get_data_error_result(message="Tenant not found.") # Build payload kb_id = get_uuid() @@ -419,12 +420,14 @@ def create_with_name( "tenant_id": tenant_id, "created_by": tenant_id, "parser_id": (parser_id or "naive"), - **kwargs + **kwargs # Includes optional fields such as description, language, permission, avatar, parser_config, etc. } - # Default parser_config (align with kb_app.create) — do not accept external overrides + # Update parser_config (always override with validated default/merged config) payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config")) - return payload + payload["parser_config"]["llm_id"] = _t.llm_id + + return True, payload @classmethod @@ -469,7 +472,7 @@ def get_list(cls, joined_tenant_ids, user_id, @classmethod @DB.connection_context() def accessible(cls, kb_id, user_id): - # Check if a knowledge base is accessible by a user + # Check if a dataset is accessible by a user # Args: # kb_id: Knowledge base ID # user_id: User ID @@ -486,12 +489,12 @@ def accessible(cls, kb_id, user_id): @classmethod @DB.connection_context() def get_kb_by_id(cls, kb_id, user_id): - # Get knowledge base by ID and user ID + # Get dataset by ID and user ID # Args: # kb_id: Knowledge base ID # user_id: User ID # Returns: - # List containing knowledge base information + # List containing dataset information kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1) kbs = kbs.dicts() @@ -500,12 +503,12 @@ def get_kb_by_id(cls, kb_id, user_id): @classmethod @DB.connection_context() def get_kb_by_name(cls, kb_name, user_id): - # Get knowledge base by name and user ID + # Get dataset by name and user ID # Args: # kb_name: Knowledge base name # user_id: User ID # Returns: - # List containing knowledge base information + # List containing dataset information kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id) ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1) kbs = kbs.dicts() diff --git a/api/db/services/langfuse_service.py b/api/db/services/langfuse_service.py index af4233bec83..b571110aba2 100644 --- a/api/db/services/langfuse_service.py +++ b/api/db/services/langfuse_service.py @@ -64,10 +64,13 @@ def update_by_tenant(cls, tenant_id, langfuse_keys): @classmethod def save(cls, **kwargs): - kwargs["create_time"] = current_timestamp() - kwargs["create_date"] = datetime_format(datetime.now()) - kwargs["update_time"] = current_timestamp() - kwargs["update_date"] = datetime_format(datetime.now()) + current_ts = current_timestamp() + current_date = datetime_format(datetime.now()) + + kwargs["create_time"] = current_ts + kwargs["create_date"] = current_date + kwargs["update_time"] = current_ts + kwargs["update_date"] = current_date obj = cls.model.create(**kwargs) return obj diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 6ccbf5a94b1..e5505af8849 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -13,15 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import inspect import logging +import queue import re -from common.token_utils import num_tokens_from_string +import threading from functools import partial from typing import Generator + from api.db.db_models import LLM from api.db.services.common_service import CommonService from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService +from common.constants import LLMType +from common.token_utils import num_tokens_from_string class LLMService(CommonService): @@ -30,8 +35,17 @@ class LLMService(CommonService): def get_init_tenant_llm(user_id): from common import settings + tenant_llm = [] + model_configs = { + LLMType.CHAT: settings.CHAT_CFG, + LLMType.EMBEDDING: settings.EMBEDDING_CFG, + LLMType.SPEECH2TEXT: settings.ASR_CFG, + LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG, + LLMType.RERANK: settings.RERANK_CFG, + } + seen = set() factory_configs = [] for factory_config in [ @@ -54,8 +68,8 @@ def get_init_tenant_llm(user_id): "llm_factory": factory_config["factory"], "llm_name": llm.llm_name, "model_type": llm.model_type, - "api_key": factory_config["api_key"], - "api_base": factory_config["base_url"], + "api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]), + "api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]), "max_tokens": llm.max_tokens if llm.max_tokens else 8192, } ) @@ -80,8 +94,8 @@ def bind_tools(self, toolcall_session, tools): def encode(self, texts: list): if self.langfuse: - generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts}) - + generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts}) + safe_texts = [] for text in texts: token_size = num_tokens_from_string(text) @@ -90,12 +104,12 @@ def encode(self, texts: list): safe_texts.append(text[:target_len]) else: safe_texts.append(text) - + embeddings, used_tokens = self.mdl.encode(safe_texts) llm_name = getattr(self, "llm_name", None) if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name): - logging.error("LLMBundle.encode can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) + logging.error("LLMBundle.encode can't update token usage for /EMBEDDING used_tokens: {}".format(used_tokens)) if self.langfuse: generation.update(usage_details={"total_tokens": used_tokens}) @@ -110,7 +124,7 @@ def encode_queries(self, query: str): emd, used_tokens = self.mdl.encode_queries(query) llm_name = getattr(self, "llm_name", None) if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, llm_name): - logging.error("LLMBundle.encode_queries can't update token usage for {}/EMBEDDING used_tokens: {}".format(self.tenant_id, used_tokens)) + logging.error("LLMBundle.encode_queries can't update token usage for /EMBEDDING used_tokens: {}".format(used_tokens)) if self.langfuse: generation.update(usage_details={"total_tokens": used_tokens}) @@ -174,6 +188,68 @@ def transcription(self, audio): return txt + def stream_transcription(self, audio): + mdl = self.mdl + supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription")) + if supports_stream: + if self.langfuse: + generation = self.langfuse.start_generation( + trace_context=self.trace_context, + name="stream_transcription", + metadata={"model": self.llm_name}, + ) + final_text = "" + used_tokens = 0 + + try: + for evt in mdl.stream_transcription(audio): + if evt.get("event") == "final": + final_text = evt.get("text", "") + + yield evt + + except Exception as e: + err = {"event": "error", "text": str(e)} + yield err + final_text = final_text or "" + finally: + if final_text: + used_tokens = num_tokens_from_string(final_text) + TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens) + + if self.langfuse: + generation.update( + output={"output": final_text}, + usage_details={"total_tokens": used_tokens}, + ) + generation.end() + + return + + if self.langfuse: + generation = self.langfuse.start_generation( + trace_context=self.trace_context, + name="stream_transcription", + metadata={"model": self.llm_name}, + ) + + full_text, used_tokens = mdl.transcription(audio) + if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens): + logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}") + + if self.langfuse: + generation.update( + output={"output": full_text}, + usage_details={"total_tokens": used_tokens}, + ) + generation.end() + + yield { + "event": "final", + "text": full_text, + "streaming": False, + } + def tts(self, text: str) -> Generator[bytes, None, None]: if self.langfuse: generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text}) @@ -218,57 +294,150 @@ def _clean_param(chat_partial, **kwargs): return kwargs else: return {k: v for k, v in kwargs.items() if k in allowed_params} - def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str: + + def _run_coroutine_sync(self, coro): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + result_queue: queue.Queue = queue.Queue() + + def runner(): + try: + result_queue.put((True, asyncio.run(coro))) + except Exception as e: + result_queue.put((False, e)) + + thread = threading.Thread(target=runner, daemon=True) + thread.start() + thread.join() + + success, value = result_queue.get_nowait() + if success: + return value + raise value + + def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs): + result_queue: queue.Queue = queue.Queue() + + def runner(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def consume(): + try: + async for item in async_gen_fn(*args, **kwargs): + result_queue.put(item) + except Exception as e: + result_queue.put(e) + finally: + result_queue.put(StopIteration) + + loop.run_until_complete(consume()) + loop.close() + + threading.Thread(target=runner, daemon=True).start() + + while True: + item = result_queue.get() + if item is StopIteration: + break + if isinstance(item, Exception): + raise item + yield item + + def _bridge_sync_stream(self, gen): + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue() + + def worker(): + try: + for item in gen: + loop.call_soon_threadsafe(queue.put_nowait, item) + except Exception as e: + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) + + threading.Thread(target=worker, daemon=True).start() + return queue + + async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs): + if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"): + base_fn = self.mdl.async_chat_with_tools + elif hasattr(self.mdl, "async_chat"): + base_fn = self.mdl.async_chat + else: + raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools") + + generation = None if self.langfuse: generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history}) - chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs) - if self.is_tools and self.mdl.is_tools: - chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs) - + chat_partial = partial(base_fn, system, history, gen_conf) use_kwargs = self._clean_param(chat_partial, **kwargs) - txt, used_tokens = chat_partial(**use_kwargs) - txt = self._remove_reasoning_content(txt) + try: + txt, used_tokens = await chat_partial(**use_kwargs) + except Exception as e: + if generation: + generation.update(output={"error": str(e)}) + generation.end() + raise + + txt = self._remove_reasoning_content(txt) if not self.verbose_tool_use: txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): - logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) + if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name): + logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens)) - if self.langfuse: + if generation: generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens}) generation.end() return txt - def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): + async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): + total_tokens = 0 + ans = "" + if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"): + stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None) + elif hasattr(self.mdl, "async_chat_streamly"): + stream_fn = getattr(self.mdl, "async_chat_streamly", None) + else: + raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools") + + generation = None if self.langfuse: generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) - ans = "" - chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf) - total_tokens = 0 - if self.is_tools and self.mdl.is_tools: - chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf) - use_kwargs = self._clean_param(chat_partial, **kwargs) - for txt in chat_partial(**use_kwargs): - if isinstance(txt, int): - total_tokens = txt - if self.langfuse: - generation.update(output={"output": ans}) + if stream_fn: + chat_partial = partial(stream_fn, system, history, gen_conf) + use_kwargs = self._clean_param(chat_partial, **kwargs) + try: + async for txt in chat_partial(**use_kwargs): + if isinstance(txt, int): + total_tokens = txt + break + + if txt.endswith(""): + ans = ans[: -len("")] + + if not self.verbose_tool_use: + txt = re.sub(r".*?", "", txt, flags=re.DOTALL) + + ans += txt + yield ans + except Exception as e: + if generation: + generation.update(output={"error": str(e)}) generation.end() - break - - if txt.endswith(""): - ans = ans[: -len("")] - - if not self.verbose_tool_use: - txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - - ans += txt - yield ans - - if total_tokens > 0: - if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): - logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) + raise + if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name): + logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens)) + if generation: + generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens}) + generation.end() + return diff --git a/api/db/services/memory_service.py b/api/db/services/memory_service.py new file mode 100644 index 00000000000..8a65d15e24d --- /dev/null +++ b/api/db/services/memory_service.py @@ -0,0 +1,170 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List + +from api.db.db_models import DB, Memory, User +from api.db.services import duplicate_name +from api.db.services.common_service import CommonService +from api.utils.memory_utils import calculate_memory_type +from api.constants import MEMORY_NAME_LIMIT +from common.misc_utils import get_uuid +from common.time_utils import get_format_time, current_timestamp +from memory.utils.prompt_util import PromptAssembler + + +class MemoryService(CommonService): + # Service class for manage memory operations + model = Memory + + @classmethod + @DB.connection_context() + def get_by_memory_id(cls, memory_id: str): + return cls.model.select().where(cls.model.id == memory_id).first() + + @classmethod + @DB.connection_context() + def get_by_tenant_id(cls, tenant_id: str): + return cls.model.select().where(cls.model.tenant_id == tenant_id) + + @classmethod + @DB.connection_context() + def get_all_memory(cls): + memory_list = cls.model.select() + return list(memory_list) + + @classmethod + @DB.connection_context() + def get_with_owner_name_by_id(cls, memory_id: str): + fields = [ + cls.model.id, + cls.model.name, + cls.model.avatar, + cls.model.tenant_id, + User.nickname.alias("owner_name"), + cls.model.memory_type, + cls.model.storage_type, + cls.model.embd_id, + cls.model.llm_id, + cls.model.permissions, + cls.model.description, + cls.model.memory_size, + cls.model.forgetting_policy, + cls.model.temperature, + cls.model.system_prompt, + cls.model.user_prompt, + cls.model.create_date, + cls.model.create_time + ] + memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where( + cls.model.id == memory_id + ).first() + return memory + + @classmethod + @DB.connection_context() + def get_by_filter(cls, filter_dict: dict, keywords: str, page: int = 1, page_size: int = 50): + fields = [ + cls.model.id, + cls.model.name, + cls.model.avatar, + cls.model.tenant_id, + User.nickname.alias("owner_name"), + cls.model.memory_type, + cls.model.storage_type, + cls.model.permissions, + cls.model.description, + cls.model.create_time, + cls.model.create_date + ] + memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)) + if filter_dict.get("tenant_id"): + memories = memories.where(cls.model.tenant_id.in_(filter_dict["tenant_id"])) + if filter_dict.get("memory_type"): + memory_type_int = calculate_memory_type(filter_dict["memory_type"]) + memories = memories.where(cls.model.memory_type.bin_and(memory_type_int) > 0) + if filter_dict.get("storage_type"): + memories = memories.where(cls.model.storage_type == filter_dict["storage_type"]) + if keywords: + memories = memories.where(cls.model.name.contains(keywords)) + count = memories.count() + memories = memories.order_by(cls.model.update_time.desc()) + memories = memories.paginate(page, page_size) + + return list(memories.dicts()), count + + @classmethod + @DB.connection_context() + def create_memory(cls, tenant_id: str, name: str, memory_type: List[str], embd_id: str, llm_id: str): + # Deduplicate name within tenant + memory_name = duplicate_name( + cls.query, + name=name, + tenant_id=tenant_id + ) + if len(memory_name) > MEMORY_NAME_LIMIT: + return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}." + + timestamp = current_timestamp() + format_time = get_format_time() + # build create dict + memory_info = { + "id": get_uuid(), + "name": memory_name, + "memory_type": calculate_memory_type(memory_type), + "tenant_id": tenant_id, + "embd_id": embd_id, + "llm_id": llm_id, + "system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}), + "create_time": timestamp, + "create_date": format_time, + "update_time": timestamp, + "update_date": format_time, + } + obj = cls.model(**memory_info).save(force_insert=True) + + if not obj: + return False, "Could not create new memory." + + db_row = cls.model.select().where(cls.model.id == memory_info["id"]).first() + + return obj, db_row + + @classmethod + @DB.connection_context() + def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict): + if not update_dict: + return 0 + if "temperature" in update_dict and isinstance(update_dict["temperature"], str): + update_dict["temperature"] = float(update_dict["temperature"]) + if "memory_type" in update_dict and isinstance(update_dict["memory_type"], list): + update_dict["memory_type"] = calculate_memory_type(update_dict["memory_type"]) + if "name" in update_dict: + update_dict["name"] = duplicate_name( + cls.query, + name=update_dict["name"], + tenant_id=tenant_id + ) + update_dict.update({ + "update_time": current_timestamp(), + "update_date": get_format_time() + }) + + return cls.model.update(update_dict).where(cls.model.id == memory_id).execute() + + @classmethod + @DB.connection_context() + def delete_memory(cls, memory_id: str): + return cls.model.delete().where(cls.model.id == memory_id).execute() diff --git a/api/db/services/pipeline_operation_log_service.py b/api/db/services/pipeline_operation_log_service.py index c3c333665ff..9846d79c123 100644 --- a/api/db/services/pipeline_operation_log_service.py +++ b/api/db/services/pipeline_operation_log_service.py @@ -121,7 +121,7 @@ def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl: else: ok, kb_info = KnowledgebaseService.get_by_id(document.kb_id) if not ok: - raise RuntimeError(f"Cannot find knowledge base {document.kb_id} for referred_document {referred_document_id}") + raise RuntimeError(f"Cannot find dataset {document.kb_id} for referred_document {referred_document_id}") tenant_id = kb_info.tenant_id title = document.parser_id @@ -169,11 +169,12 @@ def create(cls, document_id, pipeline_id, task_type, fake_document_ids=[], dsl: operation_status=operation_status, avatar=avatar, ) - log["create_time"] = current_timestamp() - log["create_date"] = datetime_format(datetime.now()) - log["update_time"] = current_timestamp() - log["update_date"] = datetime_format(datetime.now()) - + timestamp = current_timestamp() + datetime_now = datetime_format(datetime.now()) + log["create_time"] = timestamp + log["create_date"] = datetime_now + log["update_time"] = timestamp + log["update_date"] = datetime_now with DB.atomic(): obj = cls.save(**log) diff --git a/api/db/services/search_service.py b/api/db/services/search_service.py index 1c7687b5447..7366a9708b6 100644 --- a/api/db/services/search_service.py +++ b/api/db/services/search_service.py @@ -28,10 +28,13 @@ class SearchService(CommonService): @classmethod def save(cls, **kwargs): - kwargs["create_time"] = current_timestamp() - kwargs["create_date"] = datetime_format(datetime.now()) - kwargs["update_time"] = current_timestamp() - kwargs["update_date"] = datetime_format(datetime.now()) + current_ts = current_timestamp() + current_date = datetime_format(datetime.now()) + + kwargs["create_time"] = current_ts + kwargs["create_date"] = current_date + kwargs["update_time"] = current_ts + kwargs["update_date"] = current_date obj = cls.model.create(**kwargs) return obj diff --git a/api/db/services/task_service.py b/api/db/services/task_service.py index 9c771223f6e..065d2376dd7 100644 --- a/api/db/services/task_service.py +++ b/api/db/services/task_service.py @@ -76,7 +76,7 @@ def get_task(cls, task_id, doc_ids=[]): """Retrieve detailed task information by task ID. This method fetches comprehensive task details including associated document, - knowledge base, and tenant information. It also handles task retry logic and + dataset, and tenant information. It also handles task retry logic and progress updates. Args: @@ -121,6 +121,13 @@ def get_task(cls, task_id, doc_ids=[]): .where(cls.model.id == task_id) ) docs = list(docs.dicts()) + # Assuming docs = list(docs.dicts()) + if docs: + kb_config = docs[0]['kb_parser_config'] # Dict from Knowledgebase.parser_config + mineru_method = kb_config.get('mineru_parse_method', 'auto') + mineru_formula = kb_config.get('mineru_formula_enable', True) + mineru_table = kb_config.get('mineru_table_enable', True) + print(mineru_method, mineru_formula, mineru_table) if not docs: return None diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index f971be3d42a..65771f60f41 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -14,15 +14,17 @@ # limitations under the License. # import os +import json import logging +from peewee import IntegrityError from langfuse import Langfuse from common import settings -from common.constants import LLMType +from common.constants import MINERU_DEFAULT_CONFIG, MINERU_ENV_KEYS, LLMType from api.db.db_models import DB, LLMFactories, TenantLLM from api.db.services.common_service import CommonService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.user_service import TenantService -from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel +from rag.llm import ChatModel, CvModel, EmbeddingModel, OcrModel, RerankModel, Seq2txtModel, TTSModel class LLMFactoriesService(CommonService): @@ -95,7 +97,7 @@ def get_model_config(cls, tenant_id, llm_type, llm_name=None): if llm_type == LLMType.EMBEDDING.value: mdlnm = tenant.embd_id if not llm_name else llm_name elif llm_type == LLMType.SPEECH2TEXT.value: - mdlnm = tenant.asr_id + mdlnm = tenant.asr_id if not llm_name else llm_name elif llm_type == LLMType.IMAGE2TEXT.value: mdlnm = tenant.img2txt_id if not llm_name else llm_name elif llm_type == LLMType.CHAT.value: @@ -104,6 +106,10 @@ def get_model_config(cls, tenant_id, llm_type, llm_name=None): mdlnm = tenant.rerank_id if not llm_name else llm_name elif llm_type == LLMType.TTS: mdlnm = tenant.tts_id if not llm_name else llm_name + elif llm_type == LLMType.OCR: + if not llm_name: + raise LookupError("OCR model name is required") + mdlnm = llm_name else: assert False, "LLM type error" @@ -137,31 +143,31 @@ def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kw return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) - if llm_type == LLMType.RERANK: + elif llm_type == LLMType.RERANK: if model_config["llm_factory"] not in RerankModel: return None return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"]) - if llm_type == LLMType.IMAGE2TEXT.value: + elif llm_type == LLMType.IMAGE2TEXT.value: if model_config["llm_factory"] not in CvModel: return None return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs) - if llm_type == LLMType.CHAT.value: + elif llm_type == LLMType.CHAT.value: if model_config["llm_factory"] not in ChatModel: return None return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs) - if llm_type == LLMType.SPEECH2TEXT: + elif llm_type == LLMType.SPEECH2TEXT: if model_config["llm_factory"] not in Seq2txtModel: return None return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"]) - if llm_type == LLMType.TTS: + elif llm_type == LLMType.TTS: if model_config["llm_factory"] not in TTSModel: return None return TTSModel[model_config["llm_factory"]]( @@ -169,6 +175,17 @@ def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kw model_config["llm_name"], base_url=model_config["api_base"], ) + + elif llm_type == LLMType.OCR: + if model_config["llm_factory"] not in OcrModel: + return None + return OcrModel[model_config["llm_factory"]]( + key=model_config["api_key"], + model_name=model_config["llm_name"], + base_url=model_config.get("api_base", ""), + **kwargs, + ) + return None @classmethod @@ -186,6 +203,7 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None): LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name, LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name, LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name, + LLMType.OCR.value: llm_name, } mdlnm = llm_map.get(llm_type) @@ -218,6 +236,68 @@ def get_openai_models(cls): ~(cls.model.llm_name == "text-embedding-3-large")).dicts() return list(objs) + @classmethod + def _collect_mineru_env_config(cls) -> dict | None: + cfg = MINERU_DEFAULT_CONFIG + found = False + for key in MINERU_ENV_KEYS: + val = os.environ.get(key) + if val: + found = True + cfg[key] = val + return cfg if found else None + + @classmethod + @DB.connection_context() + def ensure_mineru_from_env(cls, tenant_id: str) -> str | None: + """ + Ensure a MinerU OCR model exists for the tenant if env variables are present. + Return the existing or newly created llm_name, or None if env not set. + """ + cfg = cls._collect_mineru_env_config() + if not cfg: + return None + + saved_mineru_models = cls.query(tenant_id=tenant_id, llm_factory="MinerU", model_type=LLMType.OCR.value) + + def _parse_api_key(raw: str) -> dict: + try: + return json.loads(raw or "{}") + except Exception: + return {} + + for item in saved_mineru_models: + api_cfg = _parse_api_key(item.api_key) + normalized = {k: api_cfg.get(k, MINERU_DEFAULT_CONFIG.get(k)) for k in MINERU_ENV_KEYS} + if normalized == cfg: + return item.llm_name + + used_names = {item.llm_name for item in saved_mineru_models} + idx = 1 + base_name = "mineru-from-env" + while True: + candidate = f"{base_name}-{idx}" + if candidate in used_names: + idx += 1 + continue + + try: + cls.save( + tenant_id=tenant_id, + llm_factory="MinerU", + llm_name=candidate, + model_type=LLMType.OCR.value, + api_key=json.dumps(cfg), + api_base="", + max_tokens=0, + ) + return candidate + except IntegrityError: + logging.warning("MinerU env model %s already exists for tenant %s, retry with next name", candidate, tenant_id) + used_names.add(candidate) + idx += 1 + continue + @classmethod @DB.connection_context() def delete_by_tenant_id(cls, tenant_id): diff --git a/api/db/services/user_service.py b/api/db/services/user_service.py index b5e754dbd24..20d8c3230f6 100644 --- a/api/db/services/user_service.py +++ b/api/db/services/user_service.py @@ -116,10 +116,13 @@ def save(cls, **kwargs): kwargs["password"] = generate_password_hash( str(kwargs["password"])) - kwargs["create_time"] = current_timestamp() - kwargs["create_date"] = datetime_format(datetime.now()) - kwargs["update_time"] = current_timestamp() - kwargs["update_date"] = datetime_format(datetime.now()) + current_ts = current_timestamp() + current_date = datetime_format(datetime.now()) + + kwargs["create_time"] = current_ts + kwargs["create_date"] = current_date + kwargs["update_time"] = current_ts + kwargs["update_date"] = current_date obj = cls.model(**kwargs).save(force_insert=True) return obj @@ -161,7 +164,7 @@ def is_admin(cls, user_id): @classmethod @DB.connection_context() def get_all_users(cls): - users = cls.model.select() + users = cls.model.select().order_by(cls.model.email) return list(users) diff --git a/api/ragflow_server.py b/api/ragflow_server.py index 868e054aeb1..26cd045c4de 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -20,28 +20,26 @@ from common.log_utils import init_root_logger from plugin import GlobalPluginManager -init_root_logger("ragflow_server") import logging import os import signal import sys -import time import traceback import threading import uuid +import faulthandler -from werkzeug.serving import run_simple -from api.apps import app, smtp_mail_server +from api.apps import app from api.db.runtime_config import RuntimeConfig from api.db.services.document_service import DocumentService from common.file_utils import get_project_base_directory from common import settings from api.db.db_models import init_database_tables as init_web_db -from api.db.init_data import init_web_data +from api.db.init_data import init_web_data, init_superuser from common.versions import get_ragflow_version from common.config_utils import show_configs -from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions +from common.mcp_tool_call_conn import shutdown_all_mcp_sessions from rag.utils.redis_conn import RedisDistributedLock stop_event = threading.Event() @@ -70,10 +68,12 @@ def signal_handler(sig, frame): logging.info("Received interrupt signal, shutting down...") shutdown_all_mcp_sessions() stop_event.set() - time.sleep(1) + stop_event.wait(1) sys.exit(0) if __name__ == '__main__': + faulthandler.enable() + init_root_logger("ragflow_server") logging.info(r""" ____ ___ ______ ______ __ / __ \ / | / ____// ____// /____ _ __ @@ -110,11 +110,16 @@ def signal_handler(sig, frame): parser.add_argument( "--debug", default=False, help="debug mode", action="store_true" ) + parser.add_argument( + "--init-superuser", default=False, help="init superuser", action="store_true" + ) args = parser.parse_args() if args.version: print(get_ragflow_version()) sys.exit(0) + if args.init_superuser: + init_superuser() RuntimeConfig.DEBUG = args.debug if RuntimeConfig.DEBUG: logging.info("run on debug mode") @@ -138,31 +143,12 @@ def delayed_start_update_progress(): else: threading.Timer(1.0, delayed_start_update_progress).start() - # init smtp server - if settings.SMTP_CONF: - app.config["MAIL_SERVER"] = settings.MAIL_SERVER - app.config["MAIL_PORT"] = settings.MAIL_PORT - app.config["MAIL_USE_SSL"] = settings.MAIL_USE_SSL - app.config["MAIL_USE_TLS"] = settings.MAIL_USE_TLS - app.config["MAIL_USERNAME"] = settings.MAIL_USERNAME - app.config["MAIL_PASSWORD"] = settings.MAIL_PASSWORD - app.config["MAIL_DEFAULT_SENDER"] = settings.MAIL_DEFAULT_SENDER - smtp_mail_server.init_app(app) - - # start http server try: logging.info("RAGFlow HTTP server start...") - run_simple( - hostname=settings.HOST_IP, - port=settings.HOST_PORT, - application=app, - threaded=True, - use_reloader=RuntimeConfig.DEBUG, - use_debugger=RuntimeConfig.DEBUG, - ) + app.run(host=settings.HOST_IP, port=settings.HOST_PORT) except Exception: traceback.print_exc() stop_event.set() - time.sleep(1) + stop_event.wait(1) os.kill(os.getpid(), signal.SIGKILL) diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 4cace9eca02..afb4ff772de 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -14,30 +14,30 @@ # limitations under the License. # +import asyncio import functools +import inspect import json import logging import os import time from copy import deepcopy from functools import wraps +from typing import Any import requests -import trio -from flask import ( +from quart import ( Response, jsonify, + request ) -from flask_login import current_user -from flask import ( - request as flask_request, -) + from peewee import OperationalError from common.constants import ActiveEnum from api.db.db_models import APIToken from api.utils.json_encode import CustomJSONEncoder -from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions +from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from api.db.services.tenant_llm_service import LLMFactoriesService from common.connection_utils import timeout from common.constants import RetCode @@ -46,6 +46,41 @@ requests.models.complexjson.dumps = functools.partial(json.dumps, cls=CustomJSONEncoder) +async def _coerce_request_data() -> dict: + """Fetch JSON body with sane defaults; fallback to form data.""" + payload: Any = None + last_error: Exception | None = None + + try: + payload = await request.get_json(force=True, silent=True) + except Exception as e: + last_error = e + payload = None + + if payload is None: + try: + form = await request.form + payload = form.to_dict() + except Exception as e: + last_error = e + payload = None + + if payload is None: + if last_error is not None: + raise last_error + raise ValueError("No JSON body or form data found in request.") + + if isinstance(payload, dict): + return payload or {} + + if isinstance(payload, str): + raise AttributeError("'str' object has no attribute 'get'") + + raise TypeError(f"Unsupported request payload type: {type(payload)!r}") + +async def get_request_json(): + return await _coerce_request_data() + def serialize_for_json(obj): """ Recursively serialize objects to make them JSON serializable. @@ -84,7 +119,8 @@ def get_data_error_result(code=RetCode.DATA_ERROR, message="Sorry! Data missing! def server_error_response(e): - logging.exception(e) + # Quart invokes this handler outside the original except block, so we must pass exc_info manually. + logging.error("Unhandled exception during request", exc_info=(type(e), e, e.__traceback__)) try: msg = repr(e).lower() if getattr(e, "code", None) == 401 or ("unauthorized" in msg) or ("401" in msg): @@ -105,31 +141,38 @@ def server_error_response(e): def validate_request(*args, **kwargs): + def process_args(input_arguments): + no_arguments = [] + error_arguments = [] + for arg in args: + if arg not in input_arguments: + no_arguments.append(arg) + for k, v in kwargs.items(): + config_value = input_arguments.get(k, None) + if config_value is None: + no_arguments.append(k) + elif isinstance(v, (tuple, list)): + if config_value not in v: + error_arguments.append((k, set(v))) + elif config_value != v: + error_arguments.append((k, v)) + if no_arguments or error_arguments: + error_string = "" + if no_arguments: + error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) + if error_arguments: + error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) + return error_string + return None + def wrapper(func): @wraps(func) - def decorated_function(*_args, **_kwargs): - input_arguments = flask_request.json or flask_request.form.to_dict() - no_arguments = [] - error_arguments = [] - for arg in args: - if arg not in input_arguments: - no_arguments.append(arg) - for k, v in kwargs.items(): - config_value = input_arguments.get(k, None) - if config_value is None: - no_arguments.append(k) - elif isinstance(v, (tuple, list)): - if config_value not in v: - error_arguments.append((k, set(v))) - elif config_value != v: - error_arguments.append((k, v)) - if no_arguments or error_arguments: - error_string = "" - if no_arguments: - error_string += "required argument are missing: {}; ".format(",".join(no_arguments)) - if error_arguments: - error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments])) - return get_json_result(code=RetCode.ARGUMENT_ERROR, message=error_string) + async def decorated_function(*_args, **_kwargs): + errs = process_args(await _coerce_request_data()) + if errs: + return get_json_result(code=RetCode.ARGUMENT_ERROR, message=errs) + if inspect.iscoroutinefunction(func): + return await func(*_args, **_kwargs) return func(*_args, **_kwargs) return decorated_function @@ -138,30 +181,34 @@ def decorated_function(*_args, **_kwargs): def not_allowed_parameters(*params): - def decorator(f): - def wrapper(*args, **kwargs): - input_arguments = flask_request.json or flask_request.form.to_dict() + def decorator(func): + async def wrapper(*args, **kwargs): + input_arguments = await _coerce_request_data() for param in params: if param in input_arguments: return get_json_result(code=RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed") - return f(*args, **kwargs) - + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + return func(*args, **kwargs) return wrapper return decorator -def active_required(f): - @wraps(f) - def wrapper(*args, **kwargs): +def active_required(func): + @wraps(func) + async def wrapper(*args, **kwargs): from api.db.services import UserService + from api.apps import current_user user_id = current_user.id usr = UserService.filter_by_id(user_id) # check is_active if not usr or not usr.is_active == ActiveEnum.ACTIVE.value: return get_json_result(code=RetCode.FORBIDDEN, message="User isn't active, please activate first.") - return f(*args, **kwargs) + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + return func(*args, **kwargs) return wrapper @@ -173,12 +220,15 @@ def get_json_result(code: RetCode = RetCode.SUCCESS, message="success", data=Non def apikey_required(func): @wraps(func) - def decorated_function(*args, **kwargs): - token = flask_request.headers.get("Authorization").split()[1] + async def decorated_function(*args, **kwargs): + token = request.headers.get("Authorization").split()[1] objs = APIToken.query(token=token) if not objs: return build_error_result(message="API-KEY is invalid!", code=RetCode.FORBIDDEN) kwargs["tenant_id"] = objs[0].tenant_id + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + return func(*args, **kwargs) return decorated_function @@ -199,23 +249,38 @@ def construct_json_result(code: RetCode = RetCode.SUCCESS, message="success", da def token_required(func): - @wraps(func) - def decorated_function(*args, **kwargs): + def get_tenant_id(**kwargs): if os.environ.get("DISABLE_SDK"): - return get_json_result(data=False, message="`Authorization` can't be empty") - authorization_str = flask_request.headers.get("Authorization") + return False, get_json_result(data=False, message="`Authorization` can't be empty") + authorization_str = request.headers.get("Authorization") if not authorization_str: - return get_json_result(data=False, message="`Authorization` can't be empty") + return False, get_json_result(data=False, message="`Authorization` can't be empty") authorization_list = authorization_str.split() if len(authorization_list) < 2: - return get_json_result(data=False, message="Please check your authorization format.") + return False, get_json_result(data=False, message="Please check your authorization format.") token = authorization_list[1] objs = APIToken.query(token=token) if not objs: - return get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR) + return False, get_json_result(data=False, message="Authentication error: API key is invalid!", code=RetCode.AUTHENTICATION_ERROR) kwargs["tenant_id"] = objs[0].tenant_id + return True, kwargs + + @wraps(func) + def decorated_function(*args, **kwargs): + e, kwargs = get_tenant_id(**kwargs) + if not e: + return kwargs return func(*args, **kwargs) + @wraps(func) + async def adecorated_function(*args, **kwargs): + e, kwargs = get_tenant_id(**kwargs) + if not e: + return kwargs + return await func(*args, **kwargs) + + if inspect.iscoroutinefunction(func): + return adecorated_function return decorated_function @@ -279,6 +344,10 @@ def get_parser_config(chunk_method, parser_config): chunk_method = "naive" # Define default configurations for each chunking method + base_defaults = { + "table_context_size": 0, + "image_context_size": 0, + } key_mapping = { "naive": { "layout_recognize": "DeepDOC", @@ -331,16 +400,19 @@ def get_parser_config(chunk_method, parser_config): default_config = key_mapping[chunk_method] - # If no parser_config provided, return default + # If no parser_config provided, return default merged with base defaults if not parser_config: - return default_config + if default_config is None: + return deep_merge(base_defaults, {}) + return deep_merge(base_defaults, default_config) # If parser_config is provided, merge with defaults to ensure required fields exist if default_config is None: - return parser_config + return deep_merge(base_defaults, parser_config) - # Ensure raptor and graphrag fields have default values if not provided - merged_config = deep_merge(default_config, parser_config) + # Ensure raptor and graph_rag fields have default values if not provided + merged_config = deep_merge(base_defaults, default_config) + merged_config = deep_merge(merged_config, parser_config) return merged_config @@ -610,18 +682,32 @@ async def is_strong_enough(chat_model, embedding_model): async def _is_strong_enough(): nonlocal chat_model, embedding_model if embedding_model: - with trio.fail_after(10): - _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) + await asyncio.wait_for( + asyncio.to_thread(embedding_model.encode, ["Are you strong enough!?"]), + timeout=10 + ) + if chat_model: - with trio.fail_after(30): - res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}], {})) - if res.find("**ERROR**") >= 0: + res = await asyncio.wait_for( + chat_model.async_chat("Nothing special.", [{"role": "user", "content": "Are you strong enough!?"}]), + timeout=30 + ) + if "**ERROR**" in res: raise Exception(res) # Pressure test for GraphRAG task - async with trio.open_nursery() as nursery: - for _ in range(count): - nursery.start_soon(_is_strong_enough) + tasks = [ + asyncio.create_task(_is_strong_enough()) + for _ in range(count) + ] + try: + await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Pressure test failed: {e}") + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise def get_allowed_llm_factories() -> list: diff --git a/api/utils/commands.py b/api/utils/commands.py index a1a8d025aca..a3df7b507dd 100644 --- a/api/utils/commands.py +++ b/api/utils/commands.py @@ -18,7 +18,7 @@ import click import re -from flask import Flask +from quart import Quart from werkzeug.security import generate_password_hash from api.db.services import UserService @@ -73,6 +73,7 @@ def reset_email(email, new_email, email_confirm): UserService.update_user(user[0].id,user_dict) click.echo(click.style('Congratulations!, email has been reset.', fg='green')) -def register_commands(app: Flask): + +def register_commands(app: Quart): app.cli.add_command(reset_password) app.cli.add_command(reset_email) diff --git a/api/utils/email_templates.py b/api/utils/email_templates.py index 10473908a88..767eb7b9270 100644 --- a/api/utils/email_templates.py +++ b/api/utils/email_templates.py @@ -1,21 +1,37 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + """ Reusable HTML email templates and registry. """ # Invitation email template INVITE_EMAIL_TMPL = """ -

Hi {{email}},

-

{{inviter}} has invited you to join their team (ID: {{tenant_id}}).

-

Click the link below to complete your registration:
-{{invite_url}}

-

If you did not request this, please ignore this email.

+Hi {{email}}, +{{inviter}} has invited you to join their team (ID: {{tenant_id}}). +Click the link below to complete your registration: +{{invite_url}} +If you did not request this, please ignore this email. """ # Password reset code template RESET_CODE_EMAIL_TMPL = """ -

Hello,

-

Your password reset code is: {{ code }}

-

This code will expire in {{ ttl_min }} minutes.

+Hello, +Your password reset code is: {{ code }} +This code will expire in {{ ttl_min }} minutes. """ # Template registry diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 5f0fa70f451..4cad64c35ce 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -42,7 +42,7 @@ def filename_type(filename): if re.match(r".*\.pdf$", filename): return FileType.PDF.value - if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename): + if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|mdx|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename): return FileType.DOC.value if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename): @@ -164,3 +164,23 @@ def try_open(blob): return repaired return blob + + +def sanitize_path(raw_path: str | None) -> str: + """Normalize and sanitize a user-provided path segment. + + - Converts backslashes to forward slashes + - Strips leading/trailing slashes + - Removes '.' and '..' segments + - Restricts characters to A-Za-z0-9, underscore, dash, and '/' + """ + if not raw_path: + return "" + backslash_re = re.compile(r"[\\]+") + unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]") + normalized = backslash_re.sub("/", raw_path) + normalized = normalized.strip("/") + parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")] + sanitized = "/".join(parts) + sanitized = unsafe_re.sub("", sanitized) + return sanitized diff --git a/api/utils/health_utils.py b/api/utils/health_utils.py index 88e5aaebbee..0a7ab6e7a6f 100644 --- a/api/utils/health_utils.py +++ b/api/utils/health_utils.py @@ -173,7 +173,8 @@ def check_task_executor_alive(): heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats] task_executor_heartbeats[task_executor_id] = heartbeats if task_executor_heartbeats: - return {"status": "alive", "message": task_executor_heartbeats} + status = "alive" if any(task_executor_heartbeats.values()) else "timeout" + return {"status": status, "message": task_executor_heartbeats} else: return {"status": "timeout", "message": "Not found any task executor."} except Exception as e: diff --git a/api/utils/json_encode.py b/api/utils/json_encode.py index b21addd4f9b..fa5ea973aa0 100644 --- a/api/utils/json_encode.py +++ b/api/utils/json_encode.py @@ -1,3 +1,19 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import datetime import json from enum import Enum, IntEnum diff --git a/api/utils/memory_utils.py b/api/utils/memory_utils.py new file mode 100644 index 00000000000..bb78949518b --- /dev/null +++ b/api/utils/memory_utils.py @@ -0,0 +1,54 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List +from common.constants import MemoryType + +def format_ret_data_from_memory(memory): + return { + "id": memory.id, + "name": memory.name, + "avatar": memory.avatar, + "tenant_id": memory.tenant_id, + "owner_name": memory.owner_name if hasattr(memory, "owner_name") else None, + "memory_type": get_memory_type_human(memory.memory_type), + "storage_type": memory.storage_type, + "embd_id": memory.embd_id, + "llm_id": memory.llm_id, + "permissions": memory.permissions, + "description": memory.description, + "memory_size": memory.memory_size, + "forgetting_policy": memory.forgetting_policy, + "temperature": memory.temperature, + "system_prompt": memory.system_prompt, + "user_prompt": memory.user_prompt, + "create_time": memory.create_time, + "create_date": memory.create_date, + "update_time": memory.update_time, + "update_date": memory.update_date + } + + +def get_memory_type_human(memory_type: int) -> List[str]: + return [mem_type.name.lower() for mem_type in MemoryType if memory_type & mem_type.value] + + +def calculate_memory_type(memory_type_name_list: List[str]) -> int: + memory_type = 0 + type_value_map = {mem_type.name.lower(): mem_type.value for mem_type in MemoryType} + for mem_type in memory_type_name_list: + if mem_type in type_value_map: + memory_type |= type_value_map[mem_type] + return memory_type diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index caf3f0924aa..2dcace53fe9 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -14,10 +14,11 @@ # limitations under the License. # from collections import Counter +import string from typing import Annotated, Any, Literal from uuid import UUID -from flask import Request +from quart import Request from pydantic import ( BaseModel, ConfigDict, @@ -25,6 +26,7 @@ StringConstraints, ValidationError, field_validator, + model_validator, ) from pydantic_core import PydanticCustomError from werkzeug.exceptions import BadRequest, UnsupportedMediaType @@ -32,7 +34,7 @@ from api.constants import DATASET_NAME_LIMIT -def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: +async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: """ Validates and parses JSON requests through a multi-stage validation pipeline. @@ -81,7 +83,7 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel] from the final output after validation """ try: - payload = request.get_json() or {} + payload = await request.get_json() or {} except UnsupportedMediaType: return None, f"Unsupported content type: Expected application/json, got {request.content_type}" except BadRequest: @@ -329,6 +331,7 @@ class RaptorConfig(Base): threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)] max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)] random_seed: Annotated[int, Field(default=0, ge=0)] + auto_disable_for_structured_data: Annotated[bool, Field(default=True)] class GraphragConfig(Base): @@ -361,10 +364,9 @@ class CreateDatasetReq(Base): description: Annotated[str | None, Field(default=None, max_length=65535)] embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")] permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)] - chunk_method: Annotated[ - Literal["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], - Field(default="naive", min_length=1, max_length=32, serialization_alias="parser_id"), - ] + chunk_method: Annotated[str | None, Field(default=None, serialization_alias="parser_id")] + parse_type: Annotated[int | None, Field(default=None, ge=0, le=64)] + pipeline_id: Annotated[str | None, Field(default=None, min_length=32, max_length=32, serialization_alias="pipeline_id")] parser_config: Annotated[ParserConfig | None, Field(default=None)] @field_validator("avatar", mode="after") @@ -525,6 +527,93 @@ def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserCon raise PydanticCustomError("string_too_long", "Parser config exceeds size limit (max 65,535 characters). Current size: {actual}", {"actual": len(json_str)}) return v + @field_validator("pipeline_id", mode="after") + @classmethod + def validate_pipeline_id(cls, v: str | None) -> str | None: + """Validate pipeline_id as 32-char lowercase hex string if provided. + + Rules: + - None or empty string: treat as None (not set) + - Must be exactly length 32 + - Must contain only hex digits (0-9a-fA-F); normalized to lowercase + """ + if v is None: + return None + if v == "": + return None + if len(v) != 32: + raise PydanticCustomError("format_invalid", "pipeline_id must be 32 hex characters") + if any(ch not in string.hexdigits for ch in v): + raise PydanticCustomError("format_invalid", "pipeline_id must be hexadecimal") + return v.lower() + + @model_validator(mode="after") + def validate_parser_dependency(self) -> "CreateDatasetReq": + """ + Mixed conditional validation: + - If parser_id is omitted (field not set): + * If both parse_type and pipeline_id are omitted → default chunk_method = "naive" + * If both parse_type and pipeline_id are provided → allow ingestion pipeline mode + - If parser_id is provided (valid enum) → parse_type and pipeline_id must be None (disallow mixed usage) + + Raises: + PydanticCustomError with code 'dependency_error' on violation. + """ + # Omitted chunk_method (not in fields) logic + if self.chunk_method is None and "chunk_method" not in self.model_fields_set: + # All three absent → default naive + if self.parse_type is None and self.pipeline_id is None: + object.__setattr__(self, "chunk_method", "naive") + return self + # parser_id omitted: require BOTH parse_type & pipeline_id present (no partial allowed) + if self.parse_type is None or self.pipeline_id is None: + missing = [] + if self.parse_type is None: + missing.append("parse_type") + if self.pipeline_id is None: + missing.append("pipeline_id") + raise PydanticCustomError( + "dependency_error", + "parser_id omitted → required fields missing: {fields}", + {"fields": ", ".join(missing)}, + ) + # Both provided → allow pipeline mode + return self + + # parser_id provided (valid): MUST NOT have parse_type or pipeline_id + if isinstance(self.chunk_method, str): + if self.parse_type is not None or self.pipeline_id is not None: + invalid = [] + if self.parse_type is not None: + invalid.append("parse_type") + if self.pipeline_id is not None: + invalid.append("pipeline_id") + raise PydanticCustomError( + "dependency_error", + "parser_id provided → disallowed fields present: {fields}", + {"fields": ", ".join(invalid)}, + ) + return self + + @field_validator("chunk_method", mode="wrap") + @classmethod + def validate_chunk_method(cls, v: Any, handler) -> Any: + """Wrap validation to unify error messages, including type errors (e.g. list).""" + allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"} + error_msg = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" + # Omitted field: handler won't be invoked (wrap still gets value); None treated as explicit invalid + if v is None: + raise PydanticCustomError("literal_error", error_msg) + try: + # Run inner validation (type checking) + result = handler(v) + except Exception: + raise PydanticCustomError("literal_error", error_msg) + # After handler, enforce enumeration + if not isinstance(result, str) or result == "" or result not in allowed: + raise PydanticCustomError("literal_error", error_msg) + return result + class UpdateDatasetReq(CreateDatasetReq): dataset_id: Annotated[str, Field(...)] diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py index e0e47f472e6..11e8428b77c 100644 --- a/api/utils/web_utils.py +++ b/api/utils/web_utils.py @@ -20,10 +20,11 @@ import re import socket from urllib.parse import urlparse - -from api.apps import smtp_mail_server -from flask_mail import Message -from flask import render_template_string +import aiosmtplib +from email.mime.text import MIMEText +from email.header import Header +from common import settings +from quart import render_template_string from api.utils.email_templates import EMAIL_TEMPLATES from selenium import webdriver from selenium.common.exceptions import TimeoutException @@ -35,11 +36,11 @@ from webdriver_manager.chrome import ChromeDriverManager -OTP_LENGTH = 8 -OTP_TTL_SECONDS = 5 * 60 -ATTEMPT_LIMIT = 5 -ATTEMPT_LOCK_SECONDS = 30 * 60 -RESEND_COOLDOWN_SECONDS = 60 +OTP_LENGTH = 4 +OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes +ATTEMPT_LIMIT = 5 # maximum attempts +ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes +RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute CONTENT_TYPE_MAP = { @@ -68,6 +69,7 @@ # Web "md": "text/markdown", "markdown": "text/markdown", + "mdx": "text/markdown", "htm": "text/html", "html": "text/html", "json": "application/json", @@ -183,27 +185,34 @@ def get_float(req: dict, key: str, default: float | int = 10.0) -> float: return parsed if parsed > 0 else default except (TypeError, ValueError): return default + + +async def send_email_html(to_email: str, subject: str, template_key: str, **context): + + body = await render_template_string(EMAIL_TEMPLATES.get(template_key), **context) + msg = MIMEText(body, "plain", "utf-8") + msg["Subject"] = Header(subject, "utf-8") + msg["From"] = f"{settings.MAIL_DEFAULT_SENDER[0]} <{settings.MAIL_DEFAULT_SENDER[1]}>" + msg["To"] = to_email + smtp = aiosmtplib.SMTP( + hostname=settings.MAIL_SERVER, + port=settings.MAIL_PORT, + use_tls=True, + timeout=10, + ) -def send_email_html(subject: str, to_email: str, template_key: str, **context): - """Generic HTML email sender using shared templates. - template_key must exist in EMAIL_TEMPLATES. - """ - from api.apps import app - tmpl = EMAIL_TEMPLATES.get(template_key) - if not tmpl: - raise ValueError(f"Unknown email template: {template_key}") - with app.app_context(): - msg = Message(subject=subject, recipients=[to_email]) - msg.html = render_template_string(tmpl, **context) - smtp_mail_server.send(msg) + await smtp.connect() + await smtp.login(settings.MAIL_USERNAME, settings.MAIL_PASSWORD) + await smtp.send_message(msg) + await smtp.quit() -def send_invite_email(to_email, invite_url, tenant_id, inviter): +async def send_invite_email(to_email, invite_url, tenant_id, inviter): # Reuse the generic HTML sender with 'invite' template - send_email_html( - subject="RAGFlow Invitation", + await send_email_html( to_email=to_email, + subject="RAGFlow Invitation", template_key="invite", email=to_email, invite_url=invite_url, @@ -230,4 +239,4 @@ def hash_code(code: str, salt: bytes) -> str: def captcha_key(email: str) -> str: return f"captcha:{email}" - + \ No newline at end of file diff --git a/check_comment_ascii.py b/check_comment_ascii.py new file mode 100644 index 00000000000..57d188b6c2d --- /dev/null +++ b/check_comment_ascii.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +""" +Check whether given python files contain non-ASCII comments. + +How to check the whole git repo: + +``` +$ git ls-files -z -- '*.py' | xargs -0 python3 check_comment_ascii.py +``` +""" + +import sys +import tokenize +import ast +import pathlib +import re + +ASCII = re.compile(r"^[\n -~]*\Z") # Printable ASCII + newline + + +def check(src: str, name: str) -> int: + """ + docstring line 1 + docstring line 2 + """ + ok = 1 + # A common comment begins with `#` + with tokenize.open(src) as fp: + for tk in tokenize.generate_tokens(fp.readline): + if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string): + print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}") + ok = 0 + # A docstring begins and ends with `'''` + for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc): + print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}") + ok = 0 + return ok + + +if __name__ == "__main__": + status = 0 + for file in sys.argv[1:]: + if not check(file, file): + status = 1 + sys.exit(status) diff --git a/common/connection_utils.py b/common/connection_utils.py index 618584ae978..86ebc371d8c 100644 --- a/common/connection_utils.py +++ b/common/connection_utils.py @@ -19,9 +19,8 @@ import threading from typing import Any, Callable, Coroutine, Optional, Type, Union import asyncio -import trio from functools import wraps -from flask import make_response, jsonify +from quart import make_response, jsonify from common.constants import RetCode TimeoutException = Union[Type[BaseException], BaseException] @@ -70,11 +69,10 @@ async def async_wrapper(*args, **kwargs) -> Any: for a in range(attempts): try: if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): - with trio.fail_after(seconds): - return await func(*args, **kwargs) + return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds) else: return await func(*args, **kwargs) - except trio.TooSlowError: + except asyncio.TimeoutError: if a < attempts - 1: continue if on_timeout is not None: @@ -103,7 +101,7 @@ async def async_wrapper(*args, **kwargs) -> Any: return decorator -def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): +async def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): result_dict = {"code": code, "message": message, "data": data} response_dict = {} for key, value in result_dict.items(): @@ -111,7 +109,27 @@ def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth= continue else: response_dict[key] = value - response = make_response(jsonify(response_dict)) + response = await make_response(jsonify(response_dict)) + if auth: + response.headers["Authorization"] = auth + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Method"] = "*" + response.headers["Access-Control-Allow-Headers"] = "*" + response.headers["Access-Control-Allow-Headers"] = "*" + response.headers["Access-Control-Expose-Headers"] = "Authorization" + return response + + +def sync_construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): + import flask + result_dict = {"code": code, "message": message, "data": data} + response_dict = {} + for key, value in result_dict.items(): + if value is None and key != "code": + continue + else: + response_dict[key] = value + response = flask.make_response(flask.jsonify(response_dict)) if auth: response.headers["Authorization"] = auth response.headers["Access-Control-Allow-Origin"] = "*" diff --git a/common/constants.py b/common/constants.py index dd24b4ead7e..23a75505941 100644 --- a/common/constants.py +++ b/common/constants.py @@ -49,10 +49,12 @@ class RetCode(IntEnum, CustomEnum): RUNNING = 106 PERMISSION_ERROR = 108 AUTHENTICATION_ERROR = 109 + BAD_REQUEST = 400 UNAUTHORIZED = 401 SERVER_ERROR = 500 FORBIDDEN = 403 NOT_FOUND = 404 + CONFLICT = 409 class StatusEnum(Enum): @@ -72,6 +74,7 @@ class LLMType(StrEnum): IMAGE2TEXT = 'image2text' RERANK = 'rerank' TTS = 'tts' + OCR = 'ocr' class TaskStatus(StrEnum): @@ -118,7 +121,18 @@ class FileSource(StrEnum): SHAREPOINT = "sharepoint" SLACK = "slack" TEAMS = "teams" - + WEBDAV = "webdav" + MOODLE = "moodle" + DROPBOX = "dropbox" + BOX = "box" + R2 = "r2" + OCI_STORAGE = "oci_storage" + GOOGLE_CLOUD_STORAGE = "google_cloud_storage" + AIRTABLE = "airtable" + ASANA = "asana" + GITHUB = "github" + GITLAB = "gitlab" + IMAP = "imap" class PipelineTaskType(StrEnum): PARSE = "Parse" @@ -126,6 +140,7 @@ class PipelineTaskType(StrEnum): RAPTOR = "RAPTOR" GRAPH_RAG = "GraphRAG" MINDMAP = "Mindmap" + MEMORY = "Memory" VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, @@ -144,6 +159,24 @@ class Storage(Enum): AWS_S3 = 4 OSS = 5 OPENDAL = 6 + GCS = 7 + + +class MemoryType(Enum): + RAW = 0b0001 # 1 << 0 = 1 (0b00000001) + SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010) + EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100) + PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000) + + +class MemoryStorageType(StrEnum): + TABLE = "table" + GRAPH = "graph" + + +class ForgettingPolicy(StrEnum): + FIFO = "FIFO" + # environment # ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT" @@ -194,3 +227,13 @@ class Storage(Enum): SVR_QUEUE_NAME = "rag_flow_svr_queue" SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" TAG_FLD = "tag_feas" + + +MINERU_ENV_KEYS = ["MINERU_APISERVER", "MINERU_OUTPUT_DIR", "MINERU_BACKEND", "MINERU_SERVER_URL", "MINERU_DELETE_OUTPUT"] +MINERU_DEFAULT_CONFIG = { + "MINERU_APISERVER": "", + "MINERU_OUTPUT_DIR": "", + "MINERU_BACKEND": "pipeline", + "MINERU_SERVER_URL": "", + "MINERU_DELETE_OUTPUT": 1, +} diff --git a/common/crypto_utils.py b/common/crypto_utils.py new file mode 100644 index 00000000000..5dcbd2937fa --- /dev/null +++ b/common/crypto_utils.py @@ -0,0 +1,374 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from cryptography.hazmat.primitives import hashes + + +class BaseCrypto: + """Base class for cryptographic algorithms""" + + # Magic header to identify encrypted data + ENCRYPTED_MAGIC = b'RAGF' + + def __init__(self, key, iv=None, block_size=16, key_length=32, iv_length=16): + """ + Initialize cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + block_size: Block size + key_length: Key length + iv_length: Initialization vector length + """ + self.block_size = block_size + self.key_length = key_length + self.iv_length = iv_length + + # Normalize key + self.key = self._normalize_key(key) + self.iv = iv + + def _normalize_key(self, key): + """Normalize key length""" + if isinstance(key, str): + key = key.encode('utf-8') + + # Use PBKDF2 for key derivation to ensure correct key length + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=self.key_length, + salt=b"ragflow_crypto_salt", # Fixed salt to ensure consistent key derivation results + iterations=100000, + backend=default_backend() + ) + + return kdf.derive(key) + + def encrypt(self, data): + """ + Encrypt data (template method) + + Args: + data: Data to encrypt (bytes) + + Returns: + Encrypted data (bytes), format: magic_header + iv + encrypted_data + """ + # Generate random IV + iv = os.urandom(self.iv_length) if not self.iv else self.iv + + # Use PKCS7 padding + padder = padding.PKCS7(self.block_size * 8).padder() + padded_data = padder.update(data) + padder.finalize() + + # Delegate to subclass for specific encryption + ciphertext = self._encrypt(padded_data, iv) + + # Return Magic Header + IV + encrypted data + return self.ENCRYPTED_MAGIC + iv + ciphertext + + def decrypt(self, encrypted_data): + """ + Decrypt data (template method) + + Args: + encrypted_data: Encrypted data (bytes) + + Returns: + Decrypted data (bytes) + """ + # Check if data is encrypted by magic header + if not encrypted_data.startswith(self.ENCRYPTED_MAGIC): + # Not encrypted, return as-is + return encrypted_data + + # Remove magic header + encrypted_data = encrypted_data[len(self.ENCRYPTED_MAGIC):] + + # Separate IV and encrypted data + iv = encrypted_data[:self.iv_length] + ciphertext = encrypted_data[self.iv_length:] + + # Delegate to subclass for specific decryption + padded_data = self._decrypt(ciphertext, iv) + + # Remove padding + unpadder = padding.PKCS7(self.block_size * 8).unpadder() + data = unpadder.update(padded_data) + unpadder.finalize() + + return data + + def _encrypt(self, padded_data, iv): + """ + Encrypt padded data with specific algorithm + + Args: + padded_data: Padded data to encrypt + iv: Initialization vector + + Returns: + Encrypted data + """ + raise NotImplementedError("_encrypt method must be implemented by subclass") + + def _decrypt(self, ciphertext, iv): + """ + Decrypt ciphertext with specific algorithm + + Args: + ciphertext: Ciphertext to decrypt + iv: Initialization vector + + Returns: + Decrypted padded data + """ + raise NotImplementedError("_decrypt method must be implemented by subclass") + + +class AESCrypto(BaseCrypto): + """Base class for AES cryptographic algorithm""" + + def __init__(self, key, iv=None, key_length=32): + """ + Initialize AES cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + key_length: Key length (16 for AES-128, 32 for AES-256) + """ + super().__init__(key, iv, block_size=16, key_length=key_length, iv_length=16) + + def _encrypt(self, padded_data, iv): + """AES encryption implementation""" + # Create encryptor + cipher = Cipher( + algorithms.AES(self.key), + modes.CBC(iv), + backend=default_backend() + ) + encryptor = cipher.encryptor() + + # Encrypt data + return encryptor.update(padded_data) + encryptor.finalize() + + def _decrypt(self, ciphertext, iv): + """AES decryption implementation""" + # Create decryptor + cipher = Cipher( + algorithms.AES(self.key), + modes.CBC(iv), + backend=default_backend() + ) + decryptor = cipher.decryptor() + + # Decrypt data + return decryptor.update(ciphertext) + decryptor.finalize() + + +class AES128CBC(AESCrypto): + """AES-128-CBC cryptographic algorithm""" + + def __init__(self, key, iv=None): + """ + Initialize AES-128-CBC cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + """ + super().__init__(key, iv, key_length=16) + + +class AES256CBC(AESCrypto): + """AES-256-CBC cryptographic algorithm""" + + def __init__(self, key, iv=None): + """ + Initialize AES-256-CBC cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + """ + super().__init__(key, iv, key_length=32) + + +class SM4CBC(BaseCrypto): + """SM4-CBC cryptographic algorithm using cryptography library for better performance""" + + def __init__(self, key, iv=None): + """ + Initialize SM4-CBC cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + """ + super().__init__(key, iv, block_size=16, key_length=16, iv_length=16) + + def _encrypt(self, padded_data, iv): + """SM4 encryption implementation using cryptography library""" + # Create encryptor + cipher = Cipher( + algorithms.SM4(self.key), + modes.CBC(iv), + backend=default_backend() + ) + encryptor = cipher.encryptor() + + # Encrypt data + return encryptor.update(padded_data) + encryptor.finalize() + + def _decrypt(self, ciphertext, iv): + """SM4 decryption implementation using cryptography library""" + # Create decryptor + cipher = Cipher( + algorithms.SM4(self.key), + modes.CBC(iv), + backend=default_backend() + ) + decryptor = cipher.decryptor() + + # Decrypt data + return decryptor.update(ciphertext) + decryptor.finalize() + + +class CryptoUtil: + """Cryptographic utility class, using factory pattern to create cryptographic algorithm instances""" + + # Supported cryptographic algorithms mapping + SUPPORTED_ALGORITHMS = { + "aes-128-cbc": AES128CBC, + "aes-256-cbc": AES256CBC, + "sm4-cbc": SM4CBC + } + + def __init__(self, algorithm="aes-256-cbc", key=None, iv=None): + """ + Initialize cryptographic utility + + Args: + algorithm: Cryptographic algorithm, default is aes-256-cbc + key: Encryption key, uses RAGFLOW_CRYPTO_KEY environment variable if None + iv: Initialization vector, automatically generated if None + """ + if algorithm not in self.SUPPORTED_ALGORITHMS: + raise ValueError(f"Unsupported algorithm: {algorithm}") + + if not key: + raise ValueError("Encryption key not provided and RAGFLOW_CRYPTO_KEY environment variable not set") + + # Create cryptographic algorithm instance + self.algorithm_name = algorithm + self.crypto = self.SUPPORTED_ALGORITHMS[algorithm](key=key, iv=iv) + + def encrypt(self, data): + """ + Encrypt data + + Args: + data: Data to encrypt (bytes) + + Returns: + Encrypted data (bytes) + """ + # import time + # start_time = time.time() + encrypted = self.crypto.encrypt(data) + # end_time = time.time() + # logging.info(f"Encryption completed, data length: {len(data)} bytes, time: {(end_time - start_time)*1000:.2f} ms") + return encrypted + + def decrypt(self, encrypted_data): + """ + Decrypt data + + Args: + encrypted_data: Encrypted data (bytes) + + Returns: + Decrypted data (bytes) + """ + # import time + # start_time = time.time() + decrypted = self.crypto.decrypt(encrypted_data) + # end_time = time.time() + # logging.info(f"Decryption completed, data length: {len(encrypted_data)} bytes, time: {(end_time - start_time)*1000:.2f} ms") + return decrypted + + +# Test code +if __name__ == "__main__": + # Test AES encryption + crypto = CryptoUtil(algorithm="aes-256-cbc", key="test_key_123456") + test_data = b"Hello, RAGFlow! This is a test for encryption." + + encrypted = crypto.encrypt(test_data) + decrypted = crypto.decrypt(encrypted) + + print("AES Test:") + print(f"Original: {test_data}") + print(f"Encrypted: {encrypted}") + print(f"Decrypted: {decrypted}") + print(f"Success: {test_data == decrypted}") + print() + + # Test SM4 encryption + try: + crypto_sm4 = CryptoUtil(algorithm="sm4-cbc", key="test_key_123456") + encrypted_sm4 = crypto_sm4.encrypt(test_data) + decrypted_sm4 = crypto_sm4.decrypt(encrypted_sm4) + + print("SM4 Test:") + print(f"Original: {test_data}") + print(f"Encrypted: {encrypted_sm4}") + print(f"Decrypted: {decrypted_sm4}") + print(f"Success: {test_data == decrypted_sm4}") + except Exception as e: + print(f"SM4 Test Failed: {e}") + import traceback + traceback.print_exc() + + # Test with specific algorithm classes directly + print("\nDirect Algorithm Class Test:") + + # Test AES-128-CBC + aes128 = AES128CBC(key="test_key_123456") + encrypted_aes128 = aes128.encrypt(test_data) + decrypted_aes128 = aes128.decrypt(encrypted_aes128) + print(f"AES-128-CBC test: {'passed' if decrypted_aes128 == test_data else 'failed'}") + + # Test AES-256-CBC + aes256 = AES256CBC(key="test_key_123456") + encrypted_aes256 = aes256.encrypt(test_data) + decrypted_aes256 = aes256.decrypt(encrypted_aes256) + print(f"AES-256-CBC test: {'passed' if decrypted_aes256 == test_data else 'failed'}") + + # Test SM4-CBC + try: + sm4 = SM4CBC(key="test_key_123456") + encrypted_sm4 = sm4.encrypt(test_data) + decrypted_sm4 = sm4.decrypt(encrypted_sm4) + print(f"SM4-CBC test: {'passed' if decrypted_sm4 == test_data else 'failed'}") + except Exception as e: + print(f"SM4-CBC test failed: {e}") diff --git a/common/data_source/__init__.py b/common/data_source/__init__.py index 0802a52852a..2619e779dcd 100644 --- a/common/data_source/__init__.py +++ b/common/data_source/__init__.py @@ -1,6 +1,26 @@ """ Thanks to https://github.com/onyx-dot-app/onyx + +Content of this directory is under the "MIT Expat" license as defined below. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. """ from .blob_connector import BlobStorageConnector @@ -11,9 +31,14 @@ from .discord_connector import DiscordConnector from .dropbox_connector import DropboxConnector from .google_drive.connector import GoogleDriveConnector -from .jira_connector import JiraConnector +from .jira.connector import JiraConnector from .sharepoint_connector import SharePointConnector from .teams_connector import TeamsConnector +from .webdav_connector import WebDAVConnector +from .moodle_connector import MoodleConnector +from .airtable_connector import AirtableConnector +from .asana_connector import AsanaConnector +from .imap_connector import ImapConnector from .config import BlobType, DocumentSource from .models import Document, TextSection, ImageSection, BasicExpertInfo from .exceptions import ( @@ -36,6 +61,8 @@ "JiraConnector", "SharePointConnector", "TeamsConnector", + "WebDAVConnector", + "MoodleConnector", "BlobType", "DocumentSource", "Document", @@ -46,5 +73,8 @@ "ConnectorValidationError", "CredentialExpiredError", "InsufficientPermissionsError", - "UnexpectedValidationError" + "UnexpectedValidationError", + "AirtableConnector", + "AsanaConnector", + "ImapConnector" ] diff --git a/common/data_source/airtable_connector.py b/common/data_source/airtable_connector.py new file mode 100644 index 00000000000..6f0b5a930cd --- /dev/null +++ b/common/data_source/airtable_connector.py @@ -0,0 +1,169 @@ +from datetime import datetime, timezone +import logging +from typing import Any, Generator + +import requests + +from pyairtable import Api as AirtableApi + +from common.data_source.config import AIRTABLE_CONNECTOR_SIZE_THRESHOLD, INDEX_BATCH_SIZE, DocumentSource +from common.data_source.exceptions import ConnectorMissingCredentialError +from common.data_source.interfaces import LoadConnector, PollConnector +from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch +from common.data_source.utils import extract_size_bytes, get_file_ext + +class AirtableClientNotSetUpError(PermissionError): + def __init__(self) -> None: + super().__init__( + "Airtable client is not set up. Did you forget to call load_credentials()?" + ) + + +class AirtableConnector(LoadConnector, PollConnector): + """ + Lightweight Airtable connector. + + This connector ingests Airtable attachments as raw blobs without + parsing file content or generating text/image sections. + """ + + def __init__( + self, + base_id: str, + table_name_or_id: str, + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + self.base_id = base_id + self.table_name_or_id = table_name_or_id + self.batch_size = batch_size + self._airtable_client: AirtableApi | None = None + self.size_threshold = AIRTABLE_CONNECTOR_SIZE_THRESHOLD + + # ------------------------- + # Credentials + # ------------------------- + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self._airtable_client = AirtableApi(credentials["airtable_access_token"]) + return None + + @property + def airtable_client(self) -> AirtableApi: + if not self._airtable_client: + raise AirtableClientNotSetUpError() + return self._airtable_client + + # ------------------------- + # Core logic + # ------------------------- + def load_from_state(self) -> GenerateDocumentsOutput: + """ + Fetch all Airtable records and ingest attachments as raw blobs. + + Each attachment is converted into a single Document(blob=...). + """ + if not self._airtable_client: + raise ConnectorMissingCredentialError("Airtable credentials not loaded") + + table = self.airtable_client.table(self.base_id, self.table_name_or_id) + records = table.all() + + logging.info( + f"Starting Airtable blob ingestion for table {self.table_name_or_id}, " + f"{len(records)} records found." + ) + + batch: list[Document] = [] + + for record in records: + print(record) + record_id = record.get("id") + fields = record.get("fields", {}) + created_time = record.get("createdTime") + + for field_value in fields.values(): + # We only care about attachment fields (lists of dicts with url/filename) + if not isinstance(field_value, list): + continue + + for attachment in field_value: + url = attachment.get("url") + filename = attachment.get("filename") + attachment_id = attachment.get("id") + + if not url or not filename or not attachment_id: + continue + + try: + resp = requests.get(url, timeout=30) + resp.raise_for_status() + content = resp.content + except Exception: + logging.exception( + f"Failed to download attachment {filename} " + f"(record={record_id})" + ) + continue + size_bytes = extract_size_bytes(attachment) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + logging.warning( + f"{filename} exceeds size threshold of {self.size_threshold}. Skipping." + ) + continue + batch.append( + Document( + id=f"airtable:{record_id}:{attachment_id}", + blob=content, + source=DocumentSource.AIRTABLE, + semantic_identifier=filename, + extension=get_file_ext(filename), + size_bytes=size_bytes if size_bytes else 0, + doc_updated_at=datetime.strptime(created_time, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc) + ) + ) + + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]: + """Poll source to get documents""" + start_dt = datetime.fromtimestamp(start, tz=timezone.utc) + end_dt = datetime.fromtimestamp(end, tz=timezone.utc) + + for batch in self.load_from_state(): + filtered: list[Document] = [] + + for doc in batch: + if not doc.doc_updated_at: + continue + + doc_dt = doc.doc_updated_at.astimezone(timezone.utc) + + if start_dt <= doc_dt < end_dt: + filtered.append(doc) + + if filtered: + yield filtered + +if __name__ == "__main__": + import os + + logging.basicConfig(level=logging.DEBUG) + connector = AirtableConnector("xxx","xxx") + connector.load_credentials({"airtable_access_token": os.environ.get("AIRTABLE_ACCESS_TOKEN")}) + connector.validate_connector_settings() + document_batches = connector.load_from_state() + try: + first_batch = next(document_batches) + print(f"Loaded {len(first_batch)} documents in first batch.") + for doc in first_batch: + print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)") + except StopIteration: + print("No documents available in Dropbox.") \ No newline at end of file diff --git a/common/data_source/asana_connector.py b/common/data_source/asana_connector.py new file mode 100644 index 00000000000..1dddcb6df2b --- /dev/null +++ b/common/data_source/asana_connector.py @@ -0,0 +1,454 @@ +from collections.abc import Iterator +import time +from datetime import datetime +import logging +from typing import Any, Dict +import asana +import requests +from common.data_source.config import CONTINUE_ON_CONNECTOR_FAILURE, INDEX_BATCH_SIZE, DocumentSource +from common.data_source.interfaces import LoadConnector, PollConnector +from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch +from common.data_source.utils import extract_size_bytes, get_file_ext + + + +# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints +class AsanaTask: + def __init__( + self, + id: str, + title: str, + text: str, + link: str, + last_modified: datetime, + project_gid: str, + project_name: str, + ) -> None: + self.id = id + self.title = title + self.text = text + self.link = link + self.last_modified = last_modified + self.project_gid = project_gid + self.project_name = project_name + + def __str__(self) -> str: + return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}" + + +class AsanaAPI: + def __init__( + self, api_token: str, workspace_gid: str, team_gid: str | None + ) -> None: + self._user = None + self.workspace_gid = workspace_gid + self.team_gid = team_gid + + self.configuration = asana.Configuration() + self.api_client = asana.ApiClient(self.configuration) + self.tasks_api = asana.TasksApi(self.api_client) + self.attachments_api = asana.AttachmentsApi(self.api_client) + self.stories_api = asana.StoriesApi(self.api_client) + self.users_api = asana.UsersApi(self.api_client) + self.project_api = asana.ProjectsApi(self.api_client) + self.project_memberships_api = asana.ProjectMembershipsApi(self.api_client) + self.workspaces_api = asana.WorkspacesApi(self.api_client) + + self.api_error_count = 0 + self.configuration.access_token = api_token + self.task_count = 0 + + def get_tasks( + self, project_gids: list[str] | None, start_date: str + ) -> Iterator[AsanaTask]: + """Get all tasks from the projects with the given gids that were modified since the given date. + If project_gids is None, get all tasks from all projects in the workspace.""" + logging.info("Starting to fetch Asana projects") + projects = self.project_api.get_projects( + opts={ + "workspace": self.workspace_gid, + "opt_fields": "gid,name,archived,modified_at", + } + ) + start_seconds = int(time.mktime(datetime.now().timetuple())) + projects_list = [] + project_count = 0 + for project_info in projects: + project_gid = project_info["gid"] + if project_gids is None or project_gid in project_gids: + projects_list.append(project_gid) + else: + logging.debug( + f"Skipping project: {project_gid} - not in accepted project_gids" + ) + project_count += 1 + if project_count % 100 == 0: + logging.info(f"Processed {project_count} projects") + logging.info(f"Found {len(projects_list)} projects to process") + for project_gid in projects_list: + for task in self._get_tasks_for_project( + project_gid, start_date, start_seconds + ): + yield task + logging.info(f"Completed fetching {self.task_count} tasks from Asana") + if self.api_error_count > 0: + logging.warning( + f"Encountered {self.api_error_count} API errors during task fetching" + ) + + def _get_tasks_for_project( + self, project_gid: str, start_date: str, start_seconds: int + ) -> Iterator[AsanaTask]: + project = self.project_api.get_project(project_gid, opts={}) + project_name = project.get("name", project_gid) + team = project.get("team") or {} + team_gid = team.get("gid") + + if project.get("archived"): + logging.info(f"Skipping archived project: {project_name} ({project_gid})") + return + if not team_gid: + logging.info( + f"Skipping project without a team: {project_name} ({project_gid})" + ) + return + if project.get("privacy_setting") == "private": + if self.team_gid and team_gid != self.team_gid: + logging.info( + f"Skipping private project not in configured team: {project_name} ({project_gid})" + ) + return + logging.info( + f"Processing private project in configured team: {project_name} ({project_gid})" + ) + + simple_start_date = start_date.split(".")[0].split("+")[0] + logging.info( + f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})" + ) + + opts = { + "opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at," + "created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes," + "modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on," + "workspace,permalink_url", + "modified_since": start_date, + } + tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts) + for data in tasks_from_api: + self.task_count += 1 + if self.task_count % 10 == 0: + end_seconds = time.mktime(datetime.now().timetuple()) + runtime_seconds = end_seconds - start_seconds + if runtime_seconds > 0: + logging.info( + f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds " + f"({self.task_count / runtime_seconds:.2f} tasks/second)" + ) + + logging.debug(f"Processing Asana task: {data['name']}") + + text = self._construct_task_text(data) + + try: + text += self._fetch_and_add_comments(data["gid"]) + + last_modified_date = self.format_date(data["modified_at"]) + text += f"Last modified: {last_modified_date}\n" + + task = AsanaTask( + id=data["gid"], + title=data["name"], + text=text, + link=data["permalink_url"], + last_modified=datetime.fromisoformat(data["modified_at"]), + project_gid=project_gid, + project_name=project_name, + ) + yield task + except Exception: + logging.error( + f"Error processing task {data['gid']} in project {project_gid}", + exc_info=True, + ) + self.api_error_count += 1 + + def _construct_task_text(self, data: Dict) -> str: + text = f"{data['name']}\n\n" + + if data["notes"]: + text += f"{data['notes']}\n\n" + + if data["created_by"] and data["created_by"]["gid"]: + creator = self.get_user(data["created_by"]["gid"])["name"] + created_date = self.format_date(data["created_at"]) + text += f"Created by: {creator} on {created_date}\n" + + if data["due_on"]: + due_date = self.format_date(data["due_on"]) + text += f"Due date: {due_date}\n" + + if data["completed_at"]: + completed_date = self.format_date(data["completed_at"]) + text += f"Completed on: {completed_date}\n" + + text += "\n" + return text + + def _fetch_and_add_comments(self, task_gid: str) -> str: + text = "" + stories_opts: Dict[str, str] = {} + story_start = time.time() + stories = self.stories_api.get_stories_for_task(task_gid, stories_opts) + + story_count = 0 + comment_count = 0 + + for story in stories: + story_count += 1 + if story["resource_subtype"] == "comment_added": + comment = self.stories_api.get_story( + story["gid"], opts={"opt_fields": "text,created_by,created_at"} + ) + commenter = self.get_user(comment["created_by"]["gid"])["name"] + text += f"Comment by {commenter}: {comment['text']}\n\n" + comment_count += 1 + + story_duration = time.time() - story_start + logging.debug( + f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds" + ) + + return text + + def get_attachments(self, task_gid: str) -> list[dict]: + """ + Fetch full attachment info (including download_url) for a task. + """ + attachments: list[dict] = [] + + try: + # Step 1: list attachment compact records + for att in self.attachments_api.get_attachments_for_object( + parent=task_gid, + opts={} + ): + gid = att.get("gid") + if not gid: + continue + + try: + # Step 2: expand to full attachment + full = self.attachments_api.get_attachment( + attachment_gid=gid, + opts={ + "opt_fields": "name,download_url,size,created_at" + } + ) + + if full.get("download_url"): + attachments.append(full) + + except Exception: + logging.exception( + f"Failed to fetch attachment detail {gid} for task {task_gid}" + ) + self.api_error_count += 1 + + except Exception: + logging.exception(f"Failed to list attachments for task {task_gid}") + self.api_error_count += 1 + + return attachments + + def get_accessible_emails( + self, + workspace_id: str, + project_ids: list[str] | None, + team_id: str | None, + ): + + ws_users = self.users_api.get_users( + opts={ + "workspace": workspace_id, + "opt_fields": "gid,name,email" + } + ) + + workspace_users = { + u["gid"]: u.get("email") + for u in ws_users + if u.get("email") + } + + if not project_ids: + return set(workspace_users.values()) + + + project_emails = set() + + for pid in project_ids: + project = self.project_api.get_project( + pid, + opts={"opt_fields": "team,privacy_setting"} + ) + + if project["privacy_setting"] == "private": + if team_id and project.get("team", {}).get("gid") != team_id: + continue + + memberships = self.project_memberships_api.get_project_membership( + pid, + opts={"opt_fields": "user.gid,user.email"} + ) + + for m in memberships: + email = m["user"].get("email") + if email: + project_emails.add(email) + + return project_emails + + def get_user(self, user_gid: str) -> Dict: + if self._user is not None: + return self._user + self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"}) + + if not self._user: + logging.warning(f"Unable to fetch user information for user_gid: {user_gid}") + return {"name": "Unknown"} + return self._user + + def format_date(self, date_str: str) -> str: + date = datetime.fromisoformat(date_str) + return time.strftime("%Y-%m-%d", date.timetuple()) + + def get_time(self) -> str: + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + + +class AsanaConnector(LoadConnector, PollConnector): + def __init__( + self, + asana_workspace_id: str, + asana_project_ids: str | None = None, + asana_team_id: str | None = None, + batch_size: int = INDEX_BATCH_SIZE, + continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, + ) -> None: + self.workspace_id = asana_workspace_id + self.project_ids_to_index: list[str] | None = ( + asana_project_ids.split(",") if asana_project_ids else None + ) + self.asana_team_id = asana_team_id if asana_team_id else None + self.batch_size = batch_size + self.continue_on_failure = continue_on_failure + self.size_threshold = None + logging.info( + f"AsanaConnector initialized with workspace_id: {asana_workspace_id}" + ) + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.api_token = credentials["asana_api_token_secret"] + self.asana_client = AsanaAPI( + api_token=self.api_token, + workspace_gid=self.workspace_id, + team_gid=self.asana_team_id, + ) + self.workspace_users_email = self.asana_client.get_accessible_emails(self.workspace_id, self.project_ids_to_index, self.asana_team_id) + logging.info("Asana credentials loaded and API client initialized") + return None + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None + ) -> GenerateDocumentsOutput: + start_time = datetime.fromtimestamp(start).isoformat() + logging.info(f"Starting Asana poll from {start_time}") + docs_batch: list[Document] = [] + tasks = self.asana_client.get_tasks(self.project_ids_to_index, start_time) + for task in tasks: + docs = self._task_to_documents(task) + docs_batch.extend(docs) + + if len(docs_batch) >= self.batch_size: + logging.info(f"Yielding batch of {len(docs_batch)} documents") + yield docs_batch + docs_batch = [] + + if docs_batch: + logging.info(f"Yielding final batch of {len(docs_batch)} documents") + yield docs_batch + + logging.info("Asana poll completed") + + def load_from_state(self) -> GenerateDocumentsOutput: + logging.info("Starting full index of all Asana tasks") + return self.poll_source(start=0, end=None) + + def _task_to_documents(self, task: AsanaTask) -> list[Document]: + docs: list[Document] = [] + + attachments = self.asana_client.get_attachments(task.id) + + for att in attachments: + try: + resp = requests.get(att["download_url"], timeout=30) + resp.raise_for_status() + file_blob = resp.content + filename = att.get("name", "attachment") + size_bytes = extract_size_bytes(att) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + logging.warning( + f"{filename} exceeds size threshold of {self.size_threshold}. Skipping." + ) + continue + docs.append( + Document( + id=f"asana:{task.id}:{att['gid']}", + blob=file_blob, + extension=get_file_ext(filename) or "", + size_bytes=size_bytes, + doc_updated_at=task.last_modified, + source=DocumentSource.ASANA, + semantic_identifier=filename, + primary_owners=list(self.workspace_users_email), + ) + ) + except Exception: + logging.exception( + f"Failed to download attachment {att.get('gid')} for task {task.id}" + ) + + return docs + + + +if __name__ == "__main__": + import time + import os + + logging.info("Starting Asana connector test") + connector = AsanaConnector( + os.environ["WORKSPACE_ID"], + os.environ["PROJECT_IDS"], + os.environ["TEAM_ID"], + ) + connector.load_credentials( + { + "asana_api_token_secret": os.environ["API_TOKEN"], + } + ) + logging.info("Loading all documents from Asana") + all_docs = connector.load_from_state() + current = time.time() + one_day_ago = current - 24 * 60 * 60 # 1 day + logging.info("Polling for documents updated in the last 24 hours") + latest_docs = connector.poll_source(one_day_ago, current) + for docs in all_docs: + for doc in docs: + print(doc.id) + logging.info("Asana connector test completed") \ No newline at end of file diff --git a/common/data_source/blob_connector.py b/common/data_source/blob_connector.py index 0bec7cbe643..1ab39189d79 100644 --- a/common/data_source/blob_connector.py +++ b/common/data_source/blob_connector.py @@ -56,7 +56,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None # Validate credentials if self.bucket_type == BlobType.R2: - if not all( + if not all( credentials.get(key) for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"] ): @@ -64,15 +64,23 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None elif self.bucket_type == BlobType.S3: authentication_method = credentials.get("authentication_method", "access_key") + if authentication_method == "access_key": if not all( credentials.get(key) for key in ["aws_access_key_id", "aws_secret_access_key"] ): raise ConnectorMissingCredentialError("Amazon S3") + elif authentication_method == "iam_role": if not credentials.get("aws_role_arn"): raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required") + + elif authentication_method == "assume_role": + pass + + else: + raise ConnectorMissingCredentialError("Unsupported S3 authentication method") elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE: if not all( @@ -87,6 +95,13 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None ): raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure") + elif self.bucket_type == BlobType.S3_COMPATIBLE: + if not all( + credentials.get(key) + for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key", "addressing_style"] + ): + raise ConnectorMissingCredentialError("S3 Compatible Storage") + else: raise ValueError(f"Unsupported bucket type: {self.bucket_type}") @@ -113,55 +128,72 @@ def _yield_blob_objects( paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix) - batch: list[Document] = [] + # Collect all objects first to count filename occurrences + all_objects = [] for page in pages: if "Contents" not in page: continue - for obj in page["Contents"]: if obj["Key"].endswith("/"): continue - last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) + if start < last_modified <= end: + all_objects.append(obj) + + # Count filename occurrences to determine which need full paths + filename_counts: dict[str, int] = {} + for obj in all_objects: + file_name = os.path.basename(obj["Key"]) + filename_counts[file_name] = filename_counts.get(file_name, 0) + 1 - if not (start < last_modified <= end): + batch: list[Document] = [] + for obj in all_objects: + last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) + file_name = os.path.basename(obj["Key"]) + key = obj["Key"] + + size_bytes = extract_size_bytes(obj) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + logging.warning( + f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." + ) + continue + + try: + blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold) + if blob is None: continue - file_name = os.path.basename(obj["Key"]) - key = obj["Key"] - - size_bytes = extract_size_bytes(obj) - if ( - self.size_threshold is not None - and isinstance(size_bytes, int) - and size_bytes > self.size_threshold - ): - logging.warning( - f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." + # Use full path only if filename appears multiple times + if filename_counts.get(file_name, 0) > 1: + relative_path = key + if self.prefix and key.startswith(self.prefix): + relative_path = key[len(self.prefix):] + semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name + else: + semantic_id = file_name + + batch.append( + Document( + id=f"{self.bucket_type}:{self.bucket_name}:{key}", + blob=blob, + source=DocumentSource(self.bucket_type.value), + semantic_identifier=semantic_id, + extension=get_file_ext(file_name), + doc_updated_at=last_modified, + size_bytes=size_bytes if size_bytes else 0 ) - continue - try: - blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold) - if blob is None: - continue - - batch.append( - Document( - id=f"{self.bucket_type}:{self.bucket_name}:{key}", - blob=blob, - source=DocumentSource(self.bucket_type.value), - semantic_identifier=file_name, - extension=get_file_ext(file_name), - doc_updated_at=last_modified, - size_bytes=size_bytes if size_bytes else 0 - ) - ) - if len(batch) == self.batch_size: - yield batch - batch = [] + ) + if len(batch) == self.batch_size: + yield batch + batch = [] - except Exception: - logging.exception(f"Error decoding object {key}") + except Exception: + logging.exception(f"Error decoding object {key}") if batch: yield batch @@ -269,4 +301,4 @@ def validate_connector_settings(self) -> None: except ConnectorMissingCredentialError as e: print(f"Error: {e}") except Exception as e: - print(f"An unexpected error occurred: {e}") \ No newline at end of file + print(f"An unexpected error occurred: {e}") diff --git a/common/data_source/box_connector.py b/common/data_source/box_connector.py new file mode 100644 index 00000000000..3006e709c9c --- /dev/null +++ b/common/data_source/box_connector.py @@ -0,0 +1,162 @@ +"""Box connector""" +import logging +from datetime import datetime, timezone +from typing import Any + +from box_sdk_gen import BoxClient +from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, +) +from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch +from common.data_source.models import Document, GenerateDocumentsOutput +from common.data_source.utils import get_file_ext + +class BoxConnector(LoadConnector, PollConnector): + def __init__(self, folder_id: str, batch_size: int = INDEX_BATCH_SIZE, use_marker: bool = True) -> None: + self.batch_size = batch_size + self.folder_id = "0" if not folder_id else folder_id + self.use_marker = use_marker + + + def load_credentials(self, auth: Any): + self.box_client = BoxClient(auth=auth) + return None + + + def validate_connector_settings(self): + if self.box_client is None: + raise ConnectorMissingCredentialError("Box") + + try: + self.box_client.users.get_user_me() + except Exception as e: + logging.exception("[Box]: Failed to validate Box credentials") + raise ConnectorValidationError(f"Unexpected error during Box settings validation: {e}") + + + def _yield_files_recursive( + self, + folder_id, + start: SecondsSinceUnixEpoch | None, + end: SecondsSinceUnixEpoch | None + ) -> GenerateDocumentsOutput: + + if self.box_client is None: + raise ConnectorMissingCredentialError("Box") + + result = self.box_client.folders.get_folder_items( + folder_id=folder_id, + limit=self.batch_size, + usemarker=self.use_marker + ) + + while True: + batch: list[Document] = [] + for entry in result.entries: + if entry.type == 'file' : + file = self.box_client.files.get_file_by_id( + entry.id + ) + raw_time = ( + getattr(file, "created_at", None) + or getattr(file, "content_created_at", None) + ) + + if raw_time: + modified_time = self._box_datetime_to_epoch_seconds(raw_time) + if start is not None and modified_time <= start: + continue + if end is not None and modified_time > end: + continue + + content_bytes = self.box_client.downloads.download_file(file.id) + + batch.append( + Document( + id=f"box:{file.id}", + blob=content_bytes.read(), + source=DocumentSource.BOX, + semantic_identifier=file.name, + extension=get_file_ext(file.name), + doc_updated_at=modified_time, + size_bytes=file.size, + metadata=file.metadata + ) + ) + elif entry.type == 'folder': + yield from self._yield_files_recursive(folder_id=entry.id, start=start, end=end) + + if batch: + yield batch + + if not result.next_marker: + break + + result = self.box_client.folders.get_folder_items( + folder_id=folder_id, + limit=self.batch_size, + marker=result.next_marker, + usemarker=True + ) + + + def _box_datetime_to_epoch_seconds(self, dt: datetime) -> SecondsSinceUnixEpoch: + """Convert a Box SDK datetime to Unix epoch seconds (UTC). + Only supports datetime; any non-datetime should be filtered out by caller. + """ + if not isinstance(dt, datetime): + raise TypeError(f"box_datetime_to_epoch_seconds expects datetime, got {type(dt)}") + + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: + dt = dt.astimezone(timezone.utc) + + return SecondsSinceUnixEpoch(int(dt.timestamp())) + + + def poll_source(self, start, end): + return self._yield_files_recursive(folder_id=self.folder_id, start=start, end=end) + + + def load_from_state(self): + return self._yield_files_recursive(folder_id=self.folder_id, start=None, end=None) + + +# from flask import Flask, request, redirect + +# from box_sdk_gen import BoxClient, BoxOAuth, OAuthConfig, GetAuthorizeUrlOptions + +# app = Flask(__name__) + +# AUTH = BoxOAuth( +# OAuthConfig(client_id="8suvn9ik7qezsq2dub0ye6ubox61081z", client_secret="QScvhLgBcZrb2ck1QP1ovkutpRhI2QcN") +# ) + + +# @app.route("/") +# def get_auth(): +# auth_url = AUTH.get_authorize_url( +# options=GetAuthorizeUrlOptions(redirect_uri="http://localhost:4999/oauth2callback") +# ) +# return redirect(auth_url, code=302) + + +# @app.route("/oauth2callback") +# def callback(): +# AUTH.get_tokens_authorization_code_grant(request.args.get("code")) +# box = BoxConnector() +# box.load_credentials({"auth": AUTH}) + +# lst = [] +# for file in box.load_from_state(): +# for f in file: +# lst.append(f.semantic_identifier) + +# return lst + +if __name__ == "__main__": + pass + # app.run(port=4999) \ No newline at end of file diff --git a/common/data_source/config.py b/common/data_source/config.py index 02684dbacc9..bca13b5bed6 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -13,6 +13,7 @@ def get_current_tz_offset() -> int: return round(time_diff.total_seconds() / 3600) +ONE_MINUTE = 60 ONE_HOUR = 3600 ONE_DAY = ONE_HOUR * 24 @@ -31,6 +32,7 @@ class BlobType(str, Enum): R2 = "r2" GOOGLE_CLOUD_STORAGE = "google_cloud_storage" OCI_STORAGE = "oci_storage" + S3_COMPATIBLE = "s3_compatible" class DocumentSource(str, Enum): @@ -42,11 +44,22 @@ class DocumentSource(str, Enum): OCI_STORAGE = "oci_storage" SLACK = "slack" CONFLUENCE = "confluence" + JIRA = "jira" GOOGLE_DRIVE = "google_drive" GMAIL = "gmail" DISCORD = "discord" - - + WEBDAV = "webdav" + MOODLE = "moodle" + S3_COMPATIBLE = "s3_compatible" + DROPBOX = "dropbox" + BOX = "box" + AIRTABLE = "airtable" + ASANA = "asana" + GITHUB = "github" + GITLAB = "gitlab" + IMAP = "imap" + + class FileOrigin(str, Enum): """File origins""" CONNECTOR = "connector" @@ -76,6 +89,7 @@ class FileOrigin(str, Enum): "space", "metadata.labels", "history.lastUpdated", + "ancestors", ] @@ -178,6 +192,21 @@ class FileOrigin(str, Enum): os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) ) +JIRA_CONNECTOR_LABELS_TO_SKIP = [ + ignored_tag + for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") + if ignored_tag +] +JIRA_CONNECTOR_MAX_TICKET_SIZE = int( + os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024) +) +JIRA_SYNC_TIME_BUFFER_SECONDS = int( + os.environ.get("JIRA_SYNC_TIME_BUFFER_SECONDS", ONE_MINUTE) +) +JIRA_TIMEZONE_OFFSET = float( + os.environ.get("JIRA_TIMEZONE_OFFSET", get_current_tz_offset()) +) + OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "") OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "") OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get( @@ -195,6 +224,7 @@ class FileOrigin(str, Enum): "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" ) GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback") +GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback") CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token" RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() @@ -204,6 +234,9 @@ class FileOrigin(str, Enum): _PROBLEMATIC_EXPANSIONS = "body.storage.value" _REPLACEMENT_EXPANSIONS = "body.view.value" +BOX_WEB_OAUTH_REDIRECT_URI = os.environ.get("BOX_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/box/oauth/web/callback") + +GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None class HtmlBasedConnectorTransformLinksStrategy(str, Enum): # remove links entirely @@ -226,6 +259,18 @@ class HtmlBasedConnectorTransformLinksStrategy(str, Enum): "WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside" ).split(",") +AIRTABLE_CONNECTOR_SIZE_THRESHOLD = int( + os.environ.get("AIRTABLE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) +) + +ASANA_CONNECTOR_SIZE_THRESHOLD = int( + os.environ.get("ASANA_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) +) + +IMAP_CONNECTOR_SIZE_THRESHOLD = int( + os.environ.get("IMAP_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) +) + _USER_NOT_FOUND = "Unknown Confluence User" _COMMENT_EXPANSION_FIELDS = ["body.storage.value"] diff --git a/common/data_source/confluence_connector.py b/common/data_source/confluence_connector.py index aed16ad2b66..d2494c3de74 100644 --- a/common/data_source/confluence_connector.py +++ b/common/data_source/confluence_connector.py @@ -126,7 +126,7 @@ def __init__( def _renew_credentials(self) -> tuple[dict[str, Any], bool]: """credential_json - the current json credentials Returns a tuple - 1. The up to date credentials + 1. The up-to-date credentials 2. True if the credentials were updated This method is intended to be used within a distributed lock. @@ -179,14 +179,14 @@ def _renew_credentials(self) -> tuple[dict[str, Any], bool]: credential_json["confluence_refresh_token"], ) - # store the new credentials to redis and to the db thru the provider - # redis: we use a 5 min TTL because we are given a 10 minute grace period + # store the new credentials to redis and to the db through the provider + # redis: we use a 5 min TTL because we are given a 10 minutes grace period # when keys are rotated. it's easier to expire the cached credentials # reasonably frequently rather than trying to handle strong synchronization # between the db and redis everywhere the credentials might be updated new_credential_str = json.dumps(new_credentials) self.redis_client.set( - self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL + self.credential_key, new_credential_str, exp=self.CREDENTIAL_TTL ) self._credentials_provider.set_credentials(new_credentials) @@ -690,7 +690,7 @@ def cql_paginate_all_expansions( ) -> Iterator[dict[str, Any]]: """ This function will paginate through the top level query first, then - paginate through all of the expansions. + paginate through all the expansions. """ def _traverse_and_update(data: dict | list) -> None: @@ -717,7 +717,7 @@ def paginated_cql_user_retrieval( """ The search/user endpoint can be used to fetch users. It's a separate endpoint from the content/search endpoint used only for users. - Otherwise it's very similar to the content/search endpoint. + It's very similar to the content/search endpoint. """ # this is needed since there is a live bug with Confluence Server/Data Center @@ -863,7 +863,7 @@ def get_user_email_from_username__server( # For now, we'll just return None and log a warning. This means # we will keep retrying to get the email every group sync. email = None - # We may want to just return a string that indicates failure so we dont + # We may want to just return a string that indicates failure so we don't # keep retrying # email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}" _USER_EMAIL_CACHE[user_name] = email @@ -912,7 +912,7 @@ def extract_text_from_confluence_html( confluence_object: dict[str, Any], fetched_titles: set[str], ) -> str: - """Parse a Confluence html page and replace the 'user Id' by the real + """Parse a Confluence html page and replace the 'user id' by the real User Display Name Args: @@ -1110,7 +1110,10 @@ def _make_attachment_link( ) -> str | None: download_link = "" - if "api.atlassian.com" in confluence_client.url: + from urllib.parse import urlparse + netloc =urlparse(confluence_client.url).hostname + if netloc == "api.atlassian.com" or (netloc and netloc.endswith(".api.atlassian.com")): + # if "api.atlassian.com" in confluence_client.url: # https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get if not parent_content_id: logging.warning( @@ -1308,6 +1311,9 @@ def __init__( self._low_timeout_confluence_client: OnyxConfluence | None = None self._fetched_titles: set[str] = set() self.allow_images = False + # Track document names to detect duplicates + self._document_name_counts: dict[str, int] = {} + self._document_name_paths: dict[str, list[str]] = {} # Remove trailing slash from wiki_base if present self.wiki_base = wiki_base.rstrip("/") @@ -1510,6 +1516,40 @@ def _convert_page_to_document( self.wiki_base, page["_links"]["webui"], self.is_cloud ) + # Build hierarchical path for semantic identifier + space_name = page.get("space", {}).get("name", "") + + # Build path from ancestors + path_parts = [] + if space_name: + path_parts.append(space_name) + + # Add ancestor pages to path if available + if "ancestors" in page and page["ancestors"]: + for ancestor in page["ancestors"]: + ancestor_title = ancestor.get("title", "") + if ancestor_title: + path_parts.append(ancestor_title) + + # Add current page title + path_parts.append(page_title) + + # Track page names for duplicate detection + full_path = " / ".join(path_parts) if len(path_parts) > 1 else page_title + + # Count occurrences of this page title + if page_title not in self._document_name_counts: + self._document_name_counts[page_title] = 0 + self._document_name_paths[page_title] = [] + self._document_name_counts[page_title] += 1 + self._document_name_paths[page_title].append(full_path) + + # Use simple name if no duplicates, otherwise use full path + if self._document_name_counts[page_title] == 1: + semantic_identifier = page_title + else: + semantic_identifier = full_path + # Get the page content page_content = extract_text_from_confluence_html( self.confluence_client, page, self._fetched_titles @@ -1556,12 +1596,13 @@ def _convert_page_to_document( return Document( id=page_url, source=DocumentSource.CONFLUENCE, - semantic_identifier=page_title, + semantic_identifier=semantic_identifier, extension=".html", # Confluence pages are HTML blob=page_content.encode("utf-8"), # Encode page content as bytes - size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes doc_updated_at=datetime_from_string(page["version"]["when"]), + size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes primary_owners=primary_owners if primary_owners else None, + metadata=metadata if metadata else None, ) except Exception as e: logging.error(f"Error converting page {page.get('id', 'unknown')}: {e}") @@ -1597,7 +1638,6 @@ def _fetch_page_attachments( expand=",".join(_ATTACHMENT_EXPANSION_FIELDS), ): media_type: str = attachment.get("metadata", {}).get("mediaType", "") - # TODO(rkuo): this check is partially redundant with validate_attachment_filetype # and checks in convert_attachment_to_content/process_attachment # but doing the check here avoids an unnecessary download. Due for refactoring. @@ -1665,6 +1705,34 @@ def _fetch_page_attachments( self.wiki_base, attachment["_links"]["webui"], self.is_cloud ) + # Build semantic identifier with space and page context + attachment_title = attachment.get("title", object_url) + space_name = page.get("space", {}).get("name", "") + page_title = page.get("title", "") + + # Create hierarchical name: Space / Page / Attachment + attachment_path_parts = [] + if space_name: + attachment_path_parts.append(space_name) + if page_title: + attachment_path_parts.append(page_title) + attachment_path_parts.append(attachment_title) + + full_attachment_path = " / ".join(attachment_path_parts) if len(attachment_path_parts) > 1 else attachment_title + + # Track attachment names for duplicate detection + if attachment_title not in self._document_name_counts: + self._document_name_counts[attachment_title] = 0 + self._document_name_paths[attachment_title] = [] + self._document_name_counts[attachment_title] += 1 + self._document_name_paths[attachment_title].append(full_attachment_path) + + # Use simple name if no duplicates, otherwise use full path + if self._document_name_counts[attachment_title] == 1: + attachment_semantic_identifier = attachment_title + else: + attachment_semantic_identifier = full_attachment_path + primary_owners: list[BasicExpertInfo] | None = None if "version" in attachment and "by" in attachment["version"]: author = attachment["version"]["by"] @@ -1676,11 +1744,12 @@ def _fetch_page_attachments( extension = Path(attachment.get("title", "")).suffix or ".unknown" + attachment_doc = Document( id=attachment_id, # sections=sections, source=DocumentSource.CONFLUENCE, - semantic_identifier=attachment.get("title", object_url), + semantic_identifier=attachment_semantic_identifier, extension=extension, blob=file_blob, size_bytes=len(file_blob), @@ -1737,7 +1806,7 @@ def _fetch_document_batches( start_ts, end, self.batch_size ) logging.debug(f"page_query_url: {page_query_url}") - + # store the next page start for confluence server, cursor for confluence cloud def store_next_page_url(next_page_url: str) -> None: checkpoint.next_page_url = next_page_url @@ -1788,6 +1857,7 @@ def _build_page_retrieval_url( cql_url = self.confluence_client.build_cql_url( page_query, expand=",".join(_PAGE_EXPANSION_FIELDS) ) + logging.info(f"[Confluence Connector] Building CQL URL {cql_url}") return update_param_in_path(cql_url, "limit", str(limit)) @override diff --git a/common/data_source/connector_runner.py b/common/data_source/connector_runner.py new file mode 100644 index 00000000000..d47d6512842 --- /dev/null +++ b/common/data_source/connector_runner.py @@ -0,0 +1,217 @@ +import sys +import time +import logging +from collections.abc import Generator +from datetime import datetime +from typing import Generic +from typing import TypeVar +from common.data_source.interfaces import ( + BaseConnector, + CheckpointedConnector, + CheckpointedConnectorWithPermSync, + CheckpointOutput, + LoadConnector, + PollConnector, +) +from common.data_source.models import ConnectorCheckpoint, ConnectorFailure, Document + + +TimeRange = tuple[datetime, datetime] + +CT = TypeVar("CT", bound=ConnectorCheckpoint) + + +def batched_doc_ids( + checkpoint_connector_generator: CheckpointOutput[CT], + batch_size: int, +) -> Generator[set[str], None, None]: + batch: set[str] = set() + for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()( + checkpoint_connector_generator + ): + if document is not None: + batch.add(document.id) + elif ( + failure and failure.failed_document and failure.failed_document.document_id + ): + batch.add(failure.failed_document.document_id) + + if len(batch) >= batch_size: + yield batch + batch = set() + if len(batch) > 0: + yield batch + + +class CheckpointOutputWrapper(Generic[CT]): + """ + Wraps a CheckpointOutput generator to give things back in a more digestible format, + specifically for Document outputs. + The connector format is easier for the connector implementor (e.g. it enforces exactly + one new checkpoint is returned AND that the checkpoint is at the end), thus the different + formats. + """ + + def __init__(self) -> None: + self.next_checkpoint: CT | None = None + + def __call__( + self, + checkpoint_connector_generator: CheckpointOutput[CT], + ) -> Generator[ + tuple[Document | None, ConnectorFailure | None, CT | None], + None, + None, + ]: + # grabs the final return value and stores it in the `next_checkpoint` variable + def _inner_wrapper( + checkpoint_connector_generator: CheckpointOutput[CT], + ) -> CheckpointOutput[CT]: + self.next_checkpoint = yield from checkpoint_connector_generator + return self.next_checkpoint # not used + + for document_or_failure in _inner_wrapper(checkpoint_connector_generator): + if isinstance(document_or_failure, Document): + yield document_or_failure, None, None + elif isinstance(document_or_failure, ConnectorFailure): + yield None, document_or_failure, None + else: + raise ValueError( + f"Invalid document_or_failure type: {type(document_or_failure)}" + ) + + if self.next_checkpoint is None: + raise RuntimeError( + "Checkpoint is None. This should never happen - the connector should always return a checkpoint." + ) + + yield None, None, self.next_checkpoint + + +class ConnectorRunner(Generic[CT]): + """ + Handles: + - Batching + - Additional exception logging + - Combining different connector types to a single interface + """ + + def __init__( + self, + connector: BaseConnector, + batch_size: int, + # cannot be True for non-checkpointed connectors + include_permissions: bool, + time_range: TimeRange | None = None, + ): + if not isinstance(connector, CheckpointedConnector) and include_permissions: + raise ValueError( + "include_permissions cannot be True for non-checkpointed connectors" + ) + + self.connector = connector + self.time_range = time_range + self.batch_size = batch_size + self.include_permissions = include_permissions + + self.doc_batch: list[Document] = [] + + def run(self, checkpoint: CT) -> Generator[ + tuple[list[Document] | None, ConnectorFailure | None, CT | None], + None, + None, + ]: + """Adds additional exception logging to the connector.""" + try: + if isinstance(self.connector, CheckpointedConnector): + if self.time_range is None: + raise ValueError("time_range is required for CheckpointedConnector") + + start = time.monotonic() + if self.include_permissions: + if not isinstance( + self.connector, CheckpointedConnectorWithPermSync + ): + raise ValueError( + "Connector does not support permission syncing" + ) + load_from_checkpoint = ( + self.connector.load_from_checkpoint_with_perm_sync + ) + else: + load_from_checkpoint = self.connector.load_from_checkpoint + checkpoint_connector_generator = load_from_checkpoint( + start=self.time_range[0].timestamp(), + end=self.time_range[1].timestamp(), + checkpoint=checkpoint, + ) + next_checkpoint: CT | None = None + # this is guaranteed to always run at least once with next_checkpoint being non-None + for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()( + checkpoint_connector_generator + ): + if document is not None and isinstance(document, Document): + self.doc_batch.append(document) + + if failure is not None: + yield None, failure, None + + if len(self.doc_batch) >= self.batch_size: + yield self.doc_batch, None, None + self.doc_batch = [] + + # yield remaining documents + if len(self.doc_batch) > 0: + yield self.doc_batch, None, None + self.doc_batch = [] + + yield None, None, next_checkpoint + + logging.debug( + f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint." + ) + + else: + finished_checkpoint = self.connector.build_dummy_checkpoint() + finished_checkpoint.has_more = False + + if isinstance(self.connector, PollConnector): + if self.time_range is None: + raise ValueError("time_range is required for PollConnector") + + for document_batch in self.connector.poll_source( + start=self.time_range[0].timestamp(), + end=self.time_range[1].timestamp(), + ): + yield document_batch, None, None + + yield None, None, finished_checkpoint + elif isinstance(self.connector, LoadConnector): + for document_batch in self.connector.load_from_state(): + yield document_batch, None, None + + yield None, None, finished_checkpoint + else: + raise ValueError(f"Invalid connector. type: {type(self.connector)}") + except Exception: + exc_type, _, exc_traceback = sys.exc_info() + + # Traverse the traceback to find the last frame where the exception was raised + tb = exc_traceback + if tb is None: + logging.error("No traceback found for exception") + raise + + while tb.tb_next: + tb = tb.tb_next # Move to the next frame in the traceback + + # Get the local variables from the frame where the exception occurred + local_vars = tb.tb_frame.f_locals + local_vars_str = "\n".join( + f"{key}: {value}" for key, value in local_vars.items() + ) + logging.error( + f"Error in connector. type: {exc_type};\n" + f"local_vars below -> \n{local_vars_str[:1024]}" + ) + raise \ No newline at end of file diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index 93a0477b078..e65a6324185 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -33,7 +33,7 @@ def _convert_message_to_document( metadata: dict[str, str | list[str]] = {} semantic_substring = "" - # Only messages from TextChannels will make it here but we have to check for it anyways + # Only messages from TextChannels will make it here, but we have to check for it anyway if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name): metadata["Channel"] = channel_name semantic_substring += f" in Channel: #{channel_name}" @@ -65,6 +65,7 @@ def _convert_message_to_document( blob=message.content.encode("utf-8"), extension=".txt", size_bytes=len(message.content.encode("utf-8")), + metadata=metadata if metadata else None, ) @@ -175,7 +176,7 @@ def _manage_async_retrieval( # parse requested_start_date_string to datetime pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None - # Set start_time to the later of start and pull_date, or whichever is provided + # Set start_time to the most recent of start and pull_date, or whichever is provided start_time = max(filter(None, [start, pull_date])) if start or pull_date else None end_time: datetime | None = end @@ -232,8 +233,8 @@ class DiscordConnector(LoadConnector, PollConnector): def __init__( self, - server_ids: list[str] = [], - channel_names: list[str] = [], + server_ids: list[str] | None = None, + channel_names: list[str] | None = None, # YYYY-MM-DD start_date: str | None = None, batch_size: int = INDEX_BATCH_SIZE, diff --git a/common/data_source/dropbox_connector.py b/common/data_source/dropbox_connector.py index fd349baa111..0e7131d8f3b 100644 --- a/common/data_source/dropbox_connector.py +++ b/common/data_source/dropbox_connector.py @@ -1,13 +1,24 @@ """Dropbox connector""" +import logging +from datetime import timezone from typing import Any from dropbox import Dropbox from dropbox.exceptions import ApiError, AuthError +from dropbox.files import FileMetadata, FolderMetadata -from common.data_source.config import INDEX_BATCH_SIZE -from common.data_source.exceptions import ConnectorValidationError, InsufficientPermissionsError, ConnectorMissingCredentialError +from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, + InsufficientPermissionsError, +) from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch +from common.data_source.models import Document, GenerateDocumentsOutput +from common.data_source.utils import get_file_ext + +logger = logging.getLogger(__name__) class DropboxConnector(LoadConnector, PollConnector): @@ -19,29 +30,29 @@ def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Load Dropbox credentials""" - try: - access_token = credentials.get("dropbox_access_token") - if not access_token: - raise ConnectorMissingCredentialError("Dropbox access token is required") - - self.dropbox_client = Dropbox(access_token) - return None - except Exception as e: - raise ConnectorMissingCredentialError(f"Dropbox: {e}") + access_token = credentials.get("dropbox_access_token") + if not access_token: + raise ConnectorMissingCredentialError("Dropbox access token is required") + + self.dropbox_client = Dropbox(access_token) + return None def validate_connector_settings(self) -> None: """Validate Dropbox connector settings""" - if not self.dropbox_client: + if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox") - + try: - # Test connection by getting current account info - self.dropbox_client.users_get_current_account() - except (AuthError, ApiError) as e: - if "invalid_access_token" in str(e).lower(): - raise InsufficientPermissionsError("Invalid Dropbox access token") - else: - raise ConnectorValidationError(f"Dropbox validation error: {e}") + self.dropbox_client.files_list_folder(path="", limit=1) + except AuthError as e: + logger.exception("[Dropbox]: Failed to validate Dropbox credentials") + raise ConnectorValidationError(f"Dropbox credential is invalid: {e}") + except ApiError as e: + if e.error is not None and "insufficient_permissions" in str(e.error).lower(): + raise InsufficientPermissionsError("Your Dropbox token does not have sufficient permissions.") + raise ConnectorValidationError(f"Unexpected Dropbox error during validation: {e.user_message_text or e}") + except Exception as e: + raise ConnectorValidationError(f"Unexpected error during Dropbox settings validation: {e}") def _download_file(self, path: str) -> bytes: """Download a single file from Dropbox.""" @@ -54,26 +65,145 @@ def _get_shared_link(self, path: str) -> str: """Create a shared link for a file in Dropbox.""" if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox") - + try: - # Try to get existing shared links first shared_links = self.dropbox_client.sharing_list_shared_links(path=path) if shared_links.links: return shared_links.links[0].url + + link_metadata = self.dropbox_client.sharing_create_shared_link_with_settings(path) + return link_metadata.url + except ApiError as err: + logger.exception(f"[Dropbox]: Failed to create a shared link for {path}: {err}") + return "" + + def _yield_files_recursive( + self, + path: str, + start: SecondsSinceUnixEpoch | None, + end: SecondsSinceUnixEpoch | None, + ) -> GenerateDocumentsOutput: + """Yield files in batches from a specified Dropbox folder, including subfolders.""" + if self.dropbox_client is None: + raise ConnectorMissingCredentialError("Dropbox") + + # Collect all files first to count filename occurrences + all_files = [] + self._collect_files_recursive(path, start, end, all_files) + + # Count filename occurrences + filename_counts: dict[str, int] = {} + for entry, _ in all_files: + filename_counts[entry.name] = filename_counts.get(entry.name, 0) + 1 + + # Process files in batches + batch: list[Document] = [] + for entry, downloaded_file in all_files: + modified_time = entry.client_modified + if modified_time.tzinfo is None: + modified_time = modified_time.replace(tzinfo=timezone.utc) + else: + modified_time = modified_time.astimezone(timezone.utc) - # Create a new shared link - link_settings = self.dropbox_client.sharing_create_shared_link_with_settings(path) - return link_settings.url - except Exception: - # Fallback to basic link format - return f"https://www.dropbox.com/home{path}" - - def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any: + # Use full path only if filename appears multiple times + if filename_counts.get(entry.name, 0) > 1: + # Remove leading slash and replace slashes with ' / ' + relative_path = entry.path_display.lstrip('/') + semantic_id = relative_path.replace('/', ' / ') if relative_path else entry.name + else: + semantic_id = entry.name + + batch.append( + Document( + id=f"dropbox:{entry.id}", + blob=downloaded_file, + source=DocumentSource.DROPBOX, + semantic_identifier=semantic_id, + extension=get_file_ext(entry.name), + doc_updated_at=modified_time, + size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file), + ) + ) + + if len(batch) == self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + def _collect_files_recursive( + self, + path: str, + start: SecondsSinceUnixEpoch | None, + end: SecondsSinceUnixEpoch | None, + all_files: list, + ) -> None: + """Recursively collect all files matching time criteria.""" + if self.dropbox_client is None: + raise ConnectorMissingCredentialError("Dropbox") + + result = self.dropbox_client.files_list_folder( + path, + recursive=False, + include_non_downloadable_files=False, + ) + + while True: + for entry in result.entries: + if isinstance(entry, FileMetadata): + modified_time = entry.client_modified + if modified_time.tzinfo is None: + modified_time = modified_time.replace(tzinfo=timezone.utc) + else: + modified_time = modified_time.astimezone(timezone.utc) + + time_as_seconds = modified_time.timestamp() + if start is not None and time_as_seconds <= start: + continue + if end is not None and time_as_seconds > end: + continue + + try: + downloaded_file = self._download_file(entry.path_display) + all_files.append((entry, downloaded_file)) + except Exception: + logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}") + continue + + elif isinstance(entry, FolderMetadata): + self._collect_files_recursive(entry.path_lower, start, end, all_files) + + if not result.has_more: + break + + result = self.dropbox_client.files_list_folder_continue(result.cursor) + + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput: """Poll Dropbox for recent file changes""" - # Simplified implementation - in production this would handle actual polling - return [] + if self.dropbox_client is None: + raise ConnectorMissingCredentialError("Dropbox") + + for batch in self._yield_files_recursive("", start, end): + yield batch - def load_from_state(self) -> Any: + def load_from_state(self) -> GenerateDocumentsOutput: """Load files from Dropbox state""" - # Simplified implementation - return [] \ No newline at end of file + return self._yield_files_recursive("", None, None) + + +if __name__ == "__main__": + import os + + logging.basicConfig(level=logging.DEBUG) + connector = DropboxConnector() + connector.load_credentials({"dropbox_access_token": os.environ.get("DROPBOX_ACCESS_TOKEN")}) + connector.validate_connector_settings() + document_batches = connector.load_from_state() + try: + first_batch = next(document_batches) + print(f"Loaded {len(first_batch)} documents in first batch.") + for doc in first_batch: + print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)") + except StopIteration: + print("No documents available in Dropbox.") diff --git a/common/data_source/file_types.py b/common/data_source/file_types.py index bf7eafaaaba..be4d56d7b5b 100644 --- a/common/data_source/file_types.py +++ b/common/data_source/file_types.py @@ -18,6 +18,7 @@ class UploadMimeTypes: "text/plain", "text/markdown", "text/x-markdown", + "text/mdx", "text/x-config", "text/tab-separated-values", "application/json", diff --git a/web/src/pages/add-knowledge/components/knowledge-dataset/index.less b/common/data_source/github/__init__.py similarity index 100% rename from web/src/pages/add-knowledge/components/knowledge-dataset/index.less rename to common/data_source/github/__init__.py diff --git a/common/data_source/github/connector.py b/common/data_source/github/connector.py new file mode 100644 index 00000000000..2e6d5f2af93 --- /dev/null +++ b/common/data_source/github/connector.py @@ -0,0 +1,973 @@ +import copy +import logging +from collections.abc import Callable +from collections.abc import Generator +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from enum import Enum +from typing import Any +from typing import cast + +from github import Github, Auth +from github import RateLimitExceededException +from github import Repository +from github.GithubException import GithubException +from github.Issue import Issue +from github.NamedUser import NamedUser +from github.PaginatedList import PaginatedList +from github.PullRequest import PullRequest +from pydantic import BaseModel +from typing_extensions import override +from common.data_source.google_util.util import sanitize_filename +from common.data_source.config import DocumentSource, GITHUB_CONNECTOR_BASE_URL +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, + CredentialExpiredError, + InsufficientPermissionsError, + UnexpectedValidationError, +) +from common.data_source.interfaces import CheckpointedConnectorWithPermSyncGH, CheckpointOutput +from common.data_source.models import ( + ConnectorCheckpoint, + ConnectorFailure, + Document, + DocumentFailure, + ExternalAccess, + SecondsSinceUnixEpoch, +) +from common.data_source.connector_runner import ConnectorRunner +from .models import SerializedRepository +from .rate_limit_utils import sleep_after_rate_limit_exception +from .utils import deserialize_repository +from .utils import get_external_access_permission + +ITEMS_PER_PAGE = 100 +CURSOR_LOG_FREQUENCY = 50 + +_MAX_NUM_RATE_LIMIT_RETRIES = 5 + +ONE_DAY = timedelta(days=1) +SLIM_BATCH_SIZE = 100 +# Cases +# X (from start) standard run, no fallback to cursor-based pagination +# X (from start) standard run errors, fallback to cursor-based pagination +# X error in the middle of a page +# X no errors: run to completion +# X (from checkpoint) standard run, no fallback to cursor-based pagination +# X (from checkpoint) continue from cursor-based pagination +# - retrying +# - no retrying + +# things to check: +# checkpoint state on return +# checkpoint progress (no infinite loop) + + +class DocMetadata(BaseModel): + repo: str + + +def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str: + if "_PaginatedList__nextUrl" in pag_list.__dict__: + return "_PaginatedList__nextUrl" + for key in pag_list.__dict__: + if "__nextUrl" in key: + return key + for key in pag_list.__dict__: + if "nextUrl" in key: + return key + return "" + + +def get_nextUrl( + pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str +) -> str | None: + return getattr(pag_list, nextUrl_key) if nextUrl_key else None + + +def set_nextUrl( + pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str +) -> None: + if nextUrl_key: + setattr(pag_list, nextUrl_key, nextUrl) + elif nextUrl: + raise ValueError("Next URL key not found: " + str(pag_list.__dict__)) + + +def _paginate_until_error( + git_objs: Callable[[], PaginatedList[PullRequest | Issue]], + cursor_url: str | None, + prev_num_objs: int, + cursor_url_callback: Callable[[str | None, int], None], + retrying: bool = False, +) -> Generator[PullRequest | Issue, None, None]: + num_objs = prev_num_objs + pag_list = git_objs() + nextUrl_key = get_nextUrl_key(pag_list) + if cursor_url: + set_nextUrl(pag_list, nextUrl_key, cursor_url) + elif retrying: + # if we are retrying, we want to skip the objects retrieved + # over previous calls. Unfortunately, this WILL retrieve all + # pages before the one we are resuming from, so we really + # don't want this case to be hit often + logging.warning( + "Retrying from a previous cursor-based pagination call. " + "This will retrieve all pages before the one we are resuming from, " + "which may take a while and consume many API calls." + ) + pag_list = cast(PaginatedList[PullRequest | Issue], pag_list[prev_num_objs:]) + num_objs = 0 + + try: + # this for loop handles cursor-based pagination + for issue_or_pr in pag_list: + num_objs += 1 + yield issue_or_pr + # used to store the current cursor url in the checkpoint. This value + # is updated during iteration over pag_list. + cursor_url_callback(get_nextUrl(pag_list, nextUrl_key), num_objs) + + if num_objs % CURSOR_LOG_FREQUENCY == 0: + logging.info( + f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}" + ) + + except Exception as e: + logging.exception(f"Error during cursor-based pagination: {e}") + if num_objs - prev_num_objs > 0: + raise + + if get_nextUrl(pag_list, nextUrl_key) is not None and not retrying: + logging.info( + "Assuming that this error is due to cursor " + "expiration because no objects were retrieved. " + "Retrying from the first page." + ) + yield from _paginate_until_error( + git_objs, None, prev_num_objs, cursor_url_callback, retrying=True + ) + return + + # for no cursor url or if we reach this point after a retry, raise the error + raise + + +def _get_batch_rate_limited( + # We pass in a callable because we want git_objs to produce a fresh + # PaginatedList each time it's called to avoid using the same object for cursor-based pagination + # from a partial offset-based pagination call. + git_objs: Callable[[], PaginatedList], + page_num: int, + cursor_url: str | None, + prev_num_objs: int, + cursor_url_callback: Callable[[str | None, int], None], + github_client: Github, + attempt_num: int = 0, +) -> Generator[PullRequest | Issue, None, None]: + if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: + raise RuntimeError( + "Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github" + ) + try: + if cursor_url: + # when this is set, we are resuming from an earlier + # cursor-based pagination call. + yield from _paginate_until_error( + git_objs, cursor_url, prev_num_objs, cursor_url_callback + ) + return + objs = list(git_objs().get_page(page_num)) + # fetch all data here to disable lazy loading later + # this is needed to capture the rate limit exception here (if one occurs) + for obj in objs: + if hasattr(obj, "raw_data"): + getattr(obj, "raw_data") + yield from objs + except RateLimitExceededException: + sleep_after_rate_limit_exception(github_client) + yield from _get_batch_rate_limited( + git_objs, + page_num, + cursor_url, + prev_num_objs, + cursor_url_callback, + github_client, + attempt_num + 1, + ) + except GithubException as e: + if not ( + e.status == 422 + and ( + "cursor" in (e.message or "") + or "cursor" in (e.data or {}).get("message", "") + ) + ): + raise + # Fallback to a cursor-based pagination strategy + # This can happen for "large datasets," but there's no documentation + # On the error on the web as far as we can tell. + # Error message: + # "Pagination with the page parameter is not supported for large datasets, + # please use cursor based pagination (after/before)" + yield from _paginate_until_error( + git_objs, cursor_url, prev_num_objs, cursor_url_callback + ) + + +def _get_userinfo(user: NamedUser) -> dict[str, str]: + def _safe_get(attr_name: str) -> str | None: + try: + return cast(str | None, getattr(user, attr_name)) + except GithubException: + logging.debug(f"Error getting {attr_name} for user") + return None + + return { + k: v + for k, v in { + "login": _safe_get("login"), + "name": _safe_get("name"), + "email": _safe_get("email"), + }.items() + if v is not None + } + + +def _convert_pr_to_document( + pull_request: PullRequest, repo_external_access: ExternalAccess | None +) -> Document: + repo_name = pull_request.base.repo.full_name if pull_request.base else "" + doc_metadata = DocMetadata(repo=repo_name) + file_content_byte = pull_request.body.encode('utf-8') if pull_request.body else b"" + name = sanitize_filename(pull_request.title, "md") + + return Document( + id=pull_request.html_url, + blob= file_content_byte, + source=DocumentSource.GITHUB, + external_access=repo_external_access, + semantic_identifier=f"{pull_request.number}:{name}", + # updated_at is UTC time but is timezone unaware, explicitly add UTC + # as there is logic in indexing to prevent wrong timestamped docs + # due to local time discrepancies with UTC + doc_updated_at=( + pull_request.updated_at.replace(tzinfo=timezone.utc) + if pull_request.updated_at + else None + ), + extension=".md", + # this metadata is used in perm sync + size_bytes=len(file_content_byte) if file_content_byte else 0, + primary_owners=[], + doc_metadata=doc_metadata.model_dump(), + metadata={ + k: [str(vi) for vi in v] if isinstance(v, list) else str(v) + for k, v in { + "object_type": "PullRequest", + "id": pull_request.number, + "merged": pull_request.merged, + "state": pull_request.state, + "user": _get_userinfo(pull_request.user) if pull_request.user else None, + "assignees": [ + _get_userinfo(assignee) for assignee in pull_request.assignees + ], + "repo": ( + pull_request.base.repo.full_name if pull_request.base else None + ), + "num_commits": str(pull_request.commits), + "num_files_changed": str(pull_request.changed_files), + "labels": [label.name for label in pull_request.labels], + "created_at": ( + pull_request.created_at.replace(tzinfo=timezone.utc) + if pull_request.created_at + else None + ), + "updated_at": ( + pull_request.updated_at.replace(tzinfo=timezone.utc) + if pull_request.updated_at + else None + ), + "closed_at": ( + pull_request.closed_at.replace(tzinfo=timezone.utc) + if pull_request.closed_at + else None + ), + "merged_at": ( + pull_request.merged_at.replace(tzinfo=timezone.utc) + if pull_request.merged_at + else None + ), + "merged_by": ( + _get_userinfo(pull_request.merged_by) + if pull_request.merged_by + else None + ), + }.items() + if v is not None + }, + ) + + +def _fetch_issue_comments(issue: Issue) -> str: + comments = issue.get_comments() + return "\nComment: ".join(comment.body for comment in comments) + + +def _convert_issue_to_document( + issue: Issue, repo_external_access: ExternalAccess | None +) -> Document: + repo_name = issue.repository.full_name if issue.repository else "" + doc_metadata = DocMetadata(repo=repo_name) + file_content_byte = issue.body.encode('utf-8') if issue.body else b"" + name = sanitize_filename(issue.title, "md") + + return Document( + id=issue.html_url, + blob=file_content_byte, + source=DocumentSource.GITHUB, + extension=".md", + external_access=repo_external_access, + semantic_identifier=f"{issue.number}:{name}", + # updated_at is UTC time but is timezone unaware + doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc), + # this metadata is used in perm sync + doc_metadata=doc_metadata.model_dump(), + size_bytes=len(file_content_byte) if file_content_byte else 0, + primary_owners=[_get_userinfo(issue.user) if issue.user else None], + metadata={ + k: [str(vi) for vi in v] if isinstance(v, list) else str(v) + for k, v in { + "object_type": "Issue", + "id": issue.number, + "state": issue.state, + "user": _get_userinfo(issue.user) if issue.user else None, + "assignees": [_get_userinfo(assignee) for assignee in issue.assignees], + "repo": issue.repository.full_name if issue.repository else None, + "labels": [label.name for label in issue.labels], + "created_at": ( + issue.created_at.replace(tzinfo=timezone.utc) + if issue.created_at + else None + ), + "updated_at": ( + issue.updated_at.replace(tzinfo=timezone.utc) + if issue.updated_at + else None + ), + "closed_at": ( + issue.closed_at.replace(tzinfo=timezone.utc) + if issue.closed_at + else None + ), + "closed_by": ( + _get_userinfo(issue.closed_by) if issue.closed_by else None + ), + }.items() + if v is not None + }, + ) + + +class GithubConnectorStage(Enum): + START = "start" + PRS = "prs" + ISSUES = "issues" + + +class GithubConnectorCheckpoint(ConnectorCheckpoint): + stage: GithubConnectorStage + curr_page: int + + cached_repo_ids: list[int] | None = None + cached_repo: SerializedRepository | None = None + + # Used for the fallback cursor-based pagination strategy + num_retrieved: int + cursor_url: str | None = None + + def reset(self) -> None: + """ + Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None) + """ + self.curr_page = 0 + self.num_retrieved = 0 + self.cursor_url = None + + +def make_cursor_url_callback( + checkpoint: GithubConnectorCheckpoint, +) -> Callable[[str | None, int], None]: + def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None: + # we want to maintain the old cursor url so code after retrieval + # can determine that we are using the fallback cursor-based pagination strategy + if cursor_url: + checkpoint.cursor_url = cursor_url + checkpoint.num_retrieved = num_objs + + return cursor_url_callback + + +class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpoint]): + def __init__( + self, + repo_owner: str, + repositories: str | None = None, + state_filter: str = "all", + include_prs: bool = True, + include_issues: bool = False, + ) -> None: + self.repo_owner = repo_owner + self.repositories = repositories + self.state_filter = state_filter + self.include_prs = include_prs + self.include_issues = include_issues + self.github_client: Github | None = None + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + # defaults to 30 items per page, can be set to as high as 100 + token = credentials["github_access_token"] + auth = Auth.Token(token) + + if GITHUB_CONNECTOR_BASE_URL: + self.github_client = Github( + auth=auth, + base_url=GITHUB_CONNECTOR_BASE_URL, + per_page=ITEMS_PER_PAGE, + ) + else: + self.github_client = Github( + auth=auth, + per_page=ITEMS_PER_PAGE, + ) + + return None + + def get_github_repo( + self, github_client: Github, attempt_num: int = 0 + ) -> Repository.Repository: + if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: + raise RuntimeError( + "Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github" + ) + + try: + return github_client.get_repo(f"{self.repo_owner}/{self.repositories}") + except RateLimitExceededException: + sleep_after_rate_limit_exception(github_client) + return self.get_github_repo(github_client, attempt_num + 1) + + def get_github_repos( + self, github_client: Github, attempt_num: int = 0 + ) -> list[Repository.Repository]: + """Get specific repositories based on comma-separated repo_name string.""" + if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: + raise RuntimeError( + "Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github" + ) + + try: + repos = [] + # Split repo_name by comma and strip whitespace + repo_names = [ + name.strip() for name in (cast(str, self.repositories)).split(",") + ] + + for repo_name in repo_names: + if repo_name: # Skip empty strings + try: + repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}") + repos.append(repo) + except GithubException as e: + logging.warning( + f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}" + ) + + return repos + except RateLimitExceededException: + sleep_after_rate_limit_exception(github_client) + return self.get_github_repos(github_client, attempt_num + 1) + + def get_all_repos( + self, github_client: Github, attempt_num: int = 0 + ) -> list[Repository.Repository]: + if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: + raise RuntimeError( + "Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github" + ) + + try: + # Try to get organization first + try: + org = github_client.get_organization(self.repo_owner) + return list(org.get_repos()) + + except GithubException: + # If not an org, try as a user + user = github_client.get_user(self.repo_owner) + return list(user.get_repos()) + except RateLimitExceededException: + sleep_after_rate_limit_exception(github_client) + return self.get_all_repos(github_client, attempt_num + 1) + + def _pull_requests_func( + self, repo: Repository.Repository + ) -> Callable[[], PaginatedList[PullRequest]]: + return lambda: repo.get_pulls( + state=self.state_filter, sort="updated", direction="desc" + ) + + def _issues_func( + self, repo: Repository.Repository + ) -> Callable[[], PaginatedList[Issue]]: + return lambda: repo.get_issues( + state=self.state_filter, sort="updated", direction="desc" + ) + + def _fetch_from_github( + self, + checkpoint: GithubConnectorCheckpoint, + start: datetime | None = None, + end: datetime | None = None, + include_permissions: bool = False, + ) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]: + if self.github_client is None: + raise ConnectorMissingCredentialError("GitHub") + + checkpoint = copy.deepcopy(checkpoint) + + # First run of the connector, fetch all repos and store in checkpoint + if checkpoint.cached_repo_ids is None: + repos = [] + if self.repositories: + if "," in self.repositories: + # Multiple repositories specified + repos = self.get_github_repos(self.github_client) + else: + # Single repository (backward compatibility) + repos = [self.get_github_repo(self.github_client)] + else: + # All repositories + repos = self.get_all_repos(self.github_client) + if not repos: + checkpoint.has_more = False + return checkpoint + + curr_repo = repos.pop() + checkpoint.cached_repo_ids = [repo.id for repo in repos] + checkpoint.cached_repo = SerializedRepository( + id=curr_repo.id, + headers=curr_repo.raw_headers, + raw_data=curr_repo.raw_data, + ) + checkpoint.stage = GithubConnectorStage.PRS + checkpoint.curr_page = 0 + # save checkpoint with repo ids retrieved + return checkpoint + + if checkpoint.cached_repo is None: + raise ValueError("No repo saved in checkpoint") + + # Deserialize the repository from the checkpoint + repo = deserialize_repository(checkpoint.cached_repo, self.github_client) + + cursor_url_callback = make_cursor_url_callback(checkpoint) + repo_external_access: ExternalAccess | None = None + if include_permissions: + repo_external_access = get_external_access_permission( + repo, self.github_client + ) + if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS: + logging.info(f"Fetching PRs for repo: {repo.name}") + + pr_batch = _get_batch_rate_limited( + self._pull_requests_func(repo), + checkpoint.curr_page, + checkpoint.cursor_url, + checkpoint.num_retrieved, + cursor_url_callback, + self.github_client, + ) + checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback + done_with_prs = False + num_prs = 0 + pr = None + print("start: ", start) + for pr in pr_batch: + num_prs += 1 + print("-"*40) + print("PR name", pr.title) + print("updated at", pr.updated_at) + print("-"*40) + print("\n") + # we iterate backwards in time, so at this point we stop processing prs + if ( + start is not None + and pr.updated_at + and pr.updated_at.replace(tzinfo=timezone.utc) <= start + ): + done_with_prs = True + break + # Skip PRs updated after the end date + if ( + end is not None + and pr.updated_at + and pr.updated_at.replace(tzinfo=timezone.utc) > end + ): + continue + try: + yield _convert_pr_to_document( + cast(PullRequest, pr), repo_external_access + ) + except Exception as e: + error_msg = f"Error converting PR to document: {e}" + logging.exception(error_msg) + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=str(pr.id), document_link=pr.html_url + ), + failure_message=error_msg, + exception=e, + ) + continue + + # If we reach this point with a cursor url in the checkpoint, we were using + # the fallback cursor-based pagination strategy. That strategy tries to get all + # PRs, so having curosr_url set means we are done with prs. However, we need to + # return AFTER the checkpoint reset to avoid infinite loops. + + # if we found any PRs on the page and there are more PRs to get, return the checkpoint. + # In offset mode, while indexing without time constraints, the pr batch + # will be empty when we're done. + used_cursor = checkpoint.cursor_url is not None + if num_prs > 0 and not done_with_prs and not used_cursor: + return checkpoint + + # if we went past the start date during the loop or there are no more + # prs to get, we move on to issues + checkpoint.stage = GithubConnectorStage.ISSUES + checkpoint.reset() + + if used_cursor: + # save the checkpoint after changing stage; next run will continue from issues + return checkpoint + + checkpoint.stage = GithubConnectorStage.ISSUES + + if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES: + logging.info(f"Fetching issues for repo: {repo.name}") + + issue_batch = list( + _get_batch_rate_limited( + self._issues_func(repo), + checkpoint.curr_page, + checkpoint.cursor_url, + checkpoint.num_retrieved, + cursor_url_callback, + self.github_client, + ) + ) + checkpoint.curr_page += 1 + done_with_issues = False + num_issues = 0 + for issue in issue_batch: + num_issues += 1 + issue = cast(Issue, issue) + # we iterate backwards in time, so at this point we stop processing prs + if ( + start is not None + and issue.updated_at.replace(tzinfo=timezone.utc) <= start + ): + done_with_issues = True + break + # Skip PRs updated after the end date + if ( + end is not None + and issue.updated_at.replace(tzinfo=timezone.utc) > end + ): + continue + + if issue.pull_request is not None: + # PRs are handled separately + continue + + try: + yield _convert_issue_to_document(issue, repo_external_access) + except Exception as e: + error_msg = f"Error converting issue to document: {e}" + logging.exception(error_msg) + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=str(issue.id), + document_link=issue.html_url, + ), + failure_message=error_msg, + exception=e, + ) + continue + + # if we found any issues on the page, and we're not done, return the checkpoint. + # don't return if we're using cursor-based pagination to avoid infinite loops + if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url: + return checkpoint + + # if we went past the start date during the loop or there are no more + # issues to get, we move on to the next repo + checkpoint.stage = GithubConnectorStage.PRS + checkpoint.reset() + + checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0 + if checkpoint.cached_repo_ids: + next_id = checkpoint.cached_repo_ids.pop() + next_repo = self.github_client.get_repo(next_id) + checkpoint.cached_repo = SerializedRepository( + id=next_id, + headers=next_repo.raw_headers, + raw_data=next_repo.raw_data, + ) + checkpoint.stage = GithubConnectorStage.PRS + checkpoint.reset() + + if checkpoint.cached_repo_ids: + logging.info( + f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})" + ) + else: + logging.info("No more repos remaining") + + return checkpoint + + def _load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: GithubConnectorCheckpoint, + include_permissions: bool = False, + ) -> CheckpointOutput[GithubConnectorCheckpoint]: + start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) + # add a day for timezone safety + end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) + ONE_DAY + + # Move start time back by 3 hours, since some Issues/PRs are getting dropped + # Could be due to delayed processing on GitHub side + # The non-updated issues since last poll will be shortcut-ed and not embedded + # adjusted_start_datetime = start_datetime - timedelta(hours=3) + + adjusted_start_datetime = start_datetime + + epoch = datetime.fromtimestamp(0, tz=timezone.utc) + if adjusted_start_datetime < epoch: + adjusted_start_datetime = epoch + + return self._fetch_from_github( + checkpoint, + start=adjusted_start_datetime, + end=end_datetime, + include_permissions=include_permissions, + ) + + @override + def load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: GithubConnectorCheckpoint, + ) -> CheckpointOutput[GithubConnectorCheckpoint]: + return self._load_from_checkpoint( + start, end, checkpoint, include_permissions=False + ) + + @override + def load_from_checkpoint_with_perm_sync( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: GithubConnectorCheckpoint, + ) -> CheckpointOutput[GithubConnectorCheckpoint]: + return self._load_from_checkpoint( + start, end, checkpoint, include_permissions=True + ) + + def validate_connector_settings(self) -> None: + if self.github_client is None: + raise ConnectorMissingCredentialError("GitHub credentials not loaded.") + + if not self.repo_owner: + raise ConnectorValidationError( + "Invalid connector settings: 'repo_owner' must be provided." + ) + + try: + if self.repositories: + if "," in self.repositories: + # Multiple repositories specified + repo_names = [name.strip() for name in self.repositories.split(",")] + if not repo_names: + raise ConnectorValidationError( + "Invalid connector settings: No valid repository names provided." + ) + + # Validate at least one repository exists and is accessible + valid_repos = False + validation_errors = [] + + for repo_name in repo_names: + if not repo_name: + continue + + try: + test_repo = self.github_client.get_repo( + f"{self.repo_owner}/{repo_name}" + ) + logging.info( + f"Successfully accessed repository: {self.repo_owner}/{repo_name}" + ) + test_repo.get_contents("") + valid_repos = True + # If at least one repo is valid, we can proceed + break + except GithubException as e: + validation_errors.append( + f"Repository '{repo_name}': {e.data.get('message', str(e))}" + ) + + if not valid_repos: + error_msg = ( + "None of the specified repositories could be accessed: " + ) + error_msg += ", ".join(validation_errors) + raise ConnectorValidationError(error_msg) + else: + # Single repository (backward compatibility) + test_repo = self.github_client.get_repo( + f"{self.repo_owner}/{self.repositories}" + ) + test_repo.get_contents("") + else: + # Try to get organization first + try: + org = self.github_client.get_organization(self.repo_owner) + total_count = org.get_repos().totalCount + if total_count == 0: + raise ConnectorValidationError( + f"Found no repos for organization: {self.repo_owner}. " + "Does the credential have the right scopes?" + ) + except GithubException as e: + # Check for missing SSO + MISSING_SSO_ERROR_MESSAGE = "You must grant your Personal Access token access to this organization".lower() + if MISSING_SSO_ERROR_MESSAGE in str(e).lower(): + SSO_GUIDE_LINK = ( + "https://docs.github.com/en/enterprise-cloud@latest/authentication/" + "authenticating-with-saml-single-sign-on/" + "authorizing-a-personal-access-token-for-use-with-saml-single-sign-on" + ) + raise ConnectorValidationError( + f"Your GitHub token is missing authorization to access the " + f"`{self.repo_owner}` organization. Please follow the guide to " + f"authorize your token: {SSO_GUIDE_LINK}" + ) + # If not an org, try as a user + user = self.github_client.get_user(self.repo_owner) + + # Check if we can access any repos + total_count = user.get_repos().totalCount + if total_count == 0: + raise ConnectorValidationError( + f"Found no repos for user: {self.repo_owner}. " + "Does the credential have the right scopes?" + ) + + except RateLimitExceededException: + raise UnexpectedValidationError( + "Validation failed due to GitHub rate-limits being exceeded. Please try again later." + ) + + except GithubException as e: + if e.status == 401: + raise CredentialExpiredError( + "GitHub credential appears to be invalid or expired (HTTP 401)." + ) + elif e.status == 403: + raise InsufficientPermissionsError( + "Your GitHub token does not have sufficient permissions for this repository (HTTP 403)." + ) + elif e.status == 404: + if self.repositories: + if "," in self.repositories: + raise ConnectorValidationError( + f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}" + ) + else: + raise ConnectorValidationError( + f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}" + ) + else: + raise ConnectorValidationError( + f"GitHub user or organization not found: {self.repo_owner}" + ) + else: + raise ConnectorValidationError( + f"Unexpected GitHub error (status={e.status}): {e.data}" + ) + + except Exception as exc: + raise Exception( + f"Unexpected error during GitHub settings validation: {exc}" + ) + + def validate_checkpoint_json( + self, checkpoint_json: str + ) -> GithubConnectorCheckpoint: + return GithubConnectorCheckpoint.model_validate_json(checkpoint_json) + + def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint: + return GithubConnectorCheckpoint( + stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0 + ) + + +if __name__ == "__main__": + # Initialize the connector + connector = GithubConnector( + repo_owner="EvoAgentX", + repositories="EvoAgentX", + include_issues=True, + include_prs=False, + ) + connector.load_credentials( + {"github_access_token": ""} + ) + + if connector.github_client: + get_external_access_permission( + connector.get_github_repos(connector.github_client).pop(), + connector.github_client, + ) + + # Create a time range from epoch to now + end_time = datetime.now(timezone.utc) + start_time = datetime.fromtimestamp(0, tz=timezone.utc) + time_range = (start_time, end_time) + + # Initialize the runner with a batch size of 10 + runner: ConnectorRunner[GithubConnectorCheckpoint] = ConnectorRunner( + connector, batch_size=10, include_permissions=False, time_range=time_range + ) + + # Get initial checkpoint + checkpoint = connector.build_dummy_checkpoint() + + # Run the connector + while checkpoint.has_more: + for doc_batch, failure, next_checkpoint in runner.run(checkpoint): + if doc_batch: + print(f"Retrieved batch of {len(doc_batch)} documents") + for doc in doc_batch: + print(f"Document: {doc.semantic_identifier}") + if failure: + print(f"Failure: {failure.failure_message}") + if next_checkpoint: + checkpoint = next_checkpoint \ No newline at end of file diff --git a/common/data_source/github/models.py b/common/data_source/github/models.py new file mode 100644 index 00000000000..9754bfa8db8 --- /dev/null +++ b/common/data_source/github/models.py @@ -0,0 +1,17 @@ +from typing import Any + +from github import Repository +from github.Requester import Requester +from pydantic import BaseModel + + +class SerializedRepository(BaseModel): + # id is part of the raw_data as well, just pulled out for convenience + id: int + headers: dict[str, str | int] + raw_data: dict[str, Any] + + def to_Repository(self, requester: Requester) -> Repository.Repository: + return Repository.Repository( + requester, self.headers, self.raw_data, completed=True + ) \ No newline at end of file diff --git a/common/data_source/github/rate_limit_utils.py b/common/data_source/github/rate_limit_utils.py new file mode 100644 index 00000000000..d683bad08d2 --- /dev/null +++ b/common/data_source/github/rate_limit_utils.py @@ -0,0 +1,24 @@ +import time +import logging +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +from github import Github + + +def sleep_after_rate_limit_exception(github_client: Github) -> None: + """ + Sleep until the GitHub rate limit resets. + + Args: + github_client: The GitHub client that hit the rate limit + """ + sleep_time = github_client.get_rate_limit().core.reset.replace( + tzinfo=timezone.utc + ) - datetime.now(tz=timezone.utc) + sleep_time += timedelta(minutes=1) # add an extra minute just to be safe + logging.info( + "Ran into Github rate-limit. Sleeping %s seconds.", sleep_time.seconds + ) + time.sleep(sleep_time.total_seconds()) \ No newline at end of file diff --git a/common/data_source/github/utils.py b/common/data_source/github/utils.py new file mode 100644 index 00000000000..93b843bc841 --- /dev/null +++ b/common/data_source/github/utils.py @@ -0,0 +1,44 @@ +import logging + +from github import Github +from github.Repository import Repository + +from common.data_source.models import ExternalAccess + +from .models import SerializedRepository + + +def get_external_access_permission( + repo: Repository, github_client: Github +) -> ExternalAccess: + """ + Get the external access permission for a repository. + This functionality requires Enterprise Edition. + """ + # RAGFlow doesn't implement the Onyx EE external-permissions system. + # Default to private/unknown permissions. + return ExternalAccess.empty() + + +def deserialize_repository( + cached_repo: SerializedRepository, github_client: Github +) -> Repository: + """ + Deserialize a SerializedRepository back into a Repository object. + """ + # Try to access the requester - different PyGithub versions may use different attribute names + try: + # Try to get the requester using getattr to avoid linter errors + requester = getattr(github_client, "_requester", None) + if requester is None: + requester = getattr(github_client, "_Github__requester", None) + if requester is None: + # If we can't find the requester attribute, we need to fall back to recreating the repo + raise AttributeError("Could not find requester attribute") + + return cached_repo.to_Repository(requester) + except Exception as e: + # If all else fails, re-fetch the repo directly + logging.warning("Failed to deserialize repository: %s. Attempting to re-fetch.", e) + repo_id = cached_repo.id + return github_client.get_repo(repo_id) \ No newline at end of file diff --git a/common/data_source/gitlab_connector.py b/common/data_source/gitlab_connector.py new file mode 100644 index 00000000000..0d2c0dab775 --- /dev/null +++ b/common/data_source/gitlab_connector.py @@ -0,0 +1,340 @@ +import fnmatch +import itertools +from collections import deque +from collections.abc import Iterable +from collections.abc import Iterator +from datetime import datetime +from datetime import timezone +from typing import Any +from typing import TypeVar +import gitlab +from gitlab.v4.objects import Project + +from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE +from common.data_source.exceptions import ConnectorMissingCredentialError +from common.data_source.exceptions import ConnectorValidationError +from common.data_source.exceptions import CredentialExpiredError +from common.data_source.exceptions import InsufficientPermissionsError +from common.data_source.exceptions import UnexpectedValidationError +from common.data_source.interfaces import GenerateDocumentsOutput +from common.data_source.interfaces import LoadConnector +from common.data_source.interfaces import PollConnector +from common.data_source.interfaces import SecondsSinceUnixEpoch +from common.data_source.models import BasicExpertInfo +from common.data_source.models import Document +from common.data_source.utils import get_file_ext + +T = TypeVar("T") + + + +# List of directories/Files to exclude +exclude_patterns = [ + "logs", + ".github/", + ".gitlab/", + ".pre-commit-config.yaml", +] + + +def _batch_gitlab_objects(git_objs: Iterable[T], batch_size: int) -> Iterator[list[T]]: + it = iter(git_objs) + while True: + batch = list(itertools.islice(it, batch_size)) + if not batch: + break + yield batch + + +def get_author(author: Any) -> BasicExpertInfo: + return BasicExpertInfo( + display_name=author.get("name"), + ) + + +def _convert_merge_request_to_document(mr: Any) -> Document: + mr_text = mr.description or "" + doc = Document( + id=mr.web_url, + blob=mr_text, + source=DocumentSource.GITLAB, + semantic_identifier=mr.title, + extension=".md", + # updated_at is UTC time but is timezone unaware, explicitly add UTC + # as there is logic in indexing to prevent wrong timestamped docs + # due to local time discrepancies with UTC + doc_updated_at=mr.updated_at.replace(tzinfo=timezone.utc), + size_bytes=len(mr_text.encode("utf-8")), + primary_owners=[get_author(mr.author)], + metadata={"state": mr.state, "type": "MergeRequest", "web_url": mr.web_url}, + ) + return doc + + +def _convert_issue_to_document(issue: Any) -> Document: + issue_text = issue.description or "" + doc = Document( + id=issue.web_url, + blob=issue_text, + source=DocumentSource.GITLAB, + semantic_identifier=issue.title, + extension=".md", + # updated_at is UTC time but is timezone unaware, explicitly add UTC + # as there is logic in indexing to prevent wrong timestamped docs + # due to local time discrepancies with UTC + doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc), + size_bytes=len(issue_text.encode("utf-8")), + primary_owners=[get_author(issue.author)], + metadata={ + "state": issue.state, + "type": issue.type if issue.type else "Issue", + "web_url": issue.web_url, + }, + ) + return doc + + +def _convert_code_to_document( + project: Project, file: Any, url: str, projectName: str, projectOwner: str +) -> Document: + + # Dynamically get the default branch from the project object + default_branch = project.default_branch + + # Fetch the file content using the correct branch + file_content_obj = project.files.get( + file_path=file["path"], ref=default_branch # Use the default branch + ) + # BoxConnector uses raw bytes for blob. Keep the same here. + file_content_bytes = file_content_obj.decode() + file_url = f"{url}/{projectOwner}/{projectName}/-/blob/{default_branch}/{file['path']}" + + # Try to use the last commit timestamp for incremental sync. + # Falls back to "now" if the commit lookup fails. + last_commit_at = None + try: + # Query commit history for this file on the default branch. + commits = project.commits.list( + ref_name=default_branch, + path=file["path"], + per_page=1, + ) + if commits: + # committed_date is ISO string like "2024-01-01T00:00:00.000+00:00" + committed_date = commits[0].committed_date + if isinstance(committed_date, str): + last_commit_at = datetime.strptime( + committed_date, "%Y-%m-%dT%H:%M:%S.%f%z" + ).astimezone(timezone.utc) + elif isinstance(committed_date, datetime): + last_commit_at = committed_date.astimezone(timezone.utc) + except Exception: + last_commit_at = None + + # Create and return a Document object + doc = Document( + # Use a stable ID so reruns don't create duplicates. + id=file_url, + blob=file_content_bytes, + source=DocumentSource.GITLAB, + semantic_identifier=file.get("name"), + extension=get_file_ext(file.get("name")), + doc_updated_at=last_commit_at or datetime.now(tz=timezone.utc), + size_bytes=len(file_content_bytes) if file_content_bytes is not None else 0, + primary_owners=[], # Add owners if needed + metadata={ + "type": "CodeFile", + "path": file.get("path"), + "ref": default_branch, + "project": f"{projectOwner}/{projectName}", + "web_url": file_url, + }, + ) + return doc + + +def _should_exclude(path: str) -> bool: + """Check if a path matches any of the exclude patterns.""" + return any(fnmatch.fnmatch(path, pattern) for pattern in exclude_patterns) + + +class GitlabConnector(LoadConnector, PollConnector): + def __init__( + self, + project_owner: str, + project_name: str, + batch_size: int = INDEX_BATCH_SIZE, + state_filter: str = "all", + include_mrs: bool = True, + include_issues: bool = True, + include_code_files: bool = False, + ) -> None: + self.project_owner = project_owner + self.project_name = project_name + self.batch_size = batch_size + self.state_filter = state_filter + self.include_mrs = include_mrs + self.include_issues = include_issues + self.include_code_files = include_code_files + self.gitlab_client: gitlab.Gitlab | None = None + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.gitlab_client = gitlab.Gitlab( + credentials["gitlab_url"], private_token=credentials["gitlab_access_token"] + ) + return None + + def validate_connector_settings(self) -> None: + if self.gitlab_client is None: + raise ConnectorMissingCredentialError("GitLab") + + try: + self.gitlab_client.auth() + self.gitlab_client.projects.get( + f"{self.project_owner}/{self.project_name}", + lazy=True, + ) + + except gitlab.exceptions.GitlabAuthenticationError as e: + raise CredentialExpiredError( + "Invalid or expired GitLab credentials." + ) from e + + except gitlab.exceptions.GitlabAuthorizationError as e: + raise InsufficientPermissionsError( + "Insufficient permissions to access GitLab resources." + ) from e + + except gitlab.exceptions.GitlabGetError as e: + raise ConnectorValidationError( + "GitLab project not found or not accessible." + ) from e + + except Exception as e: + raise UnexpectedValidationError( + f"Unexpected error while validating GitLab settings: {e}" + ) from e + + def _fetch_from_gitlab( + self, start: datetime | None = None, end: datetime | None = None + ) -> GenerateDocumentsOutput: + if self.gitlab_client is None: + raise ConnectorMissingCredentialError("Gitlab") + project: Project = self.gitlab_client.projects.get( + f"{self.project_owner}/{self.project_name}" + ) + + start_utc = start.astimezone(timezone.utc) if start else None + end_utc = end.astimezone(timezone.utc) if end else None + + # Fetch code files + if self.include_code_files: + # Fetching using BFS as project.report_tree with recursion causing slow load + queue = deque([""]) # Start with the root directory + while queue: + current_path = queue.popleft() + files = project.repository_tree(path=current_path, all=True) + for file_batch in _batch_gitlab_objects(files, self.batch_size): + code_doc_batch: list[Document] = [] + for file in file_batch: + if _should_exclude(file["path"]): + continue + + if file["type"] == "blob": + + doc = _convert_code_to_document( + project, + file, + self.gitlab_client.url, + self.project_name, + self.project_owner, + ) + + # Apply incremental window filtering for code files too. + if start_utc is not None and doc.doc_updated_at <= start_utc: + continue + if end_utc is not None and doc.doc_updated_at > end_utc: + continue + + code_doc_batch.append(doc) + elif file["type"] == "tree": + queue.append(file["path"]) + + if code_doc_batch: + yield code_doc_batch + + if self.include_mrs: + merge_requests = project.mergerequests.list( + state=self.state_filter, + order_by="updated_at", + sort="desc", + iterator=True, + ) + + for mr_batch in _batch_gitlab_objects(merge_requests, self.batch_size): + mr_doc_batch: list[Document] = [] + for mr in mr_batch: + mr.updated_at = datetime.strptime( + mr.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z" + ) + if start_utc is not None and mr.updated_at <= start_utc: + yield mr_doc_batch + return + if end_utc is not None and mr.updated_at > end_utc: + continue + mr_doc_batch.append(_convert_merge_request_to_document(mr)) + yield mr_doc_batch + + if self.include_issues: + issues = project.issues.list(state=self.state_filter, iterator=True) + + for issue_batch in _batch_gitlab_objects(issues, self.batch_size): + issue_doc_batch: list[Document] = [] + for issue in issue_batch: + issue.updated_at = datetime.strptime( + issue.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z" + ) + # Avoid re-syncing the last-seen item. + if start_utc is not None and issue.updated_at <= start_utc: + yield issue_doc_batch + return + if end_utc is not None and issue.updated_at > end_utc: + continue + issue_doc_batch.append(_convert_issue_to_document(issue)) + yield issue_doc_batch + + def load_from_state(self) -> GenerateDocumentsOutput: + return self._fetch_from_gitlab() + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> GenerateDocumentsOutput: + start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) + end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) + return self._fetch_from_gitlab(start_datetime, end_datetime) + + +if __name__ == "__main__": + import os + + connector = GitlabConnector( + # gitlab_url="https://gitlab.com/api/v4", + project_owner=os.environ["PROJECT_OWNER"], + project_name=os.environ["PROJECT_NAME"], + batch_size=INDEX_BATCH_SIZE, + state_filter="all", + include_mrs=True, + include_issues=True, + include_code_files=True, + ) + + connector.load_credentials( + { + "gitlab_access_token": os.environ["GITLAB_ACCESS_TOKEN"], + "gitlab_url": os.environ["GITLAB_URL"], + } + ) + document_batches = connector.load_from_state() + for f in document_batches: + print("Batch:", f) + print("Finished loading from state.") \ No newline at end of file diff --git a/common/data_source/gmail_connector.py b/common/data_source/gmail_connector.py index 67ebfae989a..e64db984714 100644 --- a/common/data_source/gmail_connector.py +++ b/common/data_source/gmail_connector.py @@ -1,6 +1,5 @@ import logging from typing import Any - from google.oauth2.credentials import Credentials as OAuthCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from googleapiclient.errors import HttpError @@ -9,10 +8,10 @@ from common.data_source.google_util.auth import get_google_creds from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, SCOPE_INSTRUCTIONS, USER_FIELDS from common.data_source.google_util.resource import get_admin_service, get_gmail_service -from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval +from common.data_source.google_util.util import _execute_single_retrieval, execute_paginated_retrieval, sanitize_filename, clean_string from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch, SlimConnectorWithPermSync from common.data_source.models import BasicExpertInfo, Document, ExternalAccess, GenerateDocumentsOutput, GenerateSlimDocumentOutput, SlimDocument, TextSection -from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, time_str_to_utc +from common.data_source.utils import build_time_range_query, clean_email_and_extract_name, get_message_body, is_mail_service_disabled_error, gmail_time_str_to_utc # Constants for Gmail API fields THREAD_LIST_FIELDS = "nextPageToken, threads(id)" @@ -67,7 +66,6 @@ def message_to_section(message: dict[str, Any]) -> tuple[TextSection, dict[str, message_data += f"{name}: {value}\n" message_body_text: str = get_message_body(payload) - return TextSection(link=link, text=message_body_text + message_data), metadata @@ -97,13 +95,15 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: if not semantic_identifier: semantic_identifier = message_metadata.get("subject", "") + semantic_identifier = clean_string(semantic_identifier) + semantic_identifier = sanitize_filename(semantic_identifier) if message_metadata.get("updated_at"): updated_at = message_metadata.get("updated_at") - + updated_at_datetime = None if updated_at: - updated_at_datetime = time_str_to_utc(updated_at) + updated_at_datetime = gmail_time_str_to_utc(updated_at) thread_id = full_thread.get("id") if not thread_id: @@ -115,15 +115,24 @@ def thread_to_document(full_thread: dict[str, Any], email_used_to_fetch_thread: if not semantic_identifier: semantic_identifier = "(no subject)" + combined_sections = "\n\n".join( + sec.text for sec in sections if hasattr(sec, "text") + ) + blob = combined_sections + size_bytes = len(blob) + extension = '.txt' + return Document( id=thread_id, semantic_identifier=semantic_identifier, - sections=sections, + blob=blob, + size_bytes=size_bytes, + extension=extension, source=DocumentSource.GMAIL, primary_owners=primary_owners, secondary_owners=secondary_owners, doc_updated_at=updated_at_datetime, - metadata={}, + metadata=message_metadata, external_access=ExternalAccess( external_user_emails={email_used_to_fetch_thread}, external_user_group_ids=set(), @@ -214,15 +223,13 @@ def _fetch_threads( q=query, continue_on_404_or_403=True, ): - full_threads = _execute_single_retrieval( + full_thread = _execute_single_retrieval( retrieval_function=gmail_service.users().threads().get, - list_key=None, userId=user_email, fields=THREAD_FIELDS, id=thread["id"], continue_on_404_or_403=True, ) - full_thread = list(full_threads)[0] doc = thread_to_document(full_thread, user_email) if doc is None: continue @@ -310,4 +317,30 @@ def retrieve_all_slim_docs_perm_sync( if __name__ == "__main__": - pass + import time + import os + from common.data_source.google_util.util import get_credentials_from_env + logging.basicConfig(level=logging.INFO) + try: + email = os.environ.get("GMAIL_TEST_EMAIL", "newyorkupperbay@gmail.com") + creds = get_credentials_from_env(email, oauth=True, source="gmail") + print("Credentials loaded successfully") + print(f"{creds=}") + + connector = GmailConnector(batch_size=2) + print("GmailConnector initialized") + connector.load_credentials(creds) + print("Credentials loaded into connector") + + print("Gmail is ready to use") + + for file in connector._fetch_threads( + int(time.time()) - 1 * 24 * 60 * 60, + int(time.time()), + ): + print("new batch","-"*80) + for f in file: + print(f) + print("\n\n") + except Exception as e: + logging.exception(f"Error loading credentials: {e}") \ No newline at end of file diff --git a/common/data_source/google_drive/connector.py b/common/data_source/google_drive/connector.py index fb88d0ed050..39017dd4a1d 100644 --- a/common/data_source/google_drive/connector.py +++ b/common/data_source/google_drive/connector.py @@ -1,7 +1,6 @@ """Google Drive connector""" import copy -import json import logging import os import sys @@ -32,7 +31,6 @@ from common.data_source.google_drive.model import DriveRetrievalStage, GoogleDriveCheckpoint, GoogleDriveFileType, RetrievedDriveFile, StageCompletion from common.data_source.google_util.auth import get_google_creds from common.data_source.google_util.constant import DB_CREDENTIALS_PRIMARY_ADMIN_KEY, MISSING_SCOPES_ERROR_STR, USER_FIELDS -from common.data_source.google_util.oauth_flow import ensure_oauth_token_dict from common.data_source.google_util.resource import GoogleDriveService, get_admin_service, get_drive_service from common.data_source.google_util.util import GoogleFields, execute_paginated_retrieval, get_file_owners from common.data_source.google_util.util_threadpool_concurrency import ThreadSafeDict @@ -1138,39 +1136,6 @@ def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoin return GoogleDriveCheckpoint.model_validate_json(checkpoint_json) -def get_credentials_from_env(email: str, oauth: bool = False) -> dict: - try: - if oauth: - raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] - else: - raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] - except KeyError: - raise ValueError("Missing Google Drive credentials in environment variables") - - try: - credential_dict = json.loads(raw_credential_string) - except json.JSONDecodeError: - raise ValueError("Invalid JSON in Google Drive credentials") - - if oauth: - credential_dict = ensure_oauth_token_dict(credential_dict, DocumentSource.GOOGLE_DRIVE) - - refried_credential_string = json.dumps(credential_dict) - - DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens" - DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key" - DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin" - DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method" - - cred_key = DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY - - return { - cred_key: refried_credential_string, - DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, - DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded", - } - - class CheckpointOutputWrapper: """ Wraps a CheckpointOutput generator to give things back in a more digestible format. @@ -1236,7 +1201,7 @@ def yield_all_docs_from_checkpoint_connector( if __name__ == "__main__": import time - + from common.data_source.google_util.util import get_credentials_from_env logging.basicConfig(level=logging.DEBUG) try: @@ -1245,7 +1210,7 @@ def yield_all_docs_from_checkpoint_connector( creds = get_credentials_from_env(email, oauth=True) print("Credentials loaded successfully") print(f"{creds=}") - + # sys.exit(0) connector = GoogleDriveConnector( include_shared_drives=False, shared_drive_urls=None, diff --git a/common/data_source/google_drive/doc_conversion.py b/common/data_source/google_drive/doc_conversion.py index d697c1b2b86..5ab68f9bfb8 100644 --- a/common/data_source/google_drive/doc_conversion.py +++ b/common/data_source/google_drive/doc_conversion.py @@ -76,7 +76,7 @@ MAX_RETRIEVER_EMAILS = 20 CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read -# This is not a standard valid unicode char, it is used by the docs advanced API to +# This is not a standard valid Unicode char, it is used by the docs advanced API to # represent smart chips (elements like dates and doc links). SMART_CHIP_CHAR = "\ue907" WEB_VIEW_LINK_KEY = "webViewLink" diff --git a/common/data_source/google_drive/file_retrieval.py b/common/data_source/google_drive/file_retrieval.py index ee6ea6b62c9..00bade1570a 100644 --- a/common/data_source/google_drive/file_retrieval.py +++ b/common/data_source/google_drive/file_retrieval.py @@ -141,7 +141,7 @@ def crawl_folders_for_files( # Only mark a folder as done if it was fully traversed without errors # This usually indicates that the owner of the folder was impersonated. # In cases where this never happens, most likely the folder owner is - # not part of the google workspace in question (or for oauth, the authenticated + # not part of the Google Workspace in question (or for oauth, the authenticated # user doesn't own the folder) if found_files: update_traversed_ids_func(parent_id) @@ -232,7 +232,7 @@ def get_files_in_shared_drive( **kwargs, ): # If we found any files, mark this drive as traversed. When a user has access to a drive, - # they have access to all the files in the drive. Also not a huge deal if we re-traverse + # they have access to all the files in the drive. Also, not a huge deal if we re-traverse # empty drives. # NOTE: ^^ the above is not actually true due to folder restrictions: # https://support.google.com/a/users/answer/12380484?hl=en @@ -341,6 +341,6 @@ def get_all_files_for_oauth( # Just in case we need to get the root folder id def get_root_folder_id(service: Resource) -> str: - # we dont paginate here because there is only one root folder per user + # we don't paginate here because there is only one root folder per user # https://developers.google.com/drive/api/guides/v2-to-v3-reference return service.files().get(fileId="root", fields=GoogleFields.ID.value).execute()[GoogleFields.ID.value] diff --git a/common/data_source/google_drive/model.py b/common/data_source/google_drive/model.py index d0e89c24e35..d66cc21a54e 100644 --- a/common/data_source/google_drive/model.py +++ b/common/data_source/google_drive/model.py @@ -22,7 +22,7 @@ class GDriveMimeType(str, Enum): MARKDOWN = "text/markdown" -# These correspond to The major stages of retrieval for google drive. +# These correspond to The major stages of retrieval for Google Drive. # The stages for the oauth flow are: # get_all_files_for_oauth(), # get_all_drive_ids(), @@ -117,7 +117,7 @@ def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion] class RetrievedDriveFile(BaseModel): """ - Describes a file that has been retrieved from google drive. + Describes a file that has been retrieved from Google Drive. user_email is the email of the user that the file was retrieved by impersonating. If an error worthy of being reported is encountered, error should be set and later propagated as a ConnectorFailure. diff --git a/common/data_source/google_util/constant.py b/common/data_source/google_util/constant.py index 8ab75fa141c..89c9afaf55b 100644 --- a/common/data_source/google_util/constant.py +++ b/common/data_source/google_util/constant.py @@ -49,11 +49,11 @@ class GoogleOAuthAuthenticationMethod(str, Enum): SCOPE_INSTRUCTIONS = "" -GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE = """ +WEB_OAUTH_POPUP_TEMPLATE = """ - Google Drive Authorization + {title} Icon_24px_CloudStorage_Color \ No newline at end of file diff --git a/web/src/assets/svg/data-source/imap.svg b/web/src/assets/svg/data-source/imap.svg new file mode 100644 index 00000000000..82a815425a0 --- /dev/null +++ b/web/src/assets/svg/data-source/imap.svg @@ -0,0 +1,7 @@ + + + + diff --git a/web/src/assets/svg/data-source/jira.svg b/web/src/assets/svg/data-source/jira.svg new file mode 100644 index 00000000000..8f9cd8b97e4 --- /dev/null +++ b/web/src/assets/svg/data-source/jira.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + diff --git a/web/src/assets/svg/data-source/moodle.svg b/web/src/assets/svg/data-source/moodle.svg new file mode 100644 index 00000000000..d268572cda0 --- /dev/null +++ b/web/src/assets/svg/data-source/moodle.svg @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/web/src/assets/svg/data-source/oracle-storage.svg b/web/src/assets/svg/data-source/oracle-storage.svg new file mode 100644 index 00000000000..90768f8bc9c --- /dev/null +++ b/web/src/assets/svg/data-source/oracle-storage.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/web/src/assets/svg/data-source/r2.svg b/web/src/assets/svg/data-source/r2.svg new file mode 100644 index 00000000000..d31b48fb07c --- /dev/null +++ b/web/src/assets/svg/data-source/r2.svg @@ -0,0 +1,5 @@ + \ No newline at end of file diff --git a/web/src/assets/svg/data-source/webdav.svg b/web/src/assets/svg/data-source/webdav.svg new file mode 100644 index 00000000000..a970d38fe9d --- /dev/null +++ b/web/src/assets/svg/data-source/webdav.svg @@ -0,0 +1,15 @@ + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/empty/no data bri.svg b/web/src/assets/svg/empty/no-data-bri.svg similarity index 100% rename from web/src/components/empty/no data bri.svg rename to web/src/assets/svg/empty/no-data-bri.svg diff --git a/web/src/components/empty/no data.svg b/web/src/assets/svg/empty/no-data-dark.svg similarity index 100% rename from web/src/components/empty/no data.svg rename to web/src/assets/svg/empty/no-data-dark.svg diff --git a/web/src/assets/svg/empty/no-search-data-bri.svg b/web/src/assets/svg/empty/no-search-data-bri.svg new file mode 100644 index 00000000000..4ec53186881 --- /dev/null +++ b/web/src/assets/svg/empty/no-search-data-bri.svg @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/svg/empty/no-search-data-dark.svg b/web/src/assets/svg/empty/no-search-data-dark.svg new file mode 100644 index 00000000000..448543b907a --- /dev/null +++ b/web/src/assets/svg/empty/no-search-data-dark.svg @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/svg/file-icon/mdx.svg b/web/src/assets/svg/file-icon/mdx.svg new file mode 100644 index 00000000000..a6fb749df63 --- /dev/null +++ b/web/src/assets/svg/file-icon/mdx.svg @@ -0,0 +1,10 @@ + + + + + + \ No newline at end of file diff --git a/web/src/assets/svg/home-icon/memory-bri.svg b/web/src/assets/svg/home-icon/memory-bri.svg new file mode 100644 index 00000000000..cb9194cdcbe --- /dev/null +++ b/web/src/assets/svg/home-icon/memory-bri.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/web/src/assets/svg/home-icon/memory.svg b/web/src/assets/svg/home-icon/memory.svg new file mode 100644 index 00000000000..f50d755f423 --- /dev/null +++ b/web/src/assets/svg/home-icon/memory.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + diff --git a/web/src/assets/svg/llm/jiekouai.svg b/web/src/assets/svg/llm/jiekouai.svg new file mode 100644 index 00000000000..914929ad7ee --- /dev/null +++ b/web/src/assets/svg/llm/jiekouai.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/src/assets/svg/llm/mineru-bright.svg b/web/src/assets/svg/llm/mineru-bright.svg new file mode 100644 index 00000000000..7b4c3257b0d --- /dev/null +++ b/web/src/assets/svg/llm/mineru-bright.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/assets/svg/llm/mineru-dark.svg b/web/src/assets/svg/llm/mineru-dark.svg new file mode 100644 index 00000000000..755fe0f3c5f --- /dev/null +++ b/web/src/assets/svg/llm/mineru-dark.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/src/components/api-service/chat-overview-modal/api-content.tsx b/web/src/components/api-service/chat-overview-modal/api-content.tsx index 7daf428876b..ebdc36581be 100644 --- a/web/src/components/api-service/chat-overview-modal/api-content.tsx +++ b/web/src/components/api-service/chat-overview-modal/api-content.tsx @@ -1,3 +1,4 @@ +import { useIsDarkTheme } from '@/components/theme-provider'; import { useSetModalState, useTranslate } from '@/hooks/common-hooks'; import { LangfuseCard } from '@/pages/user-setting/setting-model/langfuse'; import apiDoc from '@parent/docs/references/http_api_reference.md'; @@ -28,6 +29,8 @@ const ApiContent = ({ const { handlePreview } = usePreviewChat(idKey); + const isDarkTheme = useIsDarkTheme(); + return (
@@ -47,7 +50,10 @@ const ApiContent = ({
- +
{apiKeyVisible && (
- {text} + {text} ), }, diff --git a/web/src/components/api-service/hooks.ts b/web/src/components/api-service/hooks.ts index 16878d84d8b..8d82ef31abb 100644 --- a/web/src/components/api-service/hooks.ts +++ b/web/src/components/api-service/hooks.ts @@ -9,7 +9,7 @@ import { useFetchManualSystemTokenList, useFetchSystemTokenList, useRemoveSystemToken, -} from '@/hooks/user-setting-hooks'; +} from '@/hooks/use-user-setting-request'; import { IStats } from '@/interfaces/database/chat'; import { useQueryClient } from '@tanstack/react-query'; import { message } from 'antd'; diff --git a/web/src/components/auto-keywords-item.tsx b/web/src/components/auto-keywords-item.tsx deleted file mode 100644 index a18c161e44c..00000000000 --- a/web/src/components/auto-keywords-item.tsx +++ /dev/null @@ -1,48 +0,0 @@ -import { useTranslate } from '@/hooks/common-hooks'; -import { Flex, Form, InputNumber, Slider } from 'antd'; - -export const AutoKeywordsItem = () => { - const { t } = useTranslate('knowledgeDetails'); - - return ( - - - - - - - - - - - - - ); -}; - -export const AutoQuestionsItem = () => { - const { t } = useTranslate('knowledgeDetails'); - - return ( - - - - - - - - - - - - - ); -}; diff --git a/web/src/components/avatar-upload.tsx b/web/src/components/avatar-upload.tsx index 7a85e08defb..9f4a3707662 100644 --- a/web/src/components/avatar-upload.tsx +++ b/web/src/components/avatar-upload.tsx @@ -5,12 +5,14 @@ import { forwardRef, useCallback, useEffect, + useRef, useState, } from 'react'; import { useTranslation } from 'react-i18next'; import { Avatar, AvatarFallback, AvatarImage } from './ui/avatar'; import { Button } from './ui/button'; import { Input } from './ui/input'; +import { Modal } from './ui/modal/modal'; type AvatarUploadProps = { value?: string; @@ -22,14 +24,24 @@ export const AvatarUpload = forwardRef( function AvatarUpload({ value, onChange, tips }, ref) { const { t } = useTranslation(); const [avatarBase64Str, setAvatarBase64Str] = useState(''); // Avatar Image base64 + const [isCropModalOpen, setIsCropModalOpen] = useState(false); + const [imageToCrop, setImageToCrop] = useState(null); + const [cropArea, setCropArea] = useState({ x: 0, y: 0, size: 200 }); + const imageRef = useRef(null); + const canvasRef = useRef(null); + const containerRef = useRef(null); + const isDraggingRef = useRef(false); + const dragStartRef = useRef({ x: 0, y: 0 }); + const [imageScale, setImageScale] = useState(1); + const [imageOffset, setImageOffset] = useState({ x: 0, y: 0 }); const handleChange: ChangeEventHandler = useCallback( async (ev) => { const file = ev.target?.files?.[0]; if (/\.(jpg|jpeg|png|webp|bmp)$/i.test(file?.name ?? '')) { - const str = await transformFile2Base64(file!); - setAvatarBase64Str(str); - onChange?.(str); + const str = await transformFile2Base64(file!, 1000); + setImageToCrop(str); + setIsCropModalOpen(true); } ev.target.value = ''; }, @@ -41,17 +53,209 @@ export const AvatarUpload = forwardRef( onChange?.(''); }, [onChange]); + const handleCrop = useCallback(() => { + if (!imageRef.current || !canvasRef.current) return; + + const canvas = canvasRef.current; + const ctx = canvas.getContext('2d'); + const image = imageRef.current; + + if (!ctx) return; + + // Set canvas size to 64x64 (avatar size) + canvas.width = 64; + canvas.height = 64; + + // Draw cropped image on canvas + ctx.drawImage( + image, + cropArea.x, + cropArea.y, + cropArea.size, + cropArea.size, + 0, + 0, + 64, + 64, + ); + + // Convert to base64 + const croppedImageBase64 = canvas.toDataURL('image/png'); + setAvatarBase64Str(croppedImageBase64); + onChange?.(croppedImageBase64); + setIsCropModalOpen(false); + }, [cropArea, onChange]); + + const handleCancelCrop = useCallback(() => { + setIsCropModalOpen(false); + setImageToCrop(null); + }, []); + + const initCropArea = useCallback(() => { + if (!imageRef.current || !containerRef.current) return; + + const image = imageRef.current; + const container = containerRef.current; + + // Calculate image scale to fit container + const scale = Math.min( + container.clientWidth / image.width, + container.clientHeight / image.height, + ); + setImageScale(scale); + + // Calculate image offset to center it + const scaledWidth = image.width * scale; + const scaledHeight = image.height * scale; + const offsetX = (container.clientWidth - scaledWidth) / 2; + const offsetY = (container.clientHeight - scaledHeight) / 2; + setImageOffset({ x: offsetX, y: offsetY }); + + // Initialize crop area to center of image + const size = Math.min(scaledWidth, scaledHeight) * 0.8; // 80% of the smaller dimension + const x = (image.width - size / scale) / 2; + const y = (image.height - size / scale) / 2; + + setCropArea({ x, y, size: size / scale }); + }, []); + + const handleMouseMove = useCallback( + (e: MouseEvent) => { + if ( + !isDraggingRef.current || + !imageRef.current || + !containerRef.current + ) + return; + + const image = imageRef.current; + const container = containerRef.current; + const containerRect = container.getBoundingClientRect(); + + // Calculate mouse position relative to container + const mouseX = e.clientX - containerRect.left; + const mouseY = e.clientY - containerRect.top; + + // Calculate mouse position relative to image + const imageX = (mouseX - imageOffset.x) / imageScale; + const imageY = (mouseY - imageOffset.y) / imageScale; + + // Calculate new crop area position based on mouse movement + let newX = imageX - dragStartRef.current.x; + let newY = imageY - dragStartRef.current.y; + + // Boundary checks + newX = Math.max(0, Math.min(newX, image.width - cropArea.size)); + newY = Math.max(0, Math.min(newY, image.height - cropArea.size)); + + setCropArea((prev) => ({ + ...prev, + x: newX, + y: newY, + })); + }, + [cropArea.size, imageScale, imageOffset], + ); + + const handleMouseUp = useCallback(() => { + isDraggingRef.current = false; + document.removeEventListener('mousemove', handleMouseMove); + document.removeEventListener('mouseup', handleMouseUp); + }, [handleMouseMove]); + + const handleMouseDown = useCallback( + (e: React.MouseEvent) => { + e.preventDefault(); + e.stopPropagation(); + isDraggingRef.current = true; + if (imageRef.current && containerRef.current) { + const container = containerRef.current; + const containerRect = container.getBoundingClientRect(); + + // Calculate mouse position relative to container + const mouseX = e.clientX - containerRect.left; + const mouseY = e.clientY - containerRect.top; + + // Calculate mouse position relative to image + const imageX = (mouseX - imageOffset.x) / imageScale; + const imageY = (mouseY - imageOffset.y) / imageScale; + + // Store the offset between mouse position and crop area position + dragStartRef.current = { + x: imageX - cropArea.x, + y: imageY - cropArea.y, + }; + } + document.addEventListener('mousemove', handleMouseMove); + document.addEventListener('mouseup', handleMouseUp); + }, + [cropArea, imageScale, imageOffset], + ); + + const handleWheel = useCallback((e: React.WheelEvent) => { + if (!imageRef.current) return; + + e.preventDefault(); + const image = imageRef.current; + const delta = e.deltaY > 0 ? 0.9 : 1.1; // Zoom factor + + setCropArea((prev) => { + const newSize = Math.max( + 20, + Math.min(prev.size * delta, Math.min(image.width, image.height)), + ); + + // Adjust position to keep crop area centered + const centerRatioX = (prev.x + prev.size / 2) / image.width; + const centerRatioY = (prev.y + prev.size / 2) / image.height; + + const newX = centerRatioX * image.width - newSize / 2; + const newY = centerRatioY * image.height - newSize / 2; + + // Boundary checks + const boundedX = Math.max(0, Math.min(newX, image.width - newSize)); + const boundedY = Math.max(0, Math.min(newY, image.height - newSize)); + + return { + x: boundedX, + y: boundedY, + size: newSize, + }; + }); + }, []); + useEffect(() => { if (value) { setAvatarBase64Str(value); } }, [value]); + useEffect(() => { + const container = containerRef.current; + setTimeout(() => { + console.log('container', container); + // initCropArea(); + if (imageToCrop && container && isCropModalOpen) { + container.addEventListener( + 'wheel', + handleWheel as unknown as EventListener, + { passive: false }, + ); + return () => { + container.removeEventListener( + 'wheel', + handleWheel as unknown as EventListener, + ); + }; + } + }, 100); + }, [handleWheel, containerRef.current]); + return (
{!avatarBase64Str ? ( -
+

{t('common.upload')}

@@ -60,7 +264,7 @@ export const AvatarUpload = forwardRef( ) : (
- +
@@ -93,6 +297,79 @@ export const AvatarUpload = forwardRef(
{tips ?? t('knowledgeConfiguration.photoTip')}
+ + {/* Crop Modal */} + { + setIsCropModalOpen(open); + if (!open) { + setImageToCrop(null); + } + }} + title={t('setting.cropImage')} + size="small" + onCancel={handleCancelCrop} + onOk={handleCrop} + // footer={ + //
+ // + // + //
+ // } + > +
+ {imageToCrop && ( +
+
+ To crop + {imageRef.current && ( +
+ )} +
+
+

+ {t('setting.cropTip')} +

+
+ +
+ )} +
+
); }, diff --git a/web/src/components/back-button/index.tsx b/web/src/components/back-button/index.tsx index c790d688280..118042128b6 100644 --- a/web/src/components/back-button/index.tsx +++ b/web/src/components/back-button/index.tsx @@ -29,7 +29,10 @@ const BackButton: React.FC = ({ return (
diff --git a/web/src/components/key-input.tsx b/web/src/components/key-input.tsx index 4c6c2f8228e..788e4cfe3aa 100644 --- a/web/src/components/key-input.tsx +++ b/web/src/components/key-input.tsx @@ -8,7 +8,10 @@ type KeyInputProps = { } & Omit; export const KeyInput = forwardRef( - function KeyInput({ value, onChange, searchValue = /[^a-zA-Z0-9_]/g }, ref) { + function KeyInput( + { value, onChange, searchValue = /[^a-zA-Z0-9_]/g, ...props }, + ref, + ) { const handleChange = useCallback( (e: ChangeEvent) => { const value = e.target.value ?? ''; @@ -18,6 +21,6 @@ export const KeyInput = forwardRef( [onChange, searchValue], ); - return ; + return ; }, ); diff --git a/web/src/components/knowledge-base-item.tsx b/web/src/components/knowledge-base-item.tsx index cb907df5616..e3958853a17 100644 --- a/web/src/components/knowledge-base-item.tsx +++ b/web/src/components/knowledge-base-item.tsx @@ -1,77 +1,70 @@ import { DocumentParserType } from '@/constants/knowledge'; -import { useTranslate } from '@/hooks/common-hooks'; -import { useFetchKnowledgeList } from '@/hooks/knowledge-hooks'; +import { useFetchKnowledgeList } from '@/hooks/use-knowledge-request'; +import { IKnowledge } from '@/interfaces/database/knowledge'; import { useBuildQueryVariableOptions } from '@/pages/agent/hooks/use-get-begin-query'; -import { UserOutlined } from '@ant-design/icons'; -import { Avatar as AntAvatar, Form, Select, Space } from 'antd'; import { toLower } from 'lodash'; -import { useMemo } from 'react'; +import { useEffect, useMemo, useState } from 'react'; import { useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { RAGFlowAvatar } from './ragflow-avatar'; import { FormControl, FormField, FormItem, FormLabel } from './ui/form'; -import { MultiSelect } from './ui/multi-select'; +import { MultiSelect, MultiSelectOptionType } from './ui/multi-select'; -interface KnowledgeBaseItemProps { - label?: string; - tooltipText?: string; - name?: string; - required?: boolean; - onChange?(): void; +function buildQueryVariableOptionsByShowVariable(showVariable?: boolean) { + return showVariable ? useBuildQueryVariableOptions : () => []; } -const KnowledgeBaseItem = ({ - label, - tooltipText, - name, - required = true, - onChange, -}: KnowledgeBaseItemProps) => { - const { t } = useTranslate('chat'); - - const { list: knowledgeList } = useFetchKnowledgeList(true); - - const filteredKnowledgeList = knowledgeList.filter( - (x) => x.parser_id !== DocumentParserType.Tag, - ); - - const knowledgeOptions = filteredKnowledgeList.map((x) => ({ - label: ( - - } src={x.avatar} /> - {x.name} - - ), - value: x.id, - })); - - return ( - - - +export function useDisableDifferenceEmbeddingDataset() { + const [datasetOptions, setDatasetOptions] = useState( + [], ); -}; - -export default KnowledgeBaseItem; + const [datasetSelectEmbedId, setDatasetSelectEmbedId] = useState(''); + const { list: datasetListOrigin } = useFetchKnowledgeList(true); + + useEffect(() => { + const datasetListMap = datasetListOrigin + .filter((x) => x.parser_id !== DocumentParserType.Tag) + .map((item: IKnowledge) => { + return { + label: item.name, + icon: () => ( + + ), + suffix: ( +
+ {item.embd_id} +
+ ), + value: item.id, + disabled: + item.embd_id !== datasetSelectEmbedId && + datasetSelectEmbedId !== '', + }; + }); + setDatasetOptions(datasetListMap); + }, [datasetListOrigin, datasetSelectEmbedId]); + + const handleDatasetSelectChange = ( + value: string[], + onChange: (value: string[]) => void, + ) => { + if (value.length) { + const data = datasetListOrigin?.find((item) => item.id === value[0]); + setDatasetSelectEmbedId(data?.embd_id ?? ''); + } else { + setDatasetSelectEmbedId(''); + } + onChange?.(value); + }; -function buildQueryVariableOptionsByShowVariable(showVariable?: boolean) { - return showVariable ? useBuildQueryVariableOptions : () => []; + return { + datasetOptions, + handleDatasetSelectChange, + }; } export function KnowledgeBaseFormField({ @@ -82,22 +75,12 @@ export function KnowledgeBaseFormField({ const form = useFormContext(); const { t } = useTranslation(); - const { list: knowledgeList } = useFetchKnowledgeList(true); - - const filteredKnowledgeList = knowledgeList.filter( - (x) => x.parser_id !== DocumentParserType.Tag, - ); + const { datasetOptions, handleDatasetSelectChange } = + useDisableDifferenceEmbeddingDataset(); const nextOptions = buildQueryVariableOptionsByShowVariable(showVariable)(); - const knowledgeOptions = filteredKnowledgeList.map((x) => ({ - label: x.name, - value: x.id, - icon: () => ( - - ), - })); - + const knowledgeOptions = datasetOptions; const options = useMemo(() => { if (showVariable) { return [ @@ -140,11 +123,14 @@ export function KnowledgeBaseFormField({ { + handleDatasetSelectChange(value, field.onChange); + }} placeholder={t('chat.knowledgeBasesMessage')} variant="inverted" maxCount={100} defaultValue={field.value} + showSelectAll={false} {...field} /> diff --git a/web/src/components/large-model-form-field.tsx b/web/src/components/large-model-form-field.tsx index e805eb5413a..0e266258c34 100644 --- a/web/src/components/large-model-form-field.tsx +++ b/web/src/components/large-model-form-field.tsx @@ -68,7 +68,7 @@ export function LargeModelFormField({ - + diff --git a/web/src/components/layout-recognize-form-field.tsx b/web/src/components/layout-recognize-form-field.tsx index 43f0abccb88..965eee83356 100644 --- a/web/src/components/layout-recognize-form-field.tsx +++ b/web/src/components/layout-recognize-form-field.tsx @@ -1,10 +1,11 @@ import { LlmModelType } from '@/constants/knowledge'; import { useTranslate } from '@/hooks/common-hooks'; -import { useSelectLlmOptionsByModelType } from '@/hooks/llm-hooks'; +import { useSelectLlmOptionsByModelType } from '@/hooks/use-llm-request'; import { cn } from '@/lib/utils'; import { camelCase } from 'lodash'; import { ReactNode, useMemo } from 'react'; import { useFormContext } from 'react-hook-form'; +import { MinerUOptionsFormField } from './mineru-options-form-field'; import { SelectWithSearch } from './originui/select-with-search'; import { FormControl, @@ -17,7 +18,6 @@ import { export const enum ParseDocumentType { DeepDOC = 'DeepDOC', PlainText = 'Plain Text', - MinerU = 'MinerU', Docling = 'Docling', TCADPParser = 'TCADP Parser', } @@ -27,11 +27,13 @@ export function LayoutRecognizeFormField({ horizontal = true, optionsWithoutLLM, label, + showMineruOptions = true, }: { name?: string; horizontal?: boolean; optionsWithoutLLM?: { value: string; label: string }[]; label?: ReactNode; + showMineruOptions?: boolean; }) { const form = useFormContext(); @@ -44,7 +46,6 @@ export function LayoutRecognizeFormField({ : [ ParseDocumentType.DeepDOC, ParseDocumentType.PlainText, - ParseDocumentType.MinerU, ParseDocumentType.Docling, ParseDocumentType.TCADPParser, ].map((x) => ({ @@ -52,7 +53,10 @@ export function LayoutRecognizeFormField({ value: x, })); - const image2TextList = allOptions[LlmModelType.Image2text].map((x) => { + const image2TextList = [ + ...allOptions[LlmModelType.Image2text], + ...allOptions[LlmModelType.Ocr], + ].map((x) => { return { ...x, options: x.options.map((y) => { @@ -78,35 +82,38 @@ export function LayoutRecognizeFormField({ name={name} render={({ field }) => { return ( - -
- + +
- {label || t('layoutRecognize')} - -
- - - + + {label || t('layoutRecognize')} + +
+ + + +
-
-
-
- -
-
+
+
+ +
+ + {showMineruOptions && } + ); }} /> diff --git a/web/src/components/layout-recognize.tsx b/web/src/components/layout-recognize.tsx deleted file mode 100644 index b5642c37fb8..00000000000 --- a/web/src/components/layout-recognize.tsx +++ /dev/null @@ -1,55 +0,0 @@ -import { LlmModelType } from '@/constants/knowledge'; -import { useTranslate } from '@/hooks/common-hooks'; -import { useSelectLlmOptionsByModelType } from '@/hooks/llm-hooks'; -import { Form, Select } from 'antd'; -import { camelCase } from 'lodash'; -import { useMemo } from 'react'; - -const enum DocumentType { - DeepDOC = 'DeepDOC', - PlainText = 'Plain Text', -} - -const LayoutRecognize = () => { - const { t } = useTranslate('knowledgeDetails'); - const allOptions = useSelectLlmOptionsByModelType(); - - const options = useMemo(() => { - const list = [DocumentType.DeepDOC, DocumentType.PlainText].map((x) => ({ - label: x === DocumentType.PlainText ? t(camelCase(x)) : 'DeepDoc', - value: x, - })); - - const image2TextList = allOptions[LlmModelType.Image2text].map((x) => { - return { - ...x, - options: x.options.map((y) => { - return { - ...y, - label: ( -
- {y.label} - Experimental -
- ), - }; - }), - }; - }); - - return [...list, ...image2TextList]; - }, [allOptions, t]); - - return ( - - ( - - {option.label} - {option.data.description} - - )} - onChange={onChange} - value={value} - disabled={disabled} - > - ); -}; - -export default LLMToolsSelect; diff --git a/web/src/components/logical-operator.tsx b/web/src/components/logical-operator.tsx new file mode 100644 index 00000000000..7b37a256760 --- /dev/null +++ b/web/src/components/logical-operator.tsx @@ -0,0 +1,24 @@ +import { useBuildSwitchLogicOperatorOptions } from '@/hooks/logic-hooks/use-build-options'; +import { RAGFlowFormItem } from './ragflow-form'; +import { RAGFlowSelect } from './ui/select'; + +type LogicalOperatorProps = { name: string }; + +export function LogicalOperator({ name }: LogicalOperatorProps) { + const switchLogicOperatorOptions = useBuildSwitchLogicOperatorOptions(); + + return ( +
+ + + +
+
+ ); +} diff --git a/web/src/components/markdown-content/image-carousel.tsx b/web/src/components/markdown-content/image-carousel.tsx new file mode 100644 index 00000000000..f1b000da15b --- /dev/null +++ b/web/src/components/markdown-content/image-carousel.tsx @@ -0,0 +1,139 @@ +import Image from '@/components/image'; +import { + Carousel, + CarouselContent, + CarouselItem, + CarouselNext, + CarouselPrevious, +} from '@/components/ui/carousel'; +import { IReference, IReferenceChunk } from '@/interfaces/database/chat'; +import { getExtension } from '@/utils/document-util'; +import { useCallback } from 'react'; + +interface ImageCarouselProps { + group: Array<{ + id: string; + fullMatch: string; + start: number; + }>; + reference: IReference; + fileThumbnails: Record; + onImageClick: ( + documentId: string, + chunk: IReferenceChunk, + isPdf: boolean, + documentUrl?: string, + ) => void; +} + +interface ReferenceInfo { + documentUrl?: string; + fileThumbnail?: string; + fileExtension?: string; + imageId?: string; + chunkItem?: IReferenceChunk; + documentId?: string; + document?: any; +} + +const getReferenceInfo = ( + chunkIndex: number, + reference: IReference, + fileThumbnails: Record, +): ReferenceInfo => { + const chunks = reference?.chunks ?? []; + const chunkItem = chunks[chunkIndex]; + const document = reference?.doc_aggs?.find( + (x) => x?.doc_id === chunkItem?.document_id, + ); + const documentId = document?.doc_id; + const documentUrl = document?.url; + const fileThumbnail = documentId ? fileThumbnails[documentId] : ''; + const fileExtension = documentId ? getExtension(document?.doc_name) : ''; + const imageId = chunkItem?.image_id; + + return { + documentUrl, + fileThumbnail, + fileExtension, + imageId, + chunkItem, + documentId, + document, + }; +}; + +/** + * Component to render image carousel for a group of consecutive image references + */ +export const ImageCarousel = ({ + group, + reference, + fileThumbnails, + onImageClick, +}: ImageCarouselProps) => { + const getChunkIndex = (match: string) => Number(match); + + const handleImageClick = useCallback( + ( + imageId: string, + chunkItem: IReferenceChunk, + documentId: string, + fileExtension: string, + documentUrl?: string, + ) => + () => + onImageClick( + documentId, + chunkItem, + fileExtension === 'pdf', + documentUrl, + ), + [onImageClick], + ); + + return ( + + + {group.map((ref) => { + const chunkIndex = getChunkIndex(ref.id); + const { documentUrl, fileExtension, imageId, chunkItem, documentId } = + getReferenceInfo(chunkIndex, reference, fileThumbnails); + + return ( + +
+ {} + } + /> + {imageId} +
+
+ ); + })} +
+ + +
+ ); +}; + +export default ImageCarousel; diff --git a/web/src/pages/chat/markdown-content/index.less b/web/src/components/markdown-content/index.less similarity index 98% rename from web/src/pages/chat/markdown-content/index.less rename to web/src/components/markdown-content/index.less index 3a26fa4bf70..2fa7f92f1fa 100644 --- a/web/src/pages/chat/markdown-content/index.less +++ b/web/src/components/markdown-content/index.less @@ -25,7 +25,7 @@ display: block; object-fit: contain; max-width: 100%; - max-height: 6vh; + max-height: 10vh; } .referenceImagePreview { diff --git a/web/src/pages/chat/markdown-content/index.tsx b/web/src/components/markdown-content/index.tsx similarity index 80% rename from web/src/pages/chat/markdown-content/index.tsx rename to web/src/components/markdown-content/index.tsx index adc4f15c8ca..6e93bf134bd 100644 --- a/web/src/pages/chat/markdown-content/index.tsx +++ b/web/src/components/markdown-content/index.tsx @@ -2,12 +2,9 @@ import Image from '@/components/image'; import SvgIcon from '@/components/svg-icon'; import { IReference, IReferenceChunk } from '@/interfaces/database/chat'; import { getExtension } from '@/utils/document-util'; -import { InfoCircleOutlined } from '@ant-design/icons'; -import { Button, Flex, Popover } from 'antd'; import DOMPurify from 'dompurify'; import { useCallback, useEffect, useMemo } from 'react'; import Markdown from 'react-markdown'; -import reactStringReplace from 'react-string-replace'; import SyntaxHighlighter from 'react-syntax-highlighter'; import rehypeKatex from 'rehype-katex'; import rehypeRaw from 'rehype-raw'; @@ -15,24 +12,31 @@ import remarkGfm from 'remark-gfm'; import remarkMath from 'remark-math'; import { visitParents } from 'unist-util-visit-parents'; -import { useFetchDocumentThumbnailsByIds } from '@/hooks/document-hooks'; import { useTranslation } from 'react-i18next'; import 'katex/dist/katex.min.css'; // `rehype-katex` does not import the CSS for you +import { useFetchDocumentThumbnailsByIds } from '@/hooks/use-document-request'; import { + currentReg, preprocessLaTeX, + replaceTextByOldReg, replaceThinkToSection, - showImage, } from '@/utils/chat'; -import { currentReg, replaceTextByOldReg } from '../utils'; - import classNames from 'classnames'; import { omit } from 'lodash'; import { pipe } from 'lodash/fp'; +import reactStringReplace from 'react-string-replace'; +import { Button } from '../ui/button'; +import { + HoverCard, + HoverCardContent, + HoverCardTrigger, +} from '../ui/hover-card'; import styles from './index.less'; const getChunkIndex = (match: string) => Number(match); + // TODO: The display of the table is inconsistent with the display previously placed in the MessageItem. const MarkdownContent = ({ reference, @@ -144,20 +148,20 @@ const MarkdownContent = ({ return (
{imageId && ( - + + + + - } - > - - + + )}
{documentId && ( - +
{fileThumbnail ? ( )} - +
)}
@@ -206,40 +210,23 @@ const MarkdownContent = ({ let replacedText = reactStringReplace(text, currentReg, (match, i) => { const chunkIndex = getChunkIndex(match); - const { documentUrl, fileExtension, imageId, chunkItem, documentId } = - getReferenceInfo(chunkIndex); - - const docType = chunkItem?.doc_type; - - return showImage(docType) ? ( - {} - } - > - ) : ( - - - + return ( + + + + Fig. {chunkIndex + 1} + + + + {getPopoverContent(chunkIndex)} + + ); }); - // replacedText = reactStringReplace(replacedText, curReg, (match, i) => ( - // - // )); - return replacedText; }, - [getPopoverContent, getReferenceInfo, handleDocumentButtonClick], + [getPopoverContent], ); return ( diff --git a/web/src/components/markdown-content/reference-utils.ts b/web/src/components/markdown-content/reference-utils.ts new file mode 100644 index 00000000000..ffc80fbf4f3 --- /dev/null +++ b/web/src/components/markdown-content/reference-utils.ts @@ -0,0 +1,67 @@ +import { IReference } from '@/interfaces/database/chat'; +import { currentReg, showImage } from '@/utils/chat'; + +export interface ReferenceMatch { + id: string; + fullMatch: string; + start: number; + end: number; +} + +export type ReferenceGroup = ReferenceMatch[]; + +export const findAllReferenceMatches = (text: string): ReferenceMatch[] => { + const matches: ReferenceMatch[] = []; + let match; + while ((match = currentReg.exec(text)) !== null) { + matches.push({ + id: match[1], + fullMatch: match[0], + start: match.index, + end: match.index + match[0].length, + }); + } + return matches; +}; + +/** + * Helper to group consecutive references + */ +export const groupConsecutiveReferences = (text: string): ReferenceGroup[] => { + const matches = findAllReferenceMatches(text); + // Construct a two-dimensional array to distinguish whether images are continuous. + const groups: ReferenceGroup[] = []; + + if (matches.length === 0) return groups; + + let currentGroup: ReferenceGroup = [matches[0]]; + // A group with only one element contains non-contiguous images, + // while a group with multiple elements contains contiguous images. + for (let i = 1; i < matches.length; i++) { + // If the end of the previous element equals the start of the current element, + // it means that they are consecutive images. + if (matches[i].start === currentGroup[currentGroup.length - 1].end) { + currentGroup.push(matches[i]); + } else { + // Save current group and start a new one + groups.push(currentGroup); + currentGroup = [matches[i]]; + } + } + groups.push(currentGroup); + + return groups; +}; + +export const shouldShowCarousel = ( + group: ReferenceGroup, + reference: IReference, +): boolean => { + if (group.length < 2) return false; // Need at least 2 images for carousel + + return group.every((ref) => { + const chunkIndex = Number(ref.id); + const chunk = reference.chunks[chunkIndex]; + return chunk && showImage(chunk.doc_type); + }); +}; diff --git a/web/src/components/max-token-number.tsx b/web/src/components/max-token-number.tsx deleted file mode 100644 index c64b40fac6b..00000000000 --- a/web/src/components/max-token-number.tsx +++ /dev/null @@ -1,37 +0,0 @@ -import { useTranslate } from '@/hooks/common-hooks'; -import { Flex, Form, InputNumber, Slider } from 'antd'; - -interface IProps { - initialValue?: number; - max?: number; -} - -const MaxTokenNumber = ({ initialValue = 512, max = 2048 }: IProps) => { - const { t } = useTranslate('knowledgeConfiguration'); - - return ( - - - - - - - - - - - - - ); -}; - -export default MaxTokenNumber; diff --git a/web/src/components/memories-form-field.tsx b/web/src/components/memories-form-field.tsx new file mode 100644 index 00000000000..2d04a492158 --- /dev/null +++ b/web/src/components/memories-form-field.tsx @@ -0,0 +1,33 @@ +import { useFetchAllMemoryList } from '@/hooks/use-memory-request'; +import { useTranslation } from 'react-i18next'; +import { RAGFlowFormItem } from './ragflow-form'; +import { MultiSelect } from './ui/multi-select'; + +type MemoriesFormFieldProps = { + label: string; +}; + +export function MemoriesFormField({ label }: MemoriesFormFieldProps) { + const { t } = useTranslation(); + const memoryList = useFetchAllMemoryList(); + + const options = memoryList.data?.map((memory) => ({ + label: memory.name, + value: memory.id, + })); + + return ( + + {(field) => ( + + )} + + ); +} diff --git a/web/src/components/message-history-window-size-item.tsx b/web/src/components/message-history-window-size-item.tsx index e717076e540..69df072b14d 100644 --- a/web/src/components/message-history-window-size-item.tsx +++ b/web/src/components/message-history-window-size-item.tsx @@ -1,4 +1,3 @@ -import { Form, InputNumber } from 'antd'; import { useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { @@ -10,27 +9,6 @@ import { } from './ui/form'; import { NumberInput } from './ui/input'; -const MessageHistoryWindowSizeItem = ({ - initialValue, -}: { - initialValue: number; -}) => { - const { t } = useTranslation(); - - return ( - - - - ); -}; - -export default MessageHistoryWindowSizeItem; - export function MessageHistoryWindowSizeFormField() { const form = useFormContext(); const { t } = useTranslation(); diff --git a/web/src/components/message-input/index.tsx b/web/src/components/message-input/index.tsx deleted file mode 100644 index 95a4ee195ea..00000000000 --- a/web/src/components/message-input/index.tsx +++ /dev/null @@ -1,372 +0,0 @@ -import { useTranslate } from '@/hooks/common-hooks'; -import { - useDeleteDocument, - useFetchDocumentInfosByIds, - useRemoveNextDocument, - useUploadAndParseDocument, -} from '@/hooks/document-hooks'; -import { cn } from '@/lib/utils'; -import { getExtension } from '@/utils/document-util'; -import { formatBytes } from '@/utils/file-util'; -import { - CloseCircleOutlined, - InfoCircleOutlined, - LoadingOutlined, -} from '@ant-design/icons'; -import type { GetProp, UploadFile } from 'antd'; -import { - Button, - Card, - Divider, - Flex, - Input, - List, - Space, - Spin, - Typography, - Upload, - UploadProps, -} from 'antd'; -import get from 'lodash/get'; -import { CircleStop, Paperclip, SendHorizontal } from 'lucide-react'; -import { - ChangeEventHandler, - memo, - useCallback, - useEffect, - useRef, - useState, -} from 'react'; -import FileIcon from '../file-icon'; -import styles from './index.less'; - -type FileType = Parameters>[0]; -const { Text } = Typography; - -const { TextArea } = Input; - -const getFileId = (file: UploadFile) => get(file, 'response.data.0'); - -const getFileIds = (fileList: UploadFile[]) => { - const ids = fileList.reduce((pre, cur) => { - return pre.concat(get(cur, 'response.data', [])); - }, []); - - return ids; -}; - -const isUploadSuccess = (file: UploadFile) => { - const code = get(file, 'response.code'); - return typeof code === 'number' && code === 0; -}; - -interface IProps { - disabled: boolean; - value: string; - sendDisabled: boolean; - sendLoading: boolean; - onPressEnter(documentIds: string[]): void; - onInputChange: ChangeEventHandler; - conversationId: string; - uploadMethod?: string; - isShared?: boolean; - showUploadIcon?: boolean; - createConversationBeforeUploadDocument?(message: string): Promise; - stopOutputMessage?(): void; -} - -const getBase64 = (file: FileType): Promise => - new Promise((resolve, reject) => { - const reader = new FileReader(); - reader.readAsDataURL(file as any); - reader.onload = () => resolve(reader.result as string); - reader.onerror = (error) => reject(error); - }); - -const MessageInput = ({ - isShared = false, - disabled, - value, - onPressEnter, - sendDisabled, - sendLoading, - onInputChange, - conversationId, - showUploadIcon = true, - createConversationBeforeUploadDocument, - uploadMethod = 'upload_and_parse', - stopOutputMessage, -}: IProps) => { - const { t } = useTranslate('chat'); - const { removeDocument } = useRemoveNextDocument(); - const { deleteDocument } = useDeleteDocument(); - const { data: documentInfos, setDocumentIds } = useFetchDocumentInfosByIds(); - const { uploadAndParseDocument } = useUploadAndParseDocument(uploadMethod); - const conversationIdRef = useRef(conversationId); - - const [fileList, setFileList] = useState([]); - - const handlePreview = async (file: UploadFile) => { - if (!file.url && !file.preview) { - file.preview = await getBase64(file.originFileObj as FileType); - } - }; - - const handleChange: UploadProps['onChange'] = async ({ - // fileList: newFileList, - file, - }) => { - let nextConversationId: string = conversationId; - if (createConversationBeforeUploadDocument) { - const creatingRet = await createConversationBeforeUploadDocument( - file.name, - ); - if (creatingRet?.code === 0) { - nextConversationId = creatingRet.data.id; - } - } - setFileList((list) => { - list.push({ - ...file, - status: 'uploading', - originFileObj: file as any, - }); - return [...list]; - }); - const ret = await uploadAndParseDocument({ - conversationId: nextConversationId, - fileList: [file], - }); - setFileList((list) => { - const nextList = list.filter((x) => x.uid !== file.uid); - nextList.push({ - ...file, - originFileObj: file as any, - response: ret, - percent: 100, - status: ret?.code === 0 ? 'done' : 'error', - }); - return nextList; - }); - }; - - const isUploadingFile = fileList.some((x) => x.status === 'uploading'); - - const handlePressEnter = useCallback(async () => { - if (isUploadingFile) return; - const ids = getFileIds(fileList.filter((x) => isUploadSuccess(x))); - - onPressEnter(ids); - setFileList([]); - }, [fileList, onPressEnter, isUploadingFile]); - - const handleKeyDown = useCallback( - async (event: React.KeyboardEvent) => { - // check if it was shift + enter - if (event.key === 'Enter' && event.shiftKey) return; - if (event.key !== 'Enter') return; - if (sendDisabled || isUploadingFile || sendLoading) return; - - event.preventDefault(); - handlePressEnter(); - }, - [sendDisabled, isUploadingFile, sendLoading, handlePressEnter], - ); - - const handleRemove = useCallback( - async (file: UploadFile) => { - const ids = get(file, 'response.data', []); - // Upload Successfully - if (Array.isArray(ids) && ids.length) { - if (isShared) { - await deleteDocument(ids); - } else { - await removeDocument(ids[0]); - } - setFileList((preList) => { - return preList.filter((x) => getFileId(x) !== ids[0]); - }); - } else { - // Upload failed - setFileList((preList) => { - return preList.filter((x) => x.uid !== file.uid); - }); - } - }, - [removeDocument, deleteDocument, isShared], - ); - - const handleStopOutputMessage = useCallback(() => { - stopOutputMessage?.(); - }, [stopOutputMessage]); - - const getDocumentInfoById = useCallback( - (id: string) => { - return documentInfos.find((x) => x.id === id); - }, - [documentInfos], - ); - - useEffect(() => { - const ids = getFileIds(fileList); - setDocumentIds(ids); - }, [fileList, setDocumentIds]); - - useEffect(() => { - if ( - conversationIdRef.current && - conversationId !== conversationIdRef.current - ) { - setFileList([]); - } - conversationIdRef.current = conversationId; - }, [conversationId, setFileList]); - - return ( - - ; + } + + return ( + + ); + }, + [form, isDarkTheme], + ); + + return ( +
+ + append({ + [keyField]: '', + [valueField]: '', + [operatorField]: TypesWithArray.String, + }) + } + > +
+ {fields.map((field, index) => { + const keyFieldAlias = `${name}.${index}.${keyField}`; + const valueFieldAlias = `${name}.${index}.${valueField}`; + const operatorFieldAlias = `${name}.${index}.${operatorField}`; + + return ( +
+
+
+ + + + + + {(field) => ( + { + handleVariableTypeChange(val, valueFieldAlias); + field.onChange(val); + }} + options={VariableTypeOptions} + > + )} + + +
+ + {renderParameter(operatorFieldAlias)} + +
+ + +
+ ); + })} +
+
+ ); +} diff --git a/web/src/pages/agent/form/begin-form/webhook/index.tsx b/web/src/pages/agent/form/begin-form/webhook/index.tsx new file mode 100644 index 00000000000..36c0f940407 --- /dev/null +++ b/web/src/pages/agent/form/begin-form/webhook/index.tsx @@ -0,0 +1,138 @@ +import { Collapse } from '@/components/collapse'; +import { CopyToClipboardWithText } from '@/components/copy-to-clipboard'; +import NumberInput from '@/components/originui/number-input'; +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Label } from '@/components/ui/label'; +import { MultiSelect } from '@/components/ui/multi-select'; +import { Separator } from '@/components/ui/separator'; +import { Textarea } from '@/components/ui/textarea'; +import { useBuildWebhookUrl } from '@/pages/agent/hooks/use-build-webhook-url'; +import { buildOptions } from '@/utils/form'; +import { upperFirst } from 'lodash'; +import { useCallback } from 'react'; +import { useFormContext, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { + RateLimitPerList, + WebhookMaxBodySize, + WebhookMethod, + WebhookRateLimitPer, + WebhookSecurityAuthType, +} from '../../../constant'; +import { DynamicStringForm } from '../../components/dynamic-string-form'; +import { Auth } from './auth'; +import { WebhookRequestSchema } from './request-schema'; +import { WebhookResponse } from './response'; + +const RateLimitPerOptions = RateLimitPerList.map((x) => ({ + value: x, + label: upperFirst(x), +})); + +const RequestLimitMap = { + [WebhookRateLimitPer.Second]: 100, + [WebhookRateLimitPer.Minute]: 1000, + [WebhookRateLimitPer.Hour]: 10000, + [WebhookRateLimitPer.Day]: 100000, +}; + +export function WebHook() { + const { t } = useTranslation(); + const form = useFormContext(); + + const rateLimitPer = useWatch({ + name: 'security.rate_limit.per', + control: form.control, + }); + + const getLimitRateLimitPerMax = useCallback((rateLimitPer: string) => { + return RequestLimitMap[rateLimitPer as keyof typeof RequestLimitMap] ?? 100; + }, []); + + const text = useBuildWebhookUrl(); + + return ( + <> + + + {(field) => ( + + )} + + + Security
}> +
+ + + + +
+ +
+ + + + + + {(field) => ( + { + field.onChange(val); + form.setValue( + 'security.rate_limit.limit', + getLimitRateLimitPerMax(val), + ); + }} + > + )} + +
+
+ + + + +
+ + + + + + + + + ); +} diff --git a/web/src/pages/agent/form/begin-form/webhook/request-schema.tsx b/web/src/pages/agent/form/begin-form/webhook/request-schema.tsx new file mode 100644 index 00000000000..efd8cb436bd --- /dev/null +++ b/web/src/pages/agent/form/begin-form/webhook/request-schema.tsx @@ -0,0 +1,72 @@ +import { Collapse } from '@/components/collapse'; +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { + WebhookContentType, + WebhookRequestParameters, +} from '@/pages/agent/constant'; +import { buildOptions } from '@/utils/form'; +import { useMemo } from 'react'; +import { useFormContext, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { DynamicRequest } from './dynamic-request'; + +export function WebhookRequestSchema() { + const { t } = useTranslation(); + const form = useFormContext(); + const contentType = useWatch({ + name: 'content_types', + control: form.control, + }); + const isFormDataContentType = + contentType === WebhookContentType.MultipartFormData; + + const bodyOperatorList = useMemo(() => { + return isFormDataContentType + ? [ + WebhookRequestParameters.String, + WebhookRequestParameters.Number, + WebhookRequestParameters.Boolean, + WebhookRequestParameters.File, + ] + : [ + WebhookRequestParameters.String, + WebhookRequestParameters.Number, + WebhookRequestParameters.Boolean, + ]; + }, [isFormDataContentType]); + + return ( + {t('flow.webhook.schema')}
}> +
+ + + + + + +
+ + ); +} diff --git a/web/src/pages/agent/form/begin-form/webhook/response.tsx b/web/src/pages/agent/form/begin-form/webhook/response.tsx new file mode 100644 index 00000000000..a1a26b4f940 --- /dev/null +++ b/web/src/pages/agent/form/begin-form/webhook/response.tsx @@ -0,0 +1,57 @@ +import { Collapse } from '@/components/collapse'; +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Textarea } from '@/components/ui/textarea'; +import { WebHookResponseStatusFormField } from '@/components/webhook-response-status'; +import { WebhookExecutionMode } from '@/pages/agent/constant'; +import { buildOptions } from '@/utils/form'; +import { useFormContext, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +export function WebhookResponse() { + const { t } = useTranslation(); + + const form = useFormContext(); + + const executionMode = useWatch({ + control: form.control, + name: 'execution_mode', + }); + + return ( + Response
}> +
+ + + + {executionMode === WebhookExecutionMode.Immediately && ( + <> + + {/* */} + {/* */} + + + + + )} +
+ + ); +} diff --git a/web/src/pages/agent/form/components/dynamic-fom-header.tsx b/web/src/pages/agent/form/components/dynamic-fom-header.tsx index 66ff5d1a636..ecc24560c90 100644 --- a/web/src/pages/agent/form/components/dynamic-fom-header.tsx +++ b/web/src/pages/agent/form/components/dynamic-fom-header.tsx @@ -7,17 +7,24 @@ export type FormListHeaderProps = { label: ReactNode; tooltip?: string; onClick?: () => void; + disabled?: boolean; }; export function DynamicFormHeader({ label, tooltip, onClick, + disabled = false, }: FormListHeaderProps) { return (
{label} -
diff --git a/web/src/pages/agent/form/components/dynamic-input-variable.tsx b/web/src/pages/agent/form/components/dynamic-input-variable.tsx deleted file mode 100644 index a5781fd16f9..00000000000 --- a/web/src/pages/agent/form/components/dynamic-input-variable.tsx +++ /dev/null @@ -1,127 +0,0 @@ -import { RAGFlowNodeType } from '@/interfaces/database/flow'; -import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; -import { Button, Collapse, Flex, Form, Input, Select } from 'antd'; -import { PropsWithChildren, useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; -import { useBuildVariableOptions } from '../../hooks/use-get-begin-query'; - -import styles from './index.less'; - -interface IProps { - node?: RAGFlowNodeType; -} - -enum VariableType { - Reference = 'reference', - Input = 'input', -} - -const getVariableName = (type: string) => - type === VariableType.Reference ? 'component_id' : 'value'; - -const DynamicVariableForm = ({ node }: IProps) => { - const { t } = useTranslation(); - const valueOptions = useBuildVariableOptions(node?.id, node?.parentId); - const form = Form.useFormInstance(); - - const options = [ - { value: VariableType.Reference, label: t('flow.reference') }, - { value: VariableType.Input, label: t('flow.text') }, - ]; - - const handleTypeChange = useCallback( - (name: number) => () => { - setTimeout(() => { - form.setFieldValue(['query', name, 'component_id'], undefined); - form.setFieldValue(['query', name, 'value'], undefined); - }, 0); - }, - [form], - ); - - return ( - - {(fields, { add, remove }) => ( - <> - {fields.map(({ key, name, ...restField }) => ( - - - - - - {({ getFieldValue }) => { - const type = getFieldValue(['query', name, 'type']); - return ( - - {type === VariableType.Reference ? ( - - ) : ( - - )} - - ); - }} - - remove(name)} /> - - ))} - - - - - )} - - ); -}; - -export function FormCollapse({ - children, - title, -}: PropsWithChildren<{ title: string }>) { - return ( - {title}, - children, - }, - ]} - /> - ); -} - -const DynamicInputVariable = ({ node }: IProps) => { - const { t } = useTranslation(); - return ( - - - - ); -}; - -export default DynamicInputVariable; diff --git a/web/src/pages/agent/form/components/dynamic-string-form.tsx b/web/src/pages/agent/form/components/dynamic-string-form.tsx new file mode 100644 index 00000000000..224e923108b --- /dev/null +++ b/web/src/pages/agent/form/components/dynamic-string-form.tsx @@ -0,0 +1,46 @@ +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Trash2 } from 'lucide-react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { DynamicFormHeader, FormListHeaderProps } from './dynamic-fom-header'; + +type DynamicStringFormProps = { name: string } & FormListHeaderProps; +export function DynamicStringForm({ name, label }: DynamicStringFormProps) { + const form = useFormContext(); + + const { fields, append, remove } = useFieldArray({ + name: name, + control: form.control, + }); + + return ( +
+ append({ value: '' })} + > +
+ {fields.map((field, index) => ( +
+ + + + +
+ ))} +
+
+ ); +} diff --git a/web/src/pages/agent/form/components/json-viewer.tsx b/web/src/pages/agent/form/components/json-viewer.tsx new file mode 100644 index 00000000000..3a6d1b558c2 --- /dev/null +++ b/web/src/pages/agent/form/components/json-viewer.tsx @@ -0,0 +1,21 @@ +import JsonView from 'react18-json-view'; + +export function JsonViewer({ + data, + title, +}: { + data: Record; + title: string; +}) { + return ( +
+
{title}
+ +
+ ); +} diff --git a/web/src/pages/agent/form/components/next-dynamic-input-variable.tsx b/web/src/pages/agent/form/components/next-dynamic-input-variable.tsx deleted file mode 100644 index 8b4cbd8a95b..00000000000 --- a/web/src/pages/agent/form/components/next-dynamic-input-variable.tsx +++ /dev/null @@ -1,135 +0,0 @@ -'use client'; - -import { SideDown } from '@/assets/icon/next-icon'; -import { Button } from '@/components/ui/button'; -import { - Collapsible, - CollapsibleContent, - CollapsibleTrigger, -} from '@/components/ui/collapsible'; -import { - FormControl, - FormDescription, - FormField, - FormItem, - FormMessage, -} from '@/components/ui/form'; -import { Input } from '@/components/ui/input'; -import { RAGFlowSelect } from '@/components/ui/select'; -import { RAGFlowNodeType } from '@/interfaces/database/flow'; -import { Plus, Trash2 } from 'lucide-react'; -import { useFieldArray, useFormContext } from 'react-hook-form'; -import { useTranslation } from 'react-i18next'; -import { useBuildVariableOptions } from '../../hooks/use-get-begin-query'; - -interface IProps { - node?: RAGFlowNodeType; -} - -enum VariableType { - Reference = 'reference', - Input = 'input', -} - -const getVariableName = (type: string) => - type === VariableType.Reference ? 'component_id' : 'value'; - -export function DynamicVariableForm({ node }: IProps) { - const { t } = useTranslation(); - const form = useFormContext(); - const { fields, remove, append } = useFieldArray({ - name: 'query', - control: form.control, - }); - - const valueOptions = useBuildVariableOptions(node?.id, node?.parentId); - - const options = [ - { value: VariableType.Reference, label: t('flow.reference') }, - { value: VariableType.Input, label: t('flow.text') }, - ]; - - return ( -
- {fields.map((field, index) => { - const typeField = `query.${index}.type`; - const typeValue = form.watch(typeField); - return ( -
- ( - - - - { - field.onChange(val); - form.resetField(`query.${index}.value`); - form.resetField(`query.${index}.component_id`); - }} - > - - - - )} - /> - ( - - - - {typeValue === VariableType.Reference ? ( - - ) : ( - - )} - - - - )} - /> - remove(index)} - /> -
- ); - })} - -
- ); -} - -export function DynamicInputVariable({ node }: IProps) { - const { t } = useTranslation(); - - return ( - - - - {t('flow.input')} - - - - - - - - ); -} diff --git a/web/src/pages/agent/form/components/output.tsx b/web/src/pages/agent/form/components/output.tsx index 68551ab3cef..73058b67be3 100644 --- a/web/src/pages/agent/form/components/output.tsx +++ b/web/src/pages/agent/form/components/output.tsx @@ -1,6 +1,7 @@ import { RAGFlowFormItem } from '@/components/ragflow-form'; import { Input } from '@/components/ui/input'; import { t } from 'i18next'; +import { PropsWithChildren } from 'react'; import { z } from 'zod'; export type OutputType = { @@ -11,7 +12,7 @@ export type OutputType = { type OutputProps = { list: Array; isFormRequired?: boolean; -}; +} & PropsWithChildren; export function transferOutputs(outputs: Record) { return Object.entries(outputs).map(([key, value]) => ({ @@ -24,10 +25,16 @@ export const OutputSchema = { outputs: z.record(z.any()), }; -export function Output({ list, isFormRequired = false }: OutputProps) { +export function Output({ + list, + isFormRequired = false, + children, +}: OutputProps) { return (
-
{t('flow.output')}
+
+ {t('flow.output')} {children} +
    {list.map((x, idx) => (
  • (''); - let options = useFilterQueryVariableOptionsByTypes(types); + let options = useFilterQueryVariableOptionsByTypes({ types }); if (baseOptions) { options = baseOptions as typeof options; diff --git a/web/src/pages/agent/form/components/query-variable-list.tsx b/web/src/pages/agent/form/components/query-variable-list.tsx index 73734532abe..d2ed52fcede 100644 --- a/web/src/pages/agent/form/components/query-variable-list.tsx +++ b/web/src/pages/agent/form/components/query-variable-list.tsx @@ -2,6 +2,10 @@ import { Button } from '@/components/ui/button'; import { X } from 'lucide-react'; import { useFieldArray, useFormContext } from 'react-hook-form'; import { JsonSchemaDataType } from '../../constant'; +import { + flatOptions, + useFilterQueryVariableOptionsByTypes, +} from '../../hooks/use-get-begin-query'; import { DynamicFormHeader, FormListHeaderProps } from './dynamic-fom-header'; import { QueryVariable } from './query-variable'; @@ -16,6 +20,10 @@ export function QueryVariableList({ const form = useFormContext(); const name = 'query'; + let options = useFilterQueryVariableOptionsByTypes({ types }); + + const secondOptions = flatOptions(options); + const { fields, remove, append } = useFieldArray({ name: name, control: form.control, @@ -26,14 +34,15 @@ export function QueryVariableList({ append({ input: '' })} + onClick={() => append({ input: secondOptions.at(0)?.value })} + disabled={!secondOptions.length} >
    {fields.map((field, index) => { const nameField = `${name}.${index}.input`; return ( -
    +
    void; -}; + pureQuery?: boolean; + value?: string; +} & BuildQueryVariableOptions; export function QueryVariable({ name = 'query', @@ -28,11 +33,39 @@ export function QueryVariable({ hideLabel = false, className, onChange, + pureQuery = false, + value, + nodeIds = [], + variablesExceptOperatorOutputs, }: QueryVariableProps) { const { t } = useTranslation(); const form = useFormContext(); - const finalOptions = useFilterQueryVariableOptionsByTypes(types); + const finalOptions = useFilterQueryVariableOptionsByTypes({ + types, + nodeIds, + variablesExceptOperatorOutputs, + }); + + const renderWidget = ( + value?: string, + handleChange?: (value: string) => void, + ) => ( + { + handleChange?.(val); + onChange?.(val); + }} + // allowClear + types={types} + > + ); + + if (pureQuery) { + renderWidget(value, onChange); + } return ( )} - - { - field.onChange(val); - onChange?.(val); - }} - // allowClear - types={types} - > - + {renderWidget(field.value, field.onChange)} )} diff --git a/web/src/pages/agent/form/agent-form/structured-output-dialog.tsx b/web/src/pages/agent/form/components/schema-dialog.tsx similarity index 81% rename from web/src/pages/agent/form/agent-form/structured-output-dialog.tsx rename to web/src/pages/agent/form/components/schema-dialog.tsx index 6ce305bff38..4d67e00c0d6 100644 --- a/web/src/pages/agent/form/agent-form/structured-output-dialog.tsx +++ b/web/src/pages/agent/form/components/schema-dialog.tsx @@ -3,6 +3,7 @@ import { JsonSchemaVisualizer, SchemaVisualEditor, } from '@/components/jsonjoy-builder'; +import { KeyInputProps } from '@/components/jsonjoy-builder/components/schema-editor/interface'; import { Button } from '@/components/ui/button'; import { Dialog, @@ -16,11 +17,12 @@ import { IModalProps } from '@/interfaces/common'; import { useCallback, useState } from 'react'; import { useTranslation } from 'react-i18next'; -export function StructuredOutputDialog({ +export function SchemaDialog({ hideModal, onOk, initialValues, -}: IModalProps) { + pattern, +}: IModalProps & KeyInputProps) { const { t } = useTranslation(); const [schema, setSchema] = useState(initialValues); @@ -36,7 +38,11 @@ export function StructuredOutputDialog({
    - +
    diff --git a/web/src/pages/agent/form/agent-form/structured-output-panel.tsx b/web/src/pages/agent/form/components/schema-panel.tsx similarity index 78% rename from web/src/pages/agent/form/agent-form/structured-output-panel.tsx rename to web/src/pages/agent/form/components/schema-panel.tsx index 64e13c6ebf0..e76ff726e1b 100644 --- a/web/src/pages/agent/form/agent-form/structured-output-panel.tsx +++ b/web/src/pages/agent/form/components/schema-panel.tsx @@ -1,6 +1,6 @@ import { JSONSchema, JsonSchemaVisualizer } from '@/components/jsonjoy-builder'; -export function StructuredOutputPanel({ value }: { value: JSONSchema }) { +export function SchemaPanel({ value }: { value: JSONSchema }) { return (
    x === dataType) || + (types?.some((x) => x === compositeDataType) || hasSpecificTypeChild(value ?? {}, types))) ) { return (
  • {key} - {dataType} + + {compositeDataType} +
    {[JsonSchemaDataType.Object, JsonSchemaDataType.Array].some( (x) => x === dataType, @@ -101,8 +105,9 @@ export function StructuredOutputSecondaryMenu({ ); if ( - !hasJsonSchemaChild(structuredOutput) || - (!isEmpty(types) && !hasSpecificTypeChild(structuredOutput, types)) + !isEmpty(types) && + !hasSpecificTypeChild(structuredOutput, types) && + !types.some((x) => x === JsonSchemaDataType.Object) ) { return null; } @@ -124,7 +129,7 @@ export function StructuredOutputSecondaryMenu({ side="left" align="start" className={cn( - 'min-w-[140px] border border-border rounded-md shadow-lg p-0', + 'min-w-72 border border-border rounded-md shadow-lg p-0', )} >
    diff --git a/web/src/pages/agent/form/data-operations-form/index.tsx b/web/src/pages/agent/form/data-operations-form/index.tsx index dd7fc1fbe4a..6663161c082 100644 --- a/web/src/pages/agent/form/data-operations-form/index.tsx +++ b/web/src/pages/agent/form/data-operations-form/index.tsx @@ -4,6 +4,7 @@ import { Form } from '@/components/ui/form'; import { Separator } from '@/components/ui/separator'; import { buildOptions } from '@/utils/form'; import { zodResolver } from '@hookform/resolvers/zod'; +import { t } from 'i18next'; import { memo } from 'react'; import { useForm, useWatch } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; @@ -25,7 +26,11 @@ import { SelectKeys } from './select-keys'; import { Updates } from './updates'; export const RetrievalPartialSchema = { - query: z.array(z.object({ input: z.string().optional() })), + query: z.array( + z.object({ + input: z.string().min(1, { message: t('flow.queryRequired') }), + }), + ), operations: z.string(), select_keys: z.array(z.object({ name: z.string().optional() })).optional(), remove_keys: z.array(z.object({ name: z.string().optional() })).optional(), diff --git a/web/src/pages/agent/form/extractor-form/index.tsx b/web/src/pages/agent/form/extractor-form/index.tsx index 391d8c09ef7..33ab5eff2e0 100644 --- a/web/src/pages/agent/form/extractor-form/index.tsx +++ b/web/src/pages/agent/form/extractor-form/index.tsx @@ -59,6 +59,8 @@ const ExtractorForm = ({ node }: INextOperatorForm) => { useWatchFormChange(node?.id, form); + const isToc = form.getValues('field_name') === 'toc'; + return (
    @@ -76,19 +78,27 @@ const ExtractorForm = ({ node }: INextOperatorForm) => { > )} - - - - + + {!isToc && ( + + + + )} + + + {visible && ( diff --git a/web/src/pages/agent/form/invoke-form/variable-table.tsx b/web/src/pages/agent/form/invoke-form/variable-table.tsx index 8ca794bde6c..68fbf0a9c56 100644 --- a/web/src/pages/agent/form/invoke-form/variable-table.tsx +++ b/web/src/pages/agent/form/invoke-form/variable-table.tsx @@ -49,7 +49,7 @@ export function VariableTable({ nodeId, }: IProps) { const { t } = useTranslation(); - const { getLabel } = useGetVariableLabelOrTypeByValue(nodeId!); + const { getLabel } = useGetVariableLabelOrTypeByValue({ nodeId: nodeId! }); const [sorting, setSorting] = React.useState([]); const [columnFilters, setColumnFilters] = React.useState( diff --git a/web/src/pages/agent/form/iteration-form/dynamic-output.tsx b/web/src/pages/agent/form/iteration-form/dynamic-output.tsx index c31be8fd062..8cb8a4b4823 100644 --- a/web/src/pages/agent/form/iteration-form/dynamic-output.tsx +++ b/web/src/pages/agent/form/iteration-form/dynamic-output.tsx @@ -1,7 +1,7 @@ 'use client'; import { FormContainer } from '@/components/form-container'; -import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { KeyInput } from '@/components/key-input'; import { BlockButton, Button } from '@/components/ui/button'; import { FormControl, @@ -9,15 +9,18 @@ import { FormItem, FormMessage, } from '@/components/ui/form'; -import { Input } from '@/components/ui/input'; import { Separator } from '@/components/ui/separator'; +import { Operator } from '@/constants/agent'; import { RAGFlowNodeType } from '@/interfaces/database/flow'; import { t } from 'i18next'; +import { isEmpty } from 'lodash'; import { X } from 'lucide-react'; -import { ReactNode, useCallback, useMemo } from 'react'; +import { ReactNode } from 'react'; import { useFieldArray, useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import { useBuildSubNodeOutputOptions } from './use-build-options'; +import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query'; +import useGraphStore from '../../store'; +import { QueryVariable } from '../components/query-variable'; interface IProps { node?: RAGFlowNodeType; @@ -26,28 +29,22 @@ interface IProps { export function DynamicOutputForm({ node }: IProps) { const { t } = useTranslation(); const form = useFormContext(); - const options = useBuildSubNodeOutputOptions(node?.id); - const name = 'outputs'; + const { nodes } = useGraphStore((state) => state); - const flatOptions = useMemo(() => { - return options.reduce<{ label: string; value: string; type: string }[]>( - (pre, cur) => { - pre.push(...cur.options); - return pre; - }, - [], - ); - }, [options]); + const childNodeIds = nodes + .filter( + (x) => + x.parentId === node?.id && + x.data.label !== Operator.IterationStart && + !isEmpty(x.data?.form?.outputs), + ) + .map((x) => x.id); - const findType = useCallback( - (val: string) => { - const type = flatOptions.find((x) => x.value === val)?.type; - if (type) { - return `Array<${type}>`; - } - }, - [flatOptions], - ); + const name = 'outputs'; + + const { getType } = useGetVariableLabelOrTypeByValue({ + nodeIds: childNodeIds, + }); const { fields, remove, append } = useFieldArray({ name: name, @@ -67,35 +64,25 @@ export function DynamicOutputForm({ node }: IProps) { render={({ field }) => ( - + > )} /> - ( - - - { - form.setValue(typeField, findType(val)); - field.onChange(val); - }} - > - - - - )} - /> + hideLabel + className="w-2/5" + onChange={(val) => { + form.setValue(typeField, `Array<${getType(val)}>`); + }} + nodeIds={childNodeIds} + > - - - + diff --git a/web/src/pages/agent/form/iteration-form/use-build-logical-options.ts b/web/src/pages/agent/form/iteration-form/use-build-logical-options.ts new file mode 100644 index 00000000000..a7f960e98e4 --- /dev/null +++ b/web/src/pages/agent/form/iteration-form/use-build-logical-options.ts @@ -0,0 +1,59 @@ +import { buildOptions } from '@/utils/form'; +import { camelCase } from 'lodash'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + JsonSchemaDataType, + VariableAssignerLogicalArrayOperator, + VariableAssignerLogicalNumberOperator, + VariableAssignerLogicalNumberOperatorLabelMap, + VariableAssignerLogicalOperator, +} from '../../constant'; + +export function useBuildLogicalOptions() { + const { t } = useTranslation(); + + const buildVariableAssignerLogicalOptions = useCallback( + (record: Record) => { + return buildOptions( + record, + t, + 'flow.variableAssignerLogicalOperatorOptions', + true, + ); + }, + [t], + ); + + const buildLogicalOptions = useCallback( + (type: string) => { + if ( + type?.toLowerCase().startsWith(JsonSchemaDataType.Array.toLowerCase()) + ) { + return buildVariableAssignerLogicalOptions( + VariableAssignerLogicalArrayOperator, + ); + } + + if (type === JsonSchemaDataType.Number) { + return Object.values(VariableAssignerLogicalNumberOperator).map( + (val) => ({ + label: t( + `flow.variableAssignerLogicalOperatorOptions.${camelCase(VariableAssignerLogicalNumberOperatorLabelMap[val as keyof typeof VariableAssignerLogicalNumberOperatorLabelMap] || val)}`, + ), + value: val, + }), + ); + } + + return buildVariableAssignerLogicalOptions( + VariableAssignerLogicalOperator, + ); + }, + [buildVariableAssignerLogicalOptions, t], + ); + + return { + buildLogicalOptions, + }; +} diff --git a/web/src/pages/agent/form/iteration-form/use-build-options.ts b/web/src/pages/agent/form/iteration-form/use-build-options.ts deleted file mode 100644 index 46fa7ee308c..00000000000 --- a/web/src/pages/agent/form/iteration-form/use-build-options.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { buildOutputOptions } from '@/utils/canvas-util'; -import { isEmpty } from 'lodash'; -import { useMemo } from 'react'; -import { Operator } from '../../constant'; -import useGraphStore from '../../store'; - -export function useBuildSubNodeOutputOptions(nodeId?: string) { - const { nodes } = useGraphStore((state) => state); - - const nodeOutputOptions = useMemo(() => { - if (!nodeId) { - return []; - } - - const subNodeWithOutputList = nodes.filter( - (x) => - x.parentId === nodeId && - x.data.label !== Operator.IterationStart && - !isEmpty(x.data?.form?.outputs), - ); - - return subNodeWithOutputList.map((x) => ({ - label: x.data.name, - value: x.id, - title: x.data.name, - options: buildOutputOptions(x.data.form.outputs, x.id), - })); - }, [nodeId, nodes]); - - return nodeOutputOptions; -} diff --git a/web/src/pages/agent/form/jin10-form/index.tsx b/web/src/pages/agent/form/jin10-form/index.tsx deleted file mode 100644 index 2bc6d774a38..00000000000 --- a/web/src/pages/agent/form/jin10-form/index.tsx +++ /dev/null @@ -1,145 +0,0 @@ -import { useTranslate } from '@/hooks/common-hooks'; -import { Form, Input, Select } from 'antd'; -import { useMemo } from 'react'; -import { IOperatorForm } from '../../interface'; -import { - Jin10CalendarDatashapeOptions, - Jin10CalendarTypeOptions, - Jin10FlashTypeOptions, - Jin10SymbolsDatatypeOptions, - Jin10SymbolsTypeOptions, - Jin10TypeOptions, -} from '../../options'; -import DynamicInputVariable from '../components/dynamic-input-variable'; - -const Jin10Form = ({ onValuesChange, form, node }: IOperatorForm) => { - const { t } = useTranslate('flow'); - - const jin10TypeOptions = useMemo(() => { - return Jin10TypeOptions.map((x) => ({ - value: x, - label: t(`jin10TypeOptions.${x}`), - })); - }, [t]); - - const jin10FlashTypeOptions = useMemo(() => { - return Jin10FlashTypeOptions.map((x) => ({ - value: x, - label: t(`jin10FlashTypeOptions.${x}`), - })); - }, [t]); - - const jin10CalendarTypeOptions = useMemo(() => { - return Jin10CalendarTypeOptions.map((x) => ({ - value: x, - label: t(`jin10CalendarTypeOptions.${x}`), - })); - }, [t]); - - const jin10CalendarDatashapeOptions = useMemo(() => { - return Jin10CalendarDatashapeOptions.map((x) => ({ - value: x, - label: t(`jin10CalendarDatashapeOptions.${x}`), - })); - }, [t]); - - const jin10SymbolsTypeOptions = useMemo(() => { - return Jin10SymbolsTypeOptions.map((x) => ({ - value: x, - label: t(`jin10SymbolsTypeOptions.${x}`), - })); - }, [t]); - - const jin10SymbolsDatatypeOptions = useMemo(() => { - return Jin10SymbolsDatatypeOptions.map((x) => ({ - value: x, - label: t(`jin10SymbolsDatatypeOptions.${x}`), - })); - }, [t]); - - return ( - - - - - - - - - - {({ getFieldValue }) => { - const type = getFieldValue('type'); - switch (type) { - case 'flash': - return ( - <> - - - - - - - - - - - ); - - case 'calendar': - return ( - <> - - - - - - - - ); - - case 'symbols': - return ( - <> - - - - - - - - ); - - case 'news': - return ( - <> - - - - - - - - ); - - default: - return <>; - } - }} - - - ); -}; - -export default Jin10Form; diff --git a/web/src/pages/agent/form/keyword-extract-form/index.tsx b/web/src/pages/agent/form/keyword-extract-form/index.tsx deleted file mode 100644 index bda5d44f510..00000000000 --- a/web/src/pages/agent/form/keyword-extract-form/index.tsx +++ /dev/null @@ -1,48 +0,0 @@ -import { NextLLMSelect } from '@/components/llm-select/next'; -import { TopNFormField } from '@/components/top-n-item'; -import { - Form, - FormControl, - FormField, - FormItem, - FormLabel, - FormMessage, -} from '@/components/ui/form'; -import { useTranslation } from 'react-i18next'; -import { INextOperatorForm } from '../../interface'; -import { DynamicInputVariable } from '../components/next-dynamic-input-variable'; - -const KeywordExtractForm = ({ form, node }: INextOperatorForm) => { - const { t } = useTranslation(); - - return ( -
    - { - e.preventDefault(); - }} - > - - ( - - - {t('chat.model')} - - - - - - - )} - /> - - - - ); -}; - -export default KeywordExtractForm; diff --git a/web/src/pages/agent/form/list-operations-form/index.tsx b/web/src/pages/agent/form/list-operations-form/index.tsx new file mode 100644 index 00000000000..afc44e9075c --- /dev/null +++ b/web/src/pages/agent/form/list-operations-form/index.tsx @@ -0,0 +1,223 @@ +import NumberInput from '@/components/originui/number-input'; +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Separator } from '@/components/ui/separator'; +import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options'; +import { buildOptions } from '@/utils/form'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { memo, useCallback, useEffect, useMemo } from 'react'; +import { useForm, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { z } from 'zod'; +import { + ArrayFields, + DataOperationsOperatorOptions, + ListOperations, + SortMethod, + initialListOperationsValues, +} from '../../constant'; +import { useFormValues } from '../../hooks/use-form-values'; +import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query'; +import { useWatchFormChange } from '../../hooks/use-watch-form-change'; +import { INextOperatorForm } from '../../interface'; +import { getArrayElementType } from '../../utils'; +import { buildOutputList } from '../../utils/build-output-list'; +import { FormWrapper } from '../components/form-wrapper'; +import { Output, OutputSchema } from '../components/output'; +import { PromptEditor } from '../components/prompt-editor'; +import { QueryVariable } from '../components/query-variable'; + +export const RetrievalPartialSchema = { + query: z.string(), + operations: z.string(), + n: z.number().int().min(1).optional(), + sort_method: z.string().optional(), + filter: z + .object({ + value: z.string().optional(), + operator: z.string().optional(), + }) + .optional(), + ...OutputSchema, +}; + +const NumFields = [ + ListOperations.TopN, + ListOperations.Head, + ListOperations.Tail, +]; + +function showField(operations: string) { + const showNum = NumFields.includes(operations as ListOperations); + const showSortMethod = [ListOperations.Sort].includes( + operations as ListOperations, + ); + const showFilter = [ListOperations.Filter].includes( + operations as ListOperations, + ); + + return { + showNum, + showSortMethod, + showFilter, + }; +} + +export const FormSchema = z.object(RetrievalPartialSchema); + +export type ListOperationsFormSchemaType = z.infer; + +function ListOperationsForm({ node }: INextOperatorForm) { + const { t } = useTranslation(); + + const { getType } = useGetVariableLabelOrTypeByValue(); + + const defaultValues = useFormValues(initialListOperationsValues, node); + + const form = useForm({ + defaultValues: defaultValues, + mode: 'onChange', + resolver: zodResolver(FormSchema), + // shouldUnregister: true, + }); + + const operations = useWatch({ control: form.control, name: 'operations' }); + + const query = useWatch({ control: form.control, name: 'query' }); + + const subType = getArrayElementType(getType(query)); + + const currentOutputs = useMemo(() => { + return { + result: { + type: `Array<${subType}>`, + }, + first: { + type: subType, + }, + last: { + type: subType, + }, + }; + }, [subType]); + + const outputList = buildOutputList(currentOutputs); + + const ListOperationsOptions = buildOptions( + ListOperations, + t, + `flow.ListOperationsOptions`, + true, + ); + const SortMethodOptions = buildOptions( + SortMethod, + t, + `flow.SortMethodOptions`, + true, + ); + + const operatorOptions = useBuildSwitchOperatorOptions( + DataOperationsOperatorOptions, + ); + + const { showFilter, showNum, showSortMethod } = showField(operations); + + const handleOperationsChange = useCallback( + (operations: string) => { + const { showFilter, showNum, showSortMethod } = showField(operations); + + if (showNum) { + form.setValue('n', 1, { shouldDirty: true }); + } + + if (showSortMethod) { + form.setValue('sort_method', SortMethodOptions.at(0)?.value, { + shouldDirty: true, + }); + } + if (showFilter) { + form.setValue('filter.operator', operatorOptions.at(0)?.value, { + shouldDirty: true, + }); + } + }, + [SortMethodOptions, form, operatorOptions], + ); + + useEffect(() => { + form.setValue('outputs', currentOutputs, { shouldDirty: true }); + }, [currentOutputs, form]); + + useWatchFormChange(node?.id, form, true); + + return ( +
    + + + + + {(field) => ( + { + handleOperationsChange(val); + field.onChange(val); + }} + /> + )} + + {showNum && ( + ( + + {t('flow.flowNum')} + + + + + + )} + /> + )} + {showSortMethod && ( + + + + )} + {showFilter && ( +
    + + + + + + + +
    + )} + +
    +
    + ); +} + +export default memo(ListOperationsForm); diff --git a/web/src/pages/agent/form/loop-form/dynamic-variables.tsx b/web/src/pages/agent/form/loop-form/dynamic-variables.tsx new file mode 100644 index 00000000000..04318af108b --- /dev/null +++ b/web/src/pages/agent/form/loop-form/dynamic-variables.tsx @@ -0,0 +1,229 @@ +import { BoolSegmented } from '@/components/bool-segmented'; +import { KeyInput } from '@/components/key-input'; +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { useIsDarkTheme } from '@/components/theme-provider'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Separator } from '@/components/ui/separator'; +import { Textarea } from '@/components/ui/textarea'; +import { Editor, loader } from '@monaco-editor/react'; +import { X } from 'lucide-react'; +import { ReactNode, useCallback } from 'react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { InputMode, TypesWithArray } from '../../constant'; +import { + InputModeOptions, + buildConversationVariableSelectOptions, +} from '../../utils'; +import { DynamicFormHeader } from '../components/dynamic-fom-header'; +import { QueryVariable } from '../components/query-variable'; +import { useInitializeConditions } from './use-watch-form-change'; + +loader.config({ paths: { vs: '/vs' } }); + +type SelectKeysProps = { + name: string; + label: ReactNode; + tooltip?: string; + keyField?: string; + valueField?: string; + operatorField?: string; + nodeId?: string; +}; + +const VariableTypeOptions = buildConversationVariableSelectOptions(); + +const modeField = 'input_mode'; + +const ConstantValueMap = { + [TypesWithArray.Boolean]: true, + [TypesWithArray.Number]: 0, + [TypesWithArray.String]: '', + [TypesWithArray.ArrayBoolean]: '[]', + [TypesWithArray.ArrayNumber]: '[]', + [TypesWithArray.ArrayString]: '[]', + [TypesWithArray.ArrayObject]: '[]', + [TypesWithArray.Object]: '{}', +}; + +export function DynamicVariables({ + name, + label, + tooltip, + keyField = 'variable', + valueField = 'value', + operatorField = 'type', + nodeId, +}: SelectKeysProps) { + const form = useFormContext(); + const isDarkTheme = useIsDarkTheme(); + + const { fields, remove, append } = useFieldArray({ + name: name, + control: form.control, + }); + + const { initializeVariableRelatedConditions } = + useInitializeConditions(nodeId); + + const initializeValue = useCallback( + (mode: string, variableType: string, valueFieldAlias: string) => { + if (mode === InputMode.Variable) { + form.setValue(valueFieldAlias, '', { shouldDirty: true }); + } else { + const val = ConstantValueMap[variableType as TypesWithArray]; + form.setValue(valueFieldAlias, val, { shouldDirty: true }); + } + }, + [form], + ); + + const handleModeChange = useCallback( + (mode: string, valueFieldAlias: string, operatorFieldAlias: string) => { + const variableType = form.getValues(operatorFieldAlias); + initializeValue(mode, variableType, valueFieldAlias); + }, + [form, initializeValue], + ); + + const handleVariableTypeChange = useCallback( + ( + variableType: string, + valueFieldAlias: string, + modeFieldAlias: string, + keyFieldAlias: string, + ) => { + const mode = form.getValues(modeFieldAlias); + + initializeVariableRelatedConditions( + form.getValues(keyFieldAlias), + variableType, + ); + + initializeValue(mode, variableType, valueFieldAlias); + }, + [form, initializeValue, initializeVariableRelatedConditions], + ); + + const renderParameter = useCallback( + (operatorFieldName: string, modeFieldName: string) => { + const mode = form.getValues(modeFieldName); + const logicalOperator = form.getValues(operatorFieldName); + + if (mode === InputMode.Constant) { + if (logicalOperator === TypesWithArray.Boolean) { + return ; + } + + if (logicalOperator === TypesWithArray.Number) { + return ; + } + + if (logicalOperator === TypesWithArray.String) { + return ; + } + + return ( + + ); + } + + return ( + + ); + }, + [form, isDarkTheme], + ); + + return ( +
    + + append({ + [keyField]: '', + [valueField]: '', + [modeField]: InputMode.Constant, + [operatorField]: TypesWithArray.String, + }) + } + > +
    + {fields.map((field, index) => { + const keyFieldAlias = `${name}.${index}.${keyField}`; + const valueFieldAlias = `${name}.${index}.${valueField}`; + const operatorFieldAlias = `${name}.${index}.${operatorField}`; + const modeFieldAlias = `${name}.${index}.${modeField}`; + + return ( +
    +
    +
    + + + + + + {(field) => ( + { + handleVariableTypeChange( + val, + valueFieldAlias, + modeFieldAlias, + keyFieldAlias, + ); + field.onChange(val); + }} + options={VariableTypeOptions} + > + )} + + + + {(field) => ( + { + handleModeChange( + val, + valueFieldAlias, + operatorFieldAlias, + ); + field.onChange(val); + }} + options={InputModeOptions} + > + )} + +
    + + {renderParameter(operatorFieldAlias, modeFieldAlias)} + +
    + + +
    + ); + })} +
    +
    + ); +} diff --git a/web/src/pages/agent/form/loop-form/index.tsx b/web/src/pages/agent/form/loop-form/index.tsx new file mode 100644 index 00000000000..6a5e30f383f --- /dev/null +++ b/web/src/pages/agent/form/loop-form/index.tsx @@ -0,0 +1,52 @@ +import { SliderInputFormField } from '@/components/slider-input-form-field'; +import { Form } from '@/components/ui/form'; +import { FormLayout } from '@/constants/form'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { memo } from 'react'; +import { useForm } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { initialLoopValues } from '../../constant'; +import { INextOperatorForm } from '../../interface'; +import { FormWrapper } from '../components/form-wrapper'; +import { DynamicVariables } from './dynamic-variables'; +import { LoopTerminationCondition } from './loop-termination-condition'; +import { FormSchema, LoopFormSchemaType } from './schema'; +import { useFormValues } from './use-values'; +import { useWatchFormChange } from './use-watch-form-change'; + +function LoopForm({ node }: INextOperatorForm) { + const defaultValues = useFormValues(initialLoopValues, node); + const { t } = useTranslation(); + + const form = useForm({ + defaultValues: defaultValues, + resolver: zodResolver(FormSchema), + }); + + useWatchFormChange(node?.id, form); + + return ( +
    + + + + + +
    + ); +} + +export default memo(LoopForm); diff --git a/web/src/pages/agent/form/loop-form/loop-termination-condition.tsx b/web/src/pages/agent/form/loop-form/loop-termination-condition.tsx new file mode 100644 index 00000000000..52d59871199 --- /dev/null +++ b/web/src/pages/agent/form/loop-form/loop-termination-condition.tsx @@ -0,0 +1,316 @@ +import { BoolSegmented } from '@/components/bool-segmented'; +import { LogicalOperator } from '@/components/logical-operator'; +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Separator } from '@/components/ui/separator'; +import { ComparisonOperator, SwitchLogicOperator } from '@/constants/agent'; +import { loader } from '@monaco-editor/react'; +import { toLower } from 'lodash'; +import { X } from 'lucide-react'; +import { ReactNode, useCallback, useMemo } from 'react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { + AgentVariableType, + InputMode, + JsonSchemaDataType, +} from '../../constant'; +import { useFilterChildNodeIds } from '../../hooks/use-filter-child-node-ids'; +import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query'; +import { InputModeOptions } from '../../utils'; +import { DynamicFormHeader } from '../components/dynamic-fom-header'; +import { QueryVariable } from '../components/query-variable'; +import { LoopFormSchemaType } from './schema'; +import { useBuildLogicalOptions } from './use-build-logical-options'; +import { + ConditionKeyType, + ConditionModeType, + ConditionOperatorType, + ConditionValueType, + useInitializeConditions, +} from './use-watch-form-change'; + +loader.config({ paths: { vs: '/vs' } }); + +const VariablesExceptOperatorOutputs = [AgentVariableType.Conversation]; + +type LoopTerminationConditionProps = { + label: ReactNode; + tooltip?: string; + keyField?: string; + valueField?: string; + operatorField?: string; + modeField?: string; + nodeId?: string; +}; + +const EmptyFields = [ComparisonOperator.Empty, ComparisonOperator.NotEmpty]; + +const LogicalOperatorFieldName = 'logical_operator'; + +const name = 'loop_termination_condition'; + +export function LoopTerminationCondition({ + label, + tooltip, + keyField = 'variable', + valueField = 'value', + operatorField = 'operator', + modeField = 'input_mode', + nodeId, +}: LoopTerminationConditionProps) { + const form = useFormContext(); + const childNodeIds = useFilterChildNodeIds(nodeId); + + const nodeIds = useMemo(() => { + if (!nodeId) return []; + return [nodeId, ...childNodeIds]; + }, [childNodeIds, nodeId]); + + const { getType } = useGetVariableLabelOrTypeByValue({ + nodeIds: nodeIds, + variablesExceptOperatorOutputs: VariablesExceptOperatorOutputs, + }); + + const { + initializeConditionMode, + initializeConditionOperator, + initializeConditionValue, + } = useInitializeConditions(nodeId); + + const { fields, remove, append } = useFieldArray({ + name: name, + control: form.control, + }); + + const { buildLogicalOptions } = useBuildLogicalOptions(); + + const getVariableType = useCallback( + (keyFieldName: ConditionKeyType) => { + const key = form.getValues(keyFieldName); + return toLower(getType(key)); + }, + [form, getType], + ); + + const initializeMode = useCallback( + (modeFieldAlias: ConditionModeType, keyFieldAlias: ConditionKeyType) => { + const keyType = getVariableType(keyFieldAlias); + + initializeConditionMode(modeFieldAlias, keyType); + }, + [getVariableType, initializeConditionMode], + ); + + const initializeValue = useCallback( + (valueFieldAlias: ConditionValueType, keyFieldAlias: ConditionKeyType) => { + const keyType = getVariableType(keyFieldAlias); + + initializeConditionValue(valueFieldAlias, keyType); + }, + [getVariableType, initializeConditionValue], + ); + + const handleVariableChange = useCallback( + ( + operatorFieldAlias: ConditionOperatorType, + valueFieldAlias: ConditionValueType, + keyFieldAlias: ConditionKeyType, + modeFieldAlias: ConditionModeType, + ) => { + return () => { + initializeConditionOperator( + operatorFieldAlias, + getVariableType(keyFieldAlias), + ); + + initializeMode(modeFieldAlias, keyFieldAlias); + + initializeValue(valueFieldAlias, keyFieldAlias); + }; + }, + [ + getVariableType, + initializeConditionOperator, + initializeMode, + initializeValue, + ], + ); + + const handleOperatorChange = useCallback( + ( + valueFieldAlias: ConditionValueType, + keyFieldAlias: ConditionKeyType, + modeFieldAlias: ConditionModeType, + ) => { + initializeMode(modeFieldAlias, keyFieldAlias); + initializeValue(valueFieldAlias, keyFieldAlias); + }, + [initializeMode, initializeValue], + ); + + const handleModeChange = useCallback( + (mode: string, valueFieldAlias: ConditionValueType) => { + form.setValue(valueFieldAlias, mode === InputMode.Constant ? 0 : '', { + shouldDirty: true, + }); + }, + [form], + ); + + const renderParameterPanel = useCallback( + ( + keyFieldName: ConditionKeyType, + valueFieldAlias: ConditionValueType, + modeFieldAlias: ConditionModeType, + operatorFieldAlias: ConditionOperatorType, + ) => { + const type = getVariableType(keyFieldName); + const mode = form.getValues(modeFieldAlias); + const operator = form.getValues(operatorFieldAlias); + + if (EmptyFields.includes(operator as ComparisonOperator)) { + return null; + } + + if (type === JsonSchemaDataType.Number) { + return ( +
    + + {(field) => ( + { + handleModeChange(val, valueFieldAlias); + field.onChange(val); + }} + options={InputModeOptions} + > + )} + + + {mode === InputMode.Constant ? ( + + + + ) : ( + + )} +
    + ); + } + + if (type === JsonSchemaDataType.Boolean) { + return ( + + + + ); + } + + return ( + + + + ); + }, + [form, getVariableType, handleModeChange], + ); + + return ( +
    + { + if (fields.length === 1) { + form.setValue(LogicalOperatorFieldName, SwitchLogicOperator.And); + } + append({ [keyField]: '', [valueField]: '' }); + }} + > +
    + {fields.length > 1 && ( + + )} +
    + {fields.map((field, index) => { + const keyFieldAlias = + `${name}.${index}.${keyField}` as ConditionKeyType; + const valueFieldAlias = + `${name}.${index}.${valueField}` as ConditionValueType; + const operatorFieldAlias = + `${name}.${index}.${operatorField}` as ConditionOperatorType; + const modeFieldAlias = + `${name}.${index}.${modeField}` as ConditionModeType; + + return ( +
    +
    +
    + + + + + + {({ onChange, value }) => ( + { + handleOperatorChange( + valueFieldAlias, + keyFieldAlias, + modeFieldAlias, + ); + onChange(val); + }} + options={buildLogicalOptions( + getVariableType(keyFieldAlias), + )} + > + )} + +
    + {renderParameterPanel( + keyFieldAlias, + valueFieldAlias, + modeFieldAlias, + operatorFieldAlias, + )} +
    + + +
    + ); + })} +
    +
    +
    + ); +} diff --git a/web/src/pages/agent/form/loop-form/schema.ts b/web/src/pages/agent/form/loop-form/schema.ts new file mode 100644 index 00000000000..982cbb30583 --- /dev/null +++ b/web/src/pages/agent/form/loop-form/schema.ts @@ -0,0 +1,24 @@ +import { z } from 'zod'; + +export const FormSchema = z.object({ + loop_variables: z.array( + z.object({ + variable: z.string().optional(), + type: z.string().optional(), + value: z.string().or(z.number()).or(z.boolean()).optional(), + input_mode: z.string(), + }), + ), + logical_operator: z.string(), + loop_termination_condition: z.array( + z.object({ + variable: z.string().optional(), + operator: z.string().optional(), + value: z.string().or(z.number()).or(z.boolean()).optional(), + input_mode: z.string().optional(), + }), + ), + maximum_loop_count: z.number(), +}); + +export type LoopFormSchemaType = z.infer; diff --git a/web/src/pages/agent/form/loop-form/use-build-logical-options.ts b/web/src/pages/agent/form/loop-form/use-build-logical-options.ts new file mode 100644 index 00000000000..35aae3f8d9d --- /dev/null +++ b/web/src/pages/agent/form/loop-form/use-build-logical-options.ts @@ -0,0 +1,27 @@ +import { SwitchOperatorOptions } from '@/constants/agent'; +import { camelCase, toLower } from 'lodash'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { LoopTerminationStringComparisonOperatorMap } from '../../constant'; + +export function useBuildLogicalOptions() { + const { t } = useTranslation(); + + const buildLogicalOptions = useCallback( + (type: string) => { + return LoopTerminationStringComparisonOperatorMap[ + toLower(type) as keyof typeof LoopTerminationStringComparisonOperatorMap + ]?.map((x) => ({ + label: t( + `flow.switchOperatorOptions.${camelCase(SwitchOperatorOptions.find((y) => y.value === x)?.label || x)}`, + ), + value: x, + })); + }, + [t], + ); + + return { + buildLogicalOptions, + }; +} diff --git a/web/src/pages/agent/form/loop-form/use-values.ts b/web/src/pages/agent/form/loop-form/use-values.ts new file mode 100644 index 00000000000..cf7a1054a9d --- /dev/null +++ b/web/src/pages/agent/form/loop-form/use-values.ts @@ -0,0 +1,20 @@ +import { RAGFlowNodeType } from '@/interfaces/database/flow'; +import { isEmpty, omit } from 'lodash'; +import { useMemo } from 'react'; + +export function useFormValues( + defaultValues: Record, + node?: RAGFlowNodeType, +) { + const values = useMemo(() => { + const formData = node?.data?.form; + + if (isEmpty(formData)) { + return omit(defaultValues, 'outputs'); + } + + return omit(formData, 'outputs'); + }, [defaultValues, node?.data?.form]); + + return values; +} diff --git a/web/src/pages/agent/form/loop-form/use-watch-form-change.ts b/web/src/pages/agent/form/loop-form/use-watch-form-change.ts new file mode 100644 index 00000000000..f3b707c44f7 --- /dev/null +++ b/web/src/pages/agent/form/loop-form/use-watch-form-change.ts @@ -0,0 +1,116 @@ +import { JsonSchemaDataType } from '@/constants/agent'; +import { buildVariableValue } from '@/utils/canvas-util'; +import { useCallback, useEffect } from 'react'; +import { UseFormReturn, useFormContext, useWatch } from 'react-hook-form'; +import { InputMode } from '../../constant'; +import { IOutputs } from '../../interface'; +import useGraphStore from '../../store'; +import { LoopFormSchemaType } from './schema'; +import { useBuildLogicalOptions } from './use-build-logical-options'; + +export function useWatchFormChange( + id?: string, + form?: UseFormReturn, +) { + let values = useWatch({ control: form?.control }); + const { replaceNodeForm } = useGraphStore((state) => state); + + useEffect(() => { + if (id) { + let nextValues = { + ...values, + outputs: values.loop_variables?.reduce((pre, cur) => { + const variable = cur.variable; + if (variable) { + pre[variable] = { + type: cur.type, + value: '', + }; + } + return pre; + }, {} as IOutputs), + }; + + replaceNodeForm(id, nextValues); + } + }, [form?.formState.isDirty, id, replaceNodeForm, values]); +} + +type ConditionPrefixType = `loop_termination_condition.${number}.`; +export type ConditionKeyType = `${ConditionPrefixType}variable`; +export type ConditionModeType = `${ConditionPrefixType}input_mode`; +export type ConditionValueType = `${ConditionPrefixType}value`; +export type ConditionOperatorType = `${ConditionPrefixType}operator`; +export function useInitializeConditions(id?: string) { + const form = useFormContext(); + const { buildLogicalOptions } = useBuildLogicalOptions(); + + const initializeConditionMode = useCallback( + (modeFieldAlias: ConditionModeType, keyType: string) => { + if (keyType === JsonSchemaDataType.Number) { + form.setValue(modeFieldAlias, InputMode.Constant, { + shouldDirty: true, + shouldValidate: true, + }); + } + }, + [form], + ); + + const initializeConditionValue = useCallback( + (valueFieldAlias: ConditionValueType, keyType: string) => { + let initialValue: string | boolean | number = ''; + + if (keyType === JsonSchemaDataType.Number) { + initialValue = 0; + } else if (keyType === JsonSchemaDataType.Boolean) { + initialValue = true; + } + + form.setValue(valueFieldAlias, initialValue, { + shouldDirty: true, + shouldValidate: true, + }); + }, + [form], + ); + + const initializeConditionOperator = useCallback( + (operatorFieldAlias: ConditionOperatorType, keyType: string) => { + const logicalOptions = buildLogicalOptions(keyType); + + form.setValue(operatorFieldAlias, logicalOptions?.at(0)?.value, { + shouldDirty: true, + shouldValidate: true, + }); + }, + [buildLogicalOptions, form], + ); + + const initializeVariableRelatedConditions = useCallback( + (variable: string, variableType: string) => { + form?.getValues('loop_termination_condition').forEach((x, idx) => { + if (variable && x.variable === buildVariableValue(variable, id)) { + const prefix: ConditionPrefixType = `loop_termination_condition.${idx}.`; + initializeConditionMode(`${prefix}input_mode`, variableType); + initializeConditionValue(`${prefix}value`, variableType); + initializeConditionOperator(`${prefix}operator`, variableType); + } + }); + }, + [ + form, + id, + initializeConditionMode, + initializeConditionOperator, + initializeConditionValue, + ], + ); + + return { + initializeVariableRelatedConditions, + initializeConditionMode, + initializeConditionValue, + initializeConditionOperator, + }; +} diff --git a/web/src/pages/agent/form/message-form/index.tsx b/web/src/pages/agent/form/message-form/index.tsx index e93735ee740..87071e5780d 100644 --- a/web/src/pages/agent/form/message-form/index.tsx +++ b/web/src/pages/agent/form/message-form/index.tsx @@ -1,4 +1,4 @@ -import { FormContainer } from '@/components/form-container'; +import { MemoriesFormField } from '@/components/memories-form-field'; import { BlockButton, Button } from '@/components/ui/button'; import { Form, @@ -8,15 +8,20 @@ import { FormLabel, FormMessage, } from '@/components/ui/form'; +import { RAGFlowSelect } from '@/components/ui/select'; +import { Switch } from '@/components/ui/switch'; +import { WebHookResponseStatusFormField } from '@/components/webhook-response-status'; import { zodResolver } from '@hookform/resolvers/zod'; import { X } from 'lucide-react'; import { memo } from 'react'; import { useFieldArray, useForm } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { z } from 'zod'; +import { ExportFileType } from '../../constant'; import { INextOperatorForm } from '../../interface'; import { FormWrapper } from '../components/form-wrapper'; import { PromptEditor } from '../components/prompt-editor'; +import { useShowWebhookResponseStatus } from './use-show-response-status'; import { useValues } from './use-values'; import { useWatchFormChange } from './use-watch-change'; @@ -33,10 +38,18 @@ function MessageForm({ node }: INextOperatorForm) { }), ) .optional(), + output_format: z.string().optional(), + auto_play: z.boolean().optional(), + status: z.number().optional(), + memory_ids: z.array(z.string()).optional(), }); const form = useForm({ - defaultValues: values, + defaultValues: { + ...values, + output_format: values.output_format, + auto_play: values.auto_play, + }, resolver: zodResolver(FormSchema), }); @@ -47,51 +60,109 @@ function MessageForm({ node }: INextOperatorForm) { control: form.control, }); + const { showWebhookResponseStatus, isWebhookMode } = + useShowWebhookResponseStatus(form); + return (
    - - - {t('flow.msg')} -
    - {fields.map((field, index) => ( -
    - ( - - - - - - )} - /> - {fields.length > 1 && ( - + {showWebhookResponseStatus && ( + + )} + + {t('flow.msg')} +
    + {fields.map((field, index) => ( +
    + ( + + + + + )} -
    - ))} + /> + {fields.length > 1 && ( + + )} +
    + ))} - append({ value: '' })} // "" will cause the inability to add, refer to: https://github.com/orgs/react-hook-form/discussions/8485#discussioncomment-2961861 - > - {t('flow.addMessage')} - -
    - - - + append({ value: '' })} // "" will cause the inability to add, refer to: https://github.com/orgs/react-hook-form/discussions/8485#discussioncomment-2961861 + > + {t('flow.addMessage')} + +
    + +
    + {!isWebhookMode && ( + <> + + + {t('flow.downloadFileType')} + + ( + + + { + return { + value: + ExportFileType[ + key as keyof typeof ExportFileType + ], + label: key, + }; + }, + )} + {...field} + onValueChange={field.onChange} + placeholder={t('common.selectPlaceholder')} + allowClear + > + + + )} + /> + + + {t('flow.autoPlay')} + ( + + + + + + )} + /> + + + )} +
    ); diff --git a/web/src/pages/agent/form/message-form/use-show-response-status.ts b/web/src/pages/agent/form/message-form/use-show-response-status.ts new file mode 100644 index 00000000000..830fffff153 --- /dev/null +++ b/web/src/pages/agent/form/message-form/use-show-response-status.ts @@ -0,0 +1,32 @@ +import { isEmpty } from 'lodash'; +import { useEffect, useMemo } from 'react'; +import { UseFormReturn } from 'react-hook-form'; +import { + AgentDialogueMode, + BeginId, + WebhookExecutionMode, +} from '../../constant'; +import useGraphStore from '../../store'; + +export function useShowWebhookResponseStatus(form: UseFormReturn) { + const getNode = useGraphStore((state) => state.getNode); + + const formData = getNode(BeginId)?.data.form; + + const isWebhookMode = formData?.mode === AgentDialogueMode.Webhook; + + const showWebhookResponseStatus = useMemo(() => { + return ( + isWebhookMode && + formData?.execution_mode === WebhookExecutionMode.Streaming + ); + }, [formData?.execution_mode, isWebhookMode]); + + useEffect(() => { + if (showWebhookResponseStatus && isEmpty(form.getValues('status'))) { + form.setValue('status', 200, { shouldValidate: true, shouldDirty: true }); + } + }, [form, showWebhookResponseStatus]); + + return { showWebhookResponseStatus, isWebhookMode }; +} diff --git a/web/src/pages/agent/form/parser-form/index.tsx b/web/src/pages/agent/form/parser-form/index.tsx index 2584c796010..1942b2d05ac 100644 --- a/web/src/pages/agent/form/parser-form/index.tsx +++ b/web/src/pages/agent/form/parser-form/index.tsx @@ -34,6 +34,8 @@ import { OutputFormatFormField } from './common-form-fields'; import { EmailFormFields } from './email-form-fields'; import { ImageFormFields } from './image-form-fields'; import { PdfFormFields } from './pdf-form-fields'; +import { PptFormFields } from './ppt-form-fields'; +import { SpreadsheetFormFields } from './spreadsheet-form-fields'; import { buildFieldNameWithPrefix } from './utils'; import { AudioFormFields, VideoFormFields } from './video-form-fields'; @@ -41,6 +43,8 @@ const outputList = buildOutputList(initialParserValues.outputs); const FileFormatWidgetMap = { [FileType.PDF]: PdfFormFields, + [FileType.Spreadsheet]: SpreadsheetFormFields, + [FileType.PowerPoint]: PptFormFields, [FileType.Video]: VideoFormFields, [FileType.Audio]: AudioFormFields, [FileType.Email]: EmailFormFields, @@ -65,6 +69,8 @@ export const FormSchema = z.object({ fields: z.array(z.string()).optional(), llm_id: z.string().optional(), system_prompt: z.string().optional(), + table_result_type: z.string().optional(), + markdown_image_response_type: z.string().optional(), }), ), }); @@ -184,6 +190,8 @@ const ParserForm = ({ node }: INextOperatorForm) => { lang: '', fields: [], llm_id: '', + table_result_type: '', + markdown_image_response_type: '', }); }, [append]); diff --git a/web/src/pages/agent/form/parser-form/pdf-form-fields.tsx b/web/src/pages/agent/form/parser-form/pdf-form-fields.tsx index 020032c5c56..82c976f0f4d 100644 --- a/web/src/pages/agent/form/parser-form/pdf-form-fields.tsx +++ b/web/src/pages/agent/form/parser-form/pdf-form-fields.tsx @@ -1,13 +1,30 @@ import { ParseDocumentType } from '@/components/layout-recognize-form-field'; +import { + SelectWithSearch, + SelectWithSearchFlagOptionType, +} from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; import { isEmpty } from 'lodash'; import { useEffect, useMemo } from 'react'; import { useFormContext, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; import { LanguageFormField, ParserMethodFormField } from './common-form-fields'; import { CommonProps } from './interface'; import { useSetInitialLanguage } from './use-set-initial-language'; import { buildFieldNameWithPrefix } from './utils'; +const tableResultTypeOptions: SelectWithSearchFlagOptionType[] = [ + { label: 'Markdown', value: '0' }, + { label: 'HTML', value: '1' }, +]; + +const markdownImageResponseTypeOptions: SelectWithSearchFlagOptionType[] = [ + { label: 'URL', value: '0' }, + { label: 'Text', value: '1' }, +]; + export function PdfFormFields({ prefix }: CommonProps) { + const { t } = useTranslation(); const form = useFormContext(); const parseMethodName = buildFieldNameWithPrefix('parse_method', prefix); @@ -25,6 +42,12 @@ export function PdfFormFields({ prefix }: CommonProps) { ); }, [parseMethod]); + const tcadpOptionsShown = useMemo(() => { + return ( + !isEmpty(parseMethod) && parseMethod === ParseDocumentType.TCADPParser + ); + }, [parseMethod]); + useSetInitialLanguage({ prefix, languageShown }); useEffect(() => { @@ -36,10 +59,68 @@ export function PdfFormFields({ prefix }: CommonProps) { } }, [form, parseMethodName]); + // Set default values for TCADP options when TCADP is selected + useEffect(() => { + if (tcadpOptionsShown) { + const tableResultTypeName = buildFieldNameWithPrefix( + 'table_result_type', + prefix, + ); + const markdownImageResponseTypeName = buildFieldNameWithPrefix( + 'markdown_image_response_type', + prefix, + ); + + if (isEmpty(form.getValues(tableResultTypeName))) { + form.setValue(tableResultTypeName, '1', { + shouldValidate: true, + shouldDirty: true, + }); + } + if (isEmpty(form.getValues(markdownImageResponseTypeName))) { + form.setValue(markdownImageResponseTypeName, '1', { + shouldValidate: true, + shouldDirty: true, + }); + } + } + }, [tcadpOptionsShown, form, prefix]); + return ( <> {languageShown && } + {tcadpOptionsShown && ( + <> + + {(field) => ( + + )} + + + {(field) => ( + + )} + + + )} ); } diff --git a/web/src/pages/agent/form/parser-form/ppt-form-fields.tsx b/web/src/pages/agent/form/parser-form/ppt-form-fields.tsx new file mode 100644 index 00000000000..18f924959cf --- /dev/null +++ b/web/src/pages/agent/form/parser-form/ppt-form-fields.tsx @@ -0,0 +1,125 @@ +import { ParseDocumentType } from '@/components/layout-recognize-form-field'; +import { + SelectWithSearch, + SelectWithSearchFlagOptionType, +} from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { isEmpty } from 'lodash'; +import { useEffect, useMemo } from 'react'; +import { useFormContext, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { ParserMethodFormField } from './common-form-fields'; +import { CommonProps } from './interface'; +import { buildFieldNameWithPrefix } from './utils'; + +const tableResultTypeOptions: SelectWithSearchFlagOptionType[] = [ + { label: 'Markdown', value: '0' }, + { label: 'HTML', value: '1' }, +]; + +const markdownImageResponseTypeOptions: SelectWithSearchFlagOptionType[] = [ + { label: 'URL', value: '0' }, + { label: 'Text', value: '1' }, +]; + +export function PptFormFields({ prefix }: CommonProps) { + const { t } = useTranslation(); + const form = useFormContext(); + + const parseMethodName = buildFieldNameWithPrefix('parse_method', prefix); + + const parseMethod = useWatch({ + name: parseMethodName, + }); + + // PPT only supports DeepDOC and TCADPParser + const optionsWithoutLLM = [ + { label: ParseDocumentType.DeepDOC, value: ParseDocumentType.DeepDOC }, + { + label: ParseDocumentType.TCADPParser, + value: ParseDocumentType.TCADPParser, + }, + ]; + + const tcadpOptionsShown = useMemo(() => { + return ( + !isEmpty(parseMethod) && parseMethod === ParseDocumentType.TCADPParser + ); + }, [parseMethod]); + + useEffect(() => { + if (isEmpty(form.getValues(parseMethodName))) { + form.setValue(parseMethodName, ParseDocumentType.DeepDOC, { + shouldValidate: true, + shouldDirty: true, + }); + } + }, [form, parseMethodName]); + + // Set default values for TCADP options when TCADP is selected + useEffect(() => { + if (tcadpOptionsShown) { + const tableResultTypeName = buildFieldNameWithPrefix( + 'table_result_type', + prefix, + ); + const markdownImageResponseTypeName = buildFieldNameWithPrefix( + 'markdown_image_response_type', + prefix, + ); + + if (isEmpty(form.getValues(tableResultTypeName))) { + form.setValue(tableResultTypeName, '1', { + shouldValidate: true, + shouldDirty: true, + }); + } + if (isEmpty(form.getValues(markdownImageResponseTypeName))) { + form.setValue(markdownImageResponseTypeName, '1', { + shouldValidate: true, + shouldDirty: true, + }); + } + } + }, [tcadpOptionsShown, form, prefix]); + + return ( + <> + + {tcadpOptionsShown && ( + <> + + {(field) => ( + + )} + + + {(field) => ( + + )} + + + )} + + ); +} diff --git a/web/src/pages/agent/form/parser-form/spreadsheet-form-fields.tsx b/web/src/pages/agent/form/parser-form/spreadsheet-form-fields.tsx new file mode 100644 index 00000000000..40715099174 --- /dev/null +++ b/web/src/pages/agent/form/parser-form/spreadsheet-form-fields.tsx @@ -0,0 +1,125 @@ +import { ParseDocumentType } from '@/components/layout-recognize-form-field'; +import { + SelectWithSearch, + SelectWithSearchFlagOptionType, +} from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { isEmpty } from 'lodash'; +import { useEffect, useMemo } from 'react'; +import { useFormContext, useWatch } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { ParserMethodFormField } from './common-form-fields'; +import { CommonProps } from './interface'; +import { buildFieldNameWithPrefix } from './utils'; + +const tableResultTypeOptions: SelectWithSearchFlagOptionType[] = [ + { label: 'Markdown', value: '0' }, + { label: 'HTML', value: '1' }, +]; + +const markdownImageResponseTypeOptions: SelectWithSearchFlagOptionType[] = [ + { label: 'URL', value: '0' }, + { label: 'Text', value: '1' }, +]; + +export function SpreadsheetFormFields({ prefix }: CommonProps) { + const { t } = useTranslation(); + const form = useFormContext(); + + const parseMethodName = buildFieldNameWithPrefix('parse_method', prefix); + + const parseMethod = useWatch({ + name: parseMethodName, + }); + + // Spreadsheet only supports DeepDOC and TCADPParser + const optionsWithoutLLM = [ + { label: ParseDocumentType.DeepDOC, value: ParseDocumentType.DeepDOC }, + { + label: ParseDocumentType.TCADPParser, + value: ParseDocumentType.TCADPParser, + }, + ]; + + const tcadpOptionsShown = useMemo(() => { + return ( + !isEmpty(parseMethod) && parseMethod === ParseDocumentType.TCADPParser + ); + }, [parseMethod]); + + useEffect(() => { + if (isEmpty(form.getValues(parseMethodName))) { + form.setValue(parseMethodName, ParseDocumentType.DeepDOC, { + shouldValidate: true, + shouldDirty: true, + }); + } + }, [form, parseMethodName]); + + // Set default values for TCADP options when TCADP is selected + useEffect(() => { + if (tcadpOptionsShown) { + const tableResultTypeName = buildFieldNameWithPrefix( + 'table_result_type', + prefix, + ); + const markdownImageResponseTypeName = buildFieldNameWithPrefix( + 'markdown_image_response_type', + prefix, + ); + + if (isEmpty(form.getValues(tableResultTypeName))) { + form.setValue(tableResultTypeName, '1', { + shouldValidate: true, + shouldDirty: true, + }); + } + if (isEmpty(form.getValues(markdownImageResponseTypeName))) { + form.setValue(markdownImageResponseTypeName, '1', { + shouldValidate: true, + shouldDirty: true, + }); + } + } + }, [tcadpOptionsShown, form, prefix]); + + return ( + <> + + {tcadpOptionsShown && ( + <> + + {(field) => ( + + )} + + + {(field) => ( + + )} + + + )} + + ); +} diff --git a/web/src/pages/agent/form/parser-form/video-form-fields.tsx b/web/src/pages/agent/form/parser-form/video-form-fields.tsx index 628b8dc9c69..3f37a9864eb 100644 --- a/web/src/pages/agent/form/parser-form/video-form-fields.tsx +++ b/web/src/pages/agent/form/parser-form/video-form-fields.tsx @@ -1,5 +1,5 @@ import { LlmModelType } from '@/constants/knowledge'; -import { useComposeLlmOptionsByModelTypes } from '@/hooks/llm-hooks'; +import { useComposeLlmOptionsByModelTypes } from '@/hooks/use-llm-request'; import { LargeModelFormField, OutputFormatFormFieldProps, diff --git a/web/src/pages/agent/form/pdf-generator-form/index.tsx b/web/src/pages/agent/form/pdf-generator-form/index.tsx new file mode 100644 index 00000000000..110bb63691b --- /dev/null +++ b/web/src/pages/agent/form/pdf-generator-form/index.tsx @@ -0,0 +1,535 @@ +import { FormContainer } from '@/components/form-container'; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { RAGFlowSelect } from '@/components/ui/select'; +import { Switch } from '@/components/ui/switch'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { t } from 'i18next'; +import { memo, useMemo } from 'react'; +import { useForm } from 'react-hook-form'; +import { z } from 'zod'; +import { + PDFGeneratorFontFamily, + PDFGeneratorLogoPosition, + PDFGeneratorOrientation, + PDFGeneratorPageSize, +} from '../../constant'; +import { INextOperatorForm } from '../../interface'; +import { FormWrapper } from '../components/form-wrapper'; +import { Output, transferOutputs } from '../components/output'; +import { PromptEditor } from '../components/prompt-editor'; +import { useValues } from './use-values'; +import { useWatchFormChange } from './use-watch-form-change'; + +function PDFGeneratorForm({ node }: INextOperatorForm) { + const values = useValues(node); + + const FormSchema = z.object({ + output_format: z.string().default('pdf'), + content: z.string().min(1, 'Content is required'), + title: z.string().optional(), + subtitle: z.string().optional(), + header_text: z.string().optional(), + footer_text: z.string().optional(), + logo_image: z.string().optional(), + logo_position: z.string(), + logo_width: z.number(), + logo_height: z.number(), + font_family: z.string(), + font_size: z.number(), + title_font_size: z.number(), + heading1_font_size: z.number(), + heading2_font_size: z.number(), + heading3_font_size: z.number(), + text_color: z.string(), + title_color: z.string(), + page_size: z.string(), + orientation: z.string(), + margin_top: z.number(), + margin_bottom: z.number(), + margin_left: z.number(), + margin_right: z.number(), + line_spacing: z.number(), + filename: z.string().optional(), + output_directory: z.string(), + add_page_numbers: z.boolean(), + add_timestamp: z.boolean(), + watermark_text: z.string().optional(), + enable_toc: z.boolean(), + outputs: z + .object({ + file_path: z.object({ type: z.string() }), + pdf_base64: z.object({ type: z.string() }), + success: z.object({ type: z.string() }), + }) + .optional(), + }); + + const form = useForm>({ + defaultValues: values, + resolver: zodResolver(FormSchema), + }); + + const outputList = useMemo(() => { + return transferOutputs(values.outputs); + }, [values.outputs]); + + useWatchFormChange(node?.id, form); + + return ( +
    + + + {/* Output Format Selection */} + ( + + Output Format + + + + + Choose the output document format + + + + )} + /> + + {/* Content Section */} + ( + + {t('flow.content')} + + + + +
    +
    + Markdown support: **bold**, *italic*, + `code`, # Heading 1, ## Heading 2 +
    +
    + Lists: - bullet or 1. numbered +
    +
    + Tables: | Column 1 | Column 2 | (use | to + separate columns, <br> or \n for line breaks in + cells) +
    +
    + Other: --- for horizontal line, ``` for + code blocks +
    +
    +
    + +
    + )} + /> + + {/* Title & Subtitle */} + ( + + {t('flow.title')} + + + + + + )} + /> + + ( + + {t('flow.subtitle')} + + + + + + )} + /> + + {/* Logo Settings */} + ( + + {t('flow.logoImage')} + +
    + { + const file = e.target.files?.[0]; + if (file) { + const reader = new FileReader(); + reader.onloadend = () => { + field.onChange(reader.result as string); + }; + reader.readAsDataURL(file); + } + }} + className="cursor-pointer" + /> + +
    +
    + + Upload an image file or paste a file path/URL/base64 + + +
    + )} + /> + + ( + + {t('flow.logoPosition')} + + ({ label: val, value: val }), + )} + > + + + + )} + /> + +
    + ( + + {t('flow.logoWidth')} (inches) + + + field.onChange(parseFloat(e.target.value)) + } + /> + + + + )} + /> + + ( + + {t('flow.logoHeight')} (inches) + + + field.onChange(parseFloat(e.target.value)) + } + /> + + + + )} + /> +
    + + {/* Font Settings */} + ( + + {t('flow.fontFamily')} + + ({ label: val, value: val }), + )} + > + + + + )} + /> + +
    + ( + + {t('flow.fontSize')} + + field.onChange(parseInt(e.target.value))} + /> + + + + )} + /> + + ( + + {t('flow.titleFontSize')} + + field.onChange(parseInt(e.target.value))} + /> + + + + )} + /> +
    + + {/* Page Settings */} + ( + + {t('flow.pageSize')} + + ({ + label: val, + value: val, + }))} + > + + + + )} + /> + + ( + + {t('flow.orientation')} + + ({ label: val, value: val }), + )} + > + + + + )} + /> + + {/* Margins */} +
    + ( + + {t('flow.marginTop')} (inches) + + + field.onChange(parseFloat(e.target.value)) + } + /> + + + + )} + /> + + ( + + {t('flow.marginBottom')} (inches) + + + field.onChange(parseFloat(e.target.value)) + } + /> + + + + )} + /> +
    + + {/* Output Settings */} + ( + + {t('flow.filename')} + + + + + + )} + /> + + ( + + {t('flow.outputDirectory')} + + + + + + )} + /> + + {/* Additional Options */} + ( + +
    + {t('flow.addPageNumbers')} + + Add page numbers to the document + +
    + + + +
    + )} + /> + + ( + +
    + {t('flow.addTimestamp')} + + Add generation timestamp to the document + +
    + + + +
    + )} + /> + + ( + + {t('flow.watermarkText')} + + + + + + )} + /> + +
    } + /> +
    +
    +
    + +
    +
    + ); +} + +export default memo(PDFGeneratorForm); diff --git a/web/src/pages/agent/form/pdf-generator-form/use-values.ts b/web/src/pages/agent/form/pdf-generator-form/use-values.ts new file mode 100644 index 00000000000..1ecd8290893 --- /dev/null +++ b/web/src/pages/agent/form/pdf-generator-form/use-values.ts @@ -0,0 +1,11 @@ +import { useMemo } from 'react'; +import { Node } from 'reactflow'; +import { initialPDFGeneratorValues } from '../../constant'; + +export const useValues = (node?: Node) => { + const values = useMemo(() => { + return node?.data.form ?? initialPDFGeneratorValues; + }, [node?.data.form]); + + return values; +}; diff --git a/web/src/pages/agent/form/pdf-generator-form/use-watch-form-change.ts b/web/src/pages/agent/form/pdf-generator-form/use-watch-form-change.ts new file mode 100644 index 00000000000..f8f4de3db62 --- /dev/null +++ b/web/src/pages/agent/form/pdf-generator-form/use-watch-form-change.ts @@ -0,0 +1,19 @@ +import { useEffect } from 'react'; +import { UseFormReturn } from 'react-hook-form'; +import useGraphStore from '../../store'; + +export const useWatchFormChange = ( + nodeId: string | undefined, + form: UseFormReturn, +) => { + const updateNodeForm = useGraphStore((state) => state.updateNodeForm); + + useEffect(() => { + const { unsubscribe } = form.watch((value) => { + if (nodeId) { + updateNodeForm(nodeId, value); + } + }); + return () => unsubscribe(); + }, [form, nodeId, updateNodeForm]); +}; diff --git a/web/src/pages/agent/form/qweather-form/index.tsx b/web/src/pages/agent/form/qweather-form/index.tsx deleted file mode 100644 index eee088762ad..00000000000 --- a/web/src/pages/agent/form/qweather-form/index.tsx +++ /dev/null @@ -1,157 +0,0 @@ -import { - Form, - FormControl, - FormField, - FormItem, - FormLabel, - FormMessage, -} from '@/components/ui/form'; -import { Input } from '@/components/ui/input'; -import { RAGFlowSelect } from '@/components/ui/select'; -import { useCallback, useMemo } from 'react'; -import { useTranslation } from 'react-i18next'; -import { INextOperatorForm } from '../../interface'; -import { - QWeatherLangOptions, - QWeatherTimePeriodOptions, - QWeatherTypeOptions, - QWeatherUserTypeOptions, -} from '../../options'; -import { DynamicInputVariable } from '../components/next-dynamic-input-variable'; - -enum FormFieldName { - Type = 'type', - UserType = 'user_type', -} - -const QWeatherForm = ({ form, node }: INextOperatorForm) => { - const { t } = useTranslation(); - const typeValue = form.watch(FormFieldName.Type); - - const qWeatherLangOptions = useMemo(() => { - return QWeatherLangOptions.map((x) => ({ - value: x, - label: t(`flow.qWeatherLangOptions.${x}`), - })); - }, [t]); - - const qWeatherTypeOptions = useMemo(() => { - return QWeatherTypeOptions.map((x) => ({ - value: x, - label: t(`flow.qWeatherTypeOptions.${x}`), - })); - }, [t]); - - const qWeatherUserTypeOptions = useMemo(() => { - return QWeatherUserTypeOptions.map((x) => ({ - value: x, - label: t(`flow.qWeatherUserTypeOptions.${x}`), - })); - }, [t]); - - const getQWeatherTimePeriodOptions = useCallback(() => { - let options = QWeatherTimePeriodOptions; - const userType = form.getValues(FormFieldName.UserType); - if (userType === 'free') { - options = options.slice(0, 3); - } - return options.map((x) => ({ - value: x, - label: t(`flow.qWeatherTimePeriodOptions.${x}`), - })); - }, [form, t]); - - return ( -
    - { - e.preventDefault(); - }} - > - - ( - - {t('flow.webApiKey')} - - - - - - )} - /> - ( - - {t('flow.lang')} - - - - - - )} - /> - ( - - {t('flow.type')} - - - - - - )} - /> - ( - - {t('flow.userType')} - - - - - - )} - /> - {typeValue === 'weather' && ( - ( - - {t('flow.timePeriod')} - - - - - - )} - /> - )} - - - ); -}; - -export default QWeatherForm; diff --git a/web/src/pages/agent/form/relevant-form/hooks.ts b/web/src/pages/agent/form/relevant-form/hooks.ts deleted file mode 100644 index 413a0ac3834..00000000000 --- a/web/src/pages/agent/form/relevant-form/hooks.ts +++ /dev/null @@ -1,41 +0,0 @@ -import pick from 'lodash/pick'; -import { useCallback, useEffect } from 'react'; -import { IOperatorForm } from '../../interface'; -import useGraphStore from '../../store'; - -export const useBuildRelevantOptions = () => { - const nodes = useGraphStore((state) => state.nodes); - - const buildRelevantOptions = useCallback( - (toList: string[]) => { - return nodes - .filter( - (x) => !toList.some((y) => y === x.id), // filter out selected values ​​in other to fields from the current drop-down box options - ) - .map((x) => ({ label: x.data.name, value: x.id })); - }, - [nodes], - ); - - return buildRelevantOptions; -}; - -/** - * monitor changes in the connection and synchronize the target to the yes and no fields of the form - * similar to the categorize-form's useHandleFormValuesChange method - * @param param0 - */ -export const useWatchConnectionChanges = ({ nodeId, form }: IOperatorForm) => { - const getNode = useGraphStore((state) => state.getNode); - const node = getNode(nodeId); - - const watchFormChanges = useCallback(() => { - if (node) { - form?.setFieldsValue(pick(node, ['yes', 'no'])); - } - }, [node, form]); - - useEffect(() => { - watchFormChanges(); - }, [watchFormChanges]); -}; diff --git a/web/src/pages/agent/form/relevant-form/index.tsx b/web/src/pages/agent/form/relevant-form/index.tsx deleted file mode 100644 index e2366f6f05e..00000000000 --- a/web/src/pages/agent/form/relevant-form/index.tsx +++ /dev/null @@ -1,49 +0,0 @@ -import LLMSelect from '@/components/llm-select'; -import { useTranslate } from '@/hooks/common-hooks'; -import { Form, Select } from 'antd'; -import { Operator } from '../../constant'; -import { useBuildFormSelectOptions } from '../../form-hooks'; -import { IOperatorForm } from '../../interface'; -import { useWatchConnectionChanges } from './hooks'; - -const RelevantForm = ({ onValuesChange, form, node }: IOperatorForm) => { - const { t } = useTranslate('flow'); - const buildRelevantOptions = useBuildFormSelectOptions( - Operator.Relevant, - node?.id, - ); - useWatchConnectionChanges({ nodeId: node?.id, form }); - - return ( -
    - - - - - - -
    - ); -}; - -export default RelevantForm; diff --git a/web/src/pages/agent/form/retrieval-form/next.tsx b/web/src/pages/agent/form/retrieval-form/next.tsx index 848c9496789..345efe43abf 100644 --- a/web/src/pages/agent/form/retrieval-form/next.tsx +++ b/web/src/pages/agent/form/retrieval-form/next.tsx @@ -2,6 +2,7 @@ import { Collapse } from '@/components/collapse'; import { CrossLanguageFormField } from '@/components/cross-language-form-field'; import { FormContainer } from '@/components/form-container'; import { KnowledgeBaseFormField } from '@/components/knowledge-base-item'; +import { MemoriesFormField } from '@/components/memories-form-field'; import { MetadataFilter, MetadataFilterSchema, @@ -19,14 +20,20 @@ import { FormLabel, FormMessage, } from '@/components/ui/form'; +import { Radio } from '@/components/ui/radio'; import { Textarea } from '@/components/ui/textarea'; import { UseKnowledgeGraphFormField } from '@/components/use-knowledge-graph-item'; import { zodResolver } from '@hookform/resolvers/zod'; import { memo, useMemo } from 'react'; -import { useForm, useFormContext } from 'react-hook-form'; +import { + UseFormReturn, + useForm, + useFormContext, + useWatch, +} from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { z } from 'zod'; -import { initialRetrievalValues } from '../../constant'; +import { RetrievalFrom, initialRetrievalValues } from '../../constant'; import { useWatchFormChange } from '../../hooks/use-watch-form-change'; import { INextOperatorForm } from '../../interface'; import { FormWrapper } from '../components/form-wrapper'; @@ -46,6 +53,8 @@ export const RetrievalPartialSchema = { use_kg: z.boolean(), toc_enhance: z.boolean(), ...MetadataFilterSchema, + memory_ids: z.array(z.string()).optional(), + retrieval_from: z.string(), }; export const FormSchema = z.object({ @@ -53,6 +62,44 @@ export const FormSchema = z.object({ ...RetrievalPartialSchema, }); +export type RetrievalFormSchemaType = z.infer; + +export function MemoryDatasetForm() { + const { t } = useTranslation(); + const form = useFormContext(); + const retrievalFrom = useWatch({ + control: form.control, + name: 'retrieval_from', + }); + + return ( + <> + + + + {t('knowledgeDetails.dataset')} + + {t('header.memories')} + + + {retrievalFrom === RetrievalFrom.Memory ? ( + + ) : ( + + )} + + ); +} + +export function useHideKnowledgeGraphField(form: UseFormReturn) { + const retrievalFrom = useWatch({ + control: form.control, + name: 'retrieval_from', + }); + + return retrievalFrom === RetrievalFrom.Memory; +} + export function EmptyResponseField() { const { t } = useTranslation(); const form = useFormContext(); @@ -104,17 +151,17 @@ function RetrievalForm({ node }: INextOperatorForm) { resolver: zodResolver(FormSchema), }); + const hideKnowledgeGraphField = useHideKnowledgeGraphField(form); + useWatchFormChange(node?.id, form); return (
    - - - - - - + + + + {t('flow.advancedSettings')}
  • }> - - + {hideKnowledgeGraphField || ( + <> + + + + )} - - - + {hideKnowledgeGraphField || ( + <> + + + + + )} diff --git a/web/src/pages/agent/form/splitter-form/index.tsx b/web/src/pages/agent/form/splitter-form/index.tsx index 0438dcf8dca..f4dcb741883 100644 --- a/web/src/pages/agent/form/splitter-form/index.tsx +++ b/web/src/pages/agent/form/splitter-form/index.tsx @@ -2,7 +2,8 @@ import { DelimiterInput } from '@/components/delimiter-form-field'; import { RAGFlowFormItem } from '@/components/ragflow-form'; import { SliderInputFormField } from '@/components/slider-input-form-field'; import { BlockButton, Button } from '@/components/ui/button'; -import { Form } from '@/components/ui/form'; +import { Form, FormControl, FormField, FormItem } from '@/components/ui/form'; +import { Switch } from '@/components/ui/switch'; import { zodResolver } from '@hookform/resolvers/zod'; import { Trash2 } from 'lucide-react'; import { memo } from 'react'; @@ -21,11 +22,18 @@ const outputList = buildOutputList(initialSplitterValues.outputs); export const FormSchema = z.object({ chunk_token_size: z.number(), + image_table_context_window: z.number(), delimiters: z.array( z.object({ value: z.string().optional(), }), ), + enable_children: z.boolean(), + children_delimiters: z.array( + z.object({ + value: z.string().optional(), + }), + ), overlapped_percent: z.number(), // 0.0 - 0.3 , 0% - 30% }); @@ -46,6 +54,11 @@ const SplitterForm = ({ node }: INextOperatorForm) => { control: form.control, }); + const childrenDelimiters = useFieldArray({ + name: 'children_delimiters', + control: form.control, + }); + useWatchFormChange(node?.id, form); return ( @@ -62,6 +75,13 @@ const SplitterForm = ({ node }: INextOperatorForm) => { min={0} label={t('flow.overlappedPercent')} > +
    {t('flow.delimiters')}
    @@ -90,6 +110,59 @@ const SplitterForm = ({ node }: INextOperatorForm) => { append({ value: '\n' })}> {t('common.add')} + +
    +
    + {t('flow.enableChildrenDelimiters')} + + ( + + + + + + )} + /> +
    + + {form.getValues('enable_children') && ( +
    + {childrenDelimiters.fields.map((field, index) => ( +
    + + + + + +
    + ))} + + childrenDelimiters.append({ value: '\n' })} + > + {t('common.add')} + +
    + )} +
    diff --git a/web/src/pages/agent/form/switch-form/index.tsx b/web/src/pages/agent/form/switch-form/index.tsx index 6d6849147c7..53f4995afc0 100644 --- a/web/src/pages/agent/form/switch-form/index.tsx +++ b/web/src/pages/agent/form/switch-form/index.tsx @@ -1,5 +1,4 @@ import { FormContainer } from '@/components/form-container'; -import { SelectWithSearch } from '@/components/originui/select-with-search'; import { BlockButton, Button } from '@/components/ui/button'; import { Card, CardContent } from '@/components/ui/card'; import { @@ -12,20 +11,20 @@ import { import { RAGFlowSelect } from '@/components/ui/select'; import { Separator } from '@/components/ui/separator'; import { Textarea } from '@/components/ui/textarea'; +import { SwitchLogicOperator } from '@/constants/agent'; import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options'; +import { useBuildSwitchLogicOperatorOptions } from '@/hooks/logic-hooks/use-build-options'; import { cn } from '@/lib/utils'; import { zodResolver } from '@hookform/resolvers/zod'; import { t } from 'i18next'; -import { toLower } from 'lodash'; import { X } from 'lucide-react'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; import { useFieldArray, useForm, useFormContext } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { z } from 'zod'; -import { SwitchLogicOperatorOptions, VariableType } from '../../constant'; -import { useBuildQueryVariableOptions } from '../../hooks/use-get-begin-query'; import { IOperatorForm } from '../../interface'; import { FormWrapper } from '../components/form-wrapper'; +import { QueryVariable } from '../components/query-variable'; import { useValues } from './use-values'; import { useWatchFormChange } from './use-watch-change'; @@ -47,19 +46,6 @@ function ConditionCards({ }: ConditionCardsProps) { const form = useFormContext(); - const nextOptions = useBuildQueryVariableOptions(); - - const finalOptions = useMemo(() => { - return nextOptions.map((x) => { - return { - ...x, - options: x.options.filter( - (y) => !toLower(y.type).includes(VariableType.Array), - ), - }; - }); - }, [nextOptions]); - const switchOperatorOptions = useBuildSwitchOperatorOptions(); const name = `${parentName}.${ItemKey}`; @@ -101,11 +87,11 @@ function ConditionCards({ render={({ field }) => ( - + hideLabel + > @@ -200,12 +186,7 @@ function SwitchForm({ node }: IOperatorForm) { control: form.control, }); - const switchLogicOperatorOptions = useMemo(() => { - return SwitchLogicOperatorOptions.map((x) => ({ - value: x, - label: t(`flow.switchLogicOperatorOptions.${x}`), - })); - }, [t]); + const switchLogicOperatorOptions = useBuildSwitchLogicOperatorOptions(); useWatchFormChange(node?.id, form); @@ -268,7 +249,7 @@ function SwitchForm({ node }: IOperatorForm) { append({ - logical_operator: SwitchLogicOperatorOptions[0], + logical_operator: SwitchLogicOperator.And, [ItemKey]: [ { operator: switchOperatorOptions[0].value, diff --git a/web/src/pages/agent/form/tool-form/constant.tsx b/web/src/pages/agent/form/tool-form/constant.tsx index fc5f4e94c10..4f93ddb50d0 100644 --- a/web/src/pages/agent/form/tool-form/constant.tsx +++ b/web/src/pages/agent/form/tool-form/constant.tsx @@ -1,5 +1,4 @@ import { Operator } from '../../constant'; -import AkShareForm from '../akshare-form'; import ArXivForm from './arxiv-form'; import BingForm from './bing-form'; import CrawlerForm from './crawler-form'; @@ -29,7 +28,6 @@ export const ToolFormConfigMap = { [Operator.GoogleScholar]: GoogleScholarForm, [Operator.GitHub]: GithubForm, [Operator.ExeSQL]: ExeSQLForm, - [Operator.AkShare]: AkShareForm, [Operator.YahooFinance]: YahooFinanceForm, [Operator.Crawler]: CrawlerForm, [Operator.Email]: EmailForm, diff --git a/web/src/pages/agent/form/tool-form/index.tsx b/web/src/pages/agent/form/tool-form/index.tsx index 9c03870b4db..639c55c8688 100644 --- a/web/src/pages/agent/form/tool-form/index.tsx +++ b/web/src/pages/agent/form/tool-form/index.tsx @@ -7,9 +7,11 @@ const EmptyContent = () =>
    ; function ToolForm() { const clickedToolId = useGraphStore((state) => state.clickedToolId); + const { getAgentToolById } = useGraphStore(); + const tool = getAgentToolById(clickedToolId); const ToolForm = - ToolFormConfigMap[clickedToolId as keyof typeof ToolFormConfigMap] ?? + ToolFormConfigMap[tool?.component_name as keyof typeof ToolFormConfigMap] ?? MCPForm ?? EmptyContent; diff --git a/web/src/pages/agent/form/tool-form/retrieval-form/index.tsx b/web/src/pages/agent/form/tool-form/retrieval-form/index.tsx index 5ee53179d8d..84cb0896c0c 100644 --- a/web/src/pages/agent/form/tool-form/retrieval-form/index.tsx +++ b/web/src/pages/agent/form/tool-form/retrieval-form/index.tsx @@ -1,7 +1,6 @@ import { Collapse } from '@/components/collapse'; import { CrossLanguageFormField } from '@/components/cross-language-form-field'; import { FormContainer } from '@/components/form-container'; -import { KnowledgeBaseFormField } from '@/components/knowledge-base-item'; import { MetadataFilter } from '@/components/metadata-filter'; import { RerankFormFields } from '@/components/rerank'; import { SimilaritySliderFormField } from '@/components/similarity-slider'; @@ -17,7 +16,9 @@ import { DescriptionField } from '../../components/description-field'; import { FormWrapper } from '../../components/form-wrapper'; import { EmptyResponseField, + MemoryDatasetForm, RetrievalPartialSchema, + useHideKnowledgeGraphField, } from '../../retrieval-form/next'; import { useValues } from '../use-values'; import { useWatchFormChange } from '../use-watch-change'; @@ -35,15 +36,15 @@ const RetrievalForm = () => { resolver: zodResolver(FormSchema), }); + const hideKnowledgeGraphField = useHideKnowledgeGraphField(form); + useWatchFormChange(form); return ( - - - - + + {t('flow.advancedSettings')}
    }> { isTooltipShown > - - + {hideKnowledgeGraphField || ( + <> + + + + )} + - - - + {hideKnowledgeGraphField || ( + <> + + + + + )} diff --git a/web/src/pages/agent/form/tool-form/use-values.ts b/web/src/pages/agent/form/tool-form/use-values.ts index 59a2e090fdb..6000b6c070e 100644 --- a/web/src/pages/agent/form/tool-form/use-values.ts +++ b/web/src/pages/agent/form/tool-form/use-values.ts @@ -3,7 +3,6 @@ import { useMemo } from 'react'; import { Operator } from '../../constant'; import { useAgentToolInitialValues } from '../../hooks/use-agent-tool-initial-values'; import useGraphStore from '../../store'; -import { getAgentNodeTools } from '../../utils'; export enum SearchDepth { Basic = 'basic', @@ -16,22 +15,23 @@ export enum Topic { } export function useValues() { - const { clickedToolId, clickedNodeId, findUpstreamNodeById } = useGraphStore( - (state) => state, - ); + const { + clickedToolId, + clickedNodeId, + findUpstreamNodeById, + getAgentToolById, + } = useGraphStore(); + const { initializeAgentToolValues } = useAgentToolInitialValues(); const values = useMemo(() => { const agentNode = findUpstreamNodeById(clickedNodeId); - const tools = getAgentNodeTools(agentNode); - - const formData = tools.find( - (x) => x.component_name === clickedToolId, - )?.params; + const tool = getAgentToolById(clickedToolId, agentNode!); + const formData = tool?.params; if (isEmpty(formData)) { const defaultValues = initializeAgentToolValues( - clickedNodeId as Operator, + (tool?.component_name || clickedNodeId) as Operator, ); return defaultValues; @@ -44,6 +44,7 @@ export function useValues() { clickedNodeId, clickedToolId, findUpstreamNodeById, + getAgentToolById, initializeAgentToolValues, ]); diff --git a/web/src/pages/agent/form/tool-form/use-watch-change.ts b/web/src/pages/agent/form/tool-form/use-watch-change.ts index 81a70a235f1..3807592c569 100644 --- a/web/src/pages/agent/form/tool-form/use-watch-change.ts +++ b/web/src/pages/agent/form/tool-form/use-watch-change.ts @@ -1,39 +1,38 @@ import { useEffect } from 'react'; import { UseFormReturn, useWatch } from 'react-hook-form'; import useGraphStore from '../../store'; -import { getAgentNodeTools } from '../../utils'; export function useWatchFormChange(form?: UseFormReturn) { let values = useWatch({ control: form?.control }); - const { clickedToolId, clickedNodeId, findUpstreamNodeById, updateNodeForm } = - useGraphStore((state) => state); + + const { + clickedToolId, + clickedNodeId, + findUpstreamNodeById, + getAgentToolById, + updateAgentToolById, + updateNodeForm, + } = useGraphStore(); useEffect(() => { const agentNode = findUpstreamNodeById(clickedNodeId); // Manually triggered form updates are synchronized to the canvas if (agentNode && form?.formState.isDirty) { - const agentNodeId = agentNode?.id; - const tools = getAgentNodeTools(agentNode); - - values = form?.getValues(); - const nextTools = tools.map((x) => { - if (x.component_name === clickedToolId) { - return { - ...x, - params: { - ...values, - }, - }; - } - return x; + updateAgentToolById(agentNode, clickedToolId, { + params: { + ...(values ?? {}), + }, }); - - const nextValues = { - ...(agentNode?.data?.form ?? {}), - tools: nextTools, - }; - - updateNodeForm(agentNodeId, nextValues); } - }, [form?.formState.isDirty, updateNodeForm, values]); + }, [ + clickedNodeId, + clickedToolId, + findUpstreamNodeById, + form, + form?.formState.isDirty, + getAgentToolById, + updateAgentToolById, + updateNodeForm, + values, + ]); } diff --git a/web/src/pages/agent/form/tushare-form/index.tsx b/web/src/pages/agent/form/tushare-form/index.tsx deleted file mode 100644 index a64bf25bf86..00000000000 --- a/web/src/pages/agent/form/tushare-form/index.tsx +++ /dev/null @@ -1,83 +0,0 @@ -import { useTranslate } from '@/hooks/common-hooks'; -import { DatePicker, DatePickerProps, Form, Input, Select } from 'antd'; -import dayjs from 'dayjs'; -import { useCallback, useMemo } from 'react'; -import { IOperatorForm } from '../../interface'; -import { TuShareSrcOptions } from '../../options'; -import DynamicInputVariable from '../components/dynamic-input-variable'; - -const DateTimePicker = ({ - onChange, - value, -}: { - onChange?: (val: number | undefined) => void; - value?: number | undefined; -}) => { - const handleChange: DatePickerProps['onChange'] = useCallback( - (val: any) => { - const nextVal = val?.format('YYYY-MM-DD HH:mm:ss'); - onChange?.(nextVal ? nextVal : undefined); - }, - [onChange], - ); - // The value needs to be converted into a string and saved to the backend - const nextValue = useMemo(() => { - if (value) { - return dayjs(value); - } - return undefined; - }, [value]); - - return ( - - ); -}; - -const TuShareForm = ({ onValuesChange, form, node }: IOperatorForm) => { - const { t } = useTranslate('flow'); - - const tuShareSrcOptions = useMemo(() => { - return TuShareSrcOptions.map((x) => ({ - value: x, - label: t(`tuShareSrcOptions.${x}`), - })); - }, [t]); - - return ( - - - - - - - - - - - - - - - - - - - ); -}; - -export default TuShareForm; diff --git a/web/src/pages/agent/form/variable-assigner-form/dynamic-variables.tsx b/web/src/pages/agent/form/variable-assigner-form/dynamic-variables.tsx new file mode 100644 index 00000000000..4e6810a96d2 --- /dev/null +++ b/web/src/pages/agent/form/variable-assigner-form/dynamic-variables.tsx @@ -0,0 +1,270 @@ +import { SelectWithSearch } from '@/components/originui/select-with-search'; +import { RAGFlowFormItem } from '@/components/ragflow-form'; +import { useIsDarkTheme } from '@/components/theme-provider'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Label } from '@/components/ui/label'; +import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group'; +import { Separator } from '@/components/ui/separator'; +import { Textarea } from '@/components/ui/textarea'; +import Editor, { loader } from '@monaco-editor/react'; +import * as RadioGroupPrimitive from '@radix-ui/react-radio-group'; +import { X } from 'lucide-react'; +import { ReactNode, useCallback } from 'react'; +import { useFieldArray, useFormContext } from 'react-hook-form'; +import { + JsonSchemaDataType, + VariableAssignerLogicalArrayOperator, + VariableAssignerLogicalNumberOperator, + VariableAssignerLogicalOperator, +} from '../../constant'; +import { useGetVariableLabelOrTypeByValue } from '../../hooks/use-get-begin-query'; +import { getArrayElementType } from '../../utils'; +import { DynamicFormHeader } from '../components/dynamic-fom-header'; +import { QueryVariable } from '../components/query-variable'; +import { useBuildLogicalOptions } from './use-build-logical-options'; + +loader.config({ paths: { vs: '/vs' } }); + +type SelectKeysProps = { + name: string; + label: ReactNode; + tooltip?: string; + keyField?: string; + valueField?: string; + operatorField?: string; +}; + +type RadioGroupProps = React.ComponentProps; + +type RadioButtonProps = Partial< + Omit & { + onChange: RadioGroupProps['onValueChange']; + } +>; + +function RadioButton({ value, onChange }: RadioButtonProps) { + return ( + +
    + + +
    +
    + + +
    +
    + ); +} + +const EmptyFields = [ + VariableAssignerLogicalOperator.Clear, + VariableAssignerLogicalArrayOperator.RemoveFirst, + VariableAssignerLogicalArrayOperator.RemoveLast, +]; + +const EmptyValueMap = { + [JsonSchemaDataType.String]: '', + [JsonSchemaDataType.Number]: 0, + [JsonSchemaDataType.Boolean]: 'yes', + [JsonSchemaDataType.Object]: '{}', + [JsonSchemaDataType.Array]: [], +}; + +export function DynamicVariables({ + name, + label, + tooltip, + keyField = 'variable', + valueField = 'parameter', + operatorField = 'operator', +}: SelectKeysProps) { + const form = useFormContext(); + const { getType } = useGetVariableLabelOrTypeByValue(); + const isDarkTheme = useIsDarkTheme(); + + const { fields, remove, append } = useFieldArray({ + name: name, + control: form.control, + }); + + const { buildLogicalOptions } = useBuildLogicalOptions(); + + const getVariableType = useCallback( + (keyFieldName: string) => { + const key = form.getValues(keyFieldName); + return getType(key); + }, + [form, getType], + ); + + const renderParameter = useCallback( + (keyFieldName: string, operatorFieldName: string) => { + const logicalOperator = form.getValues(operatorFieldName); + const type = getVariableType(keyFieldName); + + if (EmptyFields.includes(logicalOperator)) { + return null; + } else if ( + logicalOperator === VariableAssignerLogicalOperator.Overwrite || + VariableAssignerLogicalArrayOperator.Extend === logicalOperator + ) { + return ( + + ); + } else if (logicalOperator === VariableAssignerLogicalOperator.Set) { + if (type === JsonSchemaDataType.Boolean) { + return ; + } + + if (type === JsonSchemaDataType.Number) { + return ; + } + + if (type === JsonSchemaDataType.Object) { + return ( + + ); + } + + if (type === JsonSchemaDataType.String) { + return ; + } + } else if ( + Object.values(VariableAssignerLogicalNumberOperator).some( + (x) => logicalOperator === x, + ) + ) { + return ; + } else if ( + logicalOperator === VariableAssignerLogicalArrayOperator.Append + ) { + const subType = getArrayElementType(type); + return ( + + ); + } + }, + [form, getVariableType, isDarkTheme], + ); + + const handleVariableChange = useCallback( + (operatorFieldAlias: string, valueFieldAlias: string) => { + return () => { + form.setValue( + operatorFieldAlias, + VariableAssignerLogicalOperator.Overwrite, + { shouldDirty: true, shouldValidate: true }, + ); + + form.setValue(valueFieldAlias, '', { + shouldDirty: true, + shouldValidate: true, + }); + }; + }, + [form], + ); + + const handleOperatorChange = useCallback( + (valueFieldAlias: string, keyFieldAlias: string, value: string) => { + const type = getVariableType(keyFieldAlias); + + let parameter = EmptyValueMap[type as keyof typeof EmptyValueMap]; + + if (value === VariableAssignerLogicalOperator.Overwrite) { + parameter = ''; + } + + if (value !== VariableAssignerLogicalOperator.Clear) { + form.setValue(valueFieldAlias, parameter, { + shouldDirty: true, + shouldValidate: true, + }); + } + }, + [form, getVariableType], + ); + + return ( +
    + append({ [keyField]: '', [valueField]: '' })} + > + +
    + {fields.map((field, index) => { + const keyFieldAlias = `${name}.${index}.${keyField}`; + const valueFieldAlias = `${name}.${index}.${valueField}`; + const operatorFieldAlias = `${name}.${index}.${operatorField}`; + + return ( +
    +
    +
    + + + + + + {({ onChange, value }) => ( + { + handleOperatorChange( + valueFieldAlias, + keyFieldAlias, + val, + ); + onChange(val); + }} + options={buildLogicalOptions( + getVariableType(keyFieldAlias), + )} + > + )} + +
    + + {renderParameter(keyFieldAlias, operatorFieldAlias)} + +
    + + +
    + ); + })} +
    +
    + ); +} diff --git a/web/src/pages/agent/form/variable-assigner-form/index.tsx b/web/src/pages/agent/form/variable-assigner-form/index.tsx index 97695f877ea..931351e7e34 100644 --- a/web/src/pages/agent/form/variable-assigner-form/index.tsx +++ b/web/src/pages/agent/form/variable-assigner-form/index.tsx @@ -1,97 +1,44 @@ -import { SelectWithSearch } from '@/components/originui/select-with-search'; -import { RAGFlowFormItem } from '@/components/ragflow-form'; import { Form } from '@/components/ui/form'; -import { Separator } from '@/components/ui/separator'; -import { buildOptions } from '@/utils/form'; import { zodResolver } from '@hookform/resolvers/zod'; import { memo } from 'react'; import { useForm } from 'react-hook-form'; -import { useTranslation } from 'react-i18next'; import { z } from 'zod'; -import { - JsonSchemaDataType, - Operations, - initialDataOperationsValues, -} from '../../constant'; +import { initialDataOperationsValues } from '../../constant'; import { useFormValues } from '../../hooks/use-form-values'; import { useWatchFormChange } from '../../hooks/use-watch-form-change'; import { INextOperatorForm } from '../../interface'; -import { buildOutputList } from '../../utils/build-output-list'; import { FormWrapper } from '../components/form-wrapper'; -import { Output, OutputSchema } from '../components/output'; -import { QueryVariableList } from '../components/query-variable-list'; - -export const RetrievalPartialSchema = { - query: z.array(z.object({ input: z.string().optional() })), - operations: z.string(), - select_keys: z.array(z.object({ name: z.string().optional() })).optional(), - remove_keys: z.array(z.object({ name: z.string().optional() })).optional(), - updates: z - .array( - z.object({ key: z.string().optional(), value: z.string().optional() }), - ) - .optional(), - rename_keys: z - .array( - z.object({ - old_key: z.string().optional(), - new_key: z.string().optional(), - }), - ) - .optional(), - filter_values: z - .array( - z.object({ - key: z.string().optional(), - value: z.string().optional(), - operator: z.string().optional(), - }), - ) - .optional(), - ...OutputSchema, +import { DynamicVariables } from './dynamic-variables'; + +export const VariableAssignerSchema = { + variables: z.array( + z.object({ + variable: z.string().optional(), + operator: z.string().optional(), + parameter: z.string().or(z.number()).or(z.boolean()).optional(), + }), + ), }; -export const FormSchema = z.object(RetrievalPartialSchema); - -export type DataOperationsFormSchemaType = z.infer; +export const FormSchema = z.object(VariableAssignerSchema); -const outputList = buildOutputList(initialDataOperationsValues.outputs); +export type VariableAssignerFormSchemaType = z.infer; function VariableAssignerForm({ node }: INextOperatorForm) { - const { t } = useTranslation(); - const defaultValues = useFormValues(initialDataOperationsValues, node); - const form = useForm({ + const form = useForm({ defaultValues: defaultValues, mode: 'onChange', resolver: zodResolver(FormSchema), - shouldUnregister: true, }); - const OperationsOptions = buildOptions( - Operations, - t, - `flow.operationsOptions`, - true, - ); - useWatchFormChange(node?.id, form, true); return (
    - - - - - - - +
    ); diff --git a/web/src/pages/agent/form/variable-assigner-form/use-build-logical-options.ts b/web/src/pages/agent/form/variable-assigner-form/use-build-logical-options.ts new file mode 100644 index 00000000000..a7f960e98e4 --- /dev/null +++ b/web/src/pages/agent/form/variable-assigner-form/use-build-logical-options.ts @@ -0,0 +1,59 @@ +import { buildOptions } from '@/utils/form'; +import { camelCase } from 'lodash'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { + JsonSchemaDataType, + VariableAssignerLogicalArrayOperator, + VariableAssignerLogicalNumberOperator, + VariableAssignerLogicalNumberOperatorLabelMap, + VariableAssignerLogicalOperator, +} from '../../constant'; + +export function useBuildLogicalOptions() { + const { t } = useTranslation(); + + const buildVariableAssignerLogicalOptions = useCallback( + (record: Record) => { + return buildOptions( + record, + t, + 'flow.variableAssignerLogicalOperatorOptions', + true, + ); + }, + [t], + ); + + const buildLogicalOptions = useCallback( + (type: string) => { + if ( + type?.toLowerCase().startsWith(JsonSchemaDataType.Array.toLowerCase()) + ) { + return buildVariableAssignerLogicalOptions( + VariableAssignerLogicalArrayOperator, + ); + } + + if (type === JsonSchemaDataType.Number) { + return Object.values(VariableAssignerLogicalNumberOperator).map( + (val) => ({ + label: t( + `flow.variableAssignerLogicalOperatorOptions.${camelCase(VariableAssignerLogicalNumberOperatorLabelMap[val as keyof typeof VariableAssignerLogicalNumberOperatorLabelMap] || val)}`, + ), + value: val, + }), + ); + } + + return buildVariableAssignerLogicalOptions( + VariableAssignerLogicalOperator, + ); + }, + [buildVariableAssignerLogicalOptions, t], + ); + + return { + buildLogicalOptions, + }; +} diff --git a/web/src/pages/agent/gobal-variable-sheet/component/add-variable-modal.tsx b/web/src/pages/agent/gobal-variable-sheet/component/add-variable-modal.tsx new file mode 100644 index 00000000000..8ba82f52693 --- /dev/null +++ b/web/src/pages/agent/gobal-variable-sheet/component/add-variable-modal.tsx @@ -0,0 +1,134 @@ +import { + DynamicForm, + DynamicFormRef, + FormFieldConfig, +} from '@/components/dynamic-form'; +import { Modal } from '@/components/ui/modal/modal'; +import { t } from 'i18next'; +import { useEffect, useRef } from 'react'; +import { FieldValues } from 'react-hook-form'; +import { TypeMaps, TypesWithArray } from '../constant'; +import { useHandleForm } from '../hooks/use-form'; +import { useObjectFields } from '../hooks/use-object-fields'; + +export const AddVariableModal = (props: { + fields?: FormFieldConfig[]; + setFields: (value: any) => void; + visible?: boolean; + hideModal: () => void; + defaultValues?: FieldValues; + setDefaultValues?: (value: FieldValues) => void; +}) => { + const { + fields, + setFields, + visible, + hideModal, + defaultValues, + setDefaultValues, + } = props; + + const { handleSubmit: submitForm, loading } = useHandleForm(); + + const { handleCustomValidate, handleCustomSchema, handleRender } = + useObjectFields(); + + const formRef = useRef(null); + + const handleFieldUpdate = ( + fieldName: string, + updatedField: Partial, + ) => { + setFields((prevFields: any) => + prevFields.map((field: any) => + field.name === fieldName ? { ...field, ...updatedField } : field, + ), + ); + }; + + useEffect(() => { + const typeField = fields?.find((item) => item.name === 'type'); + + if (typeField) { + typeField.onChange = (value) => { + handleFieldUpdate('value', { + type: TypeMaps[value as keyof typeof TypeMaps], + render: handleRender(value), + customValidate: handleCustomValidate(value), + schema: handleCustomSchema(value), + }); + const values = formRef.current?.getValues(); + // setTimeout(() => { + switch (value) { + case TypesWithArray.Boolean: + setDefaultValues?.({ ...values, value: false }); + break; + case TypesWithArray.Number: + setDefaultValues?.({ ...values, value: 0 }); + break; + case TypesWithArray.Object: + setDefaultValues?.({ ...values, value: {} }); + break; + case TypesWithArray.ArrayString: + setDefaultValues?.({ ...values, value: [''] }); + break; + case TypesWithArray.ArrayNumber: + setDefaultValues?.({ ...values, value: [''] }); + break; + case TypesWithArray.ArrayBoolean: + setDefaultValues?.({ ...values, value: [false] }); + break; + case TypesWithArray.ArrayObject: + setDefaultValues?.({ ...values, value: [] }); + break; + default: + setDefaultValues?.({ ...values, value: '' }); + break; + } + // }, 0); + }; + } + }, [fields]); + + const handleSubmit = async (fieldValue: FieldValues) => { + await submitForm(fieldValue); + hideModal(); + }; + + return ( + + { + console.log(data); + }} + defaultValues={defaultValues} + onFieldUpdate={handleFieldUpdate} + > +
    + { + hideModal?.(); + }} + /> + { + handleSubmit(values); + // console.log(values); + // console.log(nodes, edges); + // handleOk(values); + }} + /> +
    +
    +
    + ); +}; diff --git a/web/src/pages/agent/gobal-variable-sheet/contant.ts b/web/src/pages/agent/gobal-variable-sheet/constant.ts similarity index 59% rename from web/src/pages/agent/gobal-variable-sheet/contant.ts rename to web/src/pages/agent/gobal-variable-sheet/constant.ts index 2f3bd395f3e..8470ffa86d9 100644 --- a/web/src/pages/agent/gobal-variable-sheet/contant.ts +++ b/web/src/pages/agent/gobal-variable-sheet/constant.ts @@ -1,6 +1,8 @@ import { FormFieldConfig, FormFieldType } from '@/components/dynamic-form'; -import { buildSelectOptions } from '@/utils/component-util'; import { t } from 'i18next'; +import { TypesWithArray } from '../constant'; +import { buildConversationVariableSelectOptions } from '../utils'; +export { TypesWithArray } from '../constant'; // const TypesWithoutArray = Object.values(JsonSchemaDataType).filter( // (item) => item !== JsonSchemaDataType.Array, // ); @@ -9,25 +11,14 @@ import { t } from 'i18next'; // ...TypesWithoutArray.map((item) => `array<${item}>`), // ]; -export enum TypesWithArray { - String = 'string', - Number = 'number', - Boolean = 'boolean', - // Object = 'object', - // ArrayString = 'array', - // ArrayNumber = 'array', - // ArrayBoolean = 'array', - // ArrayObject = 'array', -} - -export const GobalFormFields = [ +export const GlobalFormFields = [ { label: t('flow.name'), name: 'name', placeholder: t('common.namePlaceholder'), required: true, validation: { - pattern: /^[a-zA-Z_]+$/, + pattern: /^[a-zA-Z_0-9]+$/, message: t('flow.variableNameMessage'), }, type: FormFieldType.Text, @@ -38,7 +29,7 @@ export const GobalFormFields = [ placeholder: '', required: true, type: FormFieldType.Select, - options: buildSelectOptions(Object.values(TypesWithArray)), + options: buildConversationVariableSelectOptions(), }, { label: t('flow.defaultValue'), @@ -50,11 +41,11 @@ export const GobalFormFields = [ label: t('flow.description'), name: 'description', placeholder: t('flow.variableDescription'), - type: 'textarea', + type: FormFieldType.Textarea, }, ] as FormFieldConfig[]; -export const GobalVariableFormDefaultValues = { +export const GlobalVariableFormDefaultValues = { name: '', type: TypesWithArray.String, value: '', @@ -65,9 +56,9 @@ export const TypeMaps = { [TypesWithArray.String]: FormFieldType.Textarea, [TypesWithArray.Number]: FormFieldType.Number, [TypesWithArray.Boolean]: FormFieldType.Checkbox, - // [TypesWithArray.Object]: FormFieldType.Textarea, - // [TypesWithArray.ArrayString]: FormFieldType.Textarea, - // [TypesWithArray.ArrayNumber]: FormFieldType.Textarea, - // [TypesWithArray.ArrayBoolean]: FormFieldType.Textarea, - // [TypesWithArray.ArrayObject]: FormFieldType.Textarea, + [TypesWithArray.Object]: FormFieldType.Textarea, + [TypesWithArray.ArrayString]: FormFieldType.Textarea, + [TypesWithArray.ArrayNumber]: FormFieldType.Textarea, + [TypesWithArray.ArrayBoolean]: FormFieldType.Textarea, + [TypesWithArray.ArrayObject]: FormFieldType.Textarea, }; diff --git a/web/src/pages/agent/gobal-variable-sheet/hooks/use-form.tsx b/web/src/pages/agent/gobal-variable-sheet/hooks/use-form.tsx new file mode 100644 index 00000000000..54d37957a2a --- /dev/null +++ b/web/src/pages/agent/gobal-variable-sheet/hooks/use-form.tsx @@ -0,0 +1,44 @@ +import { useFetchAgent } from '@/hooks/use-agent-request'; +import { GlobalVariableType } from '@/interfaces/database/agent'; +import { useCallback } from 'react'; +import { FieldValues } from 'react-hook-form'; +import { useSaveGraph } from '../../hooks/use-save-graph'; +import { TypesWithArray } from '../constant'; + +export const useHandleForm = () => { + const { data, refetch } = useFetchAgent(); + const { saveGraph, loading } = useSaveGraph(); + const handleObjectData = (value: any) => { + try { + return JSON.parse(value); + } catch (error) { + return value; + } + }; + const handleSubmit = useCallback( + async (fieldValue: FieldValues) => { + const param = { + ...(data.dsl?.variables || {}), + [fieldValue.name]: { + ...fieldValue, + value: + fieldValue.type === TypesWithArray.Object || + fieldValue.type === TypesWithArray.ArrayObject + ? handleObjectData(fieldValue.value) + : fieldValue.value, + }, + } as Record; + + const res = await saveGraph(undefined, { + globalVariables: param, + }); + + if (res.code === 0) { + refetch(); + } + }, + [data.dsl?.variables, refetch, saveGraph], + ); + + return { handleSubmit, loading }; +}; diff --git a/web/src/pages/agent/gobal-variable-sheet/hooks/use-object-fields.tsx b/web/src/pages/agent/gobal-variable-sheet/hooks/use-object-fields.tsx new file mode 100644 index 00000000000..07cc48764d6 --- /dev/null +++ b/web/src/pages/agent/gobal-variable-sheet/hooks/use-object-fields.tsx @@ -0,0 +1,321 @@ +import { BoolSegmented } from '@/components/bool-segmented'; +import JsonEditor from '@/components/json-edit'; +import { BlockButton, Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { t } from 'i18next'; +import { isEmpty } from 'lodash'; +import { Trash2, X } from 'lucide-react'; +import { useCallback } from 'react'; +import { FieldValues } from 'react-hook-form'; +import { z } from 'zod'; +import { TypesWithArray } from '../constant'; + +export const useObjectFields = () => { + const booleanRender = useCallback( + (field: FieldValues, className?: string) => { + const fieldValue = field.value ? true : false; + return ( + + ); + }, + [], + ); + const validateKeys = ( + obj: any, + path: (string | number)[] = [], + ): Array<{ path: (string | number)[]; message: string }> => { + const errors: Array<{ path: (string | number)[]; message: string }> = []; + if (typeof obj === 'object' && !Array.isArray(obj)) { + if (isEmpty(obj)) { + errors.push({ + path: [...path], + message: 'No empty parameters are allowed.', + }); + } + for (const key in obj) { + if (obj.hasOwnProperty(key)) { + if (!/^[a-zA-Z_0-9]+$/.test(key)) { + errors.push({ + path: [...path, key], + message: `Key "${key}" is invalid. Keys can only contain letters and underscores and numbers.`, + }); + } + const nestedErrors = validateKeys(obj[key], [...path, key]); + errors.push(...nestedErrors); + } + } + } else if (Array.isArray(obj)) { + obj.forEach((item, index) => { + const nestedErrors = validateKeys(item, [...path, index]); + errors.push(...nestedErrors); + }); + } + + return errors; + }; + const objectRender = useCallback((field: FieldValues) => { + // const fieldValue = + // typeof field.value === 'object' + // ? JSON.stringify(field.value, null, 2) + // : JSON.stringify({}, null, 2); + // console.log('object-render-field', field, fieldValue); + return ( + // + { + return validateKeys(json); + }, + }} + /> + ); + }, []); + + const objectValidate = useCallback((value: any) => { + try { + if (validateKeys(value, [])?.length > 0) { + throw new Error(t('flow.formatTypeError')); + } + if (!z.object({}).safeParse(value).success) { + throw new Error(t('flow.formatTypeError')); + } + if (value && typeof value === 'string' && !JSON.parse(value)) { + throw new Error(t('flow.formatTypeError')); + } + return true; + } catch (e) { + console.log('object-render-error', e, value); + throw new Error(t('flow.formatTypeError')); + } + }, []); + + const arrayObjectValidate = useCallback((value: any) => { + try { + if (validateKeys(value, [])?.length > 0) { + throw new Error(t('flow.formatTypeError')); + } + if (value && typeof value === 'string' && !JSON.parse(value)) { + throw new Error(t('flow.formatTypeError')); + } + return true; + } catch (e) { + console.log('object-render-error', e, value); + throw new Error(t('flow.formatTypeError')); + } + }, []); + + const arrayStringRender = useCallback((field: FieldValues, type = 'text') => { + const values = Array.isArray(field.value) + ? field.value + : [type === 'number' ? 0 : '']; + return ( + <> + {values?.map((item: any, index: number) => ( +
    + { + const newValues = [...values]; + newValues[index] = e.target.value; + field.onChange(newValues); + }} + /> + +
    + ))} + { + field.onChange([...field.value, '']); + }} + > + {t('flow.add')} + + + ); + }, []); + + const arrayBooleanRender = useCallback( + (field: FieldValues) => { + // const values = field.value || [false]; + const values = Array.isArray(field.value) ? field.value : [false]; + return ( +
    + {values?.map((item: any, index: number) => ( +
    + {booleanRender( + { + value: item, + onChange: (value) => { + values[index] = !!value; + field.onChange(values); + }, + }, + 'bg-transparent', + )} + +
    + ))} + { + field.onChange([...field.value, false]); + }} + > + {t('flow.add')} + +
    + ); + }, + [booleanRender], + ); + + const arrayNumberRender = useCallback( + (field: FieldValues) => { + return arrayStringRender(field, 'number'); + }, + [arrayStringRender], + ); + + const arrayValidate = useCallback((value: any, type: string = 'string') => { + if (!Array.isArray(value) || !value.every((item) => typeof item === type)) { + throw new Error(t('flow.formatTypeError')); + } + return true; + }, []); + + const arrayStringValidate = useCallback( + (value: any) => { + return arrayValidate(value, 'string'); + }, + [arrayValidate], + ); + + const arrayNumberValidate = useCallback( + (value: any) => { + return arrayValidate(value, 'number'); + }, + [arrayValidate], + ); + + const arrayBooleanValidate = useCallback( + (value: any) => { + return arrayValidate(value, 'boolean'); + }, + [arrayValidate], + ); + + const handleRender = (value: TypesWithArray) => { + switch (value) { + case TypesWithArray.Boolean: + return booleanRender; + case TypesWithArray.Object: + case TypesWithArray.ArrayObject: + return objectRender; + case TypesWithArray.ArrayString: + return arrayStringRender; + case TypesWithArray.ArrayNumber: + return arrayNumberRender; + case TypesWithArray.ArrayBoolean: + return arrayBooleanRender; + default: + return undefined; + } + }; + const handleCustomValidate = (value: TypesWithArray) => { + switch (value) { + case TypesWithArray.Object: + return objectValidate; + case TypesWithArray.ArrayObject: + return arrayObjectValidate; + case TypesWithArray.ArrayString: + return arrayStringValidate; + case TypesWithArray.ArrayNumber: + return arrayNumberValidate; + case TypesWithArray.ArrayBoolean: + return arrayBooleanValidate; + default: + return undefined; + } + }; + const handleCustomSchema = (value: TypesWithArray) => { + switch (value) { + case TypesWithArray.Object: + return z.object({}); + case TypesWithArray.ArrayObject: + return z.array(z.object({})); + case TypesWithArray.ArrayString: + return z.array(z.string()); + case TypesWithArray.ArrayNumber: + return z.array(z.number()); + case TypesWithArray.ArrayBoolean: + return z.array(z.boolean()); + default: + return undefined; + } + }; + return { + objectRender, + objectValidate, + arrayObjectValidate, + arrayStringRender, + arrayStringValidate, + arrayNumberRender, + booleanRender, + arrayBooleanRender, + arrayNumberValidate, + arrayBooleanValidate, + handleRender, + handleCustomValidate, + handleCustomSchema, + }; +}; diff --git a/web/src/pages/agent/gobal-variable-sheet/index.tsx b/web/src/pages/agent/gobal-variable-sheet/index.tsx index 4541316385b..51648b8d1e4 100644 --- a/web/src/pages/agent/gobal-variable-sheet/index.tsx +++ b/web/src/pages/agent/gobal-variable-sheet/index.tsx @@ -1,12 +1,6 @@ import { ConfirmDeleteDialog } from '@/components/confirm-delete-dialog'; -import { - DynamicForm, - DynamicFormRef, - FormFieldConfig, - FormFieldType, -} from '@/components/dynamic-form'; +import { FormFieldConfig } from '@/components/dynamic-form'; import { BlockButton, Button } from '@/components/ui/button'; -import { Modal } from '@/components/ui/modal/modal'; import { Sheet, SheetContent, @@ -19,117 +13,65 @@ import { GlobalVariableType } from '@/interfaces/database/agent'; import { cn } from '@/lib/utils'; import { t } from 'i18next'; import { Trash2 } from 'lucide-react'; -import { useEffect, useRef, useState } from 'react'; +import { useState } from 'react'; import { FieldValues } from 'react-hook-form'; import { useSaveGraph } from '../hooks/use-save-graph'; +import { AddVariableModal } from './component/add-variable-modal'; import { - GobalFormFields, - GobalVariableFormDefaultValues, + GlobalFormFields, + GlobalVariableFormDefaultValues, TypeMaps, TypesWithArray, -} from './contant'; +} from './constant'; +import { useObjectFields } from './hooks/use-object-fields'; -export type IGobalParamModalProps = { +export type IGlobalParamModalProps = { data: any; hideModal: (open: boolean) => void; }; -export const GobalParamSheet = (props: IGobalParamModalProps) => { +export const GlobalParamSheet = (props: IGlobalParamModalProps) => { const { hideModal } = props; const { data, refetch } = useFetchAgent(); - const [fields, setFields] = useState(GobalFormFields); const { visible, showModal, hideModal: hideAddModal } = useSetModalState(); + const [fields, setFields] = useState(GlobalFormFields); const [defaultValues, setDefaultValues] = useState( - GobalVariableFormDefaultValues, + GlobalVariableFormDefaultValues, ); - const formRef = useRef(null); - - const handleFieldUpdate = ( - fieldName: string, - updatedField: Partial, - ) => { - setFields((prevFields) => - prevFields.map((field) => - field.name === fieldName ? { ...field, ...updatedField } : field, - ), - ); - }; - - useEffect(() => { - const typefileld = fields.find((item) => item.name === 'type'); - - if (typefileld) { - typefileld.onChange = (value) => { - // setWatchType(value); - handleFieldUpdate('value', { - type: TypeMaps[value as keyof typeof TypeMaps], - }); - const values = formRef.current?.getValues(); - setTimeout(() => { - switch (value) { - case TypesWithArray.Boolean: - setDefaultValues({ ...values, value: false }); - break; - case TypesWithArray.Number: - setDefaultValues({ ...values, value: 0 }); - break; - default: - setDefaultValues({ ...values, value: '' }); - } - }, 0); - }; - } - }, [fields]); + const { handleCustomValidate, handleCustomSchema, handleRender } = + useObjectFields(); + const { saveGraph } = useSaveGraph(); - const { saveGraph, loading } = useSaveGraph(); - - const handleSubmit = async (value: FieldValues) => { - const param = { - ...(data.dsl?.variables || {}), - [value.name]: value, - } as Record; - - const res = await saveGraph(undefined, { - gobalVariables: param, - }); - - if (res.code === 0) { - refetch(); - } - hideAddModal(); - }; - - const handleDeleteGobalVariable = async (key: string) => { + const handleDeleteGlobalVariable = async (key: string) => { const param = { ...(data.dsl?.variables || {}), } as Record; delete param[key]; const res = await saveGraph(undefined, { - gobalVariables: param, + globalVariables: param, }); - console.log('delete gobal variable-->', res); if (res.code === 0) { refetch(); } }; - const handleEditGobalVariable = (item: FieldValues) => { - fields.forEach((field) => { - if (field.name === 'value') { - switch (item.type) { - // [TypesWithArray.String]: FormFieldType.Textarea, - // [TypesWithArray.Number]: FormFieldType.Number, - // [TypesWithArray.Boolean]: FormFieldType.Checkbox, - case TypesWithArray.Boolean: - field.type = FormFieldType.Checkbox; - break; - case TypesWithArray.Number: - field.type = FormFieldType.Number; - break; - default: - field.type = FormFieldType.Textarea; - } + const handleEditGlobalVariable = (item: FieldValues) => { + const newFields = fields.map((field) => { + let newField = field; + newField.render = undefined; + newField.schema = undefined; + newField.customValidate = undefined; + if (newField.name === 'value') { + newField = { + ...newField, + type: TypeMaps[item.type as keyof typeof TypeMaps], + render: handleRender(item.type), + customValidate: handleCustomValidate(item.type), + schema: handleCustomSchema(item.type), + }; } + return newField; }); + setFields(newFields); setDefaultValues(item); showModal(); }; @@ -149,8 +91,8 @@ export const GobalParamSheet = (props: IGobalParamModalProps) => {
    { - setFields(GobalFormFields); - setDefaultValues(GobalVariableFormDefaultValues); + setFields(GlobalFormFields); + setDefaultValues(GlobalVariableFormDefaultValues); showModal(); }} > @@ -167,7 +109,7 @@ export const GobalParamSheet = (props: IGobalParamModalProps) => { key={key} className="flex items-center gap-3 min-h-14 justify-between px-5 py-3 border border-border-default rounded-lg hover:bg-bg-card group" onClick={() => { - handleEditGobalVariable(item); + handleEditGlobalVariable(item); }} >
    @@ -177,13 +119,23 @@ export const GobalParamSheet = (props: IGobalParamModalProps) => { {item.type}
    -
    - {item.value} -
    + {![ + TypesWithArray.Object, + TypesWithArray.ArrayObject, + TypesWithArray.ArrayString, + TypesWithArray.ArrayNumber, + TypesWithArray.ArrayBoolean, + ].includes(item.type as TypesWithArray) && ( +
    + + {item.value} + +
    + )}
    handleDeleteGobalVariable(key)} + onOk={() => handleDeleteGlobalVariable(key)} >
    - - { - console.log(data); - }} - defaultValues={defaultValues} - onFieldUpdate={handleFieldUpdate} - > -
    - { - hideAddModal?.(); - }} - /> - { - handleSubmit(values); - // console.log(values); - // console.log(nodes, edges); - // handleOk(values); - }} - /> -
    -
    -
    + ); diff --git a/web/src/pages/agent/hooks/use-add-node.ts b/web/src/pages/agent/hooks/use-add-node.ts index ed092a01be5..53f99e51ca9 100644 --- a/web/src/pages/agent/hooks/use-add-node.ts +++ b/web/src/pages/agent/hooks/use-add-node.ts @@ -10,7 +10,6 @@ import { NodeMap, Operator, initialAgentValues, - initialAkShareValues, initialArXivValues, initialBeginValues, initialBingValues, @@ -29,14 +28,12 @@ import { initialInvokeValues, initialIterationStartValues, initialIterationValues, - initialJin10Values, - initialKeywordExtractValues, + initialListOperationsValues, + initialLoopValues, initialMessageValues, initialNoteValues, initialParserValues, initialPubMedValues, - initialQWeatherValues, - initialRelevantValues, initialRetrievalValues, initialRewriteQuestionValues, initialSearXNGValues, @@ -46,7 +43,6 @@ import { initialTavilyExtractValues, initialTavilyValues, initialTokenizerValues, - initialTuShareValues, initialUserFillUpValues, initialVariableAggregatorValues, initialVariableAssignerValues, @@ -67,6 +63,63 @@ function isBottomSubAgent(type: string, position: Position) { type === Operator.Tool ); } + +const GroupStartNodeMap = { + [Operator.Iteration]: { + id: `${Operator.IterationStart}:${humanId()}`, + type: 'iterationStartNode', + position: { x: 50, y: 100 }, + data: { + label: Operator.IterationStart, + name: Operator.IterationStart, + form: initialIterationStartValues, + }, + extent: 'parent' as 'parent', + }, + [Operator.Loop]: { + id: `${Operator.LoopStart}:${humanId()}`, + type: 'loopStartNode', + position: { x: 50, y: 100 }, + data: { + label: Operator.LoopStart, + name: Operator.LoopStart, + form: {}, + }, + extent: 'parent' as 'parent', + }, +}; + +function useAddGroupNode() { + const { addEdge, addNode } = useGraphStore((state) => state); + + const addGroupNode = useCallback( + (operatorType: string, newNode: Node, nodeId?: string) => { + newNode.width = 500; + newNode.height = 250; + + const startNode: Node = + GroupStartNodeMap[operatorType as keyof typeof GroupStartNodeMap]; + + startNode.parentId = newNode.id; + + addNode(newNode); + addNode(startNode); + + if (nodeId) { + addEdge({ + source: nodeId, + target: newNode.id, + sourceHandle: NodeHandleId.Start, + targetHandle: NodeHandleId.End, + }); + } + return newNode.id; + }, + [addEdge, addNode], + ); + + return { addGroupNode }; +} export const useInitializeOperatorParams = () => { const llmId = useFetchModelId(); @@ -75,16 +128,11 @@ export const useInitializeOperatorParams = () => { [Operator.Begin]: initialBeginValues, [Operator.Retrieval]: initialRetrievalValues, [Operator.Categorize]: { ...initialCategorizeValues, llm_id: llmId }, - [Operator.Relevant]: { ...initialRelevantValues, llm_id: llmId }, [Operator.RewriteQuestion]: { ...initialRewriteQuestionValues, llm_id: llmId, }, [Operator.Message]: initialMessageValues, - [Operator.KeywordExtract]: { - ...initialKeywordExtractValues, - llm_id: llmId, - }, [Operator.DuckDuckGo]: initialDuckValues, [Operator.Wikipedia]: initialWikipediaValues, [Operator.PubMed]: initialPubMedValues, @@ -94,14 +142,10 @@ export const useInitializeOperatorParams = () => { [Operator.GoogleScholar]: initialGoogleScholarValues, [Operator.SearXNG]: initialSearXNGValues, [Operator.GitHub]: initialGithubValues, - [Operator.QWeather]: initialQWeatherValues, [Operator.ExeSQL]: initialExeSqlValues, [Operator.Switch]: initialSwitchValues, [Operator.WenCai]: initialWenCaiValues, - [Operator.AkShare]: initialAkShareValues, [Operator.YahooFinance]: initialYahooFinanceValues, - [Operator.Jin10]: initialJin10Values, - [Operator.TuShare]: initialTuShareValues, [Operator.Note]: initialNoteValues, [Operator.Crawler]: initialCrawlerValues, [Operator.Invoke]: initialInvokeValues, @@ -129,8 +173,14 @@ export const useInitializeOperatorParams = () => { prompts: t('flow.prompts.user.summary'), }, [Operator.DataOperations]: initialDataOperationsValues, + [Operator.ListOperations]: initialListOperationsValues, [Operator.VariableAssigner]: initialVariableAssignerValues, [Operator.VariableAggregator]: initialVariableAggregatorValues, + [Operator.Loop]: initialLoopValues, + [Operator.LoopStart]: {}, + [Operator.ExitLoop]: {}, + [Operator.PDFGenerator]: {}, + [Operator.ExcelProcessor]: {}, }; }, [llmId]); @@ -309,6 +359,7 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) { const { addChildEdge } = useAddChildEdge(); const { addToolNode } = useAddToolNode(); const { resizeIterationNode } = useResizeIterationNode(); + const { addGroupNode } = useAddGroupNode(); // const [reactFlowInstance, setReactFlowInstance] = // useState>(); @@ -374,33 +425,8 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) { } } - if (type === Operator.Iteration) { - newNode.width = 500; - newNode.height = 250; - const iterationStartNode: Node = { - id: `${Operator.IterationStart}:${humanId()}`, - type: 'iterationStartNode', - position: { x: 50, y: 100 }, - // draggable: false, - data: { - label: Operator.IterationStart, - name: Operator.IterationStart, - form: initialIterationStartValues, - }, - parentId: newNode.id, - extent: 'parent', - }; - addNode(newNode); - addNode(iterationStartNode); - if (nodeId) { - addEdge({ - source: nodeId, - target: newNode.id, - sourceHandle: NodeHandleId.Start, - targetHandle: NodeHandleId.End, - }); - } - return newNode.id; + if ([Operator.Iteration, Operator.Loop].includes(type as Operator)) { + return addGroupNode(type, newNode, nodeId); } else if ( type === Operator.Agent && params.position === Position.Bottom @@ -454,6 +480,7 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance) { [ addChildEdge, addEdge, + addGroupNode, addNode, addToolNode, calculateNewlyBackChildPosition, diff --git a/web/src/pages/agent/hooks/use-build-dsl.ts b/web/src/pages/agent/hooks/use-build-dsl.ts index 1a856963618..47ec1c22591 100644 --- a/web/src/pages/agent/hooks/use-build-dsl.ts +++ b/web/src/pages/agent/hooks/use-build-dsl.ts @@ -4,7 +4,7 @@ import { RAGFlowNodeType } from '@/interfaces/database/flow'; import { useCallback } from 'react'; import { Operator } from '../constant'; import useGraphStore from '../store'; -import { buildDslComponentsByGraph, buildDslGobalVariables } from '../utils'; +import { buildDslComponentsByGraph, buildDslGlobalVariables } from '../utils'; export const useBuildDslData = () => { const { data } = useFetchAgent(); @@ -13,7 +13,7 @@ export const useBuildDslData = () => { const buildDslData = useCallback( ( currentNodes?: RAGFlowNodeType[], - otherParam?: { gobalVariables: Record }, + otherParam?: { globalVariables: Record }, ) => { const nodesToProcess = currentNodes ?? nodes; @@ -41,13 +41,13 @@ export const useBuildDslData = () => { data.dsl.components, ); - const gobalVariables = buildDslGobalVariables( + const globalVariables = buildDslGlobalVariables( data.dsl, - otherParam?.gobalVariables, + otherParam?.globalVariables, ); return { ...data.dsl, - ...gobalVariables, + ...globalVariables, graph: { nodes: filteredNodes, edges: filteredEdges }, components: dslComponents, }; diff --git a/web/src/pages/agent/hooks/use-build-options.tsx b/web/src/pages/agent/hooks/use-build-options.tsx index d1214f7299c..10523178bc5 100644 --- a/web/src/pages/agent/hooks/use-build-options.tsx +++ b/web/src/pages/agent/hooks/use-build-options.tsx @@ -1,4 +1,4 @@ -import { buildNodeOutputOptions } from '@/utils/canvas-util'; +import { buildUpstreamNodeOutputOptions } from '@/utils/canvas-util'; import { useMemo } from 'react'; import { Operator } from '../constant'; import OperatorIcon from '../operator-icon'; @@ -9,7 +9,7 @@ export function useBuildNodeOutputOptions(nodeId?: string) { const edges = useGraphStore((state) => state.edges); return useMemo(() => { - return buildNodeOutputOptions({ + return buildUpstreamNodeOutputOptions({ nodes, edges, nodeId, diff --git a/web/src/pages/agent/hooks/use-build-structured-output.ts b/web/src/pages/agent/hooks/use-build-structured-output.ts index 94597d348f0..c9d34bd43d6 100644 --- a/web/src/pages/agent/hooks/use-build-structured-output.ts +++ b/web/src/pages/agent/hooks/use-build-structured-output.ts @@ -1,3 +1,4 @@ +import { getStructuredDatatype } from '@/utils/canvas-util'; import { get, isPlainObject } from 'lodash'; import { ReactNode, useCallback } from 'react'; import { @@ -7,8 +8,11 @@ import { } from '../constant'; import useGraphStore from '../store'; +function splitValue(value?: string) { + return typeof value === 'string' ? value?.split('@') : []; +} function getNodeId(value: string) { - return value.split('@').at(0); + return splitValue(value).at(0); } export function useShowSecondaryMenu() { @@ -63,7 +67,7 @@ export function useFindAgentStructuredOutputLabel() { }>, ) => { // agent structured output - const fields = value.split('@'); + const fields = splitValue(value); if ( getOperatorTypeFromId(fields.at(0)) === Operator.Agent && fields.at(1)?.startsWith(AgentStructuredOutputField) @@ -103,10 +107,10 @@ export function useFindAgentStructuredOutputTypeByValue() { if (isPlainObject(values) && properties) { for (const [key, value] of Object.entries(properties)) { const nextPath = path ? `${path}.${key}` : key; - const dataType = get(value, 'type'); + const { dataType, compositeDataType } = getStructuredDatatype(value); if (nextPath === target) { - return dataType; + return compositeDataType; } if ( @@ -130,7 +134,7 @@ export function useFindAgentStructuredOutputTypeByValue() { if (!value) { return; } - const fields = value.split('@'); + const fields = splitValue(value); const nodeId = fields.at(0); const jsonSchema = filterStructuredOutput(value); @@ -163,7 +167,7 @@ export function useFindAgentStructuredOutputLabelByValue() { const operatorName = getNode(getNodeId(value ?? ''))?.data.name; if (operatorName) { - return operatorName + ' / ' + value?.split('@').at(1); + return operatorName + ' / ' + splitValue(value).at(1); } } diff --git a/web/src/pages/agent/hooks/use-build-webhook-url.ts b/web/src/pages/agent/hooks/use-build-webhook-url.ts new file mode 100644 index 00000000000..e8d7f13e607 --- /dev/null +++ b/web/src/pages/agent/hooks/use-build-webhook-url.ts @@ -0,0 +1,8 @@ +import { useParams } from 'umi'; + +export function useBuildWebhookUrl() { + const { id } = useParams(); + + const text = `${location.protocol}//${location.host}/api/v1/webhook/${id}`; + return text; +} diff --git a/web/src/pages/agent/hooks/use-change-node-name.ts b/web/src/pages/agent/hooks/use-change-node-name.ts index 61a5653d732..9d6112c2d4c 100644 --- a/web/src/pages/agent/hooks/use-change-node-name.ts +++ b/web/src/pages/agent/hooks/use-change-node-name.ts @@ -6,14 +6,13 @@ import { SetStateAction, useCallback, useEffect, - useMemo, useState, } from 'react'; import { Operator } from '../constant'; import useGraphStore from '../store'; import { getAgentNodeTools } from '../utils'; -export function useHandleTooNodeNameChange({ +export function useHandleToolNodeNameChange({ id, name, setName, @@ -22,48 +21,44 @@ export function useHandleTooNodeNameChange({ name?: string; setName: Dispatch>; }) { - const { clickedToolId, findUpstreamNodeById, updateNodeForm } = useGraphStore( - (state) => state, - ); - const agentNode = findUpstreamNodeById(id); + const { + clickedToolId, + findUpstreamNodeById, + getAgentToolById, + updateAgentToolById, + } = useGraphStore((state) => state); + const agentNode = findUpstreamNodeById(id)!; const tools = getAgentNodeTools(agentNode); - - const previousName = useMemo(() => { - const tool = tools.find((x) => x.component_name === clickedToolId); - return tool?.name || tool?.component_name; - }, [clickedToolId, tools]); + const previousName = getAgentToolById(clickedToolId, agentNode)?.name; const handleToolNameBlur = useCallback(() => { const trimmedName = trim(name); const existsSameName = tools.some((x) => x.name === trimmedName); - if (trimmedName === '' || existsSameName) { - if (existsSameName && previousName !== name) { - message.error('The name cannot be repeated'); - } + + // Not changed + if (trimmedName === '') { setName(previousName || ''); - return; + return true; + } + + if (existsSameName && previousName !== name) { + message.error('The name cannot be repeated'); + return false; } if (agentNode?.id) { - const nextTools = tools.map((x) => { - if (x.component_name === clickedToolId) { - return { - ...x, - name, - }; - } - return x; - }); - updateNodeForm(agentNode?.id, nextTools, ['tools']); + updateAgentToolById(agentNode, clickedToolId, { name }); } + + return true; }, [ - agentNode?.id, + agentNode, clickedToolId, name, previousName, setName, tools, - updateNodeForm, + updateAgentToolById, ]); return { handleToolNameBlur, previousToolName: previousName }; @@ -83,28 +78,35 @@ export const useHandleNodeNameChange = ({ const previousName = data?.name; const isToolNode = getOperatorTypeFromId(id) === Operator.Tool; - const { handleToolNameBlur, previousToolName } = useHandleTooNodeNameChange({ + const { handleToolNameBlur, previousToolName } = useHandleToolNodeNameChange({ id, name, setName, }); const handleNameBlur = useCallback(() => { + const trimmedName = trim(name); const existsSameName = nodes.some((x) => x.data.name === name); - if (trim(name) === '' || existsSameName) { - if (existsSameName && previousName !== name) { - message.error('The name cannot be repeated'); - } - setName(previousName); - return; + + // Not changed + if (!trimmedName) { + setName(previousName || ''); + return true; + } + + if (existsSameName && previousName !== name) { + message.error('The name cannot be repeated'); + return false; } if (id) { updateNodeName(id, name); } + + return true; }, [name, id, updateNodeName, previousName, nodes]); - const handleNameChange = useCallback((e: ChangeEvent) => { + const handleNameChange = useCallback((e: ChangeEvent) => { setName(e.target.value); }, []); diff --git a/web/src/pages/agent/hooks/use-chat-logic.ts b/web/src/pages/agent/hooks/use-chat-logic.ts index 42b0533cb76..3c62ae4d1d1 100644 --- a/web/src/pages/agent/hooks/use-chat-logic.ts +++ b/web/src/pages/agent/hooks/use-chat-logic.ts @@ -1,6 +1,5 @@ import { MessageType } from '@/constants/chat'; -import { Message } from '@/interfaces/database/chat'; -import { IMessage } from '@/pages/chat/interface'; +import { IMessage, Message } from '@/interfaces/database/chat'; import { get } from 'lodash'; import { useCallback, useMemo } from 'react'; import { BeginQuery } from '../interface'; diff --git a/web/src/pages/agent/hooks/use-filter-child-node-ids.ts b/web/src/pages/agent/hooks/use-filter-child-node-ids.ts new file mode 100644 index 00000000000..a60622a70c1 --- /dev/null +++ b/web/src/pages/agent/hooks/use-filter-child-node-ids.ts @@ -0,0 +1,10 @@ +import { filterChildNodeIds } from '@/utils/canvas-util'; +import useGraphStore from '../store'; + +export function useFilterChildNodeIds(nodeId?: string) { + const nodes = useGraphStore((state) => state.nodes); + + const childNodeIds = filterChildNodeIds(nodes, nodeId); + + return childNodeIds ?? []; +} diff --git a/web/src/pages/agent/hooks/use-get-begin-query.tsx b/web/src/pages/agent/hooks/use-get-begin-query.tsx index 9b9a6f6bad4..5de22e0e978 100644 --- a/web/src/pages/agent/hooks/use-get-begin-query.tsx +++ b/web/src/pages/agent/hooks/use-get-begin-query.tsx @@ -1,15 +1,21 @@ -import { AgentGlobals } from '@/constants/agent'; +import { AgentGlobals, AgentStructuredOutputField } from '@/constants/agent'; import { useFetchAgent } from '@/hooks/use-agent-request'; import { RAGFlowNodeType } from '@/interfaces/database/flow'; -import { buildNodeOutputOptions } from '@/utils/canvas-util'; +import { + buildNodeOutputOptions, + buildOutputOptions, + buildUpstreamNodeOutputOptions, + isAgentStructured, +} from '@/utils/canvas-util'; import { DefaultOptionType } from 'antd/es/select'; import { t } from 'i18next'; -import { isEmpty, toLower } from 'lodash'; +import { flatten, isEmpty, toLower } from 'lodash'; import get from 'lodash/get'; import { MessageSquareCode } from 'lucide-react'; import { useCallback, useContext, useEffect, useMemo, useState } from 'react'; import { AgentDialogueMode, + AgentVariableType, BeginId, BeginQueryType, JsonSchemaDataType, @@ -85,27 +91,41 @@ export const useGetBeginNodeDataQueryIsSafe = () => { return isBeginNodeDataQuerySafe; }; -export function useBuildNodeOutputOptions(nodeId?: string) { +export function useBuildUpstreamNodeOutputOptions(nodeId?: string) { const nodes = useGraphStore((state) => state.nodes); const edges = useGraphStore((state) => state.edges); return useMemo(() => { - return buildNodeOutputOptions({ + return buildUpstreamNodeOutputOptions({ nodes, edges, nodeId, - Icon: ({ name }) => , }); }, [edges, nodeId, nodes]); } +export function useBuildParentOutputOptions(parentId?: string) { + const { getNode, getOperatorTypeFromId } = useGraphStore((state) => state); + const parentNode = getNode(parentId); + + const parentType = getOperatorTypeFromId(parentId); + + if ( + parentType && + [Operator.Loop].includes(parentType as Operator) && + parentNode + ) { + const options = buildOutputOptions(parentNode); + if (options) { + return [options]; + } + } + + return []; +} + // exclude nodes with branches -const ExcludedNodes = [ - Operator.Categorize, - Operator.Relevant, - Operator.Begin, - Operator.Note, -]; +const ExcludedNodes = [Operator.Categorize, Operator.Begin, Operator.Note]; const StringList = [ BeginQueryType.Line, @@ -120,7 +140,7 @@ function transferToVariableType(type: string) { return type; } -export function useBuildBeginVariableOptions() { +export function useBuildBeginDynamicVariableOptions() { const inputs = useSelectBeginNodeDataInputs(); const options = useMemo(() => { @@ -144,6 +164,30 @@ export function useBuildBeginVariableOptions() { const Env = 'env.'; +export function useBuildGlobalWithBeginVariableOptions() { + const { data } = useFetchAgent(); + const dynamicBeginOptions = useBuildBeginDynamicVariableOptions(); + const globals = data?.dsl?.globals ?? {}; + const globalOptions = Object.entries(globals) + .filter(([key]) => !key.startsWith(Env)) + .map(([key, value]) => ({ + label: key, + value: key, + icon: , + parentLabel: {t('flow.beginInput')}, + type: Array.isArray(value) + ? `${VariableType.Array}${key === AgentGlobals.SysFiles ? '' : ''}` + : typeof value, + })); + + return [ + { + ...dynamicBeginOptions[0], + options: [...(dynamicBeginOptions[0]?.options ?? []), ...globalOptions], + }, + ]; +} + export function useBuildConversationVariableOptions() { const { data } = useFetchAgent(); @@ -175,55 +219,88 @@ export function useBuildConversationVariableOptions() { } export const useBuildVariableOptions = (nodeId?: string, parentId?: string) => { - const nodeOutputOptions = useBuildNodeOutputOptions(nodeId); - const parentNodeOutputOptions = useBuildNodeOutputOptions(parentId); - const beginOptions = useBuildBeginVariableOptions(); + const upstreamNodeOutputOptions = useBuildUpstreamNodeOutputOptions(nodeId); + const parentNodeOutputOptions = useBuildParentOutputOptions(parentId); + const parentUpstreamNodeOutputOptions = + useBuildUpstreamNodeOutputOptions(parentId); const options = useMemo(() => { - return [...beginOptions, ...nodeOutputOptions, ...parentNodeOutputOptions]; - }, [beginOptions, nodeOutputOptions, parentNodeOutputOptions]); + return [ + ...upstreamNodeOutputOptions, + ...parentNodeOutputOptions, + ...parentUpstreamNodeOutputOptions, + ]; + }, [ + upstreamNodeOutputOptions, + parentNodeOutputOptions, + parentUpstreamNodeOutputOptions, + ]); return options; }; -export function useBuildQueryVariableOptions(n?: RAGFlowNodeType) { - const { data } = useFetchAgent(); +export type BuildQueryVariableOptions = { + nodeIds?: string[]; + variablesExceptOperatorOutputs?: AgentVariableType[]; +}; + +export function useBuildQueryVariableOptions({ + n, + nodeIds = [], + variablesExceptOperatorOutputs, // Variables other than operator output variables +}: { + n?: RAGFlowNodeType; +} & BuildQueryVariableOptions = {}) { const node = useContext(AgentFormContext) || n; + const nodes = useGraphStore((state) => state.nodes); + const options = useBuildVariableOptions(node?.id, node?.parentId); const conversationOptions = useBuildConversationVariableOptions(); - const nextOptions = useMemo(() => { - const globals = data?.dsl?.globals ?? {}; - const globalOptions = Object.entries(globals) - .filter(([key]) => !key.startsWith(Env)) - .map(([key, value]) => ({ - label: key, - value: key, - icon: , - parentLabel: {t('flow.beginInput')}, - type: Array.isArray(value) - ? `${VariableType.Array}${key === AgentGlobals.SysFiles ? '' : ''}` - : typeof value, - })); + const globalWithBeginVariableOptions = + useBuildGlobalWithBeginVariableOptions(); + const AgentVariableOptionsMap = { + [AgentVariableType.Begin]: globalWithBeginVariableOptions, + [AgentVariableType.Conversation]: conversationOptions, + }; + + const nextOptions = useMemo(() => { return [ - { - ...options[0], - options: [...options[0]?.options, ...globalOptions], - }, - ...options.slice(1), + ...globalWithBeginVariableOptions, ...conversationOptions, + ...options, ]; - }, [conversationOptions, data?.dsl?.globals, options]); + }, [conversationOptions, globalWithBeginVariableOptions, options]); + + // Which options are entirely under external control? + if (!isEmpty(nodeIds) || !isEmpty(variablesExceptOperatorOutputs)) { + const nodeOutputOptions = buildNodeOutputOptions({ nodes, nodeIds }); + const variablesExceptOperatorOutputsOptions = + variablesExceptOperatorOutputs?.map((x) => AgentVariableOptionsMap[x]) ?? + []; + + return [ + ...flatten(variablesExceptOperatorOutputsOptions), + ...nodeOutputOptions, + ]; + } return nextOptions; } -export function useFilterQueryVariableOptionsByTypes( - types?: JsonSchemaDataType[], -) { - const nextOptions = useBuildQueryVariableOptions(); +export function useFilterQueryVariableOptionsByTypes({ + types, + nodeIds = [], + variablesExceptOperatorOutputs, +}: { + types?: JsonSchemaDataType[]; +} & BuildQueryVariableOptions) { + const nextOptions = useBuildQueryVariableOptions({ + nodeIds, + variablesExceptOperatorOutputs, + }); const filteredOptions = useMemo(() => { return !isEmpty(types) @@ -232,8 +309,16 @@ export function useFilterQueryVariableOptionsByTypes( ...x, options: x.options.filter( (y) => - types?.some((x) => toLower(y.type).includes(x)) || - y.type === undefined, // agent structured output + types?.some((x) => + toLower(x).startsWith('array') + ? toLower(y.type).includes(toLower(x)) + : toLower(y.type) === toLower(x), + ) || + // agent structured output + isAgentStructured( + y.value, + y.value.slice(-AgentStructuredOutputField.length), + ), ), }; }) @@ -287,7 +372,7 @@ export function useBuildComponentIdAndBeginOptions( parentId?: string, ) { const componentIdOptions = useBuildComponentIdOptions(nodeId, parentId); - const beginOptions = useBuildBeginVariableOptions(); + const beginOptions = useBuildBeginDynamicVariableOptions(); return [...beginOptions, ...componentIdOptions]; } @@ -310,21 +395,45 @@ export const useGetComponentLabelByValue = (nodeId: string) => { return getLabel; }; -export function useFlattenQueryVariableOptions(nodeId?: string) { +export function flatOptions(options: DefaultOptionType[]) { + return options.reduce((pre, cur) => { + return [...pre, ...cur.options]; + }, []); +} + +export function useFlattenQueryVariableOptions({ + nodeId, + nodeIds = [], + variablesExceptOperatorOutputs, +}: { + nodeId?: string; +} & BuildQueryVariableOptions = {}) { const { getNode } = useGraphStore((state) => state); - const nextOptions = useBuildQueryVariableOptions(getNode(nodeId)); + const nextOptions = useBuildQueryVariableOptions({ + n: getNode(nodeId), + nodeIds, + variablesExceptOperatorOutputs, + }); const flattenOptions = useMemo(() => { - return nextOptions.reduce((pre, cur) => { - return [...pre, ...cur.options]; - }, []); + return flatOptions(nextOptions); }, [nextOptions]); return flattenOptions; } -export function useGetVariableLabelOrTypeByValue(nodeId?: string) { - const flattenOptions = useFlattenQueryVariableOptions(nodeId); +export function useGetVariableLabelOrTypeByValue({ + nodeId, + nodeIds = [], + variablesExceptOperatorOutputs, +}: { + nodeId?: string; +} & BuildQueryVariableOptions = {}) { + const flattenOptions = useFlattenQueryVariableOptions({ + nodeId, + nodeIds, + variablesExceptOperatorOutputs, + }); const findAgentStructuredOutputTypeByValue = useFindAgentStructuredOutputTypeByValue(); const findAgentStructuredOutputLabel = diff --git a/web/src/pages/agent/hooks/use-is-mcp.ts b/web/src/pages/agent/hooks/use-is-mcp.ts index 5cec7448726..bd76b62f468 100644 --- a/web/src/pages/agent/hooks/use-is-mcp.ts +++ b/web/src/pages/agent/hooks/use-is-mcp.ts @@ -2,10 +2,12 @@ import { Operator } from '../constant'; import useGraphStore from '../store'; export function useIsMcp(operatorName: Operator) { - const clickedToolId = useGraphStore((state) => state.clickedToolId); + const { clickedToolId, getAgentToolById } = useGraphStore(); + + const { component_name: toolName } = getAgentToolById(clickedToolId) ?? {}; return ( operatorName === Operator.Tool && - Object.values(Operator).every((x) => x !== clickedToolId) + Object.values(Operator).every((x) => x !== toolName) ); } diff --git a/web/src/pages/agent/hooks/use-is-webhook.ts b/web/src/pages/agent/hooks/use-is-webhook.ts new file mode 100644 index 00000000000..297a511f080 --- /dev/null +++ b/web/src/pages/agent/hooks/use-is-webhook.ts @@ -0,0 +1,10 @@ +import { AgentDialogueMode, BeginId } from '../constant'; +import useGraphStore from '../store'; + +export function useIsWebhookMode() { + const getNode = useGraphStore((state) => state.getNode); + + const beginNode = getNode(BeginId); + + return beginNode?.data.form?.mode === AgentDialogueMode.Webhook; +} diff --git a/web/src/pages/agent/hooks/use-node-loading.ts b/web/src/pages/agent/hooks/use-node-loading.ts new file mode 100644 index 00000000000..d92702f5655 --- /dev/null +++ b/web/src/pages/agent/hooks/use-node-loading.ts @@ -0,0 +1,88 @@ +import { + INodeData, + INodeEvent, + MessageEventType, +} from '@/hooks/use-send-message'; +import { IMessage } from '@/interfaces/database/chat'; +import { useCallback, useMemo, useState } from 'react'; + +export const useNodeLoading = ({ + currentEventListWithoutMessageById, +}: { + currentEventListWithoutMessageById: (messageId: string) => INodeEvent[]; +}) => { + const [derivedMessages, setDerivedMessages] = useState(); + + const lastMessageId = useMemo(() => { + return derivedMessages?.[derivedMessages?.length - 1]?.id; + }, [derivedMessages]); + + const currentEventListWithoutMessage = useMemo(() => { + if (!lastMessageId) { + return []; + } + return currentEventListWithoutMessageById(lastMessageId); + }, [currentEventListWithoutMessageById, lastMessageId]); + + const startedNodeList = useMemo(() => { + const duplicateList = currentEventListWithoutMessage?.filter( + (x) => x.event === MessageEventType.NodeStarted, + ) as INodeEvent[]; + + // Remove duplicate nodes + return duplicateList?.reduce>((pre, cur) => { + if (pre.every((x) => x.data.component_id !== cur.data.component_id)) { + pre.push(cur); + } + return pre; + }, []); + }, [currentEventListWithoutMessage]); + + const filterFinishedNodeList = useCallback(() => { + const nodeEventList = currentEventListWithoutMessage + .filter( + (x) => x.event === MessageEventType.NodeFinished, + // x.event === MessageEventType.NodeFinished && + // (x.data as INodeData)?.component_id === componentId, + ) + .map((x) => x.data); + + return nodeEventList; + }, [currentEventListWithoutMessage]); + + const lastNode = useMemo(() => { + if (!startedNodeList) { + return null; + } + return startedNodeList[startedNodeList.length - 1]; + }, [startedNodeList]); + + const startNodeIds = useMemo(() => { + if (!startedNodeList) { + return []; + } + return startedNodeList.map((x) => x.data.component_id); + }, [startedNodeList]); + + const finishNodeIds = useMemo(() => { + if (!lastNode) { + return []; + } + const nodeDataList = filterFinishedNodeList(); + const finishNodeIdsTemp = nodeDataList.map( + (x: INodeData) => x.component_id, + ); + return Array.from(new Set(finishNodeIdsTemp)); + }, [lastNode, filterFinishedNodeList]); + + const startButNotFinishedNodeIds = useMemo(() => { + return startNodeIds.filter((x) => !finishNodeIds.includes(x)); + }, [finishNodeIds, startNodeIds]); + + return { + lastNode, + startButNotFinishedNodeIds, + filterFinishedNodeList, + setDerivedMessages, + }; +}; diff --git a/web/src/pages/agent/hooks/use-save-graph.ts b/web/src/pages/agent/hooks/use-save-graph.ts index e59b99193cf..500baf7167f 100644 --- a/web/src/pages/agent/hooks/use-save-graph.ts +++ b/web/src/pages/agent/hooks/use-save-graph.ts @@ -21,7 +21,7 @@ export const useSaveGraph = (showMessage: boolean = true) => { const saveGraph = useCallback( async ( currentNodes?: RAGFlowNodeType[], - otherParam?: { gobalVariables: Record }, + otherParam?: { globalVariables: Record }, ) => { return setAgent({ id, diff --git a/web/src/pages/agent/hooks/use-send-shared-message.ts b/web/src/pages/agent/hooks/use-send-shared-message.ts index ab7f7d2fce1..fe1e34d62a6 100644 --- a/web/src/pages/agent/hooks/use-send-shared-message.ts +++ b/web/src/pages/agent/hooks/use-send-shared-message.ts @@ -29,6 +29,7 @@ export const useGetSharedChatSearchParams = () => { from: searchParams.get('from') as SharedFrom, sharedId: searchParams.get('shared_id'), locale: searchParams.get('locale'), + theme: searchParams.get('theme'), data: data, visibleAvatar: searchParams.get('visible_avatar') ? searchParams.get('visible_avatar') !== '1' diff --git a/web/src/pages/agent/hooks/use-show-dialog.ts b/web/src/pages/agent/hooks/use-show-dialog.ts index 6178e3fbc69..454ee086b7b 100644 --- a/web/src/pages/agent/hooks/use-show-dialog.ts +++ b/web/src/pages/agent/hooks/use-show-dialog.ts @@ -5,7 +5,7 @@ import { useCreateSystemToken, useFetchSystemTokenList, useRemoveSystemToken, -} from '@/hooks/user-setting-hooks'; +} from '@/hooks/use-user-setting-request'; import { IStats } from '@/interfaces/database/chat'; import { useQueryClient } from '@tanstack/react-query'; import { useCallback } from 'react'; diff --git a/web/src/pages/agent/hooks/use-show-drawer.tsx b/web/src/pages/agent/hooks/use-show-drawer.tsx index a350af074f0..3a15b29b894 100644 --- a/web/src/pages/agent/hooks/use-show-drawer.tsx +++ b/web/src/pages/agent/hooks/use-show-drawer.tsx @@ -14,6 +14,7 @@ export const useShowFormDrawer = () => { setClickedNodeId, getNode, setClickedToolId, + getOperatorTypeFromId, } = useGraphStore((state) => state); const { visible: formDrawerVisible, @@ -23,16 +24,25 @@ export const useShowFormDrawer = () => { const handleShow = useCallback( (e: React.MouseEvent, nodeId: string) => { - const tool = get(e.target, 'dataset.tool'); + const toolId = (e.target as HTMLElement).dataset.toolId; + const tool = (e.target as HTMLElement).dataset.tool; + // TODO: Operator type judgment should be used - if (nodeId.startsWith(Operator.Tool) && !tool) { + const operatorType = getOperatorTypeFromId(nodeId); + if ( + (operatorType === Operator.Tool && !tool) || + [Operator.LoopStart, Operator.ExitLoop].includes( + operatorType as Operator, + ) + ) { return; } setClickedNodeId(nodeId); - setClickedToolId(tool); + // Guess this could gracefully handle the case where the tool id is not provided? + setClickedToolId(toolId || tool); showFormDrawer(); }, - [setClickedNodeId, setClickedToolId, showFormDrawer], + [getOperatorTypeFromId, setClickedNodeId, setClickedToolId, showFormDrawer], ); return { diff --git a/web/src/pages/agent/index.tsx b/web/src/pages/agent/index.tsx index 21ecb22e7ca..92bd3a0fe8b 100644 --- a/web/src/pages/agent/index.tsx +++ b/web/src/pages/agent/index.tsx @@ -39,13 +39,14 @@ import { useParams } from 'umi'; import AgentCanvas from './canvas'; import { DropdownProvider } from './canvas/context'; import { Operator } from './constant'; -import { GobalParamSheet } from './gobal-variable-sheet'; +import { GlobalParamSheet } from './gobal-variable-sheet'; import { useCancelCurrentDataflow } from './hooks/use-cancel-dataflow'; import { useHandleExportJsonFile } from './hooks/use-export-json'; import { useFetchDataOnMount } from './hooks/use-fetch-data'; import { useFetchPipelineLog } from './hooks/use-fetch-pipeline-log'; import { useGetBeginNodeDataInputs } from './hooks/use-get-begin-query'; import { useIsPipeline } from './hooks/use-is-pipeline'; +import { useIsWebhookMode } from './hooks/use-is-webhook'; import { useRunDataflow } from './hooks/use-run-dataflow'; import { useSaveGraph, @@ -58,6 +59,7 @@ import { SettingDialog } from './setting-dialog'; import useGraphStore from './store'; import { useAgentHistoryManager } from './use-agent-history-manager'; import { VersionDialog } from './version-dialog'; +import WebhookSheet from './webhook-sheet'; function AgentDropdownMenuItem({ children, @@ -110,6 +112,7 @@ export default function Agent() { useShowEmbedModal(); const { navigateToAgentLogs } = useNavigatePage(); const time = useWatchAgentChange(chatDrawerVisible); + const isWebhookMode = useIsWebhookMode(); // pipeline @@ -119,6 +122,12 @@ export default function Agent() { showModal: showPipelineRunSheet, } = useSetModalState(); + const { + visible: webhookTestSheetVisible, + hideModal: hideWebhookTestSheet, + showModal: showWebhookTestSheet, + } = useSetModalState(); + const { visible: pipelineLogSheetVisible, showModal: showPipelineLogSheet, @@ -126,9 +135,9 @@ export default function Agent() { } = useSetModalState(); const { - visible: gobalParamSheetVisible, - showModal: showGobalParamSheet, - hideModal: hideGobalParamSheet, + visible: globalParamSheetVisible, + showModal: showGlobalParamSheet, + hideModal: hideGlobalParamSheet, } = useSetModalState(); const { @@ -172,12 +181,22 @@ export default function Agent() { }); const handleButtonRunClick = useCallback(() => { - if (isPipeline) { + if (isWebhookMode) { + saveGraph(); + showWebhookTestSheet(); + } else if (isPipeline) { handleRunPipeline(); } else { handleRunAgent(); } - }, [handleRunAgent, handleRunPipeline, isPipeline]); + }, [ + handleRunAgent, + handleRunPipeline, + isPipeline, + isWebhookMode, + saveGraph, + showWebhookTestSheet, + ]); const { run: runPipeline, @@ -216,7 +235,7 @@ export default function Agent() { showGobalParamSheet()} + onClick={() => showGlobalParamSheet()} loading={loading} > {t('flow.conversationVariable')} @@ -227,7 +246,7 @@ export default function Agent() { {isPipeline || ( - - - - - {t('flow.createFromBlank')} - - - - {t('flow.createFromTemplate')} - - - - {t('flow.importJsonFile')} - - - - - -
    - - {data.map((x) => { - return ( - - ); - })} - -
    -
    - -
    - {agentRenameVisible && ( - + <> + {(!data?.length || data?.length <= 0) && !searchString && ( +
    + showCreatingModal()} + /> +
    )} - {creatingVisible && ( - - )} - {fileUploadVisible && ( - - )} - +
    + {(!!data?.length || searchString) && ( + <> +
    + + + + + + + + + {t('flow.createFromBlank')} + + + + {t('flow.createFromTemplate')} + + + + {t('flow.importJsonFile')} + + + + +
    + {(!data?.length || data?.length <= 0) && searchString && ( +
    + showCreatingModal()} + /> +
    + )} +
    + + {data.map((x) => { + return ( + + ); + })} + +
    +
    + +
    + + )} + {agentRenameVisible && ( + + )} + {creatingVisible && ( + + )} + {fileUploadVisible && ( + + )} +
    + ); } diff --git a/web/src/pages/agents/upload-agent-dialog/index.tsx b/web/src/pages/agents/upload-agent-dialog/index.tsx index b2084bc0f9c..83a51dbdf69 100644 --- a/web/src/pages/agents/upload-agent-dialog/index.tsx +++ b/web/src/pages/agents/upload-agent-dialog/index.tsx @@ -6,8 +6,8 @@ import { DialogHeader, DialogTitle, } from '@/components/ui/dialog'; +import { TagRenameId } from '@/constants/knowledge'; import { IModalProps } from '@/interfaces/common'; -import { TagRenameId } from '@/pages/add-knowledge/constant'; import { useTranslation } from 'react-i18next'; import { UploadAgentForm } from './upload-agent-form'; diff --git a/web/src/pages/agents/upload-agent-dialog/upload-agent-form.tsx b/web/src/pages/agents/upload-agent-dialog/upload-agent-form.tsx index f7711b70d5c..1cabecd9353 100644 --- a/web/src/pages/agents/upload-agent-dialog/upload-agent-form.tsx +++ b/web/src/pages/agents/upload-agent-dialog/upload-agent-form.tsx @@ -14,8 +14,8 @@ import { FormMessage, } from '@/components/ui/form'; import { FileMimeType } from '@/constants/common'; +import { TagRenameId } from '@/constants/knowledge'; import { IModalProps } from '@/interfaces/common'; -import { TagRenameId } from '@/pages/add-knowledge/constant'; import { NameFormField, NameFormSchema } from '../name-form-field'; export const FormSchema = z.object({ diff --git a/web/src/pages/chat/chat-configuration-modal/assistant-setting.tsx b/web/src/pages/chat/chat-configuration-modal/assistant-setting.tsx deleted file mode 100644 index 1e955c100ba..00000000000 --- a/web/src/pages/chat/chat-configuration-modal/assistant-setting.tsx +++ /dev/null @@ -1,193 +0,0 @@ -import KnowledgeBaseItem from '@/components/knowledge-base-item'; -import { TavilyItem } from '@/components/tavily-item'; -import { useTranslate } from '@/hooks/common-hooks'; -import { useFetchTenantInfo } from '@/hooks/user-setting-hooks'; -import { PlusOutlined } from '@ant-design/icons'; -import { Form, Input, message, Select, Switch, Upload } from 'antd'; -import classNames from 'classnames'; -import { useCallback } from 'react'; -import { ISegmentedContentProps } from '../interface'; - -import { DatasetMetadata } from '@/constants/chat'; -import styles from './index.less'; -import { MetadataFilterConditions } from './metadata-filter-conditions'; - -const emptyResponseField = ['prompt_config', 'empty_response']; - -const AssistantSetting = ({ - show, - form, - setHasError, -}: ISegmentedContentProps) => { - const { t } = useTranslate('chat'); - const { data } = useFetchTenantInfo(true); - - const MetadataOptions = Object.values(DatasetMetadata).map((x) => { - return { - value: x, - label: t(`meta.${x}`), - }; - }); - - const metadata = Form.useWatch(['meta_data_filter', 'method'], form); - const kbIds = Form.useWatch(['kb_ids'], form); - - const hasKnowledge = Array.isArray(kbIds) && kbIds.length > 0; - - const handleChange = useCallback(() => { - const kbIds = form.getFieldValue('kb_ids'); - const emptyResponse = form.getFieldValue(emptyResponseField); - - const required = - emptyResponse && ((Array.isArray(kbIds) && kbIds.length === 0) || !kbIds); - - setHasError(required); - form.setFields([ - { - name: emptyResponseField, - errors: required ? [t('emptyResponseMessage')] : [], - }, - ]); - }, [form, setHasError, t]); - - const normFile = (e: any) => { - if (Array.isArray(e)) { - return e; - } - return e?.fileList; - }; - - const handleTtsChange = useCallback( - (checked: boolean) => { - if (checked && !data.tts_id) { - message.error(`Please set TTS model firstly. - Setting >> Model providers >> System model settings`); - form.setFieldValue(['prompt_config', 'tts'], false); - } - }, - [data, form], - ); - - const uploadButton = ( - - ); - - return ( -
    - - - - - - - - false} - showUploadList={{ showPreviewIcon: false, showRemoveIcon: false }} - > - {show ? uploadButton : null} - - - - - - - - - - - - - - - - - - - - {hasKnowledge && ( - - - - ) : ( -
    - {children} -
    - ); - } - - return {childNode}; -}; diff --git a/web/src/pages/chat/chat-configuration-modal/index.less b/web/src/pages/chat/chat-configuration-modal/index.less deleted file mode 100644 index 706725ccfa0..00000000000 --- a/web/src/pages/chat/chat-configuration-modal/index.less +++ /dev/null @@ -1,57 +0,0 @@ -.chatConfigurationDescription { - font-size: 14px; -} - -.variableContainer { - padding-bottom: 20px; - .variableAlign { - text-align: end; - } - - .variableLabel { - margin-right: 14px; - } - - .variableIcon { - margin-inline-start: 4px; - color: rgba(0, 0, 0, 0.45); - cursor: help; - writing-mode: horizontal-tb; - } - - .variableTable { - margin-top: 14px; - } - .editableRow { - :global(.editable-cell) { - position: relative; - } - - :global(.editable-cell-value-wrap) { - padding: 5px 12px; - cursor: pointer; - height: 22px !important; - } - &:hover { - :global(.editable-cell-value-wrap) { - padding: 4px 11px; - border: 1px solid #d9d9d9; - border-radius: 2px; - } - } - } -} - -.segmentedHidden { - opacity: 0; - height: 0; - width: 0; - margin: 0; -} - -.sliderInputNumber { - width: 80px; -} -.variableSlider { - width: 100%; -} diff --git a/web/src/pages/chat/chat-configuration-modal/index.tsx b/web/src/pages/chat/chat-configuration-modal/index.tsx deleted file mode 100644 index be87cd13e46..00000000000 --- a/web/src/pages/chat/chat-configuration-modal/index.tsx +++ /dev/null @@ -1,207 +0,0 @@ -import { ReactComponent as ChatConfigurationAtom } from '@/assets/svg/chat-configuration-atom.svg'; -import { IModalManagerChildrenProps } from '@/components/modal-manager'; -import { - ModelVariableType, - settledModelVariableMap, -} from '@/constants/knowledge'; -import { useTranslate } from '@/hooks/common-hooks'; -import { useFetchModelId } from '@/hooks/logic-hooks'; -import { IDialog } from '@/interfaces/database/chat'; -import { getBase64FromUploadFileList } from '@/utils/file-util'; -import { removeUselessFieldsFromValues } from '@/utils/form'; -import { Divider, Flex, Form, Modal, Segmented, UploadFile } from 'antd'; -import { SegmentedValue } from 'antd/es/segmented'; -import camelCase from 'lodash/camelCase'; -import { useEffect, useRef, useState } from 'react'; -import { IPromptConfigParameters } from '../interface'; -import AssistantSetting from './assistant-setting'; -import ModelSetting from './model-setting'; -import PromptEngine from './prompt-engine'; - -import styles from './index.less'; - -const layout = { - labelCol: { span: 9 }, - wrapperCol: { span: 15 }, -}; - -const validateMessages = { - required: '${label} is required!', - types: { - email: '${label} is not a valid email!', - number: '${label} is not a valid number!', - }, - number: { - range: '${label} must be between ${min} and ${max}', - }, -}; - -enum ConfigurationSegmented { - AssistantSetting = 'Assistant Setting', - PromptEngine = 'Prompt Engine', - ModelSetting = 'Model Setting', -} - -const segmentedMap = { - [ConfigurationSegmented.AssistantSetting]: AssistantSetting, - [ConfigurationSegmented.ModelSetting]: ModelSetting, - [ConfigurationSegmented.PromptEngine]: PromptEngine, -}; - -interface IProps extends IModalManagerChildrenProps { - initialDialog: IDialog; - loading: boolean; - onOk: (dialog: IDialog) => void; - clearDialog: () => void; -} - -const ChatConfigurationModal = ({ - visible, - hideModal, - initialDialog, - loading, - onOk, - clearDialog, -}: IProps) => { - const [form] = Form.useForm(); - const [hasError, setHasError] = useState(false); - - const [value, setValue] = useState( - ConfigurationSegmented.AssistantSetting, - ); - const promptEngineRef = useRef>([]); - const modelId = useFetchModelId(); - const { t } = useTranslate('chat'); - - const handleOk = async () => { - const values = await form.validateFields(); - if (hasError) { - return; - } - const nextValues: any = removeUselessFieldsFromValues( - values, - 'llm_setting.', - ); - const emptyResponse = nextValues.prompt_config?.empty_response ?? ''; - - const icon = await getBase64FromUploadFileList(values.icon); - - const finalValues = { - dialog_id: initialDialog.id, - ...nextValues, - vector_similarity_weight: 1 - nextValues.vector_similarity_weight, - prompt_config: { - ...nextValues.prompt_config, - parameters: promptEngineRef.current, - empty_response: emptyResponse, - }, - icon, - }; - onOk(finalValues); - }; - - const handleSegmentedChange = (val: SegmentedValue) => { - setValue(val as ConfigurationSegmented); - }; - - const handleModalAfterClose = () => { - clearDialog(); - form.resetFields(); - }; - - const title = ( - - -
    - {t('chatConfiguration')} -
    - {t('chatConfigurationDescription')} -
    -
    -
    - ); - - useEffect(() => { - if (visible) { - const icon = initialDialog.icon; - let fileList: UploadFile[] = []; - - if (icon) { - fileList = [{ uid: '1', name: 'file', thumbUrl: icon, status: 'done' }]; - } - form.setFieldsValue({ - ...initialDialog, - llm_setting: - initialDialog.llm_setting ?? - settledModelVariableMap[ModelVariableType.Precise], - icon: fileList, - llm_id: initialDialog.llm_id ?? modelId, - vector_similarity_weight: - 1 - (initialDialog.vector_similarity_weight ?? 0.3), - }); - } - }, [initialDialog, form, visible, modelId]); - - const handleKeyDown = (e: React.KeyboardEvent) => { - // Allow Enter in textareas - if (e.target instanceof HTMLTextAreaElement) { - return; - } - - if (e.key === 'Enter' && !e.shiftKey) { - e.preventDefault(); - handleOk(); - } - }; - - return ( - - ({ - label: t(camelCase(x)), - value: x, - }))} - block - /> - -
    - {Object.entries(segmentedMap).map(([key, Element]) => ( - - ))} -
    -
    - ); -}; - -export default ChatConfigurationModal; diff --git a/web/src/pages/chat/chat-configuration-modal/metadata-filter-conditions.tsx b/web/src/pages/chat/chat-configuration-modal/metadata-filter-conditions.tsx deleted file mode 100644 index ce2a0b59a5a..00000000000 --- a/web/src/pages/chat/chat-configuration-modal/metadata-filter-conditions.tsx +++ /dev/null @@ -1,88 +0,0 @@ -import { SwitchOperatorOptions } from '@/constants/agent'; -import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options'; -import { useFetchKnowledgeMetadata } from '@/hooks/use-knowledge-request'; -import { MinusCircleOutlined, PlusOutlined } from '@ant-design/icons'; -import { - Button, - Dropdown, - Empty, - Form, - FormListOperation, - Input, - Select, - Space, -} from 'antd'; -import { useCallback } from 'react'; -import { useTranslation } from 'react-i18next'; - -export function MetadataFilterConditions({ kbIds }: { kbIds: string[] }) { - const metadata = useFetchKnowledgeMetadata(kbIds); - const { t } = useTranslation(); - const switchOperatorOptions = useBuildSwitchOperatorOptions(); - - const renderItems = useCallback( - (add: FormListOperation['add']) => { - if (Object.keys(metadata.data).length === 0) { - return [{ key: 'noData', label: }]; - } - return Object.keys(metadata.data).map((key) => { - return { - key, - onClick: () => { - add({ - key, - value: '', - op: SwitchOperatorOptions[0].value, - }); - }, - label: key, - }; - }); - }, - [metadata], - ); - return ( - - {(fields, { add, remove }) => ( - <> - {fields.map(({ key, name, ...restField }) => ( - - - - - - - - remove(name)} /> - - ))} - - - - - - - )} - - ); -} diff --git a/web/src/pages/chat/chat-configuration-modal/model-setting.tsx b/web/src/pages/chat/chat-configuration-modal/model-setting.tsx deleted file mode 100644 index a890bbbbe91..00000000000 --- a/web/src/pages/chat/chat-configuration-modal/model-setting.tsx +++ /dev/null @@ -1,55 +0,0 @@ -import classNames from 'classnames'; -import { useEffect } from 'react'; -import { ISegmentedContentProps } from '../interface'; - -import LlmSettingItems from '@/components/llm-setting-items'; -import { - ChatVariableEnabledField, - variableEnabledFieldMap, -} from '@/constants/chat'; -import { Variable } from '@/interfaces/database/chat'; -import { setInitialChatVariableEnabledFieldValue } from '@/utils/chat'; -import styles from './index.less'; - -const ModelSetting = ({ - show, - form, - initialLlmSetting, - visible, -}: ISegmentedContentProps & { - initialLlmSetting?: Variable; - visible?: boolean; -}) => { - useEffect(() => { - if (visible) { - const values = Object.keys(variableEnabledFieldMap).reduce< - Record - >((pre, field) => { - pre[field] = - initialLlmSetting === undefined - ? setInitialChatVariableEnabledFieldValue( - field as ChatVariableEnabledField, - ) - : !!initialLlmSetting[ - variableEnabledFieldMap[ - field as keyof typeof variableEnabledFieldMap - ] as keyof Variable - ]; - return pre; - }, {}); - form.setFieldsValue(values); - } - }, [form, initialLlmSetting, visible]); - - return ( -
    - {visible && } -
    - ); -}; - -export default ModelSetting; diff --git a/web/src/pages/chat/chat-configuration-modal/prompt-engine.tsx b/web/src/pages/chat/chat-configuration-modal/prompt-engine.tsx deleted file mode 100644 index 71bfb2cfe11..00000000000 --- a/web/src/pages/chat/chat-configuration-modal/prompt-engine.tsx +++ /dev/null @@ -1,222 +0,0 @@ -import SimilaritySlider from '@/components/similarity-slider'; -import { DeleteOutlined, QuestionCircleOutlined } from '@ant-design/icons'; -import { - Button, - Col, - Divider, - Form, - Input, - Row, - Switch, - Table, - TableProps, - Tooltip, -} from 'antd'; -import classNames from 'classnames'; -import { - ForwardedRef, - forwardRef, - useEffect, - useImperativeHandle, - useState, -} from 'react'; -import { v4 as uuid } from 'uuid'; -import { - VariableTableDataType as DataType, - IPromptConfigParameters, - ISegmentedContentProps, -} from '../interface'; -import { EditableCell, EditableRow } from './editable-cell'; - -import { CrossLanguageItem } from '@/components/cross-language-item'; -import Rerank from '@/components/rerank'; -import TopNItem from '@/components/top-n-item'; -import { UseKnowledgeGraphItem } from '@/components/use-knowledge-graph-item'; -import { useTranslate } from '@/hooks/common-hooks'; -import { useSelectPromptConfigParameters } from '../hooks'; -import styles from './index.less'; - -const PromptEngine = ( - { show }: ISegmentedContentProps, - ref: ForwardedRef>, -) => { - const [dataSource, setDataSource] = useState([]); - const parameters = useSelectPromptConfigParameters(); - const { t } = useTranslate('chat'); - - const components = { - body: { - row: EditableRow, - cell: EditableCell, - }, - }; - - const handleRemove = (key: string) => () => { - const newData = dataSource.filter((item) => item.key !== key); - setDataSource(newData); - }; - - const handleSave = (row: DataType) => { - const newData = [...dataSource]; - const index = newData.findIndex((item) => row.key === item.key); - const item = newData[index]; - newData.splice(index, 1, { - ...item, - ...row, - }); - setDataSource(newData); - }; - - const handleAdd = () => { - setDataSource((state) => [ - ...state, - { - key: uuid(), - variable: '', - optional: true, - }, - ]); - }; - - const handleOptionalChange = (row: DataType) => (checked: boolean) => { - const newData = [...dataSource]; - const index = newData.findIndex((item) => row.key === item.key); - const item = newData[index]; - newData.splice(index, 1, { - ...item, - optional: checked, - }); - setDataSource(newData); - }; - - useImperativeHandle( - ref, - () => { - return dataSource - .filter((x) => x.variable.trim() !== '') - .map((x) => ({ key: x.variable, optional: x.optional })); - }, - [dataSource], - ); - - const columns: TableProps['columns'] = [ - { - title: t('key'), - dataIndex: 'variable', - key: 'variable', - onCell: (record: DataType) => ({ - record, - editable: true, - dataIndex: 'variable', - title: 'key', - handleSave, - }), - }, - { - title: t('optional'), - dataIndex: 'optional', - key: 'optional', - width: 40, - align: 'center', - render(text, record) { - return ( - - ); - }, - }, - { - title: t('operation'), - dataIndex: 'operation', - width: 30, - key: 'operation', - align: 'center', - render(_, record) { - return ; - }, - }, - ]; - - useEffect(() => { - setDataSource(parameters); - }, [parameters]); - - return ( -
    - - - - - - - - - - - - - - - -
    - - - - - - - - - {dataSource.length > 0 && ( - - - - styles.editableRow} - /> - - - )} - - - ); -}; - -export default forwardRef(PromptEngine); diff --git a/web/src/pages/chat/chat-container/index.less b/web/src/pages/chat/chat-container/index.less deleted file mode 100644 index 8430b1ef64d..00000000000 --- a/web/src/pages/chat/chat-container/index.less +++ /dev/null @@ -1,7 +0,0 @@ -.chatContainer { - padding: 0 0 24px 24px; - .messageContainer { - overflow-y: auto; - padding-right: 24px; - } -} diff --git a/web/src/pages/chat/chat-container/index.tsx b/web/src/pages/chat/chat-container/index.tsx deleted file mode 100644 index de5b41faf52..00000000000 --- a/web/src/pages/chat/chat-container/index.tsx +++ /dev/null @@ -1,125 +0,0 @@ -import MessageItem from '@/components/message-item'; -import { MessageType } from '@/constants/chat'; -import { Flex, Spin } from 'antd'; -import { - useCreateConversationBeforeUploadDocument, - useGetFileIcon, - useGetSendButtonDisabled, - useSendButtonDisabled, - useSendNextMessage, -} from '../hooks'; -import { buildMessageItemReference } from '../utils'; - -import MessageInput from '@/components/message-input'; -import PdfDrawer from '@/components/pdf-drawer'; -import { useClickDrawer } from '@/components/pdf-drawer/hooks'; -import { - useFetchNextConversation, - useFetchNextDialog, - useGetChatSearchParams, -} from '@/hooks/chat-hooks'; -import { useFetchUserInfo } from '@/hooks/user-setting-hooks'; -import { buildMessageUuidWithRole } from '@/utils/chat'; -import { memo } from 'react'; -import styles from './index.less'; - -interface IProps { - controller: AbortController; -} - -const ChatContainer = ({ controller }: IProps) => { - const { conversationId } = useGetChatSearchParams(); - const { data: conversation } = useFetchNextConversation(); - const { data: currentDialog } = useFetchNextDialog(); - - const { - value, - scrollRef, - messageContainerRef, - loading, - sendLoading, - derivedMessages, - handleInputChange, - handlePressEnter, - regenerateMessage, - removeMessageById, - stopOutputMessage, - } = useSendNextMessage(controller); - - const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } = - useClickDrawer(); - const disabled = useGetSendButtonDisabled(); - const sendDisabled = useSendButtonDisabled(value); - useGetFileIcon(); - const { data: userInfo } = useFetchUserInfo(); - const { createConversationBeforeUploadDocument } = - useCreateConversationBeforeUploadDocument(); - - return ( - <> - - -
    - - {derivedMessages?.map((message, i) => { - return ( - - ); - })} - -
    -
    - - - - - - ); -}; - -export default memo(ChatContainer); diff --git a/web/src/pages/chat/chat-id-modal/index.less b/web/src/pages/chat/chat-id-modal/index.less deleted file mode 100644 index c95b34c95aa..00000000000 --- a/web/src/pages/chat/chat-id-modal/index.less +++ /dev/null @@ -1,3 +0,0 @@ -.id { - .linkText(); -} diff --git a/web/src/pages/chat/chat-id-modal/index.tsx b/web/src/pages/chat/chat-id-modal/index.tsx deleted file mode 100644 index d640be0550d..00000000000 --- a/web/src/pages/chat/chat-id-modal/index.tsx +++ /dev/null @@ -1,38 +0,0 @@ -import { useTranslate } from '@/hooks/common-hooks'; -import { IModalProps } from '@/interfaces/common'; -import { Modal, Typography } from 'antd'; - -import styles from './index.less'; - -const { Paragraph, Link } = Typography; - -const ChatIdModal = ({ - visible, - hideModal, - id, -}: IModalProps & { id: string; name?: string; idKey: string }) => { - const { t } = useTranslate('chat'); - - return ( - - - {id} - - - {t('howUseId')} - - - ); -}; - -export default ChatIdModal; diff --git a/web/src/pages/chat/constants.ts b/web/src/pages/chat/constants.ts deleted file mode 100644 index 8c9a965f465..00000000000 --- a/web/src/pages/chat/constants.ts +++ /dev/null @@ -1 +0,0 @@ -export const EmptyConversationId = 'empty'; diff --git a/web/src/pages/chat/context.ts b/web/src/pages/chat/context.ts deleted file mode 100644 index 117641da5bf..00000000000 --- a/web/src/pages/chat/context.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { createContext } from 'react'; - -export const ConversationContext = createContext< - null | ((isPlaying: boolean) => void) ->(null); diff --git a/web/src/pages/chat/hooks.ts b/web/src/pages/chat/hooks.ts deleted file mode 100644 index ce0d657cdd0..00000000000 --- a/web/src/pages/chat/hooks.ts +++ /dev/null @@ -1,620 +0,0 @@ -import { ChatSearchParams, MessageType } from '@/constants/chat'; -import { fileIconMap } from '@/constants/common'; -import { - useFetchManualConversation, - useFetchManualDialog, - useFetchNextConversation, - useFetchNextConversationList, - useFetchNextDialog, - useFetchNextDialogList, - useGetChatSearchParams, - useRemoveNextConversation, - useRemoveNextDialog, - useSetNextDialog, - useUpdateNextConversation, -} from '@/hooks/chat-hooks'; -import { - useSetModalState, - useShowDeleteConfirm, - useTranslate, -} from '@/hooks/common-hooks'; -import { - useRegenerateMessage, - useSelectDerivedMessages, - useSendMessageWithSse, -} from '@/hooks/logic-hooks'; -import { IConversation, IDialog, Message } from '@/interfaces/database/chat'; -import { getFileExtension } from '@/utils'; -import api from '@/utils/api'; -import { getConversationId } from '@/utils/chat'; -import { useMutationState } from '@tanstack/react-query'; -import { get } from 'lodash'; -import trim from 'lodash/trim'; -import { - ChangeEventHandler, - useCallback, - useEffect, - useMemo, - useState, -} from 'react'; -import { useSearchParams } from 'umi'; -import { v4 as uuid } from 'uuid'; -import { - IClientConversation, - IMessage, - VariableTableDataType, -} from './interface'; - -export const useSetChatRouteParams = () => { - const [currentQueryParameters, setSearchParams] = useSearchParams(); - const newQueryParameters: URLSearchParams = useMemo( - () => new URLSearchParams(currentQueryParameters.toString()), - [currentQueryParameters], - ); - - const setConversationIsNew = useCallback( - (value: string) => { - newQueryParameters.set(ChatSearchParams.isNew, value); - setSearchParams(newQueryParameters); - }, - [newQueryParameters, setSearchParams], - ); - - const getConversationIsNew = useCallback(() => { - return newQueryParameters.get(ChatSearchParams.isNew); - }, [newQueryParameters]); - - return { setConversationIsNew, getConversationIsNew }; -}; - -export const useSetNewConversationRouteParams = () => { - const [currentQueryParameters, setSearchParams] = useSearchParams(); - const newQueryParameters: URLSearchParams = useMemo( - () => new URLSearchParams(currentQueryParameters.toString()), - [currentQueryParameters], - ); - - const setNewConversationRouteParams = useCallback( - (conversationId: string, isNew: string) => { - newQueryParameters.set(ChatSearchParams.ConversationId, conversationId); - newQueryParameters.set(ChatSearchParams.isNew, isNew); - setSearchParams(newQueryParameters); - }, - [newQueryParameters, setSearchParams], - ); - - return { setNewConversationRouteParams }; -}; - -export const useSelectCurrentDialog = () => { - const data = useMutationState({ - filters: { mutationKey: ['fetchDialog'] }, - select: (mutation) => { - return get(mutation, 'state.data.data', {}); - }, - }); - - return (data.at(-1) ?? {}) as IDialog; -}; - -export const useSelectPromptConfigParameters = (): VariableTableDataType[] => { - const { data: currentDialog } = useFetchNextDialog(); - - const finalParameters: VariableTableDataType[] = useMemo(() => { - const parameters = currentDialog?.prompt_config?.parameters ?? []; - if (!currentDialog.id) { - // The newly created chat has a default parameter - return [{ key: uuid(), variable: 'knowledge', optional: false }]; - } - return parameters.map((x) => ({ - key: uuid(), - variable: x.key, - optional: x.optional, - })); - }, [currentDialog]); - - return finalParameters; -}; - -export const useDeleteDialog = () => { - const showDeleteConfirm = useShowDeleteConfirm(); - - const { removeDialog } = useRemoveNextDialog(); - - const onRemoveDialog = (dialogIds: Array) => { - showDeleteConfirm({ onOk: () => removeDialog(dialogIds) }); - }; - - return { onRemoveDialog }; -}; - -export const useHandleItemHover = () => { - const [activated, setActivated] = useState(''); - - const handleItemEnter = (id: string) => { - setActivated(id); - }; - - const handleItemLeave = () => { - setActivated(''); - }; - - return { - activated, - handleItemEnter, - handleItemLeave, - }; -}; - -export const useEditDialog = () => { - const [dialog, setDialog] = useState({} as IDialog); - const { fetchDialog } = useFetchManualDialog(); - const { setDialog: submitDialog, loading } = useSetNextDialog(); - - const { - visible: dialogEditVisible, - hideModal: hideDialogEditModal, - showModal: showDialogEditModal, - } = useSetModalState(); - - const hideModal = useCallback(() => { - setDialog({} as IDialog); - hideDialogEditModal(); - }, [hideDialogEditModal]); - - const onDialogEditOk = useCallback( - async (dialog: IDialog) => { - const ret = await submitDialog(dialog); - - if (ret === 0) { - hideModal(); - } - }, - [submitDialog, hideModal], - ); - - const handleShowDialogEditModal = useCallback( - async (dialogId?: string) => { - if (dialogId) { - const ret = await fetchDialog(dialogId); - if (ret.code === 0) { - setDialog(ret.data); - } - } - showDialogEditModal(); - }, - [showDialogEditModal, fetchDialog], - ); - - const clearDialog = useCallback(() => { - setDialog({} as IDialog); - }, []); - - return { - dialogSettingLoading: loading, - initialDialog: dialog, - onDialogEditOk, - dialogEditVisible, - hideDialogEditModal: hideModal, - showDialogEditModal: handleShowDialogEditModal, - clearDialog, - }; -}; - -//#region conversation - -const useFindPrologueFromDialogList = () => { - const { dialogId } = useGetChatSearchParams(); - const { data: dialogList } = useFetchNextDialogList(true); - const prologue = useMemo(() => { - return dialogList.find((x) => x.id === dialogId)?.prompt_config.prologue; - }, [dialogId, dialogList]); - - return prologue; -}; - -export const useSelectDerivedConversationList = () => { - const { t } = useTranslate('chat'); - - const [list, setList] = useState>([]); - const { data: conversationList, loading } = useFetchNextConversationList(); - const { dialogId } = useGetChatSearchParams(); - const { setNewConversationRouteParams } = useSetNewConversationRouteParams(); - const prologue = useFindPrologueFromDialogList(); - - const addTemporaryConversation = useCallback(() => { - const conversationId = getConversationId(); - setList((pre) => { - if (dialogId) { - setNewConversationRouteParams(conversationId, 'true'); - const nextList = [ - { - id: conversationId, - name: t('newConversation'), - dialog_id: dialogId, - is_new: true, - message: [ - { - content: prologue, - role: MessageType.Assistant, - }, - ], - } as any, - ...conversationList, - ]; - return nextList; - } - - return pre; - }); - }, [conversationList, dialogId, prologue, t, setNewConversationRouteParams]); - - // When you first enter the page, select the top conversation card - - useEffect(() => { - setList([...conversationList]); - }, [conversationList]); - - return { list, addTemporaryConversation, loading }; -}; - -export const useSetConversation = () => { - const { dialogId } = useGetChatSearchParams(); - const { updateConversation } = useUpdateNextConversation(); - - const setConversation = useCallback( - async ( - message: string, - isNew: boolean = false, - conversationId?: string, - ) => { - const data = await updateConversation({ - dialog_id: dialogId, - name: message, - is_new: isNew, - conversation_id: conversationId, - message: [ - { - role: MessageType.Assistant, - content: message, - }, - ], - }); - - return data; - }, - [updateConversation, dialogId], - ); - - return { setConversation }; -}; - -export const useSelectNextMessages = () => { - const { - scrollRef, - messageContainerRef, - setDerivedMessages, - derivedMessages, - addNewestAnswer, - addNewestQuestion, - removeLatestMessage, - removeMessageById, - removeMessagesAfterCurrentMessage, - } = useSelectDerivedMessages(); - const { data: conversation, loading } = useFetchNextConversation(); - const { conversationId, dialogId, isNew } = useGetChatSearchParams(); - const prologue = useFindPrologueFromDialogList(); - - const addPrologue = useCallback(() => { - if (dialogId !== '' && isNew === 'true') { - const nextMessage = { - role: MessageType.Assistant, - content: prologue, - id: uuid(), - } as IMessage; - - setDerivedMessages([nextMessage]); - } - }, [dialogId, isNew, prologue, setDerivedMessages]); - - useEffect(() => { - addPrologue(); - }, [addPrologue]); - - useEffect(() => { - if ( - conversationId && - isNew !== 'true' && - conversation.message?.length > 0 - ) { - setDerivedMessages(conversation.message); - } - - if (!conversationId) { - setDerivedMessages([]); - } - }, [conversation.message, conversationId, setDerivedMessages, isNew]); - - return { - scrollRef, - messageContainerRef, - derivedMessages, - loading, - addNewestAnswer, - addNewestQuestion, - removeLatestMessage, - removeMessageById, - removeMessagesAfterCurrentMessage, - }; -}; - -export const useHandleMessageInputChange = () => { - const [value, setValue] = useState(''); - - const handleInputChange: ChangeEventHandler = (e) => { - const value = e.target.value; - // const nextValue = value.replaceAll('\\n', '\n').replaceAll('\\t', '\t'); - setValue(value); - }; - - return { - handleInputChange, - value, - setValue, - }; -}; - -export const useSendNextMessage = (controller: AbortController) => { - const { setConversation } = useSetConversation(); - const { conversationId, isNew } = useGetChatSearchParams(); - const { handleInputChange, value, setValue } = useHandleMessageInputChange(); - - const { send, answer, done } = useSendMessageWithSse( - api.completeConversation, - ); - const { - scrollRef, - messageContainerRef, - derivedMessages, - loading, - addNewestAnswer, - addNewestQuestion, - removeLatestMessage, - removeMessageById, - removeMessagesAfterCurrentMessage, - } = useSelectNextMessages(); - const { setConversationIsNew, getConversationIsNew } = - useSetChatRouteParams(); - - const stopOutputMessage = useCallback(() => { - controller.abort(); - }, [controller]); - - const sendMessage = useCallback( - async ({ - message, - currentConversationId, - messages, - }: { - message: Message; - currentConversationId?: string; - messages?: Message[]; - }) => { - const res = await send( - { - conversation_id: currentConversationId ?? conversationId, - messages: [...(messages ?? derivedMessages ?? []), message], - }, - controller, - ); - - if (res && (res?.response.status !== 200 || res?.data?.code !== 0)) { - // cancel loading - setValue(message.content); - console.info('removeLatestMessage111'); - removeLatestMessage(); - } - }, - [ - derivedMessages, - conversationId, - removeLatestMessage, - setValue, - send, - controller, - ], - ); - - const handleSendMessage = useCallback( - async (message: Message) => { - const isNew = getConversationIsNew(); - if (isNew !== 'true') { - sendMessage({ message }); - } else { - const data = await setConversation( - message.content, - true, - conversationId, - ); - if (data.code === 0) { - setConversationIsNew(''); - const id = data.data.id; - // currentConversationIdRef.current = id; - sendMessage({ - message, - currentConversationId: id, - messages: data.data.message, - }); - } - } - }, - [ - setConversation, - sendMessage, - setConversationIsNew, - getConversationIsNew, - conversationId, - ], - ); - - const { regenerateMessage } = useRegenerateMessage({ - removeMessagesAfterCurrentMessage, - sendMessage, - messages: derivedMessages, - }); - - useEffect(() => { - // #1289 - if (answer.answer && conversationId && isNew !== 'true') { - addNewestAnswer(answer); - } - }, [answer, addNewestAnswer, conversationId, isNew]); - - const handlePressEnter = useCallback( - (documentIds: string[]) => { - if (trim(value) === '') return; - const id = uuid(); - - addNewestQuestion({ - content: value, - doc_ids: documentIds, - id, - role: MessageType.User, - }); - if (done) { - setValue(''); - handleSendMessage({ - id, - content: value.trim(), - role: MessageType.User, - doc_ids: documentIds, - }); - } - }, - [addNewestQuestion, handleSendMessage, done, setValue, value], - ); - - return { - handlePressEnter, - handleInputChange, - value, - setValue, - regenerateMessage, - sendLoading: !done, - loading, - scrollRef, - messageContainerRef, - derivedMessages, - removeMessageById, - stopOutputMessage, - }; -}; - -export const useGetFileIcon = () => { - const getFileIcon = (filename: string) => { - const ext: string = getFileExtension(filename); - const iconPath = fileIconMap[ext as keyof typeof fileIconMap]; - return `@/assets/svg/file-icon/${iconPath}`; - }; - - return getFileIcon; -}; - -export const useDeleteConversation = () => { - const showDeleteConfirm = useShowDeleteConfirm(); - const { removeConversation } = useRemoveNextConversation(); - - const deleteConversation = (conversationIds: Array) => async () => { - const ret = await removeConversation(conversationIds); - - return ret; - }; - - const onRemoveConversation = (conversationIds: Array) => { - showDeleteConfirm({ onOk: deleteConversation(conversationIds) }); - }; - - return { onRemoveConversation }; -}; - -export const useRenameConversation = () => { - const [conversation, setConversation] = useState( - {} as IClientConversation, - ); - const { fetchConversation } = useFetchManualConversation(); - const { - visible: conversationRenameVisible, - hideModal: hideConversationRenameModal, - showModal: showConversationRenameModal, - } = useSetModalState(); - const { updateConversation, loading } = useUpdateNextConversation(); - - const onConversationRenameOk = useCallback( - async (name: string) => { - const ret = await updateConversation({ - conversation_id: conversation.id, - name, - is_new: false, - }); - - if (ret.code === 0) { - hideConversationRenameModal(); - } - }, - [updateConversation, conversation, hideConversationRenameModal], - ); - - const handleShowConversationRenameModal = useCallback( - async (conversationId: string) => { - const ret = await fetchConversation(conversationId); - if (ret.code === 0) { - setConversation(ret.data); - } - showConversationRenameModal(); - }, - [showConversationRenameModal, fetchConversation], - ); - - return { - conversationRenameLoading: loading, - initialConversationName: conversation.name, - onConversationRenameOk, - conversationRenameVisible, - hideConversationRenameModal, - showConversationRenameModal: handleShowConversationRenameModal, - }; -}; - -export const useGetSendButtonDisabled = () => { - const { dialogId, conversationId } = useGetChatSearchParams(); - - return dialogId === '' || conversationId === ''; -}; - -export const useSendButtonDisabled = (value: string) => { - return trim(value) === ''; -}; - -export const useCreateConversationBeforeUploadDocument = () => { - const { setConversation } = useSetConversation(); - const { dialogId } = useGetChatSearchParams(); - const { getConversationIsNew } = useSetChatRouteParams(); - - const createConversationBeforeUploadDocument = useCallback( - async (message: string) => { - const isNew = getConversationIsNew(); - if (isNew === 'true') { - const data = await setConversation(message, true); - - return data; - } - }, - [setConversation, getConversationIsNew], - ); - - return { - createConversationBeforeUploadDocument, - dialogId, - }; -}; -//#endregion diff --git a/web/src/pages/chat/index.less b/web/src/pages/chat/index.less deleted file mode 100644 index 58bcb26dcc7..00000000000 --- a/web/src/pages/chat/index.less +++ /dev/null @@ -1,82 +0,0 @@ -.chatWrapper { - height: 100%; - width: 100%; - - .chatAppWrapper { - width: 288px; - padding: 26px; - - .chatAppContent { - overflow-y: auto; - width: 100%; - } - - .chatAppCard { - :global(.ant-card-body) { - padding: 10px; - } - .cubeIcon { - &:hover { - cursor: pointer; - } - } - } - .chatAppCardSelected { - :global(.ant-card-body) { - background-color: @gray11; - border-radius: 8px; - } - } - .chatAppCardSelectedDark { - :global(.ant-card-body) { - background-color: rgba(255, 255, 255, 0.1); - border-radius: 8px; - } - } - } - .chatTitleWrapper { - width: 220px; - padding: 26px 0; - } - - .chatTitle { - padding: 5px 15px; - } - - .chatTitleContent { - padding: 5px 10px; - overflow: auto; - } - - .chatSpin { - :global(.ant-spin-container) { - display: flex; - flex-direction: column; - gap: 10px; - } - } - - .chatTitleCard { - :global(.ant-card-body) { - padding: 8px; - } - } - - .chatTitleCardSelected { - :global(.ant-card-body) { - background-color: @gray11; - border-radius: 8px; - } - } - .chatTitleCardSelectedDark { - :global(.ant-card-body) { - background-color: rgba(255, 255, 255, 0.1); - border-radius: 8px; - } - } - - .divider { - margin: 0; - height: 100%; - } -} diff --git a/web/src/pages/chat/index.tsx b/web/src/pages/chat/index.tsx deleted file mode 100644 index 9c7d67e9237..00000000000 --- a/web/src/pages/chat/index.tsx +++ /dev/null @@ -1,393 +0,0 @@ -import { ReactComponent as ChatAppCube } from '@/assets/svg/chat-app-cube.svg'; -import RenameModal from '@/components/rename-modal'; -import { DeleteOutlined, EditOutlined } from '@ant-design/icons'; -import { - Avatar, - Button, - Card, - Divider, - Dropdown, - Flex, - MenuProps, - Space, - Spin, - Tag, - Tooltip, - Typography, -} from 'antd'; -import { MenuItemProps } from 'antd/lib/menu/MenuItem'; -import classNames from 'classnames'; -import { useCallback, useState } from 'react'; -import ChatConfigurationModal from './chat-configuration-modal'; -import ChatContainer from './chat-container'; -import { - useDeleteConversation, - useDeleteDialog, - useEditDialog, - useHandleItemHover, - useRenameConversation, - useSelectDerivedConversationList, -} from './hooks'; - -import EmbedModal from '@/components/api-service/embed-modal'; -import { useShowEmbedModal } from '@/components/api-service/hooks'; -import SvgIcon from '@/components/svg-icon'; -import { useTheme } from '@/components/theme-provider'; -import { SharedFrom } from '@/constants/chat'; -import { - useClickConversationCard, - useClickDialogCard, - useFetchNextDialogList, - useGetChatSearchParams, -} from '@/hooks/chat-hooks'; -import { useTranslate } from '@/hooks/common-hooks'; -import { useSetSelectedRecord } from '@/hooks/logic-hooks'; -import { IDialog } from '@/interfaces/database/chat'; -import { PictureInPicture2 } from 'lucide-react'; -import styles from './index.less'; - -const { Text } = Typography; - -const Chat = () => { - const { data: dialogList, loading: dialogLoading } = useFetchNextDialogList(); - const { onRemoveDialog } = useDeleteDialog(); - const { onRemoveConversation } = useDeleteConversation(); - const { handleClickDialog } = useClickDialogCard(); - const { handleClickConversation } = useClickConversationCard(); - const { dialogId, conversationId } = useGetChatSearchParams(); - const { theme } = useTheme(); - const { - list: conversationList, - addTemporaryConversation, - loading: conversationLoading, - } = useSelectDerivedConversationList(); - const { activated, handleItemEnter, handleItemLeave } = useHandleItemHover(); - const { - activated: conversationActivated, - handleItemEnter: handleConversationItemEnter, - handleItemLeave: handleConversationItemLeave, - } = useHandleItemHover(); - const { - conversationRenameLoading, - initialConversationName, - onConversationRenameOk, - conversationRenameVisible, - hideConversationRenameModal, - showConversationRenameModal, - } = useRenameConversation(); - const { - dialogSettingLoading, - initialDialog, - onDialogEditOk, - dialogEditVisible, - clearDialog, - hideDialogEditModal, - showDialogEditModal, - } = useEditDialog(); - const { t } = useTranslate('chat'); - const { currentRecord, setRecord } = useSetSelectedRecord(); - const [controller, setController] = useState(new AbortController()); - const { showEmbedModal, hideEmbedModal, embedVisible, beta } = - useShowEmbedModal(); - - const handleAppCardEnter = (id: string) => () => { - handleItemEnter(id); - }; - - const handleConversationCardEnter = (id: string) => () => { - handleConversationItemEnter(id); - }; - - const handleShowChatConfigurationModal = - (dialogId?: string): any => - (info: any) => { - info?.domEvent?.preventDefault(); - info?.domEvent?.stopPropagation(); - showDialogEditModal(dialogId); - }; - - const handleRemoveDialog = - (dialogId: string): MenuItemProps['onClick'] => - ({ domEvent }) => { - domEvent.preventDefault(); - domEvent.stopPropagation(); - onRemoveDialog([dialogId]); - }; - - const handleShowOverviewModal = - (dialog: IDialog): any => - (info: any) => { - info?.domEvent?.preventDefault(); - info?.domEvent?.stopPropagation(); - setRecord(dialog); - showEmbedModal(); - }; - - const handleRemoveConversation = - (conversationId: string): MenuItemProps['onClick'] => - ({ domEvent }) => { - domEvent.preventDefault(); - domEvent.stopPropagation(); - onRemoveConversation([conversationId]); - }; - - const handleShowConversationRenameModal = - (conversationId: string): MenuItemProps['onClick'] => - ({ domEvent }) => { - domEvent.preventDefault(); - domEvent.stopPropagation(); - showConversationRenameModal(conversationId); - }; - - const handleDialogCardClick = useCallback( - (dialogId: string) => () => { - handleClickDialog(dialogId); - }, - [handleClickDialog], - ); - - const handleConversationCardClick = useCallback( - (conversationId: string, isNew: boolean) => () => { - handleClickConversation(conversationId, isNew ? 'true' : ''); - setController((pre) => { - pre.abort(); - return new AbortController(); - }); - }, - [handleClickConversation], - ); - - const handleCreateTemporaryConversation = useCallback(() => { - addTemporaryConversation(); - }, [addTemporaryConversation]); - - const buildAppItems = (dialog: IDialog) => { - const dialogId = dialog.id; - - const appItems: MenuProps['items'] = [ - { - key: '1', - onClick: handleShowChatConfigurationModal(dialogId), - label: ( - - - {t('edit', { keyPrefix: 'common' })} - - ), - }, - { type: 'divider' }, - { - key: '2', - onClick: handleRemoveDialog(dialogId), - label: ( - - - {t('delete', { keyPrefix: 'common' })} - - ), - }, - { type: 'divider' }, - { - key: '3', - onClick: handleShowOverviewModal(dialog), - label: ( - - {/* */} - - {t('embedIntoSite', { keyPrefix: 'common' })} - - ), - }, - ]; - - return appItems; - }; - - const buildConversationItems = (conversationId: string) => { - const appItems: MenuProps['items'] = [ - { - key: '1', - onClick: handleShowConversationRenameModal(conversationId), - label: ( - - - {t('rename', { keyPrefix: 'common' })} - - ), - }, - { type: 'divider' }, - { - key: '2', - onClick: handleRemoveConversation(conversationId), - label: ( - - - {t('delete', { keyPrefix: 'common' })} - - ), - }, - ]; - - return appItems; - }; - - return ( - - - - - - - - {dialogList.map((x) => ( - - - - -
    - - - {x.name} - - -
    {x.description}
    -
    -
    - {activated === x.id && ( -
    - - - -
    - )} -
    -
    - ))} -
    -
    -
    -
    - - - - - - {t('chat')} - {conversationList.length} - - -
    - -
    -
    -
    - - - - {conversationList.map((x) => ( - - -
    - - {x.name} - -
    - {conversationActivated === x.id && - x.id !== '' && - !x.is_new && ( -
    - - - -
    - )} -
    -
    - ))} -
    -
    -
    -
    - - - {dialogEditVisible && ( - - )} - - - {embedVisible && ( - - )} -
    - ); -}; - -export default Chat; diff --git a/web/src/pages/chat/interface.ts b/web/src/pages/chat/interface.ts deleted file mode 100644 index 570c6da219e..00000000000 --- a/web/src/pages/chat/interface.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { IConversation, IReference, Message } from '@/interfaces/database/chat'; -import { FormInstance } from 'antd'; - -export interface ISegmentedContentProps { - show: boolean; - form: FormInstance; - setHasError: (hasError: boolean) => void; -} - -export interface IVariable { - temperature: number; - top_p: number; - frequency_penalty: number; - presence_penalty: number; - max_tokens: number; -} - -export interface VariableTableDataType { - key: string; - variable: string; - optional: boolean; -} - -export type IPromptConfigParameters = Omit; - -export interface IMessage extends Message { - id: string; - reference?: IReference; // the latest news has reference -} - -export interface IClientConversation extends IConversation { - message: IMessage[]; -} diff --git a/web/src/pages/chat/share/index.less b/web/src/pages/chat/share/index.less deleted file mode 100644 index 01e090061e9..00000000000 --- a/web/src/pages/chat/share/index.less +++ /dev/null @@ -1,13 +0,0 @@ -.chatWrapper { - height: 100vh; -} - -.chatContainer { - padding: 10px; - box-sizing: border-box; - height: 100%; - .messageContainer { - overflow-y: auto; - padding-right: 6px; - } -} diff --git a/web/src/pages/chat/share/index.tsx b/web/src/pages/chat/share/index.tsx deleted file mode 100644 index acaadcbf944..00000000000 --- a/web/src/pages/chat/share/index.tsx +++ /dev/null @@ -1,13 +0,0 @@ -import ChatContainer from './large'; - -import styles from './index.less'; - -const SharedChat = () => { - return ( -
    - -
    - ); -}; - -export default SharedChat; diff --git a/web/src/pages/chat/share/large.tsx b/web/src/pages/chat/share/large.tsx deleted file mode 100644 index dfdd8c5a7c3..00000000000 --- a/web/src/pages/chat/share/large.tsx +++ /dev/null @@ -1,124 +0,0 @@ -import MessageInput from '@/components/message-input'; -import MessageItem from '@/components/message-item'; -import { useClickDrawer } from '@/components/pdf-drawer/hooks'; -import { MessageType, SharedFrom } from '@/constants/chat'; -import { useSendButtonDisabled } from '@/pages/chat/hooks'; -import { Flex, Spin } from 'antd'; -import React, { forwardRef, useMemo } from 'react'; -import { - useGetSharedChatSearchParams, - useSendSharedMessage, -} from '../shared-hooks'; -import { buildMessageItemReference } from '../utils'; - -import PdfDrawer from '@/components/pdf-drawer'; -import { useFetchNextConversationSSE } from '@/hooks/chat-hooks'; -import { useFetchFlowSSE } from '@/hooks/flow-hooks'; -import i18n from '@/locales/config'; -import { buildMessageUuidWithRole } from '@/utils/chat'; -import styles from './index.less'; - -const ChatContainer = () => { - const { - sharedId: conversationId, - from, - locale, - visibleAvatar, - } = useGetSharedChatSearchParams(); - const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } = - useClickDrawer(); - - const { - handlePressEnter, - handleInputChange, - value, - sendLoading, - loading, - ref, - derivedMessages, - hasError, - stopOutputMessage, - } = useSendSharedMessage(); - const sendDisabled = useSendButtonDisabled(value); - - const useFetchAvatar = useMemo(() => { - return from === SharedFrom.Agent - ? useFetchFlowSSE - : useFetchNextConversationSSE; - }, [from]); - React.useEffect(() => { - if (locale && i18n.language !== locale) { - i18n.changeLanguage(locale); - } - }, [locale, visibleAvatar]); - const { data: avatarData } = useFetchAvatar(); - - if (!conversationId) { - return
    empty
    ; - } - - return ( - <> - - -
    - - {derivedMessages?.map((message, i) => { - return ( - - ); - })} - -
    -
    - - - - - {visible && ( - - )} - - ); -}; - -export default forwardRef(ChatContainer); diff --git a/web/src/pages/chat/shared-hooks.ts b/web/src/pages/chat/shared-hooks.ts deleted file mode 100644 index be5d49ffe57..00000000000 --- a/web/src/pages/chat/shared-hooks.ts +++ /dev/null @@ -1,149 +0,0 @@ -import { MessageType, SharedFrom } from '@/constants/chat'; -import { useCreateNextSharedConversation } from '@/hooks/chat-hooks'; -import { - useSelectDerivedMessages, - useSendMessageWithSse, -} from '@/hooks/logic-hooks'; -import { Message } from '@/interfaces/database/chat'; -import { message } from 'antd'; -import { get } from 'lodash'; -import trim from 'lodash/trim'; -import { useCallback, useEffect, useState } from 'react'; -import { useSearchParams } from 'umi'; -import { v4 as uuid } from 'uuid'; -import { useHandleMessageInputChange } from './hooks'; - -const isCompletionError = (res: any) => - res && (res?.response.status !== 200 || res?.data?.code !== 0); - -export const useSendButtonDisabled = (value: string) => { - return trim(value) === ''; -}; - -export const useGetSharedChatSearchParams = () => { - const [searchParams] = useSearchParams(); - const data_prefix = 'data_'; - const data = Object.fromEntries( - searchParams - .entries() - .filter(([key]) => key.startsWith(data_prefix)) - .map(([key, value]) => [key.replace(data_prefix, ''), value]), - ); - return { - from: searchParams.get('from') as SharedFrom, - sharedId: searchParams.get('shared_id'), - locale: searchParams.get('locale'), - data: data, - visibleAvatar: searchParams.get('visible_avatar') - ? searchParams.get('visible_avatar') !== '1' - : true, - }; -}; - -export const useSendSharedMessage = () => { - const { - from, - sharedId: conversationId, - data: data, - } = useGetSharedChatSearchParams(); - const { createSharedConversation: setConversation } = - useCreateNextSharedConversation(); - const { handleInputChange, value, setValue } = useHandleMessageInputChange(); - const { send, answer, done, stopOutputMessage } = useSendMessageWithSse( - `/api/v1/${from === SharedFrom.Agent ? 'agentbots' : 'chatbots'}/${conversationId}/completions`, - ); - const { - derivedMessages, - ref, - removeLatestMessage, - addNewestAnswer, - addNewestQuestion, - } = useSelectDerivedMessages(); - const [hasError, setHasError] = useState(false); - - const sendMessage = useCallback( - async (message: Message, id?: string) => { - const res = await send({ - conversation_id: id ?? conversationId, - quote: true, - question: message.content, - session_id: get(derivedMessages, '0.session_id'), - }); - - if (isCompletionError(res)) { - // cancel loading - setValue(message.content); - removeLatestMessage(); - } - }, - [send, conversationId, derivedMessages, setValue, removeLatestMessage], - ); - - const handleSendMessage = useCallback( - async (message: Message) => { - if (conversationId !== '') { - sendMessage(message); - } else { - const data = await setConversation('user id'); - if (data.code === 0) { - const id = data.data.id; - sendMessage(message, id); - } - } - }, - [conversationId, setConversation, sendMessage], - ); - - const fetchSessionId = useCallback(async () => { - const payload = { question: '' }; - const ret = await send({ ...payload, ...data }); - if (isCompletionError(ret)) { - message.error(ret?.data.message); - setHasError(true); - } - }, [send]); - - useEffect(() => { - fetchSessionId(); - }, [fetchSessionId, send]); - - useEffect(() => { - if (answer.answer) { - addNewestAnswer(answer); - } - }, [answer, addNewestAnswer]); - - const handlePressEnter = useCallback( - (documentIds: string[]) => { - if (trim(value) === '') return; - const id = uuid(); - if (done) { - setValue(''); - addNewestQuestion({ - content: value, - doc_ids: documentIds, - id, - role: MessageType.User, - }); - handleSendMessage({ - content: value.trim(), - id, - role: MessageType.User, - }); - } - }, - [addNewestQuestion, done, handleSendMessage, setValue, value], - ); - - return { - handlePressEnter, - handleInputChange, - value, - sendLoading: !done, - ref, - loading: false, - derivedMessages, - hasError, - stopOutputMessage, - }; -}; diff --git a/web/src/pages/chat/utils.ts b/web/src/pages/chat/utils.ts deleted file mode 100644 index e3e4f5ff205..00000000000 --- a/web/src/pages/chat/utils.ts +++ /dev/null @@ -1,53 +0,0 @@ -import { MessageType } from '@/constants/chat'; -import { IConversation, IReference } from '@/interfaces/database/chat'; -import { isEmpty } from 'lodash'; -import { EmptyConversationId } from './constants'; -import { IMessage } from './interface'; - -export const isConversationIdExist = (conversationId: string) => { - return conversationId !== EmptyConversationId && conversationId !== ''; -}; - -export const getDocumentIdsFromConversionReference = (data: IConversation) => { - const documentIds = data.reference.reduce( - (pre: Array, cur: IReference) => { - cur.doc_aggs - ?.map((x) => x.doc_id) - .forEach((x) => { - if (pre.every((y) => y !== x)) { - pre.push(x); - } - }); - return pre; - }, - [], - ); - return documentIds.join(','); -}; - -export const buildMessageItemReference = ( - conversation: { message: IMessage[]; reference: IReference[] }, - message: IMessage, -) => { - const assistantMessages = conversation.message - ?.filter((x) => x.role === MessageType.Assistant) - .slice(1); - const referenceIndex = assistantMessages.findIndex( - (x) => x.id === message.id, - ); - const reference = !isEmpty(message?.reference) - ? message?.reference - : (conversation?.reference ?? [])[referenceIndex]; - - return reference ?? { doc_aggs: [], chunks: [], total: 0 }; -}; - -const oldReg = /(#{2}\d+\${2})/g; -export const currentReg = /\[ID:(\d+)\]/g; - -// To be compatible with the old index matching mode -export const replaceTextByOldReg = (text: string) => { - return text?.replace(oldReg, (substring: string) => { - return `[ID:${substring.slice(2, -2)}]`; - }); -}; diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.less b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.less index aac7724af4f..622e3b9e5e3 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.less +++ b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.less @@ -4,6 +4,8 @@ } .imagePreview { + width: 100%; + height: 100%; max-width: 50vw; max-height: 50vh; object-fit: contain; diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.tsx index 69859ebfd4d..97a5af71431 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.tsx +++ b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-card/index.tsx @@ -2,17 +2,19 @@ import Image from '@/components/image'; import { useTheme } from '@/components/theme-provider'; import { Card } from '@/components/ui/card'; import { Checkbox } from '@/components/ui/checkbox'; -import { - Popover, - PopoverContent, - PopoverTrigger, -} from '@/components/ui/popover'; import { Switch } from '@/components/ui/switch'; -import { IChunk } from '@/interfaces/database/knowledge'; +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from '@/components/ui/tooltip'; +import type { ChunkDocType, IChunk } from '@/interfaces/database/knowledge'; +import { cn } from '@/lib/utils'; import { CheckedState } from '@radix-ui/react-checkbox'; import classNames from 'classnames'; import DOMPurify from 'dompurify'; import { useEffect, useState } from 'react'; +import { useTranslation } from 'react-i18next'; import { ChunkTextMode } from '../../constant'; import styles from './index.less'; @@ -25,6 +27,7 @@ interface IProps { selected: boolean; clickChunkCard: (chunkId: string) => void; textMode: ChunkTextMode; + t?: string | number; // Cache-busting key for images } const ChunkCard = ({ @@ -36,7 +39,9 @@ const ChunkCard = ({ selected, clickChunkCard, textMode, + t: imageCacheKey, }: IProps) => { + const { t } = useTranslation(); const available = Number(item.available_int); const [enabled, setEnabled] = useState(false); const { theme } = useTheme(); @@ -61,46 +66,61 @@ const ChunkCard = ({ useEffect(() => { setEnabled(available === 1); }, [available]); - const [open, setOpen] = useState(false); + + const chunkType = + ((item.doc_type_kwd && + String(item.doc_type_kwd)?.toLowerCase()) as ChunkDocType) || 'text'; + return ( + + {t(`chunk.docType.${chunkType}`)} + +
    + + {/* Using instead of to avoid flickering when hovering over the image */} {item.image_id && ( - - setOpen(true)} - onMouseLeave={() => setOpen(false)} - > -
    - -
    -
    - + + + + -
    - -
    -
    -
    + + +
    )} +
    -
    + +
    ((resolve, reject) => { + const fr = new FileReader(); + fr.addEventListener('load', () => { + resolve((fr.result?.toString() ?? '').replace(/^.*,/, '')); + }); + fr.onerror = reject; + fr.readAsDataURL(file); + }); +} + const ChunkCreatingModal: React.FC & kFProps> = ({ doc_id, chunkId, @@ -52,21 +71,28 @@ const ChunkCreatingModal: React.FC & kFProps> = ({ question_kwd: [], important_kwd: [], tag_feas: [], + image: [], }, }); const [checked, setChecked] = useState(false); const { removeChunk } = useDeleteChunkByIds(); const { data } = useFetchChunk(chunkId); const { t } = useTranslation(); + const isEditMode = !!chunkId; const isTagParser = parserId === 'tag'; const onSubmit = useCallback( - (values: FieldValues) => { - onOk?.({ + async (values: FieldValues) => { + const prunedValues = { ...values, + image_base64: await fileToBase64(values.image?.[0] as File), tag_feas: transformTagFeaturesArrayToObject(values.tag_feas), available_int: checked ? 1 : 0, - }); + } as FieldValues; + + Reflect.deleteProperty(prunedValues, 'image'); + + onOk?.(prunedValues); }, [checked, onOk], ); @@ -86,6 +112,7 @@ const ChunkCreatingModal: React.FC & kFProps> = ({ useEffect(() => { if (data?.code === 0) { const { available_int, tag_feas } = data.data; + form.reset({ ...data.data, tag_feas: transformTagFeaturesObjectToArray(tag_feas), @@ -119,6 +146,74 @@ const ChunkCreatingModal: React.FC & kFProps> = ({ )} /> + + {/* Do not display the type field in create mode */} + {isEditMode && ( + { + const chunkType = + ((field.value && + String(field.value)?.toLowerCase()) as ChunkDocType) || + 'text'; + + return ( + + {t(`chunk.type`)} + + + + + ); + }} + /> + )} + + {isEditMode && form.getValues('doc_type_kwd') === 'image' && ( + ( + + {t('chunk.image')} + +
    + {data?.data?.img_id && ( + + )} + +
    + + } + /> + +
    +
    +
    + )} + /> + )} + { }, [selectedChunkIds]); return ( -
    +
    -
    - {textSelectOptions.map((option) => ( -
    changeTextSelectValue(option.value)} - > - {option.label} -
    - ))} -
    -
    - } - onChange={handleInputChange} - value={searchString} + -
    - - - - - - {filterContent} - - -
    - +
    +
    + } + onChange={handleInputChange} + value={searchString} + /> + + + + + + {filterContent} + + + +
    + {/*
    +
    */}
    ); }; diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-toolbar/index.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-toolbar/index.tsx index 6a513ba78d9..e1c7c6ae5cc 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-toolbar/index.tsx +++ b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/chunk-toolbar/index.tsx @@ -1,8 +1,11 @@ import { ReactComponent as FilterIcon } from '@/assets/filter.svg'; import { KnowledgeRouteKey } from '@/constants/knowledge'; -import { IChunkListResult, useSelectChunkList } from '@/hooks/chunk-hooks'; import { useTranslate } from '@/hooks/common-hooks'; -import { useKnowledgeBaseId } from '@/hooks/knowledge-hooks'; +import { + IChunkListResult, + useSelectChunkList, +} from '@/hooks/use-chunk-request'; +import { useKnowledgeBaseId } from '@/hooks/use-knowledge-request'; import { ArrowLeftOutlined, CheckCircleOutlined, diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/doc-preview.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/doc-preview.tsx deleted file mode 100644 index 845f0374e38..00000000000 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/doc-preview.tsx +++ /dev/null @@ -1,70 +0,0 @@ -import message from '@/components/ui/message'; -import { Spin } from '@/components/ui/spin'; -import request from '@/utils/request'; -import classNames from 'classnames'; -import mammoth from 'mammoth'; -import { useEffect, useState } from 'react'; - -interface DocPreviewerProps { - className?: string; - url: string; -} - -export const DocPreviewer: React.FC = ({ - className, - url, -}) => { - // const url = useGetDocumentUrl(); - const [htmlContent, setHtmlContent] = useState(''); - const [loading, setLoading] = useState(false); - const fetchDocument = async () => { - setLoading(true); - const res = await request(url, { - method: 'GET', - responseType: 'blob', - onError: () => { - message.error('Document parsing failed'); - console.error('Error loading document:', url); - }, - }); - try { - const arrayBuffer = await res.data.arrayBuffer(); - const result = await mammoth.convertToHtml( - { arrayBuffer }, - { includeDefaultStyleMap: true }, - ); - - const styledContent = result.value - .replace(/

    /g, '

    ') - .replace(//g, ''); - - setHtmlContent(styledContent); - } catch (err) { - message.error('Document parsing failed'); - console.error('Error parsing document:', err); - } - setLoading(false); - }; - - useEffect(() => { - if (url) { - fetchDocument(); - } - }, [url]); - return ( -

    - {loading && ( -
    - -
    - )} - - {!loading &&
    } -
    - ); -}; diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/document-header.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/document-header.tsx deleted file mode 100644 index 88391f8c096..00000000000 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/document-header.tsx +++ /dev/null @@ -1,21 +0,0 @@ -import { formatDate } from '@/utils/date'; -import { formatBytes } from '@/utils/file-util'; - -type Props = { - size: number; - name: string; - create_date: string; -}; - -export default ({ size, name, create_date }: Props) => { - const sizeName = formatBytes(size); - const dateStr = formatDate(create_date); - return ( -
    -

    {name}

    -
    - Size:{sizeName} Uploaded Time:{dateStr} -
    -
    - ); -}; diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/hooks.ts b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/hooks.ts deleted file mode 100644 index fcf6a01babd..00000000000 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/hooks.ts +++ /dev/null @@ -1,55 +0,0 @@ -import { useGetKnowledgeSearchParams } from '@/hooks/route-hook'; -import { api_host } from '@/utils/api'; -import { useSize } from 'ahooks'; -import { CustomTextRenderer } from 'node_modules/react-pdf/dist/esm/shared/types'; -import { useCallback, useEffect, useMemo, useState } from 'react'; - -export const useDocumentResizeObserver = () => { - const [containerWidth, setContainerWidth] = useState(); - const [containerRef, setContainerRef] = useState(null); - const size = useSize(containerRef); - - const onResize = useCallback((width?: number) => { - if (width) { - setContainerWidth(width); - } - }, []); - - useEffect(() => { - onResize(size?.width); - }, [size?.width, onResize]); - - return { containerWidth, setContainerRef }; -}; - -function highlightPattern(text: string, pattern: string, pageNumber: number) { - if (pageNumber === 2) { - return `${text}`; - } - if (text.trim() !== '' && pattern.match(text)) { - // return pattern.replace(text, (value) => `${value}`); - return `${text}`; - } - return text.replace(pattern, (value) => `${value}`); -} - -export const useHighlightText = (searchText: string = '') => { - const textRenderer: CustomTextRenderer = useCallback( - (textItem) => { - return highlightPattern(textItem.str, searchText, textItem.pageNumber); - }, - [searchText], - ); - - return textRenderer; -}; - -export const useGetDocumentUrl = () => { - const { documentId } = useGetKnowledgeSearchParams(); - - const url = useMemo(() => { - return `${api_host}/document/get/${documentId}`; - }, [documentId]); - - return url; -}; diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/image-preview.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/image-preview.tsx deleted file mode 100644 index bec4ae75e87..00000000000 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/image-preview.tsx +++ /dev/null @@ -1,74 +0,0 @@ -import message from '@/components/ui/message'; -import { Spin } from '@/components/ui/spin'; -import request from '@/utils/request'; -import classNames from 'classnames'; -import { useCallback, useEffect, useState } from 'react'; - -interface ImagePreviewerProps { - className?: string; - url: string; -} - -export const ImagePreviewer: React.FC = ({ - className, - url, -}) => { - // const url = useGetDocumentUrl(); - const [imageSrc, setImageSrc] = useState(null); - const [isLoading, setIsLoading] = useState(true); - - const fetchImage = useCallback(async () => { - setIsLoading(true); - const res = await request(url, { - method: 'GET', - responseType: 'blob', - onError: () => { - message.error('Failed to load image'); - setIsLoading(false); - }, - }); - const objectUrl = URL.createObjectURL(res.data); - setImageSrc(objectUrl); - setIsLoading(false); - }, [url]); - - useEffect(() => { - if (url) { - fetchImage(); - } - }, [url, fetchImage]); - - useEffect(() => { - return () => { - if (imageSrc) { - URL.revokeObjectURL(imageSrc); - } - }; - }, [imageSrc]); - - return ( -
    - {isLoading && ( -
    - -
    - )} - - {!isLoading && imageSrc && ( -
    - {'image'} URL.revokeObjectURL(imageSrc!)} - /> -
    - )} -
    - ); -}; diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/pdf-preview.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/pdf-preview.tsx deleted file mode 100644 index 79b1c54ae9d..00000000000 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/components/document-preview/pdf-preview.tsx +++ /dev/null @@ -1,127 +0,0 @@ -import { memo, useEffect, useRef } from 'react'; -import { - AreaHighlight, - Highlight, - IHighlight, - PdfHighlighter, - PdfLoader, - Popup, -} from 'react-pdf-highlighter'; - -import { useCatchDocumentError } from '@/components/pdf-previewer/hooks'; -import { Spin } from '@/components/ui/spin'; -import FileError from '@/pages/document-viewer/file-error'; -import styles from './index.less'; - -export interface IProps { - highlights: IHighlight[]; - setWidthAndHeight: (width: number, height: number) => void; - url: string; -} -const HighlightPopup = ({ - comment, -}: { - comment: { text: string; emoji: string }; -}) => - comment.text ? ( -
    - {comment.emoji} {comment.text} -
    - ) : null; - -// TODO: merge with DocumentPreviewer -const PdfPreview = ({ highlights: state, setWidthAndHeight, url }: IProps) => { - // const url = useGetDocumentUrl(); - - const ref = useRef<(highlight: IHighlight) => void>(() => {}); - const error = useCatchDocumentError(url); - - const resetHash = () => {}; - - useEffect(() => { - if (state.length > 0) { - ref?.current(state[0]); - } - }, [state]); - - return ( -
    - - -
    - } - workerSrc="/pdfjs-dist/pdf.worker.min.js" - errorMessage={{error}} - > - {(pdfDocument) => { - pdfDocument.getPage(1).then((page) => { - const viewport = page.getViewport({ scale: 1 }); - const width = viewport.width; - const height = viewport.height; - setWidthAndHeight(width, height); - }); - - return ( - event.altKey} - onScrollChange={resetHash} - scrollRef={(scrollTo) => { - ref.current = scrollTo; - }} - onSelectionFinished={() => null} - highlightTransform={( - highlight, - index, - setTip, - hideTip, - viewportToScaled, - screenshot, - isScrolledTo, - ) => { - const isTextHighlight = !Boolean( - highlight.content && highlight.content.image, - ); - - const component = isTextHighlight ? ( - - ) : ( - {}} - /> - ); - - return ( - } - onMouseOver={(popupContent) => - setTip(highlight, () => popupContent) - } - onMouseOut={hideTip} - key={index} - > - {component} - - ); - }} - highlights={state} - /> - ); - }} - -
    - ); -}; - -export default memo(PdfPreview); diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/hooks.ts b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/hooks.ts index cb546b226f6..790fced3938 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/hooks.ts +++ b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/hooks.ts @@ -1,10 +1,10 @@ +import { useSetModalState, useShowDeleteConfirm } from '@/hooks/common-hooks'; +import { useGetKnowledgeSearchParams } from '@/hooks/route-hook'; import { useCreateChunk, useDeleteChunk, useSelectChunkList, -} from '@/hooks/chunk-hooks'; -import { useSetModalState, useShowDeleteConfirm } from '@/hooks/common-hooks'; -import { useGetKnowledgeSearchParams } from '@/hooks/route-hook'; +} from '@/hooks/use-chunk-request'; import { IChunk } from '@/interfaces/database/knowledge'; import { buildChunkHighlights } from '@/utils/document-util'; import { useCallback, useMemo, useState } from 'react'; diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/index.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/index.tsx index 7aff9e993ea..a73960d145c 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/index.tsx +++ b/web/src/pages/chunk/parsed-result/add-knowledge/components/knowledge-chunk/index.tsx @@ -7,7 +7,6 @@ import { useCallback, useEffect, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import ChunkCard from './components/chunk-card'; import CreatingModal from './components/chunk-creating-modal'; -import DocumentPreview from './components/document-preview'; import { useChangeChunkTextMode, useDeleteChunkByIds, @@ -18,8 +17,11 @@ import { import ChunkResultBar from './components/chunk-result-bar'; import CheckboxSets from './components/chunk-result-bar/checkbox-sets'; -import DocumentHeader from './components/document-preview/document-header'; +// import DocumentHeader from './components/document-preview/document-header'; +import DocumentPreview from '@/components/document-preview'; +import DocumentHeader from '@/components/document-preview/document-header'; +import { useGetDocumentUrl } from '@/components/document-preview/hooks'; import { PageHeader } from '@/components/page-header'; import { Breadcrumb, @@ -40,7 +42,6 @@ import { useNavigatePage, } from '@/hooks/logic-hooks/navigate-hooks'; import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-request'; -import { useGetDocumentUrl } from './components/document-preview/hooks'; import styles from './index.less'; const Chunk = () => { @@ -54,6 +55,7 @@ const Chunk = () => { handleInputChange, available, handleSetAvailable, + dataUpdatedAt, } = useFetchNextChunkList(); const { handleChunkCardClick, selectedChunkId } = useHandleChunkCardClick(); const isPdf = documentInfo?.type === 'pdf'; @@ -74,7 +76,7 @@ const Chunk = () => { } = useUpdateChunk(); const { navigateToDataFile, getQueryString, navigateToDatasetList } = useNavigatePage(); - const fileUrl = useGetDocumentUrl(); + const fileUrl = useGetDocumentUrl(false); useEffect(() => { setChunkList(data); }, [data]); @@ -170,6 +172,7 @@ const Chunk = () => { case 'docx': case 'txt': case 'md': + case 'mdx': case 'pdf': return documentInfo?.type; } @@ -276,6 +279,7 @@ const Chunk = () => { clickChunkCard={handleChunkCardClick} selected={item.chunk_id === selectedChunkId} textMode={textMode} + t={dataUpdatedAt} > ))}
    diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/constant.ts b/web/src/pages/chunk/parsed-result/add-knowledge/constant.ts index 71df266c8e6..fb4eecc3f23 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/constant.ts +++ b/web/src/pages/chunk/parsed-result/add-knowledge/constant.ts @@ -1,21 +1,4 @@ -import { KnowledgeRouteKey } from '@/constants/knowledge'; - -export const routeMap = { - [KnowledgeRouteKey.Dataset]: 'Dataset', - [KnowledgeRouteKey.Testing]: 'Retrieval testing', - [KnowledgeRouteKey.Configuration]: 'Configuration', -}; - export enum KnowledgeDatasetRouteKey { Chunk = 'chunk', File = 'file', } - -export const datasetRouteMap = { - [KnowledgeDatasetRouteKey.Chunk]: 'Chunk', - [KnowledgeDatasetRouteKey.File]: 'File Upload', -}; - -export * from '@/constants/knowledge'; - -export const TagRenameId = 'tagRename'; diff --git a/web/src/pages/chunk/parsed-result/add-knowledge/index.tsx b/web/src/pages/chunk/parsed-result/add-knowledge/index.tsx index 18187cf4e3d..7434edc618b 100644 --- a/web/src/pages/chunk/parsed-result/add-knowledge/index.tsx +++ b/web/src/pages/chunk/parsed-result/add-knowledge/index.tsx @@ -1,9 +1,9 @@ -import { useKnowledgeBaseId } from '@/hooks/knowledge-hooks'; import { useNavigateWithFromState, useSecondPathName, useThirdPathName, } from '@/hooks/route-hook'; +import { useKnowledgeBaseId } from '@/hooks/use-knowledge-request'; import { Breadcrumb } from 'antd'; import { ItemType } from 'antd/es/breadcrumb/Breadcrumb'; import { useMemo } from 'react'; diff --git a/web/src/pages/data-flows/index.tsx b/web/src/pages/data-flows/index.tsx deleted file mode 100644 index 13ee9742ce2..00000000000 --- a/web/src/pages/data-flows/index.tsx +++ /dev/null @@ -1,3 +0,0 @@ -export default function DataFlows() { - return
    xx
    ; -} diff --git a/web/src/pages/dataflow-result/components/chunk-creating-modal/index.tsx b/web/src/pages/dataflow-result/components/chunk-creating-modal/index.tsx index 66cf4d619c5..7cad7eec1c9 100644 --- a/web/src/pages/dataflow-result/components/chunk-creating-modal/index.tsx +++ b/web/src/pages/dataflow-result/components/chunk-creating-modal/index.tsx @@ -17,7 +17,7 @@ import { Modal } from '@/components/ui/modal/modal'; import Space from '@/components/ui/space'; import { Switch } from '@/components/ui/switch'; import { Textarea } from '@/components/ui/textarea'; -import { useFetchChunk } from '@/hooks/chunk-hooks'; +import { useFetchChunk } from '@/hooks/use-chunk-request'; import { IModalProps } from '@/interfaces/common'; import { Trash2 } from 'lucide-react'; import React, { useCallback, useEffect, useState } from 'react'; diff --git a/web/src/pages/dataflow-result/components/chunk-creating-modal/tag-feature-item.tsx b/web/src/pages/dataflow-result/components/chunk-creating-modal/tag-feature-item.tsx index 3c9f92c78d8..77219ce7cd0 100644 --- a/web/src/pages/dataflow-result/components/chunk-creating-modal/tag-feature-item.tsx +++ b/web/src/pages/dataflow-result/components/chunk-creating-modal/tag-feature-item.tsx @@ -8,8 +8,10 @@ import { FormMessage, } from '@/components/ui/form'; import { NumberInput } from '@/components/ui/input'; -import { useFetchTagListByKnowledgeIds } from '@/hooks/knowledge-hooks'; -import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-request'; +import { + useFetchKnowledgeBaseConfiguration, + useFetchTagListByKnowledgeIds, +} from '@/hooks/use-knowledge-request'; import { CircleMinus, Plus } from 'lucide-react'; import { useCallback, useEffect, useMemo } from 'react'; import { useFieldArray, useFormContext } from 'react-hook-form'; diff --git a/web/src/pages/dataflow-result/components/chunk-toolbar/index.tsx b/web/src/pages/dataflow-result/components/chunk-toolbar/index.tsx index 6a513ba78d9..e1c7c6ae5cc 100644 --- a/web/src/pages/dataflow-result/components/chunk-toolbar/index.tsx +++ b/web/src/pages/dataflow-result/components/chunk-toolbar/index.tsx @@ -1,8 +1,11 @@ import { ReactComponent as FilterIcon } from '@/assets/filter.svg'; import { KnowledgeRouteKey } from '@/constants/knowledge'; -import { IChunkListResult, useSelectChunkList } from '@/hooks/chunk-hooks'; import { useTranslate } from '@/hooks/common-hooks'; -import { useKnowledgeBaseId } from '@/hooks/knowledge-hooks'; +import { + IChunkListResult, + useSelectChunkList, +} from '@/hooks/use-chunk-request'; +import { useKnowledgeBaseId } from '@/hooks/use-knowledge-request'; import { ArrowLeftOutlined, CheckCircleOutlined, diff --git a/web/src/pages/dataflow-result/components/document-preview/csv-preview.tsx b/web/src/pages/dataflow-result/components/document-preview/csv-preview.tsx deleted file mode 100644 index 45b05454e1b..00000000000 --- a/web/src/pages/dataflow-result/components/document-preview/csv-preview.tsx +++ /dev/null @@ -1,114 +0,0 @@ -import message from '@/components/ui/message'; -import { Spin } from '@/components/ui/spin'; -import request from '@/utils/request'; -import classNames from 'classnames'; -import React, { useEffect, useRef, useState } from 'react'; - -interface CSVData { - rows: string[][]; - headers: string[]; -} - -interface FileViewerProps { - className?: string; - url: string; -} - -const CSVFileViewer: React.FC = ({ url }) => { - const [data, setData] = useState(null); - const [isLoading, setIsLoading] = useState(true); - const containerRef = useRef(null); - // const url = useGetDocumentUrl(); - const parseCSV = (csvText: string): CSVData => { - console.log('Parsing CSV data:', csvText); - const lines = csvText.split('\n'); - const headers = lines[0].split(',').map((header) => header.trim()); - const rows = lines - .slice(1) - .map((line) => line.split(',').map((cell) => cell.trim())); - - return { headers, rows }; - }; - - useEffect(() => { - const loadCSV = async () => { - try { - const res = await request(url, { - method: 'GET', - responseType: 'blob', - onError: () => { - message.error('file load failed'); - setIsLoading(false); - }, - }); - - // parse CSV file - const reader = new FileReader(); - reader.readAsText(res.data); - reader.onload = () => { - const parsedData = parseCSV(reader.result as string); - console.log('file loaded successfully', reader.result); - setData(parsedData); - }; - } catch (error) { - message.error('CSV file parse failed'); - console.error('Error loading CSV file:', error); - } finally { - setIsLoading(false); - } - }; - - loadCSV(); - - return () => { - setData(null); - }; - }, [url]); - - return ( -
    - {isLoading ? ( -
    - -
    - ) : data ? ( -
    - - - {data.headers.map((header, index) => ( - - ))} - - - - {data.rows.map((row, rowIndex) => ( - - {row.map((cell, cellIndex) => ( - - ))} - - ))} - -
    - {header} -
    - {cell || '-'} -
    - ) : null} - - ); -}; - -export default CSVFileViewer; diff --git a/web/src/pages/dataflow-result/components/document-preview/doc-preview.tsx b/web/src/pages/dataflow-result/components/document-preview/doc-preview.tsx deleted file mode 100644 index 845f0374e38..00000000000 --- a/web/src/pages/dataflow-result/components/document-preview/doc-preview.tsx +++ /dev/null @@ -1,70 +0,0 @@ -import message from '@/components/ui/message'; -import { Spin } from '@/components/ui/spin'; -import request from '@/utils/request'; -import classNames from 'classnames'; -import mammoth from 'mammoth'; -import { useEffect, useState } from 'react'; - -interface DocPreviewerProps { - className?: string; - url: string; -} - -export const DocPreviewer: React.FC = ({ - className, - url, -}) => { - // const url = useGetDocumentUrl(); - const [htmlContent, setHtmlContent] = useState(''); - const [loading, setLoading] = useState(false); - const fetchDocument = async () => { - setLoading(true); - const res = await request(url, { - method: 'GET', - responseType: 'blob', - onError: () => { - message.error('Document parsing failed'); - console.error('Error loading document:', url); - }, - }); - try { - const arrayBuffer = await res.data.arrayBuffer(); - const result = await mammoth.convertToHtml( - { arrayBuffer }, - { includeDefaultStyleMap: true }, - ); - - const styledContent = result.value - .replace(/

    /g, '

    ') - .replace(//g, ''); - - setHtmlContent(styledContent); - } catch (err) { - message.error('Document parsing failed'); - console.error('Error parsing document:', err); - } - setLoading(false); - }; - - useEffect(() => { - if (url) { - fetchDocument(); - } - }, [url]); - return ( -

    - {loading && ( -
    - -
    - )} - - {!loading &&
    } -
    - ); -}; diff --git a/web/src/pages/dataflow-result/components/document-preview/excel-preview.tsx b/web/src/pages/dataflow-result/components/document-preview/excel-preview.tsx deleted file mode 100644 index c86e0462c08..00000000000 --- a/web/src/pages/dataflow-result/components/document-preview/excel-preview.tsx +++ /dev/null @@ -1,25 +0,0 @@ -import { useFetchExcel } from '@/pages/document-viewer/hooks'; -import classNames from 'classnames'; - -interface ExcelCsvPreviewerProps { - className?: string; - url: string; -} - -export const ExcelCsvPreviewer: React.FC = ({ - className, - url, -}) => { - // const url = useGetDocumentUrl(); - const { containerRef } = useFetchExcel(url); - - return ( -
    - ); -}; diff --git a/web/src/pages/dataflow-result/components/document-preview/hooks.ts b/web/src/pages/dataflow-result/components/document-preview/hooks.ts deleted file mode 100644 index 30e5887d848..00000000000 --- a/web/src/pages/dataflow-result/components/document-preview/hooks.ts +++ /dev/null @@ -1,60 +0,0 @@ -import { useGetKnowledgeSearchParams } from '@/hooks/route-hook'; -import api, { api_host } from '@/utils/api'; -import { useSize } from 'ahooks'; -import { CustomTextRenderer } from 'node_modules/react-pdf/dist/esm/shared/types'; -import { useCallback, useEffect, useMemo, useState } from 'react'; -import { useGetPipelineResultSearchParams } from '../../hooks'; - -export const useDocumentResizeObserver = () => { - const [containerWidth, setContainerWidth] = useState(); - const [containerRef, setContainerRef] = useState(null); - const size = useSize(containerRef); - - const onResize = useCallback((width?: number) => { - if (width) { - setContainerWidth(width); - } - }, []); - - useEffect(() => { - onResize(size?.width); - }, [size?.width, onResize]); - - return { containerWidth, setContainerRef }; -}; - -function highlightPattern(text: string, pattern: string, pageNumber: number) { - if (pageNumber === 2) { - return `${text}`; - } - if (text.trim() !== '' && pattern.match(text)) { - // return pattern.replace(text, (value) => `${value}`); - return `${text}`; - } - return text.replace(pattern, (value) => `${value}`); -} - -export const useHighlightText = (searchText: string = '') => { - const textRenderer: CustomTextRenderer = useCallback( - (textItem) => { - return highlightPattern(textItem.str, searchText, textItem.pageNumber); - }, - [searchText], - ); - - return textRenderer; -}; - -export const useGetDocumentUrl = (isAgent: boolean) => { - const { documentId } = useGetKnowledgeSearchParams(); - const { createdBy, documentId: id } = useGetPipelineResultSearchParams(); - - const url = useMemo(() => { - if (isAgent) { - return api.downloadFile + `?id=${id}&created_by=${createdBy}`; - } - return `${api_host}/document/get/${documentId}`; - }, [createdBy, documentId, id, isAgent]); - - return url; -}; diff --git a/web/src/pages/dataflow-result/components/document-preview/index.less b/web/src/pages/dataflow-result/components/document-preview/index.less deleted file mode 100644 index 8f456af5a9b..00000000000 --- a/web/src/pages/dataflow-result/components/document-preview/index.less +++ /dev/null @@ -1,13 +0,0 @@ -.documentContainer { - width: 100%; - // height: calc(100vh - 284px); - height: calc(100vh - 180px); - position: relative; - :global(.PdfHighlighter) { - overflow-x: hidden; - } - :global(.Highlight--scrolledTo .Highlight__part) { - overflow-x: hidden; - background-color: rgba(255, 226, 143, 1); - } -} diff --git a/web/src/pages/dataflow-result/components/document-preview/index.tsx b/web/src/pages/dataflow-result/components/document-preview/index.tsx deleted file mode 100644 index 0a5cf08e81b..00000000000 --- a/web/src/pages/dataflow-result/components/document-preview/index.tsx +++ /dev/null @@ -1,67 +0,0 @@ -import { memo } from 'react'; - -import CSVFileViewer from './csv-preview'; -import { DocPreviewer } from './doc-preview'; -import { ExcelCsvPreviewer } from './excel-preview'; -import { ImagePreviewer } from './image-preview'; -import PdfPreviewer, { IProps } from './pdf-preview'; -import { PptPreviewer } from './ppt-preview'; -import { TxtPreviewer } from './txt-preview'; - -type PreviewProps = { - fileType: string; - className?: string; - url: string; -}; -const Preview = ({ - fileType, - className, - highlights, - setWidthAndHeight, - url, -}: PreviewProps & Partial) => { - return ( - <> - {fileType === 'pdf' && highlights && setWidthAndHeight && ( -
    - -
    - )} - {['doc', 'docx'].indexOf(fileType) > -1 && ( -
    - -
    - )} - {['txt', 'md'].indexOf(fileType) > -1 && ( -
    - -
    - )} - {['visual'].indexOf(fileType) > -1 && ( -
    - -
    - )} - {['pptx'].indexOf(fileType) > -1 && ( -
    - -
    - )} - {['xlsx'].indexOf(fileType) > -1 && ( -
    - -
    - )} - {['csv'].indexOf(fileType) > -1 && ( -
    - -
    - )} - - ); -}; -export default memo(Preview); diff --git a/web/src/pages/dataflow-result/components/document-preview/ppt-preview.tsx b/web/src/pages/dataflow-result/components/document-preview/ppt-preview.tsx deleted file mode 100644 index 7786c48c3b2..00000000000 --- a/web/src/pages/dataflow-result/components/document-preview/ppt-preview.tsx +++ /dev/null @@ -1,70 +0,0 @@ -import message from '@/components/ui/message'; -import request from '@/utils/request'; -import classNames from 'classnames'; -import { init } from 'pptx-preview'; -import { useEffect, useRef } from 'react'; -interface PptPreviewerProps { - className?: string; - url: string; -} - -export const PptPreviewer: React.FC = ({ - className, - url, -}) => { - // const url = useGetDocumentUrl(); - const wrapper = useRef(null); - const containerRef = useRef(null); - const fetchDocument = async () => { - const res = await request(url, { - method: 'GET', - responseType: 'blob', - onError: () => { - message.error('Document parsing failed'); - console.error('Error loading document:', url); - }, - }); - console.log(res); - try { - const arrayBuffer = await res.data.arrayBuffer(); - - if (containerRef.current) { - let width = 500; - let height = 900; - if (containerRef.current) { - width = containerRef.current.clientWidth - 50; - height = containerRef.current.clientHeight - 50; - } - let pptxPrviewer = init(containerRef.current, { - width: width, - height: height, - }); - pptxPrviewer.preview(arrayBuffer); - } - } catch (err) { - message.error('ppt parse failed'); - } - }; - - useEffect(() => { - if (url) { - fetchDocument(); - } - }, [url]); - - return ( -
    -
    -
    -
    -
    -
    -
    - ); -}; diff --git a/web/src/pages/dataflow-result/components/document-preview/txt-preview.tsx b/web/src/pages/dataflow-result/components/document-preview/txt-preview.tsx deleted file mode 100644 index cf6649e3432..00000000000 --- a/web/src/pages/dataflow-result/components/document-preview/txt-preview.tsx +++ /dev/null @@ -1,56 +0,0 @@ -import message from '@/components/ui/message'; -import { Spin } from '@/components/ui/spin'; -import request from '@/utils/request'; -import classNames from 'classnames'; -import { useEffect, useState } from 'react'; - -type TxtPreviewerProps = { className?: string; url: string }; -export const TxtPreviewer = ({ className, url }: TxtPreviewerProps) => { - // const url = useGetDocumentUrl(); - const [loading, setLoading] = useState(false); - const [data, setData] = useState(''); - const fetchTxt = async () => { - setLoading(true); - const res = await request(url, { - method: 'GET', - responseType: 'blob', - onError: (err: any) => { - message.error('Failed to load file'); - console.error('Error loading file:', err); - }, - }); - // blob to string - const reader = new FileReader(); - reader.readAsText(res.data); - reader.onload = () => { - setData(reader.result as string); - setLoading(false); - console.log('file loaded successfully', reader.result); - }; - console.log('file data:', res); - }; - useEffect(() => { - if (url) { - fetchTxt(); - } else { - setLoading(false); - setData(''); - } - }, [url]); - return ( -
    - {loading && ( -
    - -
    - )} - - {!loading &&
    {data}
    } -
    - ); -}; diff --git a/web/src/pages/dataflow-result/hooks.ts b/web/src/pages/dataflow-result/hooks.ts index 4c4ad590a4f..07d4da1491f 100644 --- a/web/src/pages/dataflow-result/hooks.ts +++ b/web/src/pages/dataflow-result/hooks.ts @@ -1,9 +1,9 @@ import { TimelineNode } from '@/components/originui/timeline'; import message from '@/components/ui/message'; -import { useCreateChunk, useDeleteChunk } from '@/hooks/chunk-hooks'; import { useSetModalState, useShowDeleteConfirm } from '@/hooks/common-hooks'; import { useGetKnowledgeSearchParams } from '@/hooks/route-hook'; import { useFetchMessageTrace } from '@/hooks/use-agent-request'; +import { useCreateChunk, useDeleteChunk } from '@/hooks/use-chunk-request'; import kbService from '@/services/knowledge-service'; import { formatSecondsToHumanReadable } from '@/utils/date'; import { buildChunkHighlights } from '@/utils/document-util'; diff --git a/web/src/pages/dataflow-result/index.tsx b/web/src/pages/dataflow-result/index.tsx index c15f71c6605..8a2780bd87e 100644 --- a/web/src/pages/dataflow-result/index.tsx +++ b/web/src/pages/dataflow-result/index.tsx @@ -1,7 +1,7 @@ +import DocumentPreview from '@/components/document-preview'; import { useFetchNextChunkList } from '@/hooks/use-chunk-request'; import { useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import DocumentPreview from './components/document-preview'; import { useFetchPipelineFileLogDetail, useFetchPipelineResult, @@ -13,8 +13,9 @@ import { useTimelineDataFlow, } from './hooks'; -import DocumentHeader from './components/document-preview/document-header'; +import DocumentHeader from '@/components/document-preview/document-header'; +import { useGetDocumentUrl } from '@/components/document-preview/hooks'; import { TimelineNode } from '@/components/originui/timeline'; import { PageHeader } from '@/components/page-header'; import Spotlight from '@/components/spotlight'; @@ -32,7 +33,6 @@ import { AgentCategory } from '@/constants/agent'; import { Images } from '@/constants/common'; import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks'; import { useGetKnowledgeSearchParams } from '@/hooks/route-hook'; -import { useGetDocumentUrl } from './components/document-preview/hooks'; import TimelineDataFlow from './components/time-line'; import { TimelineNodeType } from './constant'; import styles from './index.less'; @@ -76,16 +76,18 @@ const Chunk = () => { const fileType = useMemo(() => { if (isAgent) { return Images.some((x) => x === documentExtension) - ? 'visual' + ? documentInfo?.name.split('.').pop() || 'visual' : documentExtension; } switch (documentInfo?.type) { case 'doc': return documentInfo?.name.split('.').pop() || 'doc'; case 'visual': + return documentInfo?.name.split('.').pop() || 'visual'; case 'docx': case 'txt': case 'md': + case 'mdx': case 'pdf': return documentInfo?.type; } diff --git a/web/src/pages/dataset/components/metedata/hooks/use-manage-modal.ts b/web/src/pages/dataset/components/metedata/hooks/use-manage-modal.ts new file mode 100644 index 00000000000..2858ebb7658 --- /dev/null +++ b/web/src/pages/dataset/components/metedata/hooks/use-manage-modal.ts @@ -0,0 +1,500 @@ +import message from '@/components/ui/message'; +import { useSetModalState } from '@/hooks/common-hooks'; +import { + DocumentApiAction, + useSetDocumentMeta, +} from '@/hooks/use-document-request'; +import kbService, { + getMetaDataService, + updateMetaData, +} from '@/services/knowledge-service'; +import { useQuery, useQueryClient } from '@tanstack/react-query'; +import { TFunction } from 'i18next'; +import { useCallback, useEffect, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useParams } from 'umi'; +import { + IMetaDataReturnJSONSettings, + IMetaDataReturnJSONType, + IMetaDataReturnType, + IMetaDataTableData, + MetadataOperations, + ShowManageMetadataModalProps, +} from '../interface'; +export enum MetadataType { + Manage = 1, + UpdateSingle = 2, + Setting = 3, + SingleFileSetting = 4, +} + +export const MetadataDeleteMap = ( + t: TFunction<'translation', undefined>, +): Record< + MetadataType, + { + title: string; + warnFieldText: string; + warnValueText: string; + warnFieldName: string; + warnValueName: string; + } +> => { + return { + [MetadataType.Manage]: { + title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'), + warnFieldText: t('knowledgeDetails.metadata.deleteManageFieldAllWarn'), + warnValueText: t('knowledgeDetails.metadata.deleteManageValueAllWarn'), + warnFieldName: t('knowledgeDetails.metadata.fieldNameExists'), + warnValueName: t('knowledgeDetails.metadata.valueExists'), + }, + [MetadataType.Setting]: { + title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'), + warnFieldText: t('knowledgeDetails.metadata.deleteSettingFieldWarn'), + warnValueText: t('knowledgeDetails.metadata.deleteSettingValueWarn'), + warnFieldName: t('knowledgeDetails.metadata.fieldExists'), + warnValueName: t('knowledgeDetails.metadata.valueExists'), + }, + [MetadataType.UpdateSingle]: { + title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'), + warnFieldText: t('knowledgeDetails.metadata.deleteManageFieldSingleWarn'), + warnValueText: t('knowledgeDetails.metadata.deleteManageValueSingleWarn'), + warnFieldName: t('knowledgeDetails.metadata.fieldSingleNameExists'), + warnValueName: t('knowledgeDetails.metadata.valueSingleExists'), + }, + [MetadataType.SingleFileSetting]: { + title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'), + warnFieldText: t('knowledgeDetails.metadata.deleteSettingFieldWarn'), + warnValueText: t('knowledgeDetails.metadata.deleteSettingValueWarn'), + warnFieldName: t('knowledgeDetails.metadata.fieldExists'), + warnValueName: t('knowledgeDetails.metadata.valueSingleExists'), + }, + }; +}; +export const util = { + changeToMetaDataTableData(data: IMetaDataReturnType): IMetaDataTableData[] { + return Object.entries(data).map(([key, value]) => { + const values = value.map(([v]) => v.toString()); + console.log('values', values); + return { + field: key, + description: '', + values: values, + } as IMetaDataTableData; + }); + }, + + JSONToMetaDataTableData( + data: Record, + ): IMetaDataTableData[] { + return Object.entries(data).map(([key, value]) => { + let thisValue = [] as string[]; + if (value && Array.isArray(value)) { + thisValue = value; + } else if (value && typeof value === 'string') { + thisValue = [value]; + } else if (value && typeof value === 'object') { + thisValue = [JSON.stringify(value)]; + } else if (value) { + thisValue = [value.toString()]; + } + + return { + field: key, + description: '', + values: thisValue, + } as IMetaDataTableData; + }); + }, + + tableDataToMetaDataJSON(data: IMetaDataTableData[]): IMetaDataReturnJSONType { + return data.reduce((pre, cur) => { + pre[cur.field] = cur.values; + return pre; + }, {}); + }, + + tableDataToMetaDataSettingJSON( + data: IMetaDataTableData[], + ): IMetaDataReturnJSONSettings { + return data.map((item) => { + return { + key: item.field, + description: item.description, + enum: item.values, + }; + }); + }, + + metaDataSettingJSONToMetaDataTableData( + data: IMetaDataReturnJSONSettings, + ): IMetaDataTableData[] { + if (!Array.isArray(data)) return []; + return data.map((item) => { + return { + field: item.key, + description: item.description, + values: item.enum, + restrictDefinedValues: !!item.enum?.length, + } as IMetaDataTableData; + }); + }, +}; + +export const useMetadataOperations = () => { + const [operations, setOperations] = useState({ + deletes: [], + updates: [], + }); + + const addDeleteRow = useCallback((key: string) => { + setOperations((prev) => ({ + ...prev, + deletes: [...prev.deletes, { key }], + })); + }, []); + + const addDeleteValue = useCallback((key: string, value: string) => { + setOperations((prev) => ({ + ...prev, + deletes: [...prev.deletes, { key, value }], + })); + }, []); + + // const addUpdateValue = useCallback( + // (key: string, value: string | string[]) => { + // setOperations((prev) => ({ + // ...prev, + // updates: [...prev.updates, { key, value }], + // })); + // }, + // [], + // ); + const addUpdateValue = useCallback( + (key: string, originalValue: string, newValue: string) => { + setOperations((prev) => { + const existsIndex = prev.updates.findIndex( + (update) => update.key === key && update.match === originalValue, + ); + + if (existsIndex > -1) { + const updatedUpdates = [...prev.updates]; + updatedUpdates[existsIndex] = { + key, + match: originalValue, + value: newValue, + }; + return { + ...prev, + updates: updatedUpdates, + }; + } + return { + ...prev, + updates: [ + ...prev.updates, + { key, match: originalValue, value: newValue }, + ], + }; + }); + }, + [], + ); + + const resetOperations = useCallback(() => { + setOperations({ + deletes: [], + updates: [], + }); + }, []); + + return { + operations, + addDeleteRow, + addDeleteValue, + addUpdateValue, + resetOperations, + }; +}; + +export const useFetchMetaDataManageData = ( + type: MetadataType = MetadataType.Manage, +) => { + const { id } = useParams(); + // const [data, setData] = useState([]); + // const [loading, setLoading] = useState(false); + // const fetchData = useCallback(async (): Promise => { + // setLoading(true); + // const { data } = await getMetaDataService({ + // kb_id: id as string, + // }); + // setLoading(false); + // if (data?.data?.summary) { + // return util.changeToMetaDataTableData(data.data.summary); + // } + // return []; + // }, [id]); + // useEffect(() => { + // if (type === MetadataType.Manage) { + // fetchData() + // .then((res) => { + // setData(res); + // }) + // .catch((res) => { + // console.error(res); + // }); + // } + // }, [type, fetchData]); + + const { + data, + isFetching: loading, + refetch, + } = useQuery({ + queryKey: ['fetchMetaData', id], + enabled: !!id && type === MetadataType.Manage, + initialData: [], + gcTime: 1000, + queryFn: async () => { + const { data } = await getMetaDataService({ + kb_id: id as string, + }); + if (data?.data?.summary) { + return util.changeToMetaDataTableData(data.data.summary); + } + return []; + }, + }); + return { + data, + loading, + refetch, + }; +}; + +export const useManageMetaDataModal = ( + metaData: IMetaDataTableData[] = [], + type: MetadataType = MetadataType.Manage, + otherData?: Record, +) => { + const { id } = useParams(); + const { t } = useTranslation(); + const { data, loading } = useFetchMetaDataManageData(type); + + const [tableData, setTableData] = useState(metaData); + const queryClient = useQueryClient(); + const { + operations, + addDeleteRow, + addDeleteValue, + addUpdateValue, + resetOperations, + } = useMetadataOperations(); + + const { setDocumentMeta } = useSetDocumentMeta(); + + useEffect(() => { + if (type === MetadataType.Manage) { + if (data) { + setTableData(data); + } else { + setTableData([]); + } + } + }, [data, type]); + + useEffect(() => { + if (type !== MetadataType.Manage) { + if (metaData) { + setTableData(metaData); + } else { + setTableData([]); + } + } + }, [metaData, type]); + + const handleDeleteSingleValue = useCallback( + (field: string, value: string) => { + addDeleteValue(field, value); + + setTableData((prevTableData) => { + const newTableData = prevTableData.map((item) => { + if (item.field === field) { + return { + ...item, + values: item.values.filter((v) => v !== value), + }; + } + return item; + }); + // console.log('newTableData', newTableData, prevTableData); + return newTableData; + }); + }, + [addDeleteValue], + ); + + const handleDeleteSingleRow = useCallback( + (field: string) => { + addDeleteRow(field); + setTableData((prevTableData) => { + const newTableData = prevTableData.filter( + (item) => item.field !== field, + ); + // console.log('newTableData', newTableData, prevTableData); + return newTableData; + }); + }, + [addDeleteRow], + ); + + const handleSaveManage = useCallback( + async (callback: () => void) => { + const { data: res } = await updateMetaData({ + kb_id: id as string, + data: operations, + }); + if (res.code === 0) { + queryClient.invalidateQueries({ + queryKey: [DocumentApiAction.FetchDocumentList], + }); + resetOperations(); + message.success(t('message.operated')); + callback(); + } + }, + [operations, id, t, queryClient, resetOperations], + ); + + const handleSaveUpdateSingle = useCallback( + async (callback: () => void) => { + const reqData = util.tableDataToMetaDataJSON(tableData); + if (otherData?.id) { + const ret = await setDocumentMeta({ + documentId: otherData?.id, + meta: JSON.stringify(reqData), + }); + if (ret === 0) { + // message.success(t('message.success')); + callback(); + } + } + }, + [tableData, otherData, setDocumentMeta], + ); + + const handleSaveSettings = useCallback( + async (callback: () => void) => { + const data = util.tableDataToMetaDataSettingJSON(tableData); + const { data: res } = await kbService.kbUpdateMetaData({ + kb_id: id, + metadata: data, + }); + if (res.code === 0) { + message.success(t('message.operated')); + callback?.(); + } + + return data; + }, + [tableData, id, t], + ); + + const handleSaveSingleFileSettings = useCallback( + async (callback: () => void) => { + const data = util.tableDataToMetaDataSettingJSON(tableData); + if (otherData?.documentId) { + const { data: res } = await kbService.documentUpdateMetaData({ + doc_id: otherData.documentId, + metadata: data, + }); + if (res.code === 0) { + message.success(t('message.operated')); + callback?.(); + } + } + + return data; + }, + [tableData, t, otherData], + ); + + const handleSave = useCallback( + async ({ callback }: { callback: () => void }) => { + switch (type) { + case MetadataType.UpdateSingle: + handleSaveUpdateSingle(callback); + break; + case MetadataType.Manage: + handleSaveManage(callback); + break; + case MetadataType.Setting: + return handleSaveSettings(callback); + case MetadataType.SingleFileSetting: + return handleSaveSingleFileSettings(callback); + default: + handleSaveManage(callback); + break; + } + }, + [ + handleSaveManage, + type, + handleSaveUpdateSingle, + handleSaveSettings, + handleSaveSingleFileSettings, + ], + ); + + return { + tableData, + setTableData, + handleDeleteSingleValue, + handleDeleteSingleRow, + loading, + handleSave, + addUpdateValue, + addDeleteValue, + }; +}; + +export const useManageMetadata = () => { + const [tableData, setTableData] = useState([]); + const [config, setConfig] = useState( + {} as ShowManageMetadataModalProps, + ); + const { + visible: manageMetadataVisible, + showModal, + hideModal: hideManageMetadataModal, + } = useSetModalState(); + const showManageMetadataModal = useCallback( + (config?: ShowManageMetadataModalProps) => { + const { metadata } = config || {}; + if (metadata) { + // const dataTemp = Object.entries(metadata).map(([key, value]) => { + // return { + // field: key, + // description: '', + // values: Array.isArray(value) ? value : [value], + // } as IMetaDataTableData; + // }); + setTableData(metadata); + console.log('metadata-2', metadata); + } + console.log('metadata-3', metadata); + if (config) { + setConfig(config); + } + showModal(); + }, + [showModal], + ); + return { + manageMetadataVisible, + showManageMetadataModal, + hideManageMetadataModal, + tableData, + config, + }; +}; diff --git a/web/src/pages/dataset/components/metedata/hooks/use-manage-values-modal.ts b/web/src/pages/dataset/components/metedata/hooks/use-manage-values-modal.ts new file mode 100644 index 00000000000..38608109df8 --- /dev/null +++ b/web/src/pages/dataset/components/metedata/hooks/use-manage-values-modal.ts @@ -0,0 +1,210 @@ +import { useCallback, useEffect, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { MetadataDeleteMap, MetadataType } from '../hooks/use-manage-modal'; +import { IManageValuesProps, IMetaDataTableData } from '../interface'; + +export const useManageValues = (props: IManageValuesProps) => { + const { + data, + + isShowValueSwitch, + hideModal, + onSave, + addUpdateValue, + addDeleteValue, + existsKeys, + type, + } = props; + const { t } = useTranslation(); + const [metaData, setMetaData] = useState(data); + const [valueError, setValueError] = useState>({ + field: '', + values: '', + }); + const [deleteDialogContent, setDeleteDialogContent] = useState({ + visible: false, + title: '', + name: '', + warnText: '', + onOk: () => {}, + onCancel: () => {}, + }); + const hideDeleteModal = () => { + setDeleteDialogContent({ + visible: false, + title: '', + name: '', + warnText: '', + onOk: () => {}, + onCancel: () => {}, + }); + }; + + // Use functional update to avoid closure issues + const handleChange = useCallback( + (field: string, value: any) => { + if (field === 'field' && existsKeys.includes(value)) { + setValueError((prev) => { + return { + ...prev, + field: MetadataDeleteMap(t)[type as MetadataType].warnFieldName, + // type === MetadataType.Setting + // ? t('knowledgeDetails.metadata.fieldExists') + // : t('knowledgeDetails.metadata.fieldNameExists'), + }; + }); + } else if (field === 'field' && !existsKeys.includes(value)) { + setValueError((prev) => { + return { + ...prev, + field: '', + }; + }); + } + setMetaData((prev) => ({ + ...prev, + [field]: value, + })); + }, + [existsKeys, type, t], + ); + + // Maintain separate state for each input box + const [tempValues, setTempValues] = useState([...data.values]); + + useEffect(() => { + setTempValues([...data.values]); + setMetaData(data); + }, [data]); + + const handleHideModal = useCallback(() => { + hideModal(); + setMetaData({} as IMetaDataTableData); + }, [hideModal]); + + const handleSave = useCallback(() => { + if (type === MetadataType.Setting && valueError.field) { + return; + } + if (!metaData.restrictDefinedValues && isShowValueSwitch) { + const newMetaData = { ...metaData, values: [] }; + onSave(newMetaData); + } else { + onSave(metaData); + } + handleHideModal(); + }, [metaData, onSave, handleHideModal, isShowValueSwitch, type, valueError]); + + // Handle value changes, only update temporary state + const handleValueChange = useCallback( + (index: number, value: string) => { + setTempValues((prev) => { + if (prev.includes(value)) { + setValueError((prev) => { + return { + ...prev, + values: MetadataDeleteMap(t)[type as MetadataType].warnValueName, + // t('knowledgeDetails.metadata.valueExists'), + }; + }); + } else { + setValueError((prev) => { + return { + ...prev, + values: '', + }; + }); + } + const newValues = [...prev]; + newValues[index] = value; + + return newValues; + }); + }, + [t, type], + ); + + // Handle blur event, synchronize to main state + const handleValueBlur = useCallback(() => { + if (data.values.length > 0) { + tempValues.forEach((newValue, index) => { + if (index < data.values.length) { + const originalValue = data.values[index]; + if (originalValue !== newValue) { + addUpdateValue(metaData.field, originalValue, newValue); + } + } else { + if (newValue) { + addUpdateValue(metaData.field, '', newValue); + } + } + }); + } + handleChange('values', [...new Set([...tempValues])]); + }, [handleChange, tempValues, metaData, data, addUpdateValue]); + + // Handle delete operation + const handleDelete = useCallback( + (index: number) => { + setTempValues((prev) => { + const newTempValues = [...prev]; + addDeleteValue(metaData.field, newTempValues[index]); + newTempValues.splice(index, 1); + return newTempValues; + }); + + // Synchronize to main state + setMetaData((prev) => { + const newMetaDataValues = [...prev.values]; + newMetaDataValues.splice(index, 1); + return { + ...prev, + values: newMetaDataValues, + }; + }); + }, + [addDeleteValue, metaData], + ); + + const showDeleteModal = (item: string, callback: () => void) => { + setDeleteDialogContent({ + visible: true, + title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.value'), + name: item, + warnText: MetadataDeleteMap(t)[type as MetadataType].warnValueText, + onOk: () => { + hideDeleteModal(); + callback(); + }, + onCancel: () => { + hideDeleteModal(); + }, + }); + }; + + // Handle adding new value + const handleAddValue = useCallback(() => { + setTempValues((prev) => [...new Set([...prev, ''])]); + + // Synchronize to main state + setMetaData((prev) => ({ + ...prev, + values: [...new Set([...prev.values, ''])], + })); + }, []); + + return { + metaData, + tempValues, + valueError, + deleteDialogContent, + handleChange, + handleValueChange, + handleValueBlur, + handleDelete, + handleAddValue, + showDeleteModal, + handleSave, + handleHideModal, + }; +}; diff --git a/web/src/pages/dataset/components/metedata/interface.ts b/web/src/pages/dataset/components/metedata/interface.ts new file mode 100644 index 00000000000..ef299036657 --- /dev/null +++ b/web/src/pages/dataset/components/metedata/interface.ts @@ -0,0 +1,87 @@ +import { ReactNode } from 'react'; +import { MetadataType } from './hook'; +export type IMetaDataReturnType = Record>>; +export type IMetaDataReturnJSONType = Record< + string, + Array | string +>; + +export interface IMetaDataReturnJSONSettingItem { + key: string; + description?: string; + enum?: string[]; +} +export type IMetaDataReturnJSONSettings = Array; + +export type IMetaDataTableData = { + field: string; + description: string; + restrictDefinedValues?: boolean; + values: string[]; +}; + +export type IManageModalProps = { + title: ReactNode; + isShowDescription?: boolean; + isDeleteSingleValue?: boolean; + visible: boolean; + hideModal: () => void; + tableData?: IMetaDataTableData[]; + isCanAdd: boolean; + type: MetadataType; + otherData?: Record; + isEditField?: boolean; + isAddValue?: boolean; + isShowValueSwitch?: boolean; + isVerticalShowValue?: boolean; + success?: (data: any) => void; +}; + +export interface IManageValuesProps { + title: ReactNode; + existsKeys: string[]; + visible: boolean; + isEditField?: boolean; + isAddValue?: boolean; + isShowDescription?: boolean; + isShowValueSwitch?: boolean; + isVerticalShowValue?: boolean; + data: IMetaDataTableData; + type: MetadataType; + hideModal: () => void; + onSave: (data: IMetaDataTableData) => void; + addUpdateValue: ( + key: string, + originalValue: string, + newValue: string, + ) => void; + addDeleteValue: (key: string, value: string) => void; +} + +interface DeleteOperation { + key: string; + value?: string; +} + +interface UpdateOperation { + key: string; + match: string; + value: string; +} + +export interface MetadataOperations { + deletes: DeleteOperation[]; + updates: UpdateOperation[]; +} +export interface ShowManageMetadataModalOptions { + title?: ReactNode | string; +} +export type ShowManageMetadataModalProps = Partial & { + metadata?: IMetaDataTableData[]; + isCanAdd: boolean; + type: MetadataType; + record?: Record; + options?: ShowManageMetadataModalOptions; + title?: ReactNode | string; + isDeleteSingleValue?: boolean; +}; diff --git a/web/src/pages/dataset/components/metedata/manage-modal.tsx b/web/src/pages/dataset/components/metedata/manage-modal.tsx new file mode 100644 index 00000000000..9b21db765e0 --- /dev/null +++ b/web/src/pages/dataset/components/metedata/manage-modal.tsx @@ -0,0 +1,468 @@ +import { + ConfirmDeleteDialog, + ConfirmDeleteDialogNode, +} from '@/components/confirm-delete-dialog'; +import { EmptyType } from '@/components/empty/constant'; +import Empty from '@/components/empty/empty'; +import { Button } from '@/components/ui/button'; +import { Modal } from '@/components/ui/modal/modal'; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from '@/components/ui/table'; +import { useSetModalState } from '@/hooks/common-hooks'; +import { Routes } from '@/routes'; +import { + ColumnDef, + flexRender, + getCoreRowModel, + getFilteredRowModel, + getPaginationRowModel, + getSortedRowModel, + useReactTable, +} from '@tanstack/react-table'; +import { Plus, Settings, Trash2 } from 'lucide-react'; +import { useCallback, useEffect, useMemo, useState } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useHandleMenuClick } from '../../sidebar/hooks'; +import { + MetadataDeleteMap, + MetadataType, + useManageMetaDataModal, +} from './hooks/use-manage-modal'; +import { IManageModalProps, IMetaDataTableData } from './interface'; +import { ManageValuesModal } from './manage-values-modal'; + +export const ManageMetadataModal = (props: IManageModalProps) => { + const { + title, + visible, + hideModal, + isDeleteSingleValue, + tableData: originalTableData, + isCanAdd, + type: metadataType, + otherData, + isEditField, + isAddValue, + isShowDescription = false, + isShowValueSwitch = false, + isVerticalShowValue = true, + success, + } = props; + const { t } = useTranslation(); + const [valueData, setValueData] = useState({ + field: '', + description: '', + values: [], + }); + + const [currentValueIndex, setCurrentValueIndex] = useState(0); + const [deleteDialogContent, setDeleteDialogContent] = useState({ + visible: false, + title: '', + name: '', + warnText: '', + onOk: () => {}, + onCancel: () => {}, + }); + + const { + tableData, + setTableData, + handleDeleteSingleValue, + handleDeleteSingleRow, + handleSave, + addUpdateValue, + addDeleteValue, + } = useManageMetaDataModal(originalTableData, metadataType, otherData); + const { handleMenuClick } = useHandleMenuClick(); + const { + visible: manageValuesVisible, + showModal: showManageValuesModal, + hideModal: hideManageValuesModal, + } = useSetModalState(); + const hideDeleteModal = () => { + setDeleteDialogContent({ + visible: false, + title: '', + name: '', + warnText: '', + onOk: () => {}, + onCancel: () => {}, + }); + }; + const handAddValueRow = () => { + setValueData({ + field: '', + description: '', + values: [], + }); + setCurrentValueIndex(tableData.length || 0); + showManageValuesModal(); + }; + const handleEditValueRow = useCallback( + (data: IMetaDataTableData, index: number) => { + setCurrentValueIndex(index); + setValueData(data); + showManageValuesModal(); + }, + [showManageValuesModal], + ); + + const columns: ColumnDef[] = useMemo(() => { + const cols: ColumnDef[] = [ + { + accessorKey: 'field', + header: () => {t('knowledgeDetails.metadata.field')}, + cell: ({ row }) => ( +
    + {row.getValue('field')} +
    + ), + }, + { + accessorKey: 'description', + header: () => {t('knowledgeDetails.metadata.description')}, + cell: ({ row }) => ( +
    + {row.getValue('description')} +
    + ), + }, + { + accessorKey: 'values', + header: () => {t('knowledgeDetails.metadata.values')}, + cell: ({ row }) => { + const values = row.getValue('values') as Array; + return ( +
    + {Array.isArray(values) && + values.length > 0 && + values + .filter((value: string, index: number) => index < 2) + ?.map((value: string) => { + return ( + + )} +
    + + ); + })} + {Array.isArray(values) && values.length > 2 && ( +
    ...
    + )} +
    + ); + }, + }, + { + accessorKey: 'action', + header: () => {t('knowledgeDetails.metadata.action')}, + meta: { + cellClassName: 'w-12', + }, + cell: ({ row }) => ( +
    + + +
    + ), + }, + ]; + if (!isShowDescription) { + cols.splice(1, 1); + } + return cols; + }, [ + handleDeleteSingleRow, + t, + handleDeleteSingleValue, + isShowDescription, + isDeleteSingleValue, + handleEditValueRow, + metadataType, + ]); + + const table = useReactTable({ + data: tableData as IMetaDataTableData[], + columns, + getCoreRowModel: getCoreRowModel(), + getPaginationRowModel: getPaginationRowModel(), + getSortedRowModel: getSortedRowModel(), + getFilteredRowModel: getFilteredRowModel(), + manualPagination: true, + }); + const [shouldSave, setShouldSave] = useState(false); + const handleSaveValues = (data: IMetaDataTableData) => { + setTableData((prev) => { + let newData; + if (currentValueIndex >= prev.length) { + // Add operation + newData = [...prev, data]; + } else { + // Edit operation + newData = prev.map((item, index) => { + if (index === currentValueIndex) { + return data; + } + return item; + }); + } + + // Deduplicate by field and merge values + const fieldMap = new Map(); + newData.forEach((item) => { + if (fieldMap.has(item.field)) { + // Merge values if field exists + const existingItem = fieldMap.get(item.field)!; + const mergedValues = [ + ...new Set([...existingItem.values, ...item.values]), + ]; + fieldMap.set(item.field, { ...existingItem, values: mergedValues }); + } else { + fieldMap.set(item.field, item); + } + }); + + return Array.from(fieldMap.values()); + }); + setShouldSave(true); + }; + + useEffect(() => { + if (shouldSave) { + const timer = setTimeout(() => { + handleSave({ callback: () => {} }); + setShouldSave(false); + }, 0); + + return () => clearTimeout(timer); + } + }, [tableData, shouldSave, handleSave]); + + const existsKeys = useMemo(() => { + return tableData.map((item) => item.field); + }, [tableData]); + + return ( + <> + { + const res = await handleSave({ callback: hideModal }); + console.log('data', res); + success?.(res); + }} + > + <> +
    +
    +
    {t('knowledgeDetails.metadata.metadata')}
    + {metadataType === MetadataType.Manage && false && ( + + )} + {isCanAdd && ( + + )} +
    + + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + + {header.isPlaceholder + ? null + : flexRender( + header.column.columnDef.header, + header.getContext(), + )} + + ))} + + ))} + + + {table.getRowModel().rows?.length ? ( + table.getRowModel().rows.map((row) => ( + + {row.getVisibleCells().map((cell) => ( + + {flexRender( + cell.column.columnDef.cell, + cell.getContext(), + )} + + ))} + + )) + ) : ( + + + + + + )} + +
    +
    + {metadataType === MetadataType.Manage && ( +
    + {t('knowledgeDetails.metadata.toMetadataSettingTip')} +
    + )} + +
    + {manageValuesVisible && ( + + {metadataType === MetadataType.Setting || + metadataType === MetadataType.SingleFileSetting + ? t('knowledgeDetails.metadata.fieldSetting') + : t('knowledgeDetails.metadata.editMetadata')} +
    + } + type={metadataType} + existsKeys={existsKeys} + visible={manageValuesVisible} + hideModal={hideManageValuesModal} + data={valueData} + onSave={handleSaveValues} + addUpdateValue={addUpdateValue} + addDeleteValue={addDeleteValue} + isEditField={isEditField || isCanAdd} + isAddValue={isAddValue || isCanAdd} + isShowDescription={isShowDescription} + isShowValueSwitch={isShowValueSwitch} + isVerticalShowValue={isVerticalShowValue} + // handleDeleteSingleValue={handleDeleteSingleValue} + // handleDeleteSingleRow={handleDeleteSingleRow} + /> + )} + + {deleteDialogContent.visible && ( + + ), + }} + /> + )} + + ); +}; diff --git a/web/src/pages/dataset/components/metedata/manage-values-modal.tsx b/web/src/pages/dataset/components/metedata/manage-values-modal.tsx new file mode 100644 index 00000000000..f1c6343f645 --- /dev/null +++ b/web/src/pages/dataset/components/metedata/manage-values-modal.tsx @@ -0,0 +1,232 @@ +import { + ConfirmDeleteDialog, + ConfirmDeleteDialogNode, +} from '@/components/confirm-delete-dialog'; +import EditTag from '@/components/edit-tag'; +import { Button } from '@/components/ui/button'; +import { FormLabel } from '@/components/ui/form'; +import { Input } from '@/components/ui/input'; +import { Modal } from '@/components/ui/modal/modal'; +import { Switch } from '@/components/ui/switch'; +import { Textarea } from '@/components/ui/textarea'; +import { Plus, Trash2 } from 'lucide-react'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useManageValues } from './hooks/use-manage-values-modal'; +import { IManageValuesProps } from './interface'; + +// Create a separate input component, wrapped with memo to avoid unnecessary re-renders +const ValueInputItem = memo( + ({ + item, + index, + onValueChange, + onDelete, + onBlur, + }: { + item: string; + index: number; + onValueChange: (index: number, value: string) => void; + onDelete: (index: number) => void; + onBlur: (index: number) => void; + }) => { + return ( +
    +
    + onValueChange(index, e.target.value)} + onBlur={() => onBlur(index)} + /> +
    + +
    + ); + }, +); + +export const ManageValuesModal = (props: IManageValuesProps) => { + const { + title, + isEditField, + visible, + isAddValue, + isShowDescription, + isShowValueSwitch, + isVerticalShowValue, + } = props; + const { + metaData, + tempValues, + valueError, + deleteDialogContent, + handleChange, + handleValueChange, + handleValueBlur, + handleDelete, + handleAddValue, + showDeleteModal, + handleSave, + handleHideModal, + } = useManageValues(props); + const { t } = useTranslation(); + + return ( + +
    + {!isEditField && ( +
    + {metaData.field} +
    + )} + {isEditField && ( +
    +
    {t('knowledgeDetails.metadata.fieldName')}
    +
    + { + const value = e.target?.value || ''; + if (/^[a-zA-Z_]*$/.test(value)) { + handleChange('field', value); + } + }} + /> +
    {valueError.field}
    +
    +
    + )} + {isShowDescription && ( +
    + + {t('knowledgeDetails.metadata.description')} + +
    +