]*>', '', text, flags=re.IGNORECASE) + text = re.sub(r'
', '\n', text, flags=re.IGNORECASE) + + # First, handle common markdown table artifacts + text = re.sub(r'^[|\-\s:]+$', '', text, flags=re.MULTILINE) # Remove separator lines + text = re.sub(r'^\s*\|\s*|\s*\|\s*$', '', text) # Remove leading/trailing pipes + text = re.sub(r'\s*\|\s*', ' | ', text) # Normalize pipes + + # Remove markdown links, but keep other formatting characters for _format_inline + text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text) # Remove markdown links + + # Escape HTML special characters + text = text.replace('&', '&') + text = text.replace('<', '<') + text = text.replace('>', '>') + + # Clean up excessive whitespace + text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text) # Multiple blank lines to double + text = re.sub(r' +', ' ', text) # Multiple spaces to single + + return text.strip() + + def _get_cell_style(self, row_idx: int, is_header: bool = False, font_size: int = None) -> 'ParagraphStyle': + """Get the appropriate style for a table cell.""" + styles = getSampleStyleSheet() + + # Helper function to get the correct bold font name + def get_bold_font(font_family): + font_map = { + 'Helvetica': 'Helvetica-Bold', + 'Times-Roman': 'Times-Bold', + 'Courier': 'Courier-Bold', + } + if 'Bold' in font_family: + return font_family + return font_map.get(font_family, 'Helvetica-Bold') + + if is_header: + return ParagraphStyle( + 'TableHeader', + parent=styles['Normal'], + fontSize=self._param.font_size, + fontName=self._get_active_bold_font(), + textColor=colors.whitesmoke, + alignment=TA_CENTER, + leading=self._param.font_size * 1.2, + wordWrap='CJK' + ) + else: + font_size = font_size or (self._param.font_size - 1) + return ParagraphStyle( + 'TableCell', + parent=styles['Normal'], + fontSize=font_size, + fontName=self._get_active_font(), + textColor=colors.black, + alignment=TA_LEFT, + leading=font_size * 1.15, + wordWrap='CJK' + ) + + def _convert_table_to_definition_list(self, data: list[list[str]]) -> list: + """Convert a table to a definition list format for better handling of large content. + + This method handles both simple and complex tables, including those with nested content. + It ensures that large cell content is properly wrapped and paginated. + """ + elements = [] + styles = getSampleStyleSheet() + + # Base styles + base_font_size = getattr(self._param, 'font_size', 10) + + # Body style + body_style = ParagraphStyle( + 'TableBody', + parent=styles['Normal'], + fontSize=base_font_size, + fontName=self._get_active_font(), + textColor=colors.HexColor(getattr(self._param, 'text_color', '#000000')), + spaceAfter=6, + leading=base_font_size * 1.2 + ) + + # Label style (for field names) + label_style = ParagraphStyle( + 'LabelStyle', + parent=body_style, + fontName=self._get_active_bold_font(), + textColor=colors.HexColor('#2c3e50'), + fontSize=base_font_size, + spaceAfter=4, + leftIndent=0, + leading=base_font_size * 1.3 + ) + + # Value style (for cell content) - clean, no borders + value_style = ParagraphStyle( + 'ValueStyle', + parent=body_style, + leftIndent=15, + rightIndent=0, + spaceAfter=8, + spaceBefore=2, + fontSize=base_font_size, + textColor=colors.HexColor('#333333'), + alignment=TA_JUSTIFY, + leading=base_font_size * 1.4, + # No borders or background - clean text only + ) + + try: + # If we have no data, return empty list + if not data or not any(data): + return elements + + # Get column headers or generate them + headers = [] + if data and len(data) > 0: + headers = [str(h).strip() for h in data[0]] + + # If no headers or empty headers, generate them + if not any(headers): + headers = [f"Column {i+1}" for i in range(len(data[0]) if data and len(data) > 0 else 0)] + + # Process each data row (skip header if it exists) + start_row = 1 if len(data) > 1 and any(data[0]) else 0 + + for row_idx in range(start_row, len(data)): + row = data[row_idx] if row_idx < len(data) else [] + if not row: + continue + + # Create a container for the row + row_elements = [] + + # Process each cell in the row + for col_idx in range(len(headers)): + if col_idx >= len(headers): + continue + + # Get cell content + cell_text = str(row[col_idx]).strip() if col_idx < len(row) and row[col_idx] is not None else "" + + # Skip empty cells + if not cell_text or cell_text.isspace(): + continue + + # Clean up markdown artifacts for regular text content + cell_text = str(cell_text) # Ensure it's a string + + # Remove markdown table formatting + cell_text = re.sub(r'^[|\-\s:]+$', '', cell_text, flags=re.MULTILINE) # Remove separator lines + cell_text = re.sub(r'^\s*\|\s*|\s*\|\s*$', '', cell_text) # Remove leading/trailing pipes + cell_text = re.sub(r'\s*\|\s*', ' | ', cell_text) # Normalize pipes + cell_text = re.sub(r'\s+', ' ', cell_text).strip() # Normalize whitespace + + # Remove any remaining markdown formatting + cell_text = re.sub(r'`(.*?)`', r'\1', cell_text) # Remove code ticks + cell_text = re.sub(r'\*\*(.*?)\*\*', r'\1', cell_text) # Remove bold + cell_text = re.sub(r'\*(.*?)\*', r'\1', cell_text) # Remove italic + + # Clean up any HTML entities or special characters + cell_text = self._escape_html(cell_text) + + # If content still looks like a table, convert it to plain text + if '|' in cell_text and ('--' in cell_text or any(cell_text.count('|') > 2 for line in cell_text.split('\n') if line.strip())): + # Convert to a simple text format + lines = [line.strip() for line in cell_text.split('\n') if line.strip()] + cell_text = ' | '.join(lines[:5]) # Join first 5 lines with pipe + if len(lines) > 5: + cell_text += '...' + + # Process long content with better wrapping + max_chars_per_line = 100 # Reduced for better readability + max_paragraphs = 3 # Maximum number of paragraphs to show initially + + # Split into paragraphs + paragraphs = [p for p in cell_text.split('\n\n') if p.strip()] + + # If content is too long, truncate with "show more" indicator + if len(paragraphs) > max_paragraphs or any(len(p) > max_chars_per_line * 3 for p in paragraphs): + wrapped_paragraphs = [] + + for i, para in enumerate(paragraphs[:max_paragraphs]): + if len(para) > max_chars_per_line * 3: + # Split long paragraphs + words = para.split() + current_line = [] + current_length = 0 + + for word in words: + if current_line and current_length + len(word) + 1 > max_chars_per_line: + wrapped_paragraphs.append(' '.join(current_line)) + current_line = [word] + current_length = len(word) + else: + current_line.append(word) + current_length += len(word) + (1 if current_line else 0) + + if current_line: + wrapped_paragraphs.append(' '.join(current_line)) + else: + wrapped_paragraphs.append(para) + + # Add "show more" indicator if there are more paragraphs + if len(paragraphs) > max_paragraphs: + wrapped_paragraphs.append(f"... and {len(paragraphs) - max_paragraphs} more paragraphs") + + cell_text = '\n\n'.join(wrapped_paragraphs) + + # Add label and content with clean formatting (no borders) + label_para = Paragraph(f"{self._escape_html(headers[col_idx])}:", label_style) + value_para = Paragraph(self._escape_html(cell_text), value_style) + + # Add elements with proper spacing + row_elements.append(label_para) + row_elements.append(Spacer(1, 0.03 * 72)) # Tiny space between label and value + row_elements.append(value_para) + + # Add spacing between rows + if row_elements and row_idx < len(data) - 1: + # Add a subtle horizontal line as separator + row_elements.append(Spacer(1, 0.1 * 72)) + row_elements.append(self._create_horizontal_line(width=0.5, color='#e0e0e0')) + row_elements.append(Spacer(1, 0.15 * 72)) + + elements.extend(row_elements) + + # Add some space after the table + if elements: + elements.append(Spacer(1, 0.3 * 72)) # 0.3 inches in points + + except Exception as e: + # Fallback to simple text representation if something goes wrong + error_style = ParagraphStyle( + 'ErrorStyle', + parent=styles['Normal'], + fontSize=base_font_size - 1, + textColor=colors.red, + backColor=colors.HexColor('#fff0f0'), + borderWidth=1, + borderColor=colors.red, + borderPadding=5 + ) + + error_msg = [ + Paragraph("Error processing table:", error_style), + Paragraph(str(e), error_style), + Spacer(1, 0.2 * 72) + ] + + # Add a simplified version of the table + try: + for row in data[:10]: # Limit to first 10 rows to avoid huge error output + error_msg.append(Paragraph(" | ".join(str(cell) for cell in row), body_style)) + if len(data) > 10: + error_msg.append(Paragraph(f"... and {len(data) - 10} more rows", body_style)) + except Exception: + pass + + elements.extend(error_msg) + + return elements + + def _create_table(self, table_lines: list[str]) -> Optional[list]: + """Create a table from markdown table syntax with robust error handling. + + This method handles simple tables and falls back to a list format for complex cases. + + Returns: + A list of flowables (could be a table or alternative representation) + Returns None if the table cannot be created. + """ + if not table_lines or len(table_lines) < 2: + return None + + try: + # Parse table data + data = [] + max_columns = 0 + + for line in table_lines: + # Skip separator lines (e.g., |---|---|) + if re.match(r'^\|[\s\-:]+\|$', line): + continue + + # Handle empty lines within tables + if not line.strip(): + continue + + # Split by | and clean up cells + cells = [] + in_quotes = False + current_cell = "" + + # Custom split to handle escaped pipes and quoted content + for char in line[1:]: # Skip initial | + if char == '|' and not in_quotes: + cells.append(current_cell.strip()) + current_cell = "" + elif char == '"': + in_quotes = not in_quotes + current_cell += char + elif char == '\\' and not in_quotes: + # Handle escaped characters + pass + else: + current_cell += char + + # Add the last cell + if current_cell.strip() or len(cells) > 0: + cells.append(current_cell.strip()) + + # Remove empty first/last elements if they're empty (from leading/trailing |) + if cells and not cells[0]: + cells = cells[1:] + if cells and not cells[-1]: + cells = cells[:-1] + + if cells: + data.append(cells) + max_columns = max(max_columns, len(cells)) + + if not data or max_columns == 0: + return None + + # Ensure all rows have the same number of columns + for row in data: + while len(row) < max_columns: + row.append('') + + # Calculate available width for table + from reportlab.lib.pagesizes import A4 + page_width = A4[0] if self._param.orientation == 'portrait' else A4[1] + available_width = page_width - (self._param.margin_left + self._param.margin_right) * inch + + # Check if we should use definition list format + max_cell_length = max((len(str(cell)) for row in data for cell in row), default=0) + total_rows = len(data) + + # Use definition list format if: + # - Any cell is too large (> 300 chars), OR + # - More than 6 columns, OR + # - More than 20 rows, OR + # - Contains nested tables or complex structures + has_nested_tables = any('|' in cell and '---' in cell for row in data for cell in row) + has_complex_cells = any(len(str(cell)) > 150 for row in data for cell in row) + + should_use_list_format = ( + max_cell_length > 300 or + max_columns > 6 or + total_rows > 20 or + has_nested_tables or + has_complex_cells + ) + + if should_use_list_format: + return self._convert_table_to_definition_list(data) + + # Process cells for normal table + processed_data = [] + for row_idx, row in enumerate(data): + processed_row = [] + for cell_idx, cell in enumerate(row): + cell_text = str(cell).strip() if cell is not None else "" + + # Handle empty cells + if not cell_text: + processed_row.append("") + continue + + # Clean up markdown table artifacts + cell_text = re.sub(r'\\\|', '|', cell_text) # Unescape pipes + cell_text = re.sub(r'\\n', '\n', cell_text) # Handle explicit newlines + + # Check for nested tables + if '|' in cell_text and '---' in cell_text: + # This cell contains a nested table + nested_lines = [line.strip() for line in cell_text.split('\n') if line.strip()] + nested_table = self._create_table(nested_lines) + if nested_table: + processed_row.append(nested_table[0]) # Add the nested table + continue + + # Process as regular text + font_size = self._param.font_size - 1 if row_idx > 0 else self._param.font_size + try: + style = self._get_cell_style(row_idx, is_header=(row_idx == 0), font_size=font_size) + escaped_text = self._escape_html(cell_text) + processed_row.append(Paragraph(escaped_text, style)) + except Exception: + processed_row.append(self._escape_html(cell_text)) + + processed_data.append(processed_row) + + # Calculate column widths + min_col_width = 0.5 * inch + max_cols = int(available_width / min_col_width) + + if max_columns > max_cols: + return self._convert_table_to_definition_list(data) + + col_width = max(min_col_width, available_width / max_columns) + col_widths = [col_width] * max_columns + + # Create the table + try: + table = LongTable(processed_data, colWidths=col_widths, repeatRows=1) + + # Define table style + table_style = [ + ('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c3e50')), # Darker header + ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), + ('ALIGN', (0, 0), (-1, 0), 'CENTER'), + ('FONTNAME', (0, 0), (-1, 0), self._get_active_bold_font()), + ('FONTSIZE', (0, 0), (-1, -1), self._param.font_size - 1), + ('BOTTOMPADDING', (0, 0), (-1, 0), 12), + ('BACKGROUND', (0, 1), (-1, -1), colors.HexColor('#f8f9fa')), # Lighter background + ('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#dee2e6')), # Lighter grid + ('VALIGN', (0, 0), (-1, -1), 'TOP'), + ('TOPPADDING', (0, 0), (-1, -1), 8), + ('BOTTOMPADDING', (0, 0), (-1, -1), 8), + ('LEFTPADDING', (0, 0), (-1, -1), 8), + ('RIGHTPADDING', (0, 0), (-1, -1), 8), + ] + + # Add zebra striping for better readability + for i in range(1, len(processed_data)): + if i % 2 == 0: + table_style.append(('BACKGROUND', (0, i), (-1, i), colors.HexColor('#f1f3f5'))) + + table.setStyle(TableStyle(table_style)) + + # Add a small spacer after the table + return [table, Spacer(1, 0.2 * inch)] + + except Exception as table_error: + print(f"Error creating table: {table_error}") + return self._convert_table_to_definition_list(data) + + except Exception as e: + print(f"Error processing table: {e}") + # Return a simple text representation of the table + try: + text_content = [] + for row in data: + text_content.append(" | ".join(str(cell) for cell in row)) + return [Paragraph("Hi {{email}},
-{{inviter}} has invited you to join their team (ID: {{tenant_id}}).
-Click the link below to complete your registration:
-{{invite_url}}
If you did not request this, please ignore this email.
+Hi {{email}}, +{{inviter}} has invited you to join their team (ID: {{tenant_id}}). +Click the link below to complete your registration: +{{invite_url}} +If you did not request this, please ignore this email. """ # Password reset code template RESET_CODE_EMAIL_TMPL = """ -Hello,
-Your password reset code is: {{ code }}
-This code will expire in {{ ttl_min }} minutes.
+Hello, +Your password reset code is: {{ code }} +This code will expire in {{ ttl_min }} minutes. """ # Template registry diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 5f0fa70f451..4cad64c35ce 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -42,7 +42,7 @@ def filename_type(filename): if re.match(r".*\.pdf$", filename): return FileType.PDF.value - if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename): + if re.match(r".*\.(msg|eml|doc|docx|ppt|pptx|yml|xml|htm|json|jsonl|ldjson|csv|txt|ini|xls|xlsx|wps|rtf|hlp|pages|numbers|key|md|mdx|py|js|java|c|cpp|h|php|go|ts|sh|cs|kt|html|sql)$", filename): return FileType.DOC.value if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus)$", filename): @@ -164,3 +164,23 @@ def try_open(blob): return repaired return blob + + +def sanitize_path(raw_path: str | None) -> str: + """Normalize and sanitize a user-provided path segment. + + - Converts backslashes to forward slashes + - Strips leading/trailing slashes + - Removes '.' and '..' segments + - Restricts characters to A-Za-z0-9, underscore, dash, and '/' + """ + if not raw_path: + return "" + backslash_re = re.compile(r"[\\]+") + unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]") + normalized = backslash_re.sub("/", raw_path) + normalized = normalized.strip("/") + parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")] + sanitized = "/".join(parts) + sanitized = unsafe_re.sub("", sanitized) + return sanitized diff --git a/api/utils/health_utils.py b/api/utils/health_utils.py index 88e5aaebbee..0a7ab6e7a6f 100644 --- a/api/utils/health_utils.py +++ b/api/utils/health_utils.py @@ -173,7 +173,8 @@ def check_task_executor_alive(): heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats] task_executor_heartbeats[task_executor_id] = heartbeats if task_executor_heartbeats: - return {"status": "alive", "message": task_executor_heartbeats} + status = "alive" if any(task_executor_heartbeats.values()) else "timeout" + return {"status": status, "message": task_executor_heartbeats} else: return {"status": "timeout", "message": "Not found any task executor."} except Exception as e: diff --git a/api/utils/json_encode.py b/api/utils/json_encode.py index b21addd4f9b..fa5ea973aa0 100644 --- a/api/utils/json_encode.py +++ b/api/utils/json_encode.py @@ -1,3 +1,19 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import datetime import json from enum import Enum, IntEnum diff --git a/api/utils/memory_utils.py b/api/utils/memory_utils.py new file mode 100644 index 00000000000..bb78949518b --- /dev/null +++ b/api/utils/memory_utils.py @@ -0,0 +1,54 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List +from common.constants import MemoryType + +def format_ret_data_from_memory(memory): + return { + "id": memory.id, + "name": memory.name, + "avatar": memory.avatar, + "tenant_id": memory.tenant_id, + "owner_name": memory.owner_name if hasattr(memory, "owner_name") else None, + "memory_type": get_memory_type_human(memory.memory_type), + "storage_type": memory.storage_type, + "embd_id": memory.embd_id, + "llm_id": memory.llm_id, + "permissions": memory.permissions, + "description": memory.description, + "memory_size": memory.memory_size, + "forgetting_policy": memory.forgetting_policy, + "temperature": memory.temperature, + "system_prompt": memory.system_prompt, + "user_prompt": memory.user_prompt, + "create_time": memory.create_time, + "create_date": memory.create_date, + "update_time": memory.update_time, + "update_date": memory.update_date + } + + +def get_memory_type_human(memory_type: int) -> List[str]: + return [mem_type.name.lower() for mem_type in MemoryType if memory_type & mem_type.value] + + +def calculate_memory_type(memory_type_name_list: List[str]) -> int: + memory_type = 0 + type_value_map = {mem_type.name.lower(): mem_type.value for mem_type in MemoryType} + for mem_type in memory_type_name_list: + if mem_type in type_value_map: + memory_type |= type_value_map[mem_type] + return memory_type diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index caf3f0924aa..2dcace53fe9 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -14,10 +14,11 @@ # limitations under the License. # from collections import Counter +import string from typing import Annotated, Any, Literal from uuid import UUID -from flask import Request +from quart import Request from pydantic import ( BaseModel, ConfigDict, @@ -25,6 +26,7 @@ StringConstraints, ValidationError, field_validator, + model_validator, ) from pydantic_core import PydanticCustomError from werkzeug.exceptions import BadRequest, UnsupportedMediaType @@ -32,7 +34,7 @@ from api.constants import DATASET_NAME_LIMIT -def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: +async def validate_and_parse_json_request(request: Request, validator: type[BaseModel], *, extras: dict[str, Any] | None = None, exclude_unset: bool = False) -> tuple[dict[str, Any] | None, str | None]: """ Validates and parses JSON requests through a multi-stage validation pipeline. @@ -81,7 +83,7 @@ def validate_and_parse_json_request(request: Request, validator: type[BaseModel] from the final output after validation """ try: - payload = request.get_json() or {} + payload = await request.get_json() or {} except UnsupportedMediaType: return None, f"Unsupported content type: Expected application/json, got {request.content_type}" except BadRequest: @@ -329,6 +331,7 @@ class RaptorConfig(Base): threshold: Annotated[float, Field(default=0.1, ge=0.0, le=1.0)] max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)] random_seed: Annotated[int, Field(default=0, ge=0)] + auto_disable_for_structured_data: Annotated[bool, Field(default=True)] class GraphragConfig(Base): @@ -361,10 +364,9 @@ class CreateDatasetReq(Base): description: Annotated[str | None, Field(default=None, max_length=65535)] embedding_model: Annotated[str | None, Field(default=None, max_length=255, serialization_alias="embd_id")] permission: Annotated[Literal["me", "team"], Field(default="me", min_length=1, max_length=16)] - chunk_method: Annotated[ - Literal["naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"], - Field(default="naive", min_length=1, max_length=32, serialization_alias="parser_id"), - ] + chunk_method: Annotated[str | None, Field(default=None, serialization_alias="parser_id")] + parse_type: Annotated[int | None, Field(default=None, ge=0, le=64)] + pipeline_id: Annotated[str | None, Field(default=None, min_length=32, max_length=32, serialization_alias="pipeline_id")] parser_config: Annotated[ParserConfig | None, Field(default=None)] @field_validator("avatar", mode="after") @@ -525,6 +527,93 @@ def validate_parser_config_json_length(cls, v: ParserConfig | None) -> ParserCon raise PydanticCustomError("string_too_long", "Parser config exceeds size limit (max 65,535 characters). Current size: {actual}", {"actual": len(json_str)}) return v + @field_validator("pipeline_id", mode="after") + @classmethod + def validate_pipeline_id(cls, v: str | None) -> str | None: + """Validate pipeline_id as 32-char lowercase hex string if provided. + + Rules: + - None or empty string: treat as None (not set) + - Must be exactly length 32 + - Must contain only hex digits (0-9a-fA-F); normalized to lowercase + """ + if v is None: + return None + if v == "": + return None + if len(v) != 32: + raise PydanticCustomError("format_invalid", "pipeline_id must be 32 hex characters") + if any(ch not in string.hexdigits for ch in v): + raise PydanticCustomError("format_invalid", "pipeline_id must be hexadecimal") + return v.lower() + + @model_validator(mode="after") + def validate_parser_dependency(self) -> "CreateDatasetReq": + """ + Mixed conditional validation: + - If parser_id is omitted (field not set): + * If both parse_type and pipeline_id are omitted → default chunk_method = "naive" + * If both parse_type and pipeline_id are provided → allow ingestion pipeline mode + - If parser_id is provided (valid enum) → parse_type and pipeline_id must be None (disallow mixed usage) + + Raises: + PydanticCustomError with code 'dependency_error' on violation. + """ + # Omitted chunk_method (not in fields) logic + if self.chunk_method is None and "chunk_method" not in self.model_fields_set: + # All three absent → default naive + if self.parse_type is None and self.pipeline_id is None: + object.__setattr__(self, "chunk_method", "naive") + return self + # parser_id omitted: require BOTH parse_type & pipeline_id present (no partial allowed) + if self.parse_type is None or self.pipeline_id is None: + missing = [] + if self.parse_type is None: + missing.append("parse_type") + if self.pipeline_id is None: + missing.append("pipeline_id") + raise PydanticCustomError( + "dependency_error", + "parser_id omitted → required fields missing: {fields}", + {"fields": ", ".join(missing)}, + ) + # Both provided → allow pipeline mode + return self + + # parser_id provided (valid): MUST NOT have parse_type or pipeline_id + if isinstance(self.chunk_method, str): + if self.parse_type is not None or self.pipeline_id is not None: + invalid = [] + if self.parse_type is not None: + invalid.append("parse_type") + if self.pipeline_id is not None: + invalid.append("pipeline_id") + raise PydanticCustomError( + "dependency_error", + "parser_id provided → disallowed fields present: {fields}", + {"fields": ", ".join(invalid)}, + ) + return self + + @field_validator("chunk_method", mode="wrap") + @classmethod + def validate_chunk_method(cls, v: Any, handler) -> Any: + """Wrap validation to unify error messages, including type errors (e.g. list).""" + allowed = {"naive", "book", "email", "laws", "manual", "one", "paper", "picture", "presentation", "qa", "table", "tag"} + error_msg = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'table' or 'tag'" + # Omitted field: handler won't be invoked (wrap still gets value); None treated as explicit invalid + if v is None: + raise PydanticCustomError("literal_error", error_msg) + try: + # Run inner validation (type checking) + result = handler(v) + except Exception: + raise PydanticCustomError("literal_error", error_msg) + # After handler, enforce enumeration + if not isinstance(result, str) or result == "" or result not in allowed: + raise PydanticCustomError("literal_error", error_msg) + return result + class UpdateDatasetReq(CreateDatasetReq): dataset_id: Annotated[str, Field(...)] diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py index e0e47f472e6..11e8428b77c 100644 --- a/api/utils/web_utils.py +++ b/api/utils/web_utils.py @@ -20,10 +20,11 @@ import re import socket from urllib.parse import urlparse - -from api.apps import smtp_mail_server -from flask_mail import Message -from flask import render_template_string +import aiosmtplib +from email.mime.text import MIMEText +from email.header import Header +from common import settings +from quart import render_template_string from api.utils.email_templates import EMAIL_TEMPLATES from selenium import webdriver from selenium.common.exceptions import TimeoutException @@ -35,11 +36,11 @@ from webdriver_manager.chrome import ChromeDriverManager -OTP_LENGTH = 8 -OTP_TTL_SECONDS = 5 * 60 -ATTEMPT_LIMIT = 5 -ATTEMPT_LOCK_SECONDS = 30 * 60 -RESEND_COOLDOWN_SECONDS = 60 +OTP_LENGTH = 4 +OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes +ATTEMPT_LIMIT = 5 # maximum attempts +ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes +RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute CONTENT_TYPE_MAP = { @@ -68,6 +69,7 @@ # Web "md": "text/markdown", "markdown": "text/markdown", + "mdx": "text/markdown", "htm": "text/html", "html": "text/html", "json": "application/json", @@ -183,27 +185,34 @@ def get_float(req: dict, key: str, default: float | int = 10.0) -> float: return parsed if parsed > 0 else default except (TypeError, ValueError): return default + + +async def send_email_html(to_email: str, subject: str, template_key: str, **context): + + body = await render_template_string(EMAIL_TEMPLATES.get(template_key), **context) + msg = MIMEText(body, "plain", "utf-8") + msg["Subject"] = Header(subject, "utf-8") + msg["From"] = f"{settings.MAIL_DEFAULT_SENDER[0]} <{settings.MAIL_DEFAULT_SENDER[1]}>" + msg["To"] = to_email + smtp = aiosmtplib.SMTP( + hostname=settings.MAIL_SERVER, + port=settings.MAIL_PORT, + use_tls=True, + timeout=10, + ) -def send_email_html(subject: str, to_email: str, template_key: str, **context): - """Generic HTML email sender using shared templates. - template_key must exist in EMAIL_TEMPLATES. - """ - from api.apps import app - tmpl = EMAIL_TEMPLATES.get(template_key) - if not tmpl: - raise ValueError(f"Unknown email template: {template_key}") - with app.app_context(): - msg = Message(subject=subject, recipients=[to_email]) - msg.html = render_template_string(tmpl, **context) - smtp_mail_server.send(msg) + await smtp.connect() + await smtp.login(settings.MAIL_USERNAME, settings.MAIL_PASSWORD) + await smtp.send_message(msg) + await smtp.quit() -def send_invite_email(to_email, invite_url, tenant_id, inviter): +async def send_invite_email(to_email, invite_url, tenant_id, inviter): # Reuse the generic HTML sender with 'invite' template - send_email_html( - subject="RAGFlow Invitation", + await send_email_html( to_email=to_email, + subject="RAGFlow Invitation", template_key="invite", email=to_email, invite_url=invite_url, @@ -230,4 +239,4 @@ def hash_code(code: str, salt: bytes) -> str: def captcha_key(email: str) -> str: return f"captcha:{email}" - + \ No newline at end of file diff --git a/check_comment_ascii.py b/check_comment_ascii.py new file mode 100644 index 00000000000..57d188b6c2d --- /dev/null +++ b/check_comment_ascii.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +""" +Check whether given python files contain non-ASCII comments. + +How to check the whole git repo: + +``` +$ git ls-files -z -- '*.py' | xargs -0 python3 check_comment_ascii.py +``` +""" + +import sys +import tokenize +import ast +import pathlib +import re + +ASCII = re.compile(r"^[\n -~]*\Z") # Printable ASCII + newline + + +def check(src: str, name: str) -> int: + """ + docstring line 1 + docstring line 2 + """ + ok = 1 + # A common comment begins with `#` + with tokenize.open(src) as fp: + for tk in tokenize.generate_tokens(fp.readline): + if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string): + print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}") + ok = 0 + # A docstring begins and ends with `'''` + for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)): + if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc): + print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}") + ok = 0 + return ok + + +if __name__ == "__main__": + status = 0 + for file in sys.argv[1:]: + if not check(file, file): + status = 1 + sys.exit(status) diff --git a/common/connection_utils.py b/common/connection_utils.py index 618584ae978..86ebc371d8c 100644 --- a/common/connection_utils.py +++ b/common/connection_utils.py @@ -19,9 +19,8 @@ import threading from typing import Any, Callable, Coroutine, Optional, Type, Union import asyncio -import trio from functools import wraps -from flask import make_response, jsonify +from quart import make_response, jsonify from common.constants import RetCode TimeoutException = Union[Type[BaseException], BaseException] @@ -70,11 +69,10 @@ async def async_wrapper(*args, **kwargs) -> Any: for a in range(attempts): try: if os.environ.get("ENABLE_TIMEOUT_ASSERTION"): - with trio.fail_after(seconds): - return await func(*args, **kwargs) + return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds) else: return await func(*args, **kwargs) - except trio.TooSlowError: + except asyncio.TimeoutError: if a < attempts - 1: continue if on_timeout is not None: @@ -103,7 +101,7 @@ async def async_wrapper(*args, **kwargs) -> Any: return decorator -def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): +async def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): result_dict = {"code": code, "message": message, "data": data} response_dict = {} for key, value in result_dict.items(): @@ -111,7 +109,27 @@ def construct_response(code=RetCode.SUCCESS, message="success", data=None, auth= continue else: response_dict[key] = value - response = make_response(jsonify(response_dict)) + response = await make_response(jsonify(response_dict)) + if auth: + response.headers["Authorization"] = auth + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Method"] = "*" + response.headers["Access-Control-Allow-Headers"] = "*" + response.headers["Access-Control-Allow-Headers"] = "*" + response.headers["Access-Control-Expose-Headers"] = "Authorization" + return response + + +def sync_construct_response(code=RetCode.SUCCESS, message="success", data=None, auth=None): + import flask + result_dict = {"code": code, "message": message, "data": data} + response_dict = {} + for key, value in result_dict.items(): + if value is None and key != "code": + continue + else: + response_dict[key] = value + response = flask.make_response(flask.jsonify(response_dict)) if auth: response.headers["Authorization"] = auth response.headers["Access-Control-Allow-Origin"] = "*" diff --git a/common/constants.py b/common/constants.py index dd24b4ead7e..23a75505941 100644 --- a/common/constants.py +++ b/common/constants.py @@ -49,10 +49,12 @@ class RetCode(IntEnum, CustomEnum): RUNNING = 106 PERMISSION_ERROR = 108 AUTHENTICATION_ERROR = 109 + BAD_REQUEST = 400 UNAUTHORIZED = 401 SERVER_ERROR = 500 FORBIDDEN = 403 NOT_FOUND = 404 + CONFLICT = 409 class StatusEnum(Enum): @@ -72,6 +74,7 @@ class LLMType(StrEnum): IMAGE2TEXT = 'image2text' RERANK = 'rerank' TTS = 'tts' + OCR = 'ocr' class TaskStatus(StrEnum): @@ -118,7 +121,18 @@ class FileSource(StrEnum): SHAREPOINT = "sharepoint" SLACK = "slack" TEAMS = "teams" - + WEBDAV = "webdav" + MOODLE = "moodle" + DROPBOX = "dropbox" + BOX = "box" + R2 = "r2" + OCI_STORAGE = "oci_storage" + GOOGLE_CLOUD_STORAGE = "google_cloud_storage" + AIRTABLE = "airtable" + ASANA = "asana" + GITHUB = "github" + GITLAB = "gitlab" + IMAP = "imap" class PipelineTaskType(StrEnum): PARSE = "Parse" @@ -126,6 +140,7 @@ class PipelineTaskType(StrEnum): RAPTOR = "RAPTOR" GRAPH_RAG = "GraphRAG" MINDMAP = "Mindmap" + MEMORY = "Memory" VALID_PIPELINE_TASK_TYPES = {PipelineTaskType.PARSE, PipelineTaskType.DOWNLOAD, PipelineTaskType.RAPTOR, @@ -144,6 +159,24 @@ class Storage(Enum): AWS_S3 = 4 OSS = 5 OPENDAL = 6 + GCS = 7 + + +class MemoryType(Enum): + RAW = 0b0001 # 1 << 0 = 1 (0b00000001) + SEMANTIC = 0b0010 # 1 << 1 = 2 (0b00000010) + EPISODIC = 0b0100 # 1 << 2 = 4 (0b00000100) + PROCEDURAL = 0b1000 # 1 << 3 = 8 (0b00001000) + + +class MemoryStorageType(StrEnum): + TABLE = "table" + GRAPH = "graph" + + +class ForgettingPolicy(StrEnum): + FIFO = "FIFO" + # environment # ENV_STRONG_TEST_COUNT = "STRONG_TEST_COUNT" @@ -194,3 +227,13 @@ class Storage(Enum): SVR_QUEUE_NAME = "rag_flow_svr_queue" SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" TAG_FLD = "tag_feas" + + +MINERU_ENV_KEYS = ["MINERU_APISERVER", "MINERU_OUTPUT_DIR", "MINERU_BACKEND", "MINERU_SERVER_URL", "MINERU_DELETE_OUTPUT"] +MINERU_DEFAULT_CONFIG = { + "MINERU_APISERVER": "", + "MINERU_OUTPUT_DIR": "", + "MINERU_BACKEND": "pipeline", + "MINERU_SERVER_URL": "", + "MINERU_DELETE_OUTPUT": 1, +} diff --git a/common/crypto_utils.py b/common/crypto_utils.py new file mode 100644 index 00000000000..5dcbd2937fa --- /dev/null +++ b/common/crypto_utils.py @@ -0,0 +1,374 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from cryptography.hazmat.primitives import hashes + + +class BaseCrypto: + """Base class for cryptographic algorithms""" + + # Magic header to identify encrypted data + ENCRYPTED_MAGIC = b'RAGF' + + def __init__(self, key, iv=None, block_size=16, key_length=32, iv_length=16): + """ + Initialize cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + block_size: Block size + key_length: Key length + iv_length: Initialization vector length + """ + self.block_size = block_size + self.key_length = key_length + self.iv_length = iv_length + + # Normalize key + self.key = self._normalize_key(key) + self.iv = iv + + def _normalize_key(self, key): + """Normalize key length""" + if isinstance(key, str): + key = key.encode('utf-8') + + # Use PBKDF2 for key derivation to ensure correct key length + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=self.key_length, + salt=b"ragflow_crypto_salt", # Fixed salt to ensure consistent key derivation results + iterations=100000, + backend=default_backend() + ) + + return kdf.derive(key) + + def encrypt(self, data): + """ + Encrypt data (template method) + + Args: + data: Data to encrypt (bytes) + + Returns: + Encrypted data (bytes), format: magic_header + iv + encrypted_data + """ + # Generate random IV + iv = os.urandom(self.iv_length) if not self.iv else self.iv + + # Use PKCS7 padding + padder = padding.PKCS7(self.block_size * 8).padder() + padded_data = padder.update(data) + padder.finalize() + + # Delegate to subclass for specific encryption + ciphertext = self._encrypt(padded_data, iv) + + # Return Magic Header + IV + encrypted data + return self.ENCRYPTED_MAGIC + iv + ciphertext + + def decrypt(self, encrypted_data): + """ + Decrypt data (template method) + + Args: + encrypted_data: Encrypted data (bytes) + + Returns: + Decrypted data (bytes) + """ + # Check if data is encrypted by magic header + if not encrypted_data.startswith(self.ENCRYPTED_MAGIC): + # Not encrypted, return as-is + return encrypted_data + + # Remove magic header + encrypted_data = encrypted_data[len(self.ENCRYPTED_MAGIC):] + + # Separate IV and encrypted data + iv = encrypted_data[:self.iv_length] + ciphertext = encrypted_data[self.iv_length:] + + # Delegate to subclass for specific decryption + padded_data = self._decrypt(ciphertext, iv) + + # Remove padding + unpadder = padding.PKCS7(self.block_size * 8).unpadder() + data = unpadder.update(padded_data) + unpadder.finalize() + + return data + + def _encrypt(self, padded_data, iv): + """ + Encrypt padded data with specific algorithm + + Args: + padded_data: Padded data to encrypt + iv: Initialization vector + + Returns: + Encrypted data + """ + raise NotImplementedError("_encrypt method must be implemented by subclass") + + def _decrypt(self, ciphertext, iv): + """ + Decrypt ciphertext with specific algorithm + + Args: + ciphertext: Ciphertext to decrypt + iv: Initialization vector + + Returns: + Decrypted padded data + """ + raise NotImplementedError("_decrypt method must be implemented by subclass") + + +class AESCrypto(BaseCrypto): + """Base class for AES cryptographic algorithm""" + + def __init__(self, key, iv=None, key_length=32): + """ + Initialize AES cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + key_length: Key length (16 for AES-128, 32 for AES-256) + """ + super().__init__(key, iv, block_size=16, key_length=key_length, iv_length=16) + + def _encrypt(self, padded_data, iv): + """AES encryption implementation""" + # Create encryptor + cipher = Cipher( + algorithms.AES(self.key), + modes.CBC(iv), + backend=default_backend() + ) + encryptor = cipher.encryptor() + + # Encrypt data + return encryptor.update(padded_data) + encryptor.finalize() + + def _decrypt(self, ciphertext, iv): + """AES decryption implementation""" + # Create decryptor + cipher = Cipher( + algorithms.AES(self.key), + modes.CBC(iv), + backend=default_backend() + ) + decryptor = cipher.decryptor() + + # Decrypt data + return decryptor.update(ciphertext) + decryptor.finalize() + + +class AES128CBC(AESCrypto): + """AES-128-CBC cryptographic algorithm""" + + def __init__(self, key, iv=None): + """ + Initialize AES-128-CBC cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + """ + super().__init__(key, iv, key_length=16) + + +class AES256CBC(AESCrypto): + """AES-256-CBC cryptographic algorithm""" + + def __init__(self, key, iv=None): + """ + Initialize AES-256-CBC cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + """ + super().__init__(key, iv, key_length=32) + + +class SM4CBC(BaseCrypto): + """SM4-CBC cryptographic algorithm using cryptography library for better performance""" + + def __init__(self, key, iv=None): + """ + Initialize SM4-CBC cryptographic algorithm + + Args: + key: Encryption key + iv: Initialization vector, automatically generated if None + """ + super().__init__(key, iv, block_size=16, key_length=16, iv_length=16) + + def _encrypt(self, padded_data, iv): + """SM4 encryption implementation using cryptography library""" + # Create encryptor + cipher = Cipher( + algorithms.SM4(self.key), + modes.CBC(iv), + backend=default_backend() + ) + encryptor = cipher.encryptor() + + # Encrypt data + return encryptor.update(padded_data) + encryptor.finalize() + + def _decrypt(self, ciphertext, iv): + """SM4 decryption implementation using cryptography library""" + # Create decryptor + cipher = Cipher( + algorithms.SM4(self.key), + modes.CBC(iv), + backend=default_backend() + ) + decryptor = cipher.decryptor() + + # Decrypt data + return decryptor.update(ciphertext) + decryptor.finalize() + + +class CryptoUtil: + """Cryptographic utility class, using factory pattern to create cryptographic algorithm instances""" + + # Supported cryptographic algorithms mapping + SUPPORTED_ALGORITHMS = { + "aes-128-cbc": AES128CBC, + "aes-256-cbc": AES256CBC, + "sm4-cbc": SM4CBC + } + + def __init__(self, algorithm="aes-256-cbc", key=None, iv=None): + """ + Initialize cryptographic utility + + Args: + algorithm: Cryptographic algorithm, default is aes-256-cbc + key: Encryption key, uses RAGFLOW_CRYPTO_KEY environment variable if None + iv: Initialization vector, automatically generated if None + """ + if algorithm not in self.SUPPORTED_ALGORITHMS: + raise ValueError(f"Unsupported algorithm: {algorithm}") + + if not key: + raise ValueError("Encryption key not provided and RAGFLOW_CRYPTO_KEY environment variable not set") + + # Create cryptographic algorithm instance + self.algorithm_name = algorithm + self.crypto = self.SUPPORTED_ALGORITHMS[algorithm](key=key, iv=iv) + + def encrypt(self, data): + """ + Encrypt data + + Args: + data: Data to encrypt (bytes) + + Returns: + Encrypted data (bytes) + """ + # import time + # start_time = time.time() + encrypted = self.crypto.encrypt(data) + # end_time = time.time() + # logging.info(f"Encryption completed, data length: {len(data)} bytes, time: {(end_time - start_time)*1000:.2f} ms") + return encrypted + + def decrypt(self, encrypted_data): + """ + Decrypt data + + Args: + encrypted_data: Encrypted data (bytes) + + Returns: + Decrypted data (bytes) + """ + # import time + # start_time = time.time() + decrypted = self.crypto.decrypt(encrypted_data) + # end_time = time.time() + # logging.info(f"Decryption completed, data length: {len(encrypted_data)} bytes, time: {(end_time - start_time)*1000:.2f} ms") + return decrypted + + +# Test code +if __name__ == "__main__": + # Test AES encryption + crypto = CryptoUtil(algorithm="aes-256-cbc", key="test_key_123456") + test_data = b"Hello, RAGFlow! This is a test for encryption." + + encrypted = crypto.encrypt(test_data) + decrypted = crypto.decrypt(encrypted) + + print("AES Test:") + print(f"Original: {test_data}") + print(f"Encrypted: {encrypted}") + print(f"Decrypted: {decrypted}") + print(f"Success: {test_data == decrypted}") + print() + + # Test SM4 encryption + try: + crypto_sm4 = CryptoUtil(algorithm="sm4-cbc", key="test_key_123456") + encrypted_sm4 = crypto_sm4.encrypt(test_data) + decrypted_sm4 = crypto_sm4.decrypt(encrypted_sm4) + + print("SM4 Test:") + print(f"Original: {test_data}") + print(f"Encrypted: {encrypted_sm4}") + print(f"Decrypted: {decrypted_sm4}") + print(f"Success: {test_data == decrypted_sm4}") + except Exception as e: + print(f"SM4 Test Failed: {e}") + import traceback + traceback.print_exc() + + # Test with specific algorithm classes directly + print("\nDirect Algorithm Class Test:") + + # Test AES-128-CBC + aes128 = AES128CBC(key="test_key_123456") + encrypted_aes128 = aes128.encrypt(test_data) + decrypted_aes128 = aes128.decrypt(encrypted_aes128) + print(f"AES-128-CBC test: {'passed' if decrypted_aes128 == test_data else 'failed'}") + + # Test AES-256-CBC + aes256 = AES256CBC(key="test_key_123456") + encrypted_aes256 = aes256.encrypt(test_data) + decrypted_aes256 = aes256.decrypt(encrypted_aes256) + print(f"AES-256-CBC test: {'passed' if decrypted_aes256 == test_data else 'failed'}") + + # Test SM4-CBC + try: + sm4 = SM4CBC(key="test_key_123456") + encrypted_sm4 = sm4.encrypt(test_data) + decrypted_sm4 = sm4.decrypt(encrypted_sm4) + print(f"SM4-CBC test: {'passed' if decrypted_sm4 == test_data else 'failed'}") + except Exception as e: + print(f"SM4-CBC test failed: {e}") diff --git a/common/data_source/__init__.py b/common/data_source/__init__.py index 0802a52852a..2619e779dcd 100644 --- a/common/data_source/__init__.py +++ b/common/data_source/__init__.py @@ -1,6 +1,26 @@ """ Thanks to https://github.com/onyx-dot-app/onyx + +Content of this directory is under the "MIT Expat" license as defined below. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. """ from .blob_connector import BlobStorageConnector @@ -11,9 +31,14 @@ from .discord_connector import DiscordConnector from .dropbox_connector import DropboxConnector from .google_drive.connector import GoogleDriveConnector -from .jira_connector import JiraConnector +from .jira.connector import JiraConnector from .sharepoint_connector import SharePointConnector from .teams_connector import TeamsConnector +from .webdav_connector import WebDAVConnector +from .moodle_connector import MoodleConnector +from .airtable_connector import AirtableConnector +from .asana_connector import AsanaConnector +from .imap_connector import ImapConnector from .config import BlobType, DocumentSource from .models import Document, TextSection, ImageSection, BasicExpertInfo from .exceptions import ( @@ -36,6 +61,8 @@ "JiraConnector", "SharePointConnector", "TeamsConnector", + "WebDAVConnector", + "MoodleConnector", "BlobType", "DocumentSource", "Document", @@ -46,5 +73,8 @@ "ConnectorValidationError", "CredentialExpiredError", "InsufficientPermissionsError", - "UnexpectedValidationError" + "UnexpectedValidationError", + "AirtableConnector", + "AsanaConnector", + "ImapConnector" ] diff --git a/common/data_source/airtable_connector.py b/common/data_source/airtable_connector.py new file mode 100644 index 00000000000..6f0b5a930cd --- /dev/null +++ b/common/data_source/airtable_connector.py @@ -0,0 +1,169 @@ +from datetime import datetime, timezone +import logging +from typing import Any, Generator + +import requests + +from pyairtable import Api as AirtableApi + +from common.data_source.config import AIRTABLE_CONNECTOR_SIZE_THRESHOLD, INDEX_BATCH_SIZE, DocumentSource +from common.data_source.exceptions import ConnectorMissingCredentialError +from common.data_source.interfaces import LoadConnector, PollConnector +from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch +from common.data_source.utils import extract_size_bytes, get_file_ext + +class AirtableClientNotSetUpError(PermissionError): + def __init__(self) -> None: + super().__init__( + "Airtable client is not set up. Did you forget to call load_credentials()?" + ) + + +class AirtableConnector(LoadConnector, PollConnector): + """ + Lightweight Airtable connector. + + This connector ingests Airtable attachments as raw blobs without + parsing file content or generating text/image sections. + """ + + def __init__( + self, + base_id: str, + table_name_or_id: str, + batch_size: int = INDEX_BATCH_SIZE, + ) -> None: + self.base_id = base_id + self.table_name_or_id = table_name_or_id + self.batch_size = batch_size + self._airtable_client: AirtableApi | None = None + self.size_threshold = AIRTABLE_CONNECTOR_SIZE_THRESHOLD + + # ------------------------- + # Credentials + # ------------------------- + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self._airtable_client = AirtableApi(credentials["airtable_access_token"]) + return None + + @property + def airtable_client(self) -> AirtableApi: + if not self._airtable_client: + raise AirtableClientNotSetUpError() + return self._airtable_client + + # ------------------------- + # Core logic + # ------------------------- + def load_from_state(self) -> GenerateDocumentsOutput: + """ + Fetch all Airtable records and ingest attachments as raw blobs. + + Each attachment is converted into a single Document(blob=...). + """ + if not self._airtable_client: + raise ConnectorMissingCredentialError("Airtable credentials not loaded") + + table = self.airtable_client.table(self.base_id, self.table_name_or_id) + records = table.all() + + logging.info( + f"Starting Airtable blob ingestion for table {self.table_name_or_id}, " + f"{len(records)} records found." + ) + + batch: list[Document] = [] + + for record in records: + print(record) + record_id = record.get("id") + fields = record.get("fields", {}) + created_time = record.get("createdTime") + + for field_value in fields.values(): + # We only care about attachment fields (lists of dicts with url/filename) + if not isinstance(field_value, list): + continue + + for attachment in field_value: + url = attachment.get("url") + filename = attachment.get("filename") + attachment_id = attachment.get("id") + + if not url or not filename or not attachment_id: + continue + + try: + resp = requests.get(url, timeout=30) + resp.raise_for_status() + content = resp.content + except Exception: + logging.exception( + f"Failed to download attachment {filename} " + f"(record={record_id})" + ) + continue + size_bytes = extract_size_bytes(attachment) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + logging.warning( + f"{filename} exceeds size threshold of {self.size_threshold}. Skipping." + ) + continue + batch.append( + Document( + id=f"airtable:{record_id}:{attachment_id}", + blob=content, + source=DocumentSource.AIRTABLE, + semantic_identifier=filename, + extension=get_file_ext(filename), + size_bytes=size_bytes if size_bytes else 0, + doc_updated_at=datetime.strptime(created_time, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc) + ) + ) + + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Generator[list[Document], None, None]: + """Poll source to get documents""" + start_dt = datetime.fromtimestamp(start, tz=timezone.utc) + end_dt = datetime.fromtimestamp(end, tz=timezone.utc) + + for batch in self.load_from_state(): + filtered: list[Document] = [] + + for doc in batch: + if not doc.doc_updated_at: + continue + + doc_dt = doc.doc_updated_at.astimezone(timezone.utc) + + if start_dt <= doc_dt < end_dt: + filtered.append(doc) + + if filtered: + yield filtered + +if __name__ == "__main__": + import os + + logging.basicConfig(level=logging.DEBUG) + connector = AirtableConnector("xxx","xxx") + connector.load_credentials({"airtable_access_token": os.environ.get("AIRTABLE_ACCESS_TOKEN")}) + connector.validate_connector_settings() + document_batches = connector.load_from_state() + try: + first_batch = next(document_batches) + print(f"Loaded {len(first_batch)} documents in first batch.") + for doc in first_batch: + print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)") + except StopIteration: + print("No documents available in Dropbox.") \ No newline at end of file diff --git a/common/data_source/asana_connector.py b/common/data_source/asana_connector.py new file mode 100644 index 00000000000..1dddcb6df2b --- /dev/null +++ b/common/data_source/asana_connector.py @@ -0,0 +1,454 @@ +from collections.abc import Iterator +import time +from datetime import datetime +import logging +from typing import Any, Dict +import asana +import requests +from common.data_source.config import CONTINUE_ON_CONNECTOR_FAILURE, INDEX_BATCH_SIZE, DocumentSource +from common.data_source.interfaces import LoadConnector, PollConnector +from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch +from common.data_source.utils import extract_size_bytes, get_file_ext + + + +# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints +class AsanaTask: + def __init__( + self, + id: str, + title: str, + text: str, + link: str, + last_modified: datetime, + project_gid: str, + project_name: str, + ) -> None: + self.id = id + self.title = title + self.text = text + self.link = link + self.last_modified = last_modified + self.project_gid = project_gid + self.project_name = project_name + + def __str__(self) -> str: + return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}" + + +class AsanaAPI: + def __init__( + self, api_token: str, workspace_gid: str, team_gid: str | None + ) -> None: + self._user = None + self.workspace_gid = workspace_gid + self.team_gid = team_gid + + self.configuration = asana.Configuration() + self.api_client = asana.ApiClient(self.configuration) + self.tasks_api = asana.TasksApi(self.api_client) + self.attachments_api = asana.AttachmentsApi(self.api_client) + self.stories_api = asana.StoriesApi(self.api_client) + self.users_api = asana.UsersApi(self.api_client) + self.project_api = asana.ProjectsApi(self.api_client) + self.project_memberships_api = asana.ProjectMembershipsApi(self.api_client) + self.workspaces_api = asana.WorkspacesApi(self.api_client) + + self.api_error_count = 0 + self.configuration.access_token = api_token + self.task_count = 0 + + def get_tasks( + self, project_gids: list[str] | None, start_date: str + ) -> Iterator[AsanaTask]: + """Get all tasks from the projects with the given gids that were modified since the given date. + If project_gids is None, get all tasks from all projects in the workspace.""" + logging.info("Starting to fetch Asana projects") + projects = self.project_api.get_projects( + opts={ + "workspace": self.workspace_gid, + "opt_fields": "gid,name,archived,modified_at", + } + ) + start_seconds = int(time.mktime(datetime.now().timetuple())) + projects_list = [] + project_count = 0 + for project_info in projects: + project_gid = project_info["gid"] + if project_gids is None or project_gid in project_gids: + projects_list.append(project_gid) + else: + logging.debug( + f"Skipping project: {project_gid} - not in accepted project_gids" + ) + project_count += 1 + if project_count % 100 == 0: + logging.info(f"Processed {project_count} projects") + logging.info(f"Found {len(projects_list)} projects to process") + for project_gid in projects_list: + for task in self._get_tasks_for_project( + project_gid, start_date, start_seconds + ): + yield task + logging.info(f"Completed fetching {self.task_count} tasks from Asana") + if self.api_error_count > 0: + logging.warning( + f"Encountered {self.api_error_count} API errors during task fetching" + ) + + def _get_tasks_for_project( + self, project_gid: str, start_date: str, start_seconds: int + ) -> Iterator[AsanaTask]: + project = self.project_api.get_project(project_gid, opts={}) + project_name = project.get("name", project_gid) + team = project.get("team") or {} + team_gid = team.get("gid") + + if project.get("archived"): + logging.info(f"Skipping archived project: {project_name} ({project_gid})") + return + if not team_gid: + logging.info( + f"Skipping project without a team: {project_name} ({project_gid})" + ) + return + if project.get("privacy_setting") == "private": + if self.team_gid and team_gid != self.team_gid: + logging.info( + f"Skipping private project not in configured team: {project_name} ({project_gid})" + ) + return + logging.info( + f"Processing private project in configured team: {project_name} ({project_gid})" + ) + + simple_start_date = start_date.split(".")[0].split("+")[0] + logging.info( + f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})" + ) + + opts = { + "opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at," + "created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes," + "modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on," + "workspace,permalink_url", + "modified_since": start_date, + } + tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts) + for data in tasks_from_api: + self.task_count += 1 + if self.task_count % 10 == 0: + end_seconds = time.mktime(datetime.now().timetuple()) + runtime_seconds = end_seconds - start_seconds + if runtime_seconds > 0: + logging.info( + f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds " + f"({self.task_count / runtime_seconds:.2f} tasks/second)" + ) + + logging.debug(f"Processing Asana task: {data['name']}") + + text = self._construct_task_text(data) + + try: + text += self._fetch_and_add_comments(data["gid"]) + + last_modified_date = self.format_date(data["modified_at"]) + text += f"Last modified: {last_modified_date}\n" + + task = AsanaTask( + id=data["gid"], + title=data["name"], + text=text, + link=data["permalink_url"], + last_modified=datetime.fromisoformat(data["modified_at"]), + project_gid=project_gid, + project_name=project_name, + ) + yield task + except Exception: + logging.error( + f"Error processing task {data['gid']} in project {project_gid}", + exc_info=True, + ) + self.api_error_count += 1 + + def _construct_task_text(self, data: Dict) -> str: + text = f"{data['name']}\n\n" + + if data["notes"]: + text += f"{data['notes']}\n\n" + + if data["created_by"] and data["created_by"]["gid"]: + creator = self.get_user(data["created_by"]["gid"])["name"] + created_date = self.format_date(data["created_at"]) + text += f"Created by: {creator} on {created_date}\n" + + if data["due_on"]: + due_date = self.format_date(data["due_on"]) + text += f"Due date: {due_date}\n" + + if data["completed_at"]: + completed_date = self.format_date(data["completed_at"]) + text += f"Completed on: {completed_date}\n" + + text += "\n" + return text + + def _fetch_and_add_comments(self, task_gid: str) -> str: + text = "" + stories_opts: Dict[str, str] = {} + story_start = time.time() + stories = self.stories_api.get_stories_for_task(task_gid, stories_opts) + + story_count = 0 + comment_count = 0 + + for story in stories: + story_count += 1 + if story["resource_subtype"] == "comment_added": + comment = self.stories_api.get_story( + story["gid"], opts={"opt_fields": "text,created_by,created_at"} + ) + commenter = self.get_user(comment["created_by"]["gid"])["name"] + text += f"Comment by {commenter}: {comment['text']}\n\n" + comment_count += 1 + + story_duration = time.time() - story_start + logging.debug( + f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds" + ) + + return text + + def get_attachments(self, task_gid: str) -> list[dict]: + """ + Fetch full attachment info (including download_url) for a task. + """ + attachments: list[dict] = [] + + try: + # Step 1: list attachment compact records + for att in self.attachments_api.get_attachments_for_object( + parent=task_gid, + opts={} + ): + gid = att.get("gid") + if not gid: + continue + + try: + # Step 2: expand to full attachment + full = self.attachments_api.get_attachment( + attachment_gid=gid, + opts={ + "opt_fields": "name,download_url,size,created_at" + } + ) + + if full.get("download_url"): + attachments.append(full) + + except Exception: + logging.exception( + f"Failed to fetch attachment detail {gid} for task {task_gid}" + ) + self.api_error_count += 1 + + except Exception: + logging.exception(f"Failed to list attachments for task {task_gid}") + self.api_error_count += 1 + + return attachments + + def get_accessible_emails( + self, + workspace_id: str, + project_ids: list[str] | None, + team_id: str | None, + ): + + ws_users = self.users_api.get_users( + opts={ + "workspace": workspace_id, + "opt_fields": "gid,name,email" + } + ) + + workspace_users = { + u["gid"]: u.get("email") + for u in ws_users + if u.get("email") + } + + if not project_ids: + return set(workspace_users.values()) + + + project_emails = set() + + for pid in project_ids: + project = self.project_api.get_project( + pid, + opts={"opt_fields": "team,privacy_setting"} + ) + + if project["privacy_setting"] == "private": + if team_id and project.get("team", {}).get("gid") != team_id: + continue + + memberships = self.project_memberships_api.get_project_membership( + pid, + opts={"opt_fields": "user.gid,user.email"} + ) + + for m in memberships: + email = m["user"].get("email") + if email: + project_emails.add(email) + + return project_emails + + def get_user(self, user_gid: str) -> Dict: + if self._user is not None: + return self._user + self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"}) + + if not self._user: + logging.warning(f"Unable to fetch user information for user_gid: {user_gid}") + return {"name": "Unknown"} + return self._user + + def format_date(self, date_str: str) -> str: + date = datetime.fromisoformat(date_str) + return time.strftime("%Y-%m-%d", date.timetuple()) + + def get_time(self) -> str: + return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + + +class AsanaConnector(LoadConnector, PollConnector): + def __init__( + self, + asana_workspace_id: str, + asana_project_ids: str | None = None, + asana_team_id: str | None = None, + batch_size: int = INDEX_BATCH_SIZE, + continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, + ) -> None: + self.workspace_id = asana_workspace_id + self.project_ids_to_index: list[str] | None = ( + asana_project_ids.split(",") if asana_project_ids else None + ) + self.asana_team_id = asana_team_id if asana_team_id else None + self.batch_size = batch_size + self.continue_on_failure = continue_on_failure + self.size_threshold = None + logging.info( + f"AsanaConnector initialized with workspace_id: {asana_workspace_id}" + ) + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + self.api_token = credentials["asana_api_token_secret"] + self.asana_client = AsanaAPI( + api_token=self.api_token, + workspace_gid=self.workspace_id, + team_gid=self.asana_team_id, + ) + self.workspace_users_email = self.asana_client.get_accessible_emails(self.workspace_id, self.project_ids_to_index, self.asana_team_id) + logging.info("Asana credentials loaded and API client initialized") + return None + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None + ) -> GenerateDocumentsOutput: + start_time = datetime.fromtimestamp(start).isoformat() + logging.info(f"Starting Asana poll from {start_time}") + docs_batch: list[Document] = [] + tasks = self.asana_client.get_tasks(self.project_ids_to_index, start_time) + for task in tasks: + docs = self._task_to_documents(task) + docs_batch.extend(docs) + + if len(docs_batch) >= self.batch_size: + logging.info(f"Yielding batch of {len(docs_batch)} documents") + yield docs_batch + docs_batch = [] + + if docs_batch: + logging.info(f"Yielding final batch of {len(docs_batch)} documents") + yield docs_batch + + logging.info("Asana poll completed") + + def load_from_state(self) -> GenerateDocumentsOutput: + logging.info("Starting full index of all Asana tasks") + return self.poll_source(start=0, end=None) + + def _task_to_documents(self, task: AsanaTask) -> list[Document]: + docs: list[Document] = [] + + attachments = self.asana_client.get_attachments(task.id) + + for att in attachments: + try: + resp = requests.get(att["download_url"], timeout=30) + resp.raise_for_status() + file_blob = resp.content + filename = att.get("name", "attachment") + size_bytes = extract_size_bytes(att) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + logging.warning( + f"{filename} exceeds size threshold of {self.size_threshold}. Skipping." + ) + continue + docs.append( + Document( + id=f"asana:{task.id}:{att['gid']}", + blob=file_blob, + extension=get_file_ext(filename) or "", + size_bytes=size_bytes, + doc_updated_at=task.last_modified, + source=DocumentSource.ASANA, + semantic_identifier=filename, + primary_owners=list(self.workspace_users_email), + ) + ) + except Exception: + logging.exception( + f"Failed to download attachment {att.get('gid')} for task {task.id}" + ) + + return docs + + + +if __name__ == "__main__": + import time + import os + + logging.info("Starting Asana connector test") + connector = AsanaConnector( + os.environ["WORKSPACE_ID"], + os.environ["PROJECT_IDS"], + os.environ["TEAM_ID"], + ) + connector.load_credentials( + { + "asana_api_token_secret": os.environ["API_TOKEN"], + } + ) + logging.info("Loading all documents from Asana") + all_docs = connector.load_from_state() + current = time.time() + one_day_ago = current - 24 * 60 * 60 # 1 day + logging.info("Polling for documents updated in the last 24 hours") + latest_docs = connector.poll_source(one_day_ago, current) + for docs in all_docs: + for doc in docs: + print(doc.id) + logging.info("Asana connector test completed") \ No newline at end of file diff --git a/common/data_source/blob_connector.py b/common/data_source/blob_connector.py index 0bec7cbe643..1ab39189d79 100644 --- a/common/data_source/blob_connector.py +++ b/common/data_source/blob_connector.py @@ -56,7 +56,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None # Validate credentials if self.bucket_type == BlobType.R2: - if not all( + if not all( credentials.get(key) for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"] ): @@ -64,15 +64,23 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None elif self.bucket_type == BlobType.S3: authentication_method = credentials.get("authentication_method", "access_key") + if authentication_method == "access_key": if not all( credentials.get(key) for key in ["aws_access_key_id", "aws_secret_access_key"] ): raise ConnectorMissingCredentialError("Amazon S3") + elif authentication_method == "iam_role": if not credentials.get("aws_role_arn"): raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required") + + elif authentication_method == "assume_role": + pass + + else: + raise ConnectorMissingCredentialError("Unsupported S3 authentication method") elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE: if not all( @@ -87,6 +95,13 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None ): raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure") + elif self.bucket_type == BlobType.S3_COMPATIBLE: + if not all( + credentials.get(key) + for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key", "addressing_style"] + ): + raise ConnectorMissingCredentialError("S3 Compatible Storage") + else: raise ValueError(f"Unsupported bucket type: {self.bucket_type}") @@ -113,55 +128,72 @@ def _yield_blob_objects( paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix) - batch: list[Document] = [] + # Collect all objects first to count filename occurrences + all_objects = [] for page in pages: if "Contents" not in page: continue - for obj in page["Contents"]: if obj["Key"].endswith("/"): continue - last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) + if start < last_modified <= end: + all_objects.append(obj) + + # Count filename occurrences to determine which need full paths + filename_counts: dict[str, int] = {} + for obj in all_objects: + file_name = os.path.basename(obj["Key"]) + filename_counts[file_name] = filename_counts.get(file_name, 0) + 1 - if not (start < last_modified <= end): + batch: list[Document] = [] + for obj in all_objects: + last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) + file_name = os.path.basename(obj["Key"]) + key = obj["Key"] + + size_bytes = extract_size_bytes(obj) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + logging.warning( + f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." + ) + continue + + try: + blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold) + if blob is None: continue - file_name = os.path.basename(obj["Key"]) - key = obj["Key"] - - size_bytes = extract_size_bytes(obj) - if ( - self.size_threshold is not None - and isinstance(size_bytes, int) - and size_bytes > self.size_threshold - ): - logging.warning( - f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." + # Use full path only if filename appears multiple times + if filename_counts.get(file_name, 0) > 1: + relative_path = key + if self.prefix and key.startswith(self.prefix): + relative_path = key[len(self.prefix):] + semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name + else: + semantic_id = file_name + + batch.append( + Document( + id=f"{self.bucket_type}:{self.bucket_name}:{key}", + blob=blob, + source=DocumentSource(self.bucket_type.value), + semantic_identifier=semantic_id, + extension=get_file_ext(file_name), + doc_updated_at=last_modified, + size_bytes=size_bytes if size_bytes else 0 ) - continue - try: - blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold) - if blob is None: - continue - - batch.append( - Document( - id=f"{self.bucket_type}:{self.bucket_name}:{key}", - blob=blob, - source=DocumentSource(self.bucket_type.value), - semantic_identifier=file_name, - extension=get_file_ext(file_name), - doc_updated_at=last_modified, - size_bytes=size_bytes if size_bytes else 0 - ) - ) - if len(batch) == self.batch_size: - yield batch - batch = [] + ) + if len(batch) == self.batch_size: + yield batch + batch = [] - except Exception: - logging.exception(f"Error decoding object {key}") + except Exception: + logging.exception(f"Error decoding object {key}") if batch: yield batch @@ -269,4 +301,4 @@ def validate_connector_settings(self) -> None: except ConnectorMissingCredentialError as e: print(f"Error: {e}") except Exception as e: - print(f"An unexpected error occurred: {e}") \ No newline at end of file + print(f"An unexpected error occurred: {e}") diff --git a/common/data_source/box_connector.py b/common/data_source/box_connector.py new file mode 100644 index 00000000000..3006e709c9c --- /dev/null +++ b/common/data_source/box_connector.py @@ -0,0 +1,162 @@ +"""Box connector""" +import logging +from datetime import datetime, timezone +from typing import Any + +from box_sdk_gen import BoxClient +from common.data_source.config import DocumentSource, INDEX_BATCH_SIZE +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, +) +from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch +from common.data_source.models import Document, GenerateDocumentsOutput +from common.data_source.utils import get_file_ext + +class BoxConnector(LoadConnector, PollConnector): + def __init__(self, folder_id: str, batch_size: int = INDEX_BATCH_SIZE, use_marker: bool = True) -> None: + self.batch_size = batch_size + self.folder_id = "0" if not folder_id else folder_id + self.use_marker = use_marker + + + def load_credentials(self, auth: Any): + self.box_client = BoxClient(auth=auth) + return None + + + def validate_connector_settings(self): + if self.box_client is None: + raise ConnectorMissingCredentialError("Box") + + try: + self.box_client.users.get_user_me() + except Exception as e: + logging.exception("[Box]: Failed to validate Box credentials") + raise ConnectorValidationError(f"Unexpected error during Box settings validation: {e}") + + + def _yield_files_recursive( + self, + folder_id, + start: SecondsSinceUnixEpoch | None, + end: SecondsSinceUnixEpoch | None + ) -> GenerateDocumentsOutput: + + if self.box_client is None: + raise ConnectorMissingCredentialError("Box") + + result = self.box_client.folders.get_folder_items( + folder_id=folder_id, + limit=self.batch_size, + usemarker=self.use_marker + ) + + while True: + batch: list[Document] = [] + for entry in result.entries: + if entry.type == 'file' : + file = self.box_client.files.get_file_by_id( + entry.id + ) + raw_time = ( + getattr(file, "created_at", None) + or getattr(file, "content_created_at", None) + ) + + if raw_time: + modified_time = self._box_datetime_to_epoch_seconds(raw_time) + if start is not None and modified_time <= start: + continue + if end is not None and modified_time > end: + continue + + content_bytes = self.box_client.downloads.download_file(file.id) + + batch.append( + Document( + id=f"box:{file.id}", + blob=content_bytes.read(), + source=DocumentSource.BOX, + semantic_identifier=file.name, + extension=get_file_ext(file.name), + doc_updated_at=modified_time, + size_bytes=file.size, + metadata=file.metadata + ) + ) + elif entry.type == 'folder': + yield from self._yield_files_recursive(folder_id=entry.id, start=start, end=end) + + if batch: + yield batch + + if not result.next_marker: + break + + result = self.box_client.folders.get_folder_items( + folder_id=folder_id, + limit=self.batch_size, + marker=result.next_marker, + usemarker=True + ) + + + def _box_datetime_to_epoch_seconds(self, dt: datetime) -> SecondsSinceUnixEpoch: + """Convert a Box SDK datetime to Unix epoch seconds (UTC). + Only supports datetime; any non-datetime should be filtered out by caller. + """ + if not isinstance(dt, datetime): + raise TypeError(f"box_datetime_to_epoch_seconds expects datetime, got {type(dt)}") + + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: + dt = dt.astimezone(timezone.utc) + + return SecondsSinceUnixEpoch(int(dt.timestamp())) + + + def poll_source(self, start, end): + return self._yield_files_recursive(folder_id=self.folder_id, start=start, end=end) + + + def load_from_state(self): + return self._yield_files_recursive(folder_id=self.folder_id, start=None, end=None) + + +# from flask import Flask, request, redirect + +# from box_sdk_gen import BoxClient, BoxOAuth, OAuthConfig, GetAuthorizeUrlOptions + +# app = Flask(__name__) + +# AUTH = BoxOAuth( +# OAuthConfig(client_id="8suvn9ik7qezsq2dub0ye6ubox61081z", client_secret="QScvhLgBcZrb2ck1QP1ovkutpRhI2QcN") +# ) + + +# @app.route("/") +# def get_auth(): +# auth_url = AUTH.get_authorize_url( +# options=GetAuthorizeUrlOptions(redirect_uri="http://localhost:4999/oauth2callback") +# ) +# return redirect(auth_url, code=302) + + +# @app.route("/oauth2callback") +# def callback(): +# AUTH.get_tokens_authorization_code_grant(request.args.get("code")) +# box = BoxConnector() +# box.load_credentials({"auth": AUTH}) + +# lst = [] +# for file in box.load_from_state(): +# for f in file: +# lst.append(f.semantic_identifier) + +# return lst + +if __name__ == "__main__": + pass + # app.run(port=4999) \ No newline at end of file diff --git a/common/data_source/config.py b/common/data_source/config.py index 02684dbacc9..bca13b5bed6 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -13,6 +13,7 @@ def get_current_tz_offset() -> int: return round(time_diff.total_seconds() / 3600) +ONE_MINUTE = 60 ONE_HOUR = 3600 ONE_DAY = ONE_HOUR * 24 @@ -31,6 +32,7 @@ class BlobType(str, Enum): R2 = "r2" GOOGLE_CLOUD_STORAGE = "google_cloud_storage" OCI_STORAGE = "oci_storage" + S3_COMPATIBLE = "s3_compatible" class DocumentSource(str, Enum): @@ -42,11 +44,22 @@ class DocumentSource(str, Enum): OCI_STORAGE = "oci_storage" SLACK = "slack" CONFLUENCE = "confluence" + JIRA = "jira" GOOGLE_DRIVE = "google_drive" GMAIL = "gmail" DISCORD = "discord" - - + WEBDAV = "webdav" + MOODLE = "moodle" + S3_COMPATIBLE = "s3_compatible" + DROPBOX = "dropbox" + BOX = "box" + AIRTABLE = "airtable" + ASANA = "asana" + GITHUB = "github" + GITLAB = "gitlab" + IMAP = "imap" + + class FileOrigin(str, Enum): """File origins""" CONNECTOR = "connector" @@ -76,6 +89,7 @@ class FileOrigin(str, Enum): "space", "metadata.labels", "history.lastUpdated", + "ancestors", ] @@ -178,6 +192,21 @@ class FileOrigin(str, Enum): os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) ) +JIRA_CONNECTOR_LABELS_TO_SKIP = [ + ignored_tag + for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") + if ignored_tag +] +JIRA_CONNECTOR_MAX_TICKET_SIZE = int( + os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024) +) +JIRA_SYNC_TIME_BUFFER_SECONDS = int( + os.environ.get("JIRA_SYNC_TIME_BUFFER_SECONDS", ONE_MINUTE) +) +JIRA_TIMEZONE_OFFSET = float( + os.environ.get("JIRA_TIMEZONE_OFFSET", get_current_tz_offset()) +) + OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "") OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "") OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get( @@ -195,6 +224,7 @@ class FileOrigin(str, Enum): "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" ) GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback") +GMAIL_WEB_OAUTH_REDIRECT_URI = os.environ.get("GMAIL_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/gmail/oauth/web/callback") CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token" RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() @@ -204,6 +234,9 @@ class FileOrigin(str, Enum): _PROBLEMATIC_EXPANSIONS = "body.storage.value" _REPLACEMENT_EXPANSIONS = "body.view.value" +BOX_WEB_OAUTH_REDIRECT_URI = os.environ.get("BOX_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/box/oauth/web/callback") + +GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None class HtmlBasedConnectorTransformLinksStrategy(str, Enum): # remove links entirely @@ -226,6 +259,18 @@ class HtmlBasedConnectorTransformLinksStrategy(str, Enum): "WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside" ).split(",") +AIRTABLE_CONNECTOR_SIZE_THRESHOLD = int( + os.environ.get("AIRTABLE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) +) + +ASANA_CONNECTOR_SIZE_THRESHOLD = int( + os.environ.get("ASANA_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) +) + +IMAP_CONNECTOR_SIZE_THRESHOLD = int( + os.environ.get("IMAP_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) +) + _USER_NOT_FOUND = "Unknown Confluence User" _COMMENT_EXPANSION_FIELDS = ["body.storage.value"] diff --git a/common/data_source/confluence_connector.py b/common/data_source/confluence_connector.py index aed16ad2b66..d2494c3de74 100644 --- a/common/data_source/confluence_connector.py +++ b/common/data_source/confluence_connector.py @@ -126,7 +126,7 @@ def __init__( def _renew_credentials(self) -> tuple[dict[str, Any], bool]: """credential_json - the current json credentials Returns a tuple - 1. The up to date credentials + 1. The up-to-date credentials 2. True if the credentials were updated This method is intended to be used within a distributed lock. @@ -179,14 +179,14 @@ def _renew_credentials(self) -> tuple[dict[str, Any], bool]: credential_json["confluence_refresh_token"], ) - # store the new credentials to redis and to the db thru the provider - # redis: we use a 5 min TTL because we are given a 10 minute grace period + # store the new credentials to redis and to the db through the provider + # redis: we use a 5 min TTL because we are given a 10 minutes grace period # when keys are rotated. it's easier to expire the cached credentials # reasonably frequently rather than trying to handle strong synchronization # between the db and redis everywhere the credentials might be updated new_credential_str = json.dumps(new_credentials) self.redis_client.set( - self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL + self.credential_key, new_credential_str, exp=self.CREDENTIAL_TTL ) self._credentials_provider.set_credentials(new_credentials) @@ -690,7 +690,7 @@ def cql_paginate_all_expansions( ) -> Iterator[dict[str, Any]]: """ This function will paginate through the top level query first, then - paginate through all of the expansions. + paginate through all the expansions. """ def _traverse_and_update(data: dict | list) -> None: @@ -717,7 +717,7 @@ def paginated_cql_user_retrieval( """ The search/user endpoint can be used to fetch users. It's a separate endpoint from the content/search endpoint used only for users. - Otherwise it's very similar to the content/search endpoint. + It's very similar to the content/search endpoint. """ # this is needed since there is a live bug with Confluence Server/Data Center @@ -863,7 +863,7 @@ def get_user_email_from_username__server( # For now, we'll just return None and log a warning. This means # we will keep retrying to get the email every group sync. email = None - # We may want to just return a string that indicates failure so we dont + # We may want to just return a string that indicates failure so we don't # keep retrying # email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}" _USER_EMAIL_CACHE[user_name] = email @@ -912,7 +912,7 @@ def extract_text_from_confluence_html( confluence_object: dict[str, Any], fetched_titles: set[str], ) -> str: - """Parse a Confluence html page and replace the 'user Id' by the real + """Parse a Confluence html page and replace the 'user id' by the real User Display Name Args: @@ -1110,7 +1110,10 @@ def _make_attachment_link( ) -> str | None: download_link = "" - if "api.atlassian.com" in confluence_client.url: + from urllib.parse import urlparse + netloc =urlparse(confluence_client.url).hostname + if netloc == "api.atlassian.com" or (netloc and netloc.endswith(".api.atlassian.com")): + # if "api.atlassian.com" in confluence_client.url: # https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get if not parent_content_id: logging.warning( @@ -1308,6 +1311,9 @@ def __init__( self._low_timeout_confluence_client: OnyxConfluence | None = None self._fetched_titles: set[str] = set() self.allow_images = False + # Track document names to detect duplicates + self._document_name_counts: dict[str, int] = {} + self._document_name_paths: dict[str, list[str]] = {} # Remove trailing slash from wiki_base if present self.wiki_base = wiki_base.rstrip("/") @@ -1510,6 +1516,40 @@ def _convert_page_to_document( self.wiki_base, page["_links"]["webui"], self.is_cloud ) + # Build hierarchical path for semantic identifier + space_name = page.get("space", {}).get("name", "") + + # Build path from ancestors + path_parts = [] + if space_name: + path_parts.append(space_name) + + # Add ancestor pages to path if available + if "ancestors" in page and page["ancestors"]: + for ancestor in page["ancestors"]: + ancestor_title = ancestor.get("title", "") + if ancestor_title: + path_parts.append(ancestor_title) + + # Add current page title + path_parts.append(page_title) + + # Track page names for duplicate detection + full_path = " / ".join(path_parts) if len(path_parts) > 1 else page_title + + # Count occurrences of this page title + if page_title not in self._document_name_counts: + self._document_name_counts[page_title] = 0 + self._document_name_paths[page_title] = [] + self._document_name_counts[page_title] += 1 + self._document_name_paths[page_title].append(full_path) + + # Use simple name if no duplicates, otherwise use full path + if self._document_name_counts[page_title] == 1: + semantic_identifier = page_title + else: + semantic_identifier = full_path + # Get the page content page_content = extract_text_from_confluence_html( self.confluence_client, page, self._fetched_titles @@ -1556,12 +1596,13 @@ def _convert_page_to_document( return Document( id=page_url, source=DocumentSource.CONFLUENCE, - semantic_identifier=page_title, + semantic_identifier=semantic_identifier, extension=".html", # Confluence pages are HTML blob=page_content.encode("utf-8"), # Encode page content as bytes - size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes doc_updated_at=datetime_from_string(page["version"]["when"]), + size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes primary_owners=primary_owners if primary_owners else None, + metadata=metadata if metadata else None, ) except Exception as e: logging.error(f"Error converting page {page.get('id', 'unknown')}: {e}") @@ -1597,7 +1638,6 @@ def _fetch_page_attachments( expand=",".join(_ATTACHMENT_EXPANSION_FIELDS), ): media_type: str = attachment.get("metadata", {}).get("mediaType", "") - # TODO(rkuo): this check is partially redundant with validate_attachment_filetype # and checks in convert_attachment_to_content/process_attachment # but doing the check here avoids an unnecessary download. Due for refactoring. @@ -1665,6 +1705,34 @@ def _fetch_page_attachments( self.wiki_base, attachment["_links"]["webui"], self.is_cloud ) + # Build semantic identifier with space and page context + attachment_title = attachment.get("title", object_url) + space_name = page.get("space", {}).get("name", "") + page_title = page.get("title", "") + + # Create hierarchical name: Space / Page / Attachment + attachment_path_parts = [] + if space_name: + attachment_path_parts.append(space_name) + if page_title: + attachment_path_parts.append(page_title) + attachment_path_parts.append(attachment_title) + + full_attachment_path = " / ".join(attachment_path_parts) if len(attachment_path_parts) > 1 else attachment_title + + # Track attachment names for duplicate detection + if attachment_title not in self._document_name_counts: + self._document_name_counts[attachment_title] = 0 + self._document_name_paths[attachment_title] = [] + self._document_name_counts[attachment_title] += 1 + self._document_name_paths[attachment_title].append(full_attachment_path) + + # Use simple name if no duplicates, otherwise use full path + if self._document_name_counts[attachment_title] == 1: + attachment_semantic_identifier = attachment_title + else: + attachment_semantic_identifier = full_attachment_path + primary_owners: list[BasicExpertInfo] | None = None if "version" in attachment and "by" in attachment["version"]: author = attachment["version"]["by"] @@ -1676,11 +1744,12 @@ def _fetch_page_attachments( extension = Path(attachment.get("title", "")).suffix or ".unknown" + attachment_doc = Document( id=attachment_id, # sections=sections, source=DocumentSource.CONFLUENCE, - semantic_identifier=attachment.get("title", object_url), + semantic_identifier=attachment_semantic_identifier, extension=extension, blob=file_blob, size_bytes=len(file_blob), @@ -1737,7 +1806,7 @@ def _fetch_document_batches( start_ts, end, self.batch_size ) logging.debug(f"page_query_url: {page_query_url}") - + # store the next page start for confluence server, cursor for confluence cloud def store_next_page_url(next_page_url: str) -> None: checkpoint.next_page_url = next_page_url @@ -1788,6 +1857,7 @@ def _build_page_retrieval_url( cql_url = self.confluence_client.build_cql_url( page_query, expand=",".join(_PAGE_EXPANSION_FIELDS) ) + logging.info(f"[Confluence Connector] Building CQL URL {cql_url}") return update_param_in_path(cql_url, "limit", str(limit)) @override diff --git a/common/data_source/connector_runner.py b/common/data_source/connector_runner.py new file mode 100644 index 00000000000..d47d6512842 --- /dev/null +++ b/common/data_source/connector_runner.py @@ -0,0 +1,217 @@ +import sys +import time +import logging +from collections.abc import Generator +from datetime import datetime +from typing import Generic +from typing import TypeVar +from common.data_source.interfaces import ( + BaseConnector, + CheckpointedConnector, + CheckpointedConnectorWithPermSync, + CheckpointOutput, + LoadConnector, + PollConnector, +) +from common.data_source.models import ConnectorCheckpoint, ConnectorFailure, Document + + +TimeRange = tuple[datetime, datetime] + +CT = TypeVar("CT", bound=ConnectorCheckpoint) + + +def batched_doc_ids( + checkpoint_connector_generator: CheckpointOutput[CT], + batch_size: int, +) -> Generator[set[str], None, None]: + batch: set[str] = set() + for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()( + checkpoint_connector_generator + ): + if document is not None: + batch.add(document.id) + elif ( + failure and failure.failed_document and failure.failed_document.document_id + ): + batch.add(failure.failed_document.document_id) + + if len(batch) >= batch_size: + yield batch + batch = set() + if len(batch) > 0: + yield batch + + +class CheckpointOutputWrapper(Generic[CT]): + """ + Wraps a CheckpointOutput generator to give things back in a more digestible format, + specifically for Document outputs. + The connector format is easier for the connector implementor (e.g. it enforces exactly + one new checkpoint is returned AND that the checkpoint is at the end), thus the different + formats. + """ + + def __init__(self) -> None: + self.next_checkpoint: CT | None = None + + def __call__( + self, + checkpoint_connector_generator: CheckpointOutput[CT], + ) -> Generator[ + tuple[Document | None, ConnectorFailure | None, CT | None], + None, + None, + ]: + # grabs the final return value and stores it in the `next_checkpoint` variable + def _inner_wrapper( + checkpoint_connector_generator: CheckpointOutput[CT], + ) -> CheckpointOutput[CT]: + self.next_checkpoint = yield from checkpoint_connector_generator + return self.next_checkpoint # not used + + for document_or_failure in _inner_wrapper(checkpoint_connector_generator): + if isinstance(document_or_failure, Document): + yield document_or_failure, None, None + elif isinstance(document_or_failure, ConnectorFailure): + yield None, document_or_failure, None + else: + raise ValueError( + f"Invalid document_or_failure type: {type(document_or_failure)}" + ) + + if self.next_checkpoint is None: + raise RuntimeError( + "Checkpoint is None. This should never happen - the connector should always return a checkpoint." + ) + + yield None, None, self.next_checkpoint + + +class ConnectorRunner(Generic[CT]): + """ + Handles: + - Batching + - Additional exception logging + - Combining different connector types to a single interface + """ + + def __init__( + self, + connector: BaseConnector, + batch_size: int, + # cannot be True for non-checkpointed connectors + include_permissions: bool, + time_range: TimeRange | None = None, + ): + if not isinstance(connector, CheckpointedConnector) and include_permissions: + raise ValueError( + "include_permissions cannot be True for non-checkpointed connectors" + ) + + self.connector = connector + self.time_range = time_range + self.batch_size = batch_size + self.include_permissions = include_permissions + + self.doc_batch: list[Document] = [] + + def run(self, checkpoint: CT) -> Generator[ + tuple[list[Document] | None, ConnectorFailure | None, CT | None], + None, + None, + ]: + """Adds additional exception logging to the connector.""" + try: + if isinstance(self.connector, CheckpointedConnector): + if self.time_range is None: + raise ValueError("time_range is required for CheckpointedConnector") + + start = time.monotonic() + if self.include_permissions: + if not isinstance( + self.connector, CheckpointedConnectorWithPermSync + ): + raise ValueError( + "Connector does not support permission syncing" + ) + load_from_checkpoint = ( + self.connector.load_from_checkpoint_with_perm_sync + ) + else: + load_from_checkpoint = self.connector.load_from_checkpoint + checkpoint_connector_generator = load_from_checkpoint( + start=self.time_range[0].timestamp(), + end=self.time_range[1].timestamp(), + checkpoint=checkpoint, + ) + next_checkpoint: CT | None = None + # this is guaranteed to always run at least once with next_checkpoint being non-None + for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()( + checkpoint_connector_generator + ): + if document is not None and isinstance(document, Document): + self.doc_batch.append(document) + + if failure is not None: + yield None, failure, None + + if len(self.doc_batch) >= self.batch_size: + yield self.doc_batch, None, None + self.doc_batch = [] + + # yield remaining documents + if len(self.doc_batch) > 0: + yield self.doc_batch, None, None + self.doc_batch = [] + + yield None, None, next_checkpoint + + logging.debug( + f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint." + ) + + else: + finished_checkpoint = self.connector.build_dummy_checkpoint() + finished_checkpoint.has_more = False + + if isinstance(self.connector, PollConnector): + if self.time_range is None: + raise ValueError("time_range is required for PollConnector") + + for document_batch in self.connector.poll_source( + start=self.time_range[0].timestamp(), + end=self.time_range[1].timestamp(), + ): + yield document_batch, None, None + + yield None, None, finished_checkpoint + elif isinstance(self.connector, LoadConnector): + for document_batch in self.connector.load_from_state(): + yield document_batch, None, None + + yield None, None, finished_checkpoint + else: + raise ValueError(f"Invalid connector. type: {type(self.connector)}") + except Exception: + exc_type, _, exc_traceback = sys.exc_info() + + # Traverse the traceback to find the last frame where the exception was raised + tb = exc_traceback + if tb is None: + logging.error("No traceback found for exception") + raise + + while tb.tb_next: + tb = tb.tb_next # Move to the next frame in the traceback + + # Get the local variables from the frame where the exception occurred + local_vars = tb.tb_frame.f_locals + local_vars_str = "\n".join( + f"{key}: {value}" for key, value in local_vars.items() + ) + logging.error( + f"Error in connector. type: {exc_type};\n" + f"local_vars below -> \n{local_vars_str[:1024]}" + ) + raise \ No newline at end of file diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index 93a0477b078..e65a6324185 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -33,7 +33,7 @@ def _convert_message_to_document( metadata: dict[str, str | list[str]] = {} semantic_substring = "" - # Only messages from TextChannels will make it here but we have to check for it anyways + # Only messages from TextChannels will make it here, but we have to check for it anyway if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name): metadata["Channel"] = channel_name semantic_substring += f" in Channel: #{channel_name}" @@ -65,6 +65,7 @@ def _convert_message_to_document( blob=message.content.encode("utf-8"), extension=".txt", size_bytes=len(message.content.encode("utf-8")), + metadata=metadata if metadata else None, ) @@ -175,7 +176,7 @@ def _manage_async_retrieval( # parse requested_start_date_string to datetime pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None - # Set start_time to the later of start and pull_date, or whichever is provided + # Set start_time to the most recent of start and pull_date, or whichever is provided start_time = max(filter(None, [start, pull_date])) if start or pull_date else None end_time: datetime | None = end @@ -232,8 +233,8 @@ class DiscordConnector(LoadConnector, PollConnector): def __init__( self, - server_ids: list[str] = [], - channel_names: list[str] = [], + server_ids: list[str] | None = None, + channel_names: list[str] | None = None, # YYYY-MM-DD start_date: str | None = None, batch_size: int = INDEX_BATCH_SIZE, diff --git a/common/data_source/dropbox_connector.py b/common/data_source/dropbox_connector.py index fd349baa111..0e7131d8f3b 100644 --- a/common/data_source/dropbox_connector.py +++ b/common/data_source/dropbox_connector.py @@ -1,13 +1,24 @@ """Dropbox connector""" +import logging +from datetime import timezone from typing import Any from dropbox import Dropbox from dropbox.exceptions import ApiError, AuthError +from dropbox.files import FileMetadata, FolderMetadata -from common.data_source.config import INDEX_BATCH_SIZE -from common.data_source.exceptions import ConnectorValidationError, InsufficientPermissionsError, ConnectorMissingCredentialError +from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, + InsufficientPermissionsError, +) from common.data_source.interfaces import LoadConnector, PollConnector, SecondsSinceUnixEpoch +from common.data_source.models import Document, GenerateDocumentsOutput +from common.data_source.utils import get_file_ext + +logger = logging.getLogger(__name__) class DropboxConnector(LoadConnector, PollConnector): @@ -19,29 +30,29 @@ def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Load Dropbox credentials""" - try: - access_token = credentials.get("dropbox_access_token") - if not access_token: - raise ConnectorMissingCredentialError("Dropbox access token is required") - - self.dropbox_client = Dropbox(access_token) - return None - except Exception as e: - raise ConnectorMissingCredentialError(f"Dropbox: {e}") + access_token = credentials.get("dropbox_access_token") + if not access_token: + raise ConnectorMissingCredentialError("Dropbox access token is required") + + self.dropbox_client = Dropbox(access_token) + return None def validate_connector_settings(self) -> None: """Validate Dropbox connector settings""" - if not self.dropbox_client: + if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox") - + try: - # Test connection by getting current account info - self.dropbox_client.users_get_current_account() - except (AuthError, ApiError) as e: - if "invalid_access_token" in str(e).lower(): - raise InsufficientPermissionsError("Invalid Dropbox access token") - else: - raise ConnectorValidationError(f"Dropbox validation error: {e}") + self.dropbox_client.files_list_folder(path="", limit=1) + except AuthError as e: + logger.exception("[Dropbox]: Failed to validate Dropbox credentials") + raise ConnectorValidationError(f"Dropbox credential is invalid: {e}") + except ApiError as e: + if e.error is not None and "insufficient_permissions" in str(e.error).lower(): + raise InsufficientPermissionsError("Your Dropbox token does not have sufficient permissions.") + raise ConnectorValidationError(f"Unexpected Dropbox error during validation: {e.user_message_text or e}") + except Exception as e: + raise ConnectorValidationError(f"Unexpected error during Dropbox settings validation: {e}") def _download_file(self, path: str) -> bytes: """Download a single file from Dropbox.""" @@ -54,26 +65,145 @@ def _get_shared_link(self, path: str) -> str: """Create a shared link for a file in Dropbox.""" if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox") - + try: - # Try to get existing shared links first shared_links = self.dropbox_client.sharing_list_shared_links(path=path) if shared_links.links: return shared_links.links[0].url + + link_metadata = self.dropbox_client.sharing_create_shared_link_with_settings(path) + return link_metadata.url + except ApiError as err: + logger.exception(f"[Dropbox]: Failed to create a shared link for {path}: {err}") + return "" + + def _yield_files_recursive( + self, + path: str, + start: SecondsSinceUnixEpoch | None, + end: SecondsSinceUnixEpoch | None, + ) -> GenerateDocumentsOutput: + """Yield files in batches from a specified Dropbox folder, including subfolders.""" + if self.dropbox_client is None: + raise ConnectorMissingCredentialError("Dropbox") + + # Collect all files first to count filename occurrences + all_files = [] + self._collect_files_recursive(path, start, end, all_files) + + # Count filename occurrences + filename_counts: dict[str, int] = {} + for entry, _ in all_files: + filename_counts[entry.name] = filename_counts.get(entry.name, 0) + 1 + + # Process files in batches + batch: list[Document] = [] + for entry, downloaded_file in all_files: + modified_time = entry.client_modified + if modified_time.tzinfo is None: + modified_time = modified_time.replace(tzinfo=timezone.utc) + else: + modified_time = modified_time.astimezone(timezone.utc) - # Create a new shared link - link_settings = self.dropbox_client.sharing_create_shared_link_with_settings(path) - return link_settings.url - except Exception: - # Fallback to basic link format - return f"https://www.dropbox.com/home{path}" - - def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any: + # Use full path only if filename appears multiple times + if filename_counts.get(entry.name, 0) > 1: + # Remove leading slash and replace slashes with ' / ' + relative_path = entry.path_display.lstrip('/') + semantic_id = relative_path.replace('/', ' / ') if relative_path else entry.name + else: + semantic_id = entry.name + + batch.append( + Document( + id=f"dropbox:{entry.id}", + blob=downloaded_file, + source=DocumentSource.DROPBOX, + semantic_identifier=semantic_id, + extension=get_file_ext(entry.name), + doc_updated_at=modified_time, + size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file), + ) + ) + + if len(batch) == self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + def _collect_files_recursive( + self, + path: str, + start: SecondsSinceUnixEpoch | None, + end: SecondsSinceUnixEpoch | None, + all_files: list, + ) -> None: + """Recursively collect all files matching time criteria.""" + if self.dropbox_client is None: + raise ConnectorMissingCredentialError("Dropbox") + + result = self.dropbox_client.files_list_folder( + path, + recursive=False, + include_non_downloadable_files=False, + ) + + while True: + for entry in result.entries: + if isinstance(entry, FileMetadata): + modified_time = entry.client_modified + if modified_time.tzinfo is None: + modified_time = modified_time.replace(tzinfo=timezone.utc) + else: + modified_time = modified_time.astimezone(timezone.utc) + + time_as_seconds = modified_time.timestamp() + if start is not None and time_as_seconds <= start: + continue + if end is not None and time_as_seconds > end: + continue + + try: + downloaded_file = self._download_file(entry.path_display) + all_files.append((entry, downloaded_file)) + except Exception: + logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}") + continue + + elif isinstance(entry, FolderMetadata): + self._collect_files_recursive(entry.path_lower, start, end, all_files) + + if not result.has_more: + break + + result = self.dropbox_client.files_list_folder_continue(result.cursor) + + def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> GenerateDocumentsOutput: """Poll Dropbox for recent file changes""" - # Simplified implementation - in production this would handle actual polling - return [] + if self.dropbox_client is None: + raise ConnectorMissingCredentialError("Dropbox") + + for batch in self._yield_files_recursive("", start, end): + yield batch - def load_from_state(self) -> Any: + def load_from_state(self) -> GenerateDocumentsOutput: """Load files from Dropbox state""" - # Simplified implementation - return [] \ No newline at end of file + return self._yield_files_recursive("", None, None) + + +if __name__ == "__main__": + import os + + logging.basicConfig(level=logging.DEBUG) + connector = DropboxConnector() + connector.load_credentials({"dropbox_access_token": os.environ.get("DROPBOX_ACCESS_TOKEN")}) + connector.validate_connector_settings() + document_batches = connector.load_from_state() + try: + first_batch = next(document_batches) + print(f"Loaded {len(first_batch)} documents in first batch.") + for doc in first_batch: + print(f"- {doc.semantic_identifier} ({doc.size_bytes} bytes)") + except StopIteration: + print("No documents available in Dropbox.") diff --git a/common/data_source/file_types.py b/common/data_source/file_types.py index bf7eafaaaba..be4d56d7b5b 100644 --- a/common/data_source/file_types.py +++ b/common/data_source/file_types.py @@ -18,6 +18,7 @@ class UploadMimeTypes: "text/plain", "text/markdown", "text/x-markdown", + "text/mdx", "text/x-config", "text/tab-separated-values", "application/json", diff --git a/web/src/pages/add-knowledge/components/knowledge-dataset/index.less b/common/data_source/github/__init__.py similarity index 100% rename from web/src/pages/add-knowledge/components/knowledge-dataset/index.less rename to common/data_source/github/__init__.py diff --git a/common/data_source/github/connector.py b/common/data_source/github/connector.py new file mode 100644 index 00000000000..2e6d5f2af93 --- /dev/null +++ b/common/data_source/github/connector.py @@ -0,0 +1,973 @@ +import copy +import logging +from collections.abc import Callable +from collections.abc import Generator +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from enum import Enum +from typing import Any +from typing import cast + +from github import Github, Auth +from github import RateLimitExceededException +from github import Repository +from github.GithubException import GithubException +from github.Issue import Issue +from github.NamedUser import NamedUser +from github.PaginatedList import PaginatedList +from github.PullRequest import PullRequest +from pydantic import BaseModel +from typing_extensions import override +from common.data_source.google_util.util import sanitize_filename +from common.data_source.config import DocumentSource, GITHUB_CONNECTOR_BASE_URL +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, + CredentialExpiredError, + InsufficientPermissionsError, + UnexpectedValidationError, +) +from common.data_source.interfaces import CheckpointedConnectorWithPermSyncGH, CheckpointOutput +from common.data_source.models import ( + ConnectorCheckpoint, + ConnectorFailure, + Document, + DocumentFailure, + ExternalAccess, + SecondsSinceUnixEpoch, +) +from common.data_source.connector_runner import ConnectorRunner +from .models import SerializedRepository +from .rate_limit_utils import sleep_after_rate_limit_exception +from .utils import deserialize_repository +from .utils import get_external_access_permission + +ITEMS_PER_PAGE = 100 +CURSOR_LOG_FREQUENCY = 50 + +_MAX_NUM_RATE_LIMIT_RETRIES = 5 + +ONE_DAY = timedelta(days=1) +SLIM_BATCH_SIZE = 100 +# Cases +# X (from start) standard run, no fallback to cursor-based pagination +# X (from start) standard run errors, fallback to cursor-based pagination +# X error in the middle of a page +# X no errors: run to completion +# X (from checkpoint) standard run, no fallback to cursor-based pagination +# X (from checkpoint) continue from cursor-based pagination +# - retrying +# - no retrying + +# things to check: +# checkpoint state on return +# checkpoint progress (no infinite loop) + + +class DocMetadata(BaseModel): + repo: str + + +def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str: + if "_PaginatedList__nextUrl" in pag_list.__dict__: + return "_PaginatedList__nextUrl" + for key in pag_list.__dict__: + if "__nextUrl" in key: + return key + for key in pag_list.__dict__: + if "nextUrl" in key: + return key + return "" + + +def get_nextUrl( + pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str +) -> str | None: + return getattr(pag_list, nextUrl_key) if nextUrl_key else None + + +def set_nextUrl( + pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str +) -> None: + if nextUrl_key: + setattr(pag_list, nextUrl_key, nextUrl) + elif nextUrl: + raise ValueError("Next URL key not found: " + str(pag_list.__dict__)) + + +def _paginate_until_error( + git_objs: Callable[[], PaginatedList[PullRequest | Issue]], + cursor_url: str | None, + prev_num_objs: int, + cursor_url_callback: Callable[[str | None, int], None], + retrying: bool = False, +) -> Generator[PullRequest | Issue, None, None]: + num_objs = prev_num_objs + pag_list = git_objs() + nextUrl_key = get_nextUrl_key(pag_list) + if cursor_url: + set_nextUrl(pag_list, nextUrl_key, cursor_url) + elif retrying: + # if we are retrying, we want to skip the objects retrieved + # over previous calls. Unfortunately, this WILL retrieve all + # pages before the one we are resuming from, so we really + # don't want this case to be hit often + logging.warning( + "Retrying from a previous cursor-based pagination call. " + "This will retrieve all pages before the one we are resuming from, " + "which may take a while and consume many API calls." + ) + pag_list = cast(PaginatedList[PullRequest | Issue], pag_list[prev_num_objs:]) + num_objs = 0 + + try: + # this for loop handles cursor-based pagination + for issue_or_pr in pag_list: + num_objs += 1 + yield issue_or_pr + # used to store the current cursor url in the checkpoint. This value + # is updated during iteration over pag_list. + cursor_url_callback(get_nextUrl(pag_list, nextUrl_key), num_objs) + + if num_objs % CURSOR_LOG_FREQUENCY == 0: + logging.info( + f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}" + ) + + except Exception as e: + logging.exception(f"Error during cursor-based pagination: {e}") + if num_objs - prev_num_objs > 0: + raise + + if get_nextUrl(pag_list, nextUrl_key) is not None and not retrying: + logging.info( + "Assuming that this error is due to cursor " + "expiration because no objects were retrieved. " + "Retrying from the first page." + ) + yield from _paginate_until_error( + git_objs, None, prev_num_objs, cursor_url_callback, retrying=True + ) + return + + # for no cursor url or if we reach this point after a retry, raise the error + raise + + +def _get_batch_rate_limited( + # We pass in a callable because we want git_objs to produce a fresh + # PaginatedList each time it's called to avoid using the same object for cursor-based pagination + # from a partial offset-based pagination call. + git_objs: Callable[[], PaginatedList], + page_num: int, + cursor_url: str | None, + prev_num_objs: int, + cursor_url_callback: Callable[[str | None, int], None], + github_client: Github, + attempt_num: int = 0, +) -> Generator[PullRequest | Issue, None, None]: + if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: + raise RuntimeError( + "Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github" + ) + try: + if cursor_url: + # when this is set, we are resuming from an earlier + # cursor-based pagination call. + yield from _paginate_until_error( + git_objs, cursor_url, prev_num_objs, cursor_url_callback + ) + return + objs = list(git_objs().get_page(page_num)) + # fetch all data here to disable lazy loading later + # this is needed to capture the rate limit exception here (if one occurs) + for obj in objs: + if hasattr(obj, "raw_data"): + getattr(obj, "raw_data") + yield from objs + except RateLimitExceededException: + sleep_after_rate_limit_exception(github_client) + yield from _get_batch_rate_limited( + git_objs, + page_num, + cursor_url, + prev_num_objs, + cursor_url_callback, + github_client, + attempt_num + 1, + ) + except GithubException as e: + if not ( + e.status == 422 + and ( + "cursor" in (e.message or "") + or "cursor" in (e.data or {}).get("message", "") + ) + ): + raise + # Fallback to a cursor-based pagination strategy + # This can happen for "large datasets," but there's no documentation + # On the error on the web as far as we can tell. + # Error message: + # "Pagination with the page parameter is not supported for large datasets, + # please use cursor based pagination (after/before)" + yield from _paginate_until_error( + git_objs, cursor_url, prev_num_objs, cursor_url_callback + ) + + +def _get_userinfo(user: NamedUser) -> dict[str, str]: + def _safe_get(attr_name: str) -> str | None: + try: + return cast(str | None, getattr(user, attr_name)) + except GithubException: + logging.debug(f"Error getting {attr_name} for user") + return None + + return { + k: v + for k, v in { + "login": _safe_get("login"), + "name": _safe_get("name"), + "email": _safe_get("email"), + }.items() + if v is not None + } + + +def _convert_pr_to_document( + pull_request: PullRequest, repo_external_access: ExternalAccess | None +) -> Document: + repo_name = pull_request.base.repo.full_name if pull_request.base else "" + doc_metadata = DocMetadata(repo=repo_name) + file_content_byte = pull_request.body.encode('utf-8') if pull_request.body else b"" + name = sanitize_filename(pull_request.title, "md") + + return Document( + id=pull_request.html_url, + blob= file_content_byte, + source=DocumentSource.GITHUB, + external_access=repo_external_access, + semantic_identifier=f"{pull_request.number}:{name}", + # updated_at is UTC time but is timezone unaware, explicitly add UTC + # as there is logic in indexing to prevent wrong timestamped docs + # due to local time discrepancies with UTC + doc_updated_at=( + pull_request.updated_at.replace(tzinfo=timezone.utc) + if pull_request.updated_at + else None + ), + extension=".md", + # this metadata is used in perm sync + size_bytes=len(file_content_byte) if file_content_byte else 0, + primary_owners=[], + doc_metadata=doc_metadata.model_dump(), + metadata={ + k: [str(vi) for vi in v] if isinstance(v, list) else str(v) + for k, v in { + "object_type": "PullRequest", + "id": pull_request.number, + "merged": pull_request.merged, + "state": pull_request.state, + "user": _get_userinfo(pull_request.user) if pull_request.user else None, + "assignees": [ + _get_userinfo(assignee) for assignee in pull_request.assignees + ], + "repo": ( + pull_request.base.repo.full_name if pull_request.base else None + ), + "num_commits": str(pull_request.commits), + "num_files_changed": str(pull_request.changed_files), + "labels": [label.name for label in pull_request.labels], + "created_at": ( + pull_request.created_at.replace(tzinfo=timezone.utc) + if pull_request.created_at + else None + ), + "updated_at": ( + pull_request.updated_at.replace(tzinfo=timezone.utc) + if pull_request.updated_at + else None + ), + "closed_at": ( + pull_request.closed_at.replace(tzinfo=timezone.utc) + if pull_request.closed_at + else None + ), + "merged_at": ( + pull_request.merged_at.replace(tzinfo=timezone.utc) + if pull_request.merged_at + else None + ), + "merged_by": ( + _get_userinfo(pull_request.merged_by) + if pull_request.merged_by + else None + ), + }.items() + if v is not None + }, + ) + + +def _fetch_issue_comments(issue: Issue) -> str: + comments = issue.get_comments() + return "\nComment: ".join(comment.body for comment in comments) + + +def _convert_issue_to_document( + issue: Issue, repo_external_access: ExternalAccess | None +) -> Document: + repo_name = issue.repository.full_name if issue.repository else "" + doc_metadata = DocMetadata(repo=repo_name) + file_content_byte = issue.body.encode('utf-8') if issue.body else b"" + name = sanitize_filename(issue.title, "md") + + return Document( + id=issue.html_url, + blob=file_content_byte, + source=DocumentSource.GITHUB, + extension=".md", + external_access=repo_external_access, + semantic_identifier=f"{issue.number}:{name}", + # updated_at is UTC time but is timezone unaware + doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc), + # this metadata is used in perm sync + doc_metadata=doc_metadata.model_dump(), + size_bytes=len(file_content_byte) if file_content_byte else 0, + primary_owners=[_get_userinfo(issue.user) if issue.user else None], + metadata={ + k: [str(vi) for vi in v] if isinstance(v, list) else str(v) + for k, v in { + "object_type": "Issue", + "id": issue.number, + "state": issue.state, + "user": _get_userinfo(issue.user) if issue.user else None, + "assignees": [_get_userinfo(assignee) for assignee in issue.assignees], + "repo": issue.repository.full_name if issue.repository else None, + "labels": [label.name for label in issue.labels], + "created_at": ( + issue.created_at.replace(tzinfo=timezone.utc) + if issue.created_at + else None + ), + "updated_at": ( + issue.updated_at.replace(tzinfo=timezone.utc) + if issue.updated_at + else None + ), + "closed_at": ( + issue.closed_at.replace(tzinfo=timezone.utc) + if issue.closed_at + else None + ), + "closed_by": ( + _get_userinfo(issue.closed_by) if issue.closed_by else None + ), + }.items() + if v is not None + }, + ) + + +class GithubConnectorStage(Enum): + START = "start" + PRS = "prs" + ISSUES = "issues" + + +class GithubConnectorCheckpoint(ConnectorCheckpoint): + stage: GithubConnectorStage + curr_page: int + + cached_repo_ids: list[int] | None = None + cached_repo: SerializedRepository | None = None + + # Used for the fallback cursor-based pagination strategy + num_retrieved: int + cursor_url: str | None = None + + def reset(self) -> None: + """ + Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None) + """ + self.curr_page = 0 + self.num_retrieved = 0 + self.cursor_url = None + + +def make_cursor_url_callback( + checkpoint: GithubConnectorCheckpoint, +) -> Callable[[str | None, int], None]: + def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None: + # we want to maintain the old cursor url so code after retrieval + # can determine that we are using the fallback cursor-based pagination strategy + if cursor_url: + checkpoint.cursor_url = cursor_url + checkpoint.num_retrieved = num_objs + + return cursor_url_callback + + +class GithubConnector(CheckpointedConnectorWithPermSyncGH[GithubConnectorCheckpoint]): + def __init__( + self, + repo_owner: str, + repositories: str | None = None, + state_filter: str = "all", + include_prs: bool = True, + include_issues: bool = False, + ) -> None: + self.repo_owner = repo_owner + self.repositories = repositories + self.state_filter = state_filter + self.include_prs = include_prs + self.include_issues = include_issues + self.github_client: Github | None = None + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + # defaults to 30 items per page, can be set to as high as 100 + token = credentials["github_access_token"] + auth = Auth.Token(token) + + if GITHUB_CONNECTOR_BASE_URL: + self.github_client = Github( + auth=auth, + base_url=GITHUB_CONNECTOR_BASE_URL, + per_page=ITEMS_PER_PAGE, + ) + else: + self.github_client = Github( + auth=auth, + per_page=ITEMS_PER_PAGE, + ) + + return None + + def get_github_repo( + self, github_client: Github, attempt_num: int = 0 + ) -> Repository.Repository: + if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: + raise RuntimeError( + "Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github" + ) + + try: + return github_client.get_repo(f"{self.repo_owner}/{self.repositories}") + except RateLimitExceededException: + sleep_after_rate_limit_exception(github_client) + return self.get_github_repo(github_client, attempt_num + 1) + + def get_github_repos( + self, github_client: Github, attempt_num: int = 0 + ) -> list[Repository.Repository]: + """Get specific repositories based on comma-separated repo_name string.""" + if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: + raise RuntimeError( + "Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github" + ) + + try: + repos = [] + # Split repo_name by comma and strip whitespace + repo_names = [ + name.strip() for name in (cast(str, self.repositories)).split(",") + ] + + for repo_name in repo_names: + if repo_name: # Skip empty strings + try: + repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}") + repos.append(repo) + except GithubException as e: + logging.warning( + f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}" + ) + + return repos + except RateLimitExceededException: + sleep_after_rate_limit_exception(github_client) + return self.get_github_repos(github_client, attempt_num + 1) + + def get_all_repos( + self, github_client: Github, attempt_num: int = 0 + ) -> list[Repository.Repository]: + if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: + raise RuntimeError( + "Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github" + ) + + try: + # Try to get organization first + try: + org = github_client.get_organization(self.repo_owner) + return list(org.get_repos()) + + except GithubException: + # If not an org, try as a user + user = github_client.get_user(self.repo_owner) + return list(user.get_repos()) + except RateLimitExceededException: + sleep_after_rate_limit_exception(github_client) + return self.get_all_repos(github_client, attempt_num + 1) + + def _pull_requests_func( + self, repo: Repository.Repository + ) -> Callable[[], PaginatedList[PullRequest]]: + return lambda: repo.get_pulls( + state=self.state_filter, sort="updated", direction="desc" + ) + + def _issues_func( + self, repo: Repository.Repository + ) -> Callable[[], PaginatedList[Issue]]: + return lambda: repo.get_issues( + state=self.state_filter, sort="updated", direction="desc" + ) + + def _fetch_from_github( + self, + checkpoint: GithubConnectorCheckpoint, + start: datetime | None = None, + end: datetime | None = None, + include_permissions: bool = False, + ) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]: + if self.github_client is None: + raise ConnectorMissingCredentialError("GitHub") + + checkpoint = copy.deepcopy(checkpoint) + + # First run of the connector, fetch all repos and store in checkpoint + if checkpoint.cached_repo_ids is None: + repos = [] + if self.repositories: + if "," in self.repositories: + # Multiple repositories specified + repos = self.get_github_repos(self.github_client) + else: + # Single repository (backward compatibility) + repos = [self.get_github_repo(self.github_client)] + else: + # All repositories + repos = self.get_all_repos(self.github_client) + if not repos: + checkpoint.has_more = False + return checkpoint + + curr_repo = repos.pop() + checkpoint.cached_repo_ids = [repo.id for repo in repos] + checkpoint.cached_repo = SerializedRepository( + id=curr_repo.id, + headers=curr_repo.raw_headers, + raw_data=curr_repo.raw_data, + ) + checkpoint.stage = GithubConnectorStage.PRS + checkpoint.curr_page = 0 + # save checkpoint with repo ids retrieved + return checkpoint + + if checkpoint.cached_repo is None: + raise ValueError("No repo saved in checkpoint") + + # Deserialize the repository from the checkpoint + repo = deserialize_repository(checkpoint.cached_repo, self.github_client) + + cursor_url_callback = make_cursor_url_callback(checkpoint) + repo_external_access: ExternalAccess | None = None + if include_permissions: + repo_external_access = get_external_access_permission( + repo, self.github_client + ) + if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS: + logging.info(f"Fetching PRs for repo: {repo.name}") + + pr_batch = _get_batch_rate_limited( + self._pull_requests_func(repo), + checkpoint.curr_page, + checkpoint.cursor_url, + checkpoint.num_retrieved, + cursor_url_callback, + self.github_client, + ) + checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback + done_with_prs = False + num_prs = 0 + pr = None + print("start: ", start) + for pr in pr_batch: + num_prs += 1 + print("-"*40) + print("PR name", pr.title) + print("updated at", pr.updated_at) + print("-"*40) + print("\n") + # we iterate backwards in time, so at this point we stop processing prs + if ( + start is not None + and pr.updated_at + and pr.updated_at.replace(tzinfo=timezone.utc) <= start + ): + done_with_prs = True + break + # Skip PRs updated after the end date + if ( + end is not None + and pr.updated_at + and pr.updated_at.replace(tzinfo=timezone.utc) > end + ): + continue + try: + yield _convert_pr_to_document( + cast(PullRequest, pr), repo_external_access + ) + except Exception as e: + error_msg = f"Error converting PR to document: {e}" + logging.exception(error_msg) + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=str(pr.id), document_link=pr.html_url + ), + failure_message=error_msg, + exception=e, + ) + continue + + # If we reach this point with a cursor url in the checkpoint, we were using + # the fallback cursor-based pagination strategy. That strategy tries to get all + # PRs, so having curosr_url set means we are done with prs. However, we need to + # return AFTER the checkpoint reset to avoid infinite loops. + + # if we found any PRs on the page and there are more PRs to get, return the checkpoint. + # In offset mode, while indexing without time constraints, the pr batch + # will be empty when we're done. + used_cursor = checkpoint.cursor_url is not None + if num_prs > 0 and not done_with_prs and not used_cursor: + return checkpoint + + # if we went past the start date during the loop or there are no more + # prs to get, we move on to issues + checkpoint.stage = GithubConnectorStage.ISSUES + checkpoint.reset() + + if used_cursor: + # save the checkpoint after changing stage; next run will continue from issues + return checkpoint + + checkpoint.stage = GithubConnectorStage.ISSUES + + if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES: + logging.info(f"Fetching issues for repo: {repo.name}") + + issue_batch = list( + _get_batch_rate_limited( + self._issues_func(repo), + checkpoint.curr_page, + checkpoint.cursor_url, + checkpoint.num_retrieved, + cursor_url_callback, + self.github_client, + ) + ) + checkpoint.curr_page += 1 + done_with_issues = False + num_issues = 0 + for issue in issue_batch: + num_issues += 1 + issue = cast(Issue, issue) + # we iterate backwards in time, so at this point we stop processing prs + if ( + start is not None + and issue.updated_at.replace(tzinfo=timezone.utc) <= start + ): + done_with_issues = True + break + # Skip PRs updated after the end date + if ( + end is not None + and issue.updated_at.replace(tzinfo=timezone.utc) > end + ): + continue + + if issue.pull_request is not None: + # PRs are handled separately + continue + + try: + yield _convert_issue_to_document(issue, repo_external_access) + except Exception as e: + error_msg = f"Error converting issue to document: {e}" + logging.exception(error_msg) + yield ConnectorFailure( + failed_document=DocumentFailure( + document_id=str(issue.id), + document_link=issue.html_url, + ), + failure_message=error_msg, + exception=e, + ) + continue + + # if we found any issues on the page, and we're not done, return the checkpoint. + # don't return if we're using cursor-based pagination to avoid infinite loops + if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url: + return checkpoint + + # if we went past the start date during the loop or there are no more + # issues to get, we move on to the next repo + checkpoint.stage = GithubConnectorStage.PRS + checkpoint.reset() + + checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0 + if checkpoint.cached_repo_ids: + next_id = checkpoint.cached_repo_ids.pop() + next_repo = self.github_client.get_repo(next_id) + checkpoint.cached_repo = SerializedRepository( + id=next_id, + headers=next_repo.raw_headers, + raw_data=next_repo.raw_data, + ) + checkpoint.stage = GithubConnectorStage.PRS + checkpoint.reset() + + if checkpoint.cached_repo_ids: + logging.info( + f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})" + ) + else: + logging.info("No more repos remaining") + + return checkpoint + + def _load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: GithubConnectorCheckpoint, + include_permissions: bool = False, + ) -> CheckpointOutput[GithubConnectorCheckpoint]: + start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) + # add a day for timezone safety + end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) + ONE_DAY + + # Move start time back by 3 hours, since some Issues/PRs are getting dropped + # Could be due to delayed processing on GitHub side + # The non-updated issues since last poll will be shortcut-ed and not embedded + # adjusted_start_datetime = start_datetime - timedelta(hours=3) + + adjusted_start_datetime = start_datetime + + epoch = datetime.fromtimestamp(0, tz=timezone.utc) + if adjusted_start_datetime < epoch: + adjusted_start_datetime = epoch + + return self._fetch_from_github( + checkpoint, + start=adjusted_start_datetime, + end=end_datetime, + include_permissions=include_permissions, + ) + + @override + def load_from_checkpoint( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: GithubConnectorCheckpoint, + ) -> CheckpointOutput[GithubConnectorCheckpoint]: + return self._load_from_checkpoint( + start, end, checkpoint, include_permissions=False + ) + + @override + def load_from_checkpoint_with_perm_sync( + self, + start: SecondsSinceUnixEpoch, + end: SecondsSinceUnixEpoch, + checkpoint: GithubConnectorCheckpoint, + ) -> CheckpointOutput[GithubConnectorCheckpoint]: + return self._load_from_checkpoint( + start, end, checkpoint, include_permissions=True + ) + + def validate_connector_settings(self) -> None: + if self.github_client is None: + raise ConnectorMissingCredentialError("GitHub credentials not loaded.") + + if not self.repo_owner: + raise ConnectorValidationError( + "Invalid connector settings: 'repo_owner' must be provided." + ) + + try: + if self.repositories: + if "," in self.repositories: + # Multiple repositories specified + repo_names = [name.strip() for name in self.repositories.split(",")] + if not repo_names: + raise ConnectorValidationError( + "Invalid connector settings: No valid repository names provided." + ) + + # Validate at least one repository exists and is accessible + valid_repos = False + validation_errors = [] + + for repo_name in repo_names: + if not repo_name: + continue + + try: + test_repo = self.github_client.get_repo( + f"{self.repo_owner}/{repo_name}" + ) + logging.info( + f"Successfully accessed repository: {self.repo_owner}/{repo_name}" + ) + test_repo.get_contents("") + valid_repos = True + # If at least one repo is valid, we can proceed + break + except GithubException as e: + validation_errors.append( + f"Repository '{repo_name}': {e.data.get('message', str(e))}" + ) + + if not valid_repos: + error_msg = ( + "None of the specified repositories could be accessed: " + ) + error_msg += ", ".join(validation_errors) + raise ConnectorValidationError(error_msg) + else: + # Single repository (backward compatibility) + test_repo = self.github_client.get_repo( + f"{self.repo_owner}/{self.repositories}" + ) + test_repo.get_contents("") + else: + # Try to get organization first + try: + org = self.github_client.get_organization(self.repo_owner) + total_count = org.get_repos().totalCount + if total_count == 0: + raise ConnectorValidationError( + f"Found no repos for organization: {self.repo_owner}. " + "Does the credential have the right scopes?" + ) + except GithubException as e: + # Check for missing SSO + MISSING_SSO_ERROR_MESSAGE = "You must grant your Personal Access token access to this organization".lower() + if MISSING_SSO_ERROR_MESSAGE in str(e).lower(): + SSO_GUIDE_LINK = ( + "https://docs.github.com/en/enterprise-cloud@latest/authentication/" + "authenticating-with-saml-single-sign-on/" + "authorizing-a-personal-access-token-for-use-with-saml-single-sign-on" + ) + raise ConnectorValidationError( + f"Your GitHub token is missing authorization to access the " + f"`{self.repo_owner}` organization. Please follow the guide to " + f"authorize your token: {SSO_GUIDE_LINK}" + ) + # If not an org, try as a user + user = self.github_client.get_user(self.repo_owner) + + # Check if we can access any repos + total_count = user.get_repos().totalCount + if total_count == 0: + raise ConnectorValidationError( + f"Found no repos for user: {self.repo_owner}. " + "Does the credential have the right scopes?" + ) + + except RateLimitExceededException: + raise UnexpectedValidationError( + "Validation failed due to GitHub rate-limits being exceeded. Please try again later." + ) + + except GithubException as e: + if e.status == 401: + raise CredentialExpiredError( + "GitHub credential appears to be invalid or expired (HTTP 401)." + ) + elif e.status == 403: + raise InsufficientPermissionsError( + "Your GitHub token does not have sufficient permissions for this repository (HTTP 403)." + ) + elif e.status == 404: + if self.repositories: + if "," in self.repositories: + raise ConnectorValidationError( + f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}" + ) + else: + raise ConnectorValidationError( + f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}" + ) + else: + raise ConnectorValidationError( + f"GitHub user or organization not found: {self.repo_owner}" + ) + else: + raise ConnectorValidationError( + f"Unexpected GitHub error (status={e.status}): {e.data}" + ) + + except Exception as exc: + raise Exception( + f"Unexpected error during GitHub settings validation: {exc}" + ) + + def validate_checkpoint_json( + self, checkpoint_json: str + ) -> GithubConnectorCheckpoint: + return GithubConnectorCheckpoint.model_validate_json(checkpoint_json) + + def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint: + return GithubConnectorCheckpoint( + stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0 + ) + + +if __name__ == "__main__": + # Initialize the connector + connector = GithubConnector( + repo_owner="EvoAgentX", + repositories="EvoAgentX", + include_issues=True, + include_prs=False, + ) + connector.load_credentials( + {"github_access_token": "