diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py index 34cd07110..c687df6d6 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py @@ -16,6 +16,7 @@ # under the License. import argparse +import os import gradio as gr import uvicorn @@ -202,5 +203,5 @@ def create_app(): host=args.host, port=args.port, factory=True, - reload=True, + reload=os.getenv("HG_DEV_RELOAD") == "1", ) diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py index 195ea6f6f..bb0abccfe 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py @@ -75,20 +75,41 @@ def get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[L all_embeddings = [] for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] - response = self.client.embed(model=self.model, input=batch)["embeddings"] - all_embeddings.extend([list(inner_sequence) for inner_sequence in response]) + response = self.client.embed(model=self.model, input=batch) + all_embeddings.extend(self._get_embeddings_from_response(response)) return all_embeddings + def _get_embeddings_from_response(self, response) -> List[List[float]]: + if "embeddings" not in response: + raise ValueError("Ollama embedding response missing 'embeddings'.") + embeddings = response["embeddings"] + if not embeddings: + raise ValueError("Ollama embedding response returned no embeddings.") + return [list(inner_sequence) for inner_sequence in embeddings] + async def async_get_text_embedding(self, text: str) -> List[float]: """Get embedding for a single text asynchronously.""" - response = await self.async_client.embeddings(model=self.model, prompt=text) - return list(response["embedding"]) + if not hasattr(self.async_client, "embed"): + error_message = ( + "The required 'embed' method was not found on the Ollama async client. " + "Please ensure your ollama library is up-to-date and supports batch embedding. " + ) + raise AttributeError(error_message) + + response = await self.async_client.embed(model=self.model, input=[text]) + return self._get_embeddings_from_response(response)[0] async def async_get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: - # Ollama python client may not provide batch async embeddings; fallback per item - # batch_size parameter included for consistency with base class signature + if not hasattr(self.async_client, "embed"): + error_message = ( + "The required 'embed' method was not found on the Ollama async client. " + "Please ensure your ollama library is up-to-date and supports batch embedding. " + ) + raise AttributeError(error_message) + results: List[List[float]] = [] - for t in texts: - response = await self.async_client.embeddings(model=self.model, prompt=t) - results.append(list(response["embedding"])) + for i in range(0, len(texts), batch_size): + batch = texts[i : i + batch_size] + response = await self.async_client.embed(model=self.model, input=batch) + results.extend(self._get_embeddings_from_response(response)) return results diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py index dcf479c2f..60ad481ff 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py @@ -51,7 +51,8 @@ def __init__( @retry( stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=2, max=5), - retry=retry_if_exception_type((RateLimitError, BudgetExceededError, APIError)), + retry=retry_if_exception_type((RateLimitError, APIError)), + reraise=True, ) def generate( self, @@ -75,12 +76,13 @@ def generate( return response.choices[0].message.content except (RateLimitError, BudgetExceededError, APIError) as e: log.error("Error in LiteLLM call: %s", e) - return f"Error: {str(e)}" + raise @retry( stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=2, max=5), - retry=retry_if_exception_type((RateLimitError, BudgetExceededError, APIError)), + retry=retry_if_exception_type((RateLimitError, APIError)), + reraise=True, ) async def agenerate( self, @@ -104,7 +106,7 @@ async def agenerate( return response.choices[0].message.content except (RateLimitError, BudgetExceededError, APIError) as e: log.error("Error in async LiteLLM call: %s", e) - return f"Error: {str(e)}" + raise def generate_streaming( self, @@ -138,7 +140,7 @@ def generate_streaming( return result except (RateLimitError, BudgetExceededError, APIError) as e: log.error("Error in streaming LiteLLM call: %s", e) - return f"Error: {str(e)}" + raise async def agenerate_streaming( self, @@ -170,7 +172,7 @@ async def agenerate_streaming( yield chunk.choices[0].delta.content except (RateLimitError, BudgetExceededError, APIError) as e: log.error("Error in async streaming LiteLLM call: %s", e) - yield f"Error: {str(e)}" + raise def num_tokens_from_string(self, string: str) -> int: """Get token count from string.""" diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py index f7a6d3f9c..3370d47d0 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py @@ -70,7 +70,10 @@ def generate( max_tokens=self.max_tokens, messages=messages, ) - log.info("Token usage: %s", completions.usage.model_dump_json()) + if not completions.choices: + raise RuntimeError(f"Empty choices in LLM response: {str(completions)[:200]}") + if completions.usage: + log.info("Token usage: %s", completions.usage.model_dump_json()) return completions.choices[0].message.content # catch context length / do not retry except openai.BadRequestError as e: @@ -105,7 +108,10 @@ async def agenerate( max_tokens=self.max_tokens, messages=messages, ) - log.info("Token usage: %s", completions.usage.model_dump_json()) + if not completions.choices: + raise RuntimeError(f"Empty choices in LLM response: {str(completions)[:200]}") + if completions.usage: + log.info("Token usage: %s", completions.usage.model_dump_json()) return completions.choices[0].message.content # catch context length / do not retry except openai.BadRequestError as e: diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py index 7ed8af8a7..4da43e902 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import traceback from typing import Dict, Optional from pycgraph import CStatus, GNode @@ -22,6 +23,11 @@ from hugegraph_llm.utils.log import log +def _format_node_err(node, exc: Exception, prefix: str = "Node failed") -> str: + node_info = f"Node type: {type(node).__name__}, Node object: {node}" + return f"{prefix}: {exc}\n{node_info}\n{traceback.format_exc()}" + + class BaseNode(GNode): """ Base class for workflow nodes, providing context management and operation scheduling. @@ -70,10 +76,8 @@ def run(self): try: res = self.operator_schedule(data_json) except (ValueError, TypeError, KeyError, NotImplementedError) as exc: - import traceback - - node_info = f"Node type: {type(self).__name__}, Node object: {self}" - err_msg = f"Node failed: {exc}\n{node_info}\n{traceback.format_exc()}" + err_msg = _format_node_err(self, exc) + log.error(err_msg) return CStatus(-1, err_msg) # For unexpected exceptions, re-raise to let them propagate or be caught elsewhere diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py index 01b2ca64d..e40d69c90 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py @@ -81,8 +81,7 @@ def node_init(self): def operator_schedule(self, data_json): try: schema_result = self.schema_builder.run(data_json) - return {"schema": schema_result} except (ValueError, RuntimeError) as e: log.error("Failed to generate schema: %s", e) - return {"schema": f"Schema generation failed: {e}"} + raise ValueError(f"Failed to generate schema: {e}") from e diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py index c3f427e93..8916356ae 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/fetch_graph_data.py @@ -44,8 +44,16 @@ def res = [:]; return res; """ - result = self.graph.gremlin().exec(groovy_code)["data"] - + response = self.graph.gremlin().exec(groovy_code) + result = response.get("data") if isinstance(response, dict) else None if isinstance(result, list) and len(result) > 0: - graph_summary.update({key: result[i].get(key) for i, key in enumerate(keys)}) + if len(result) == 1 and isinstance(result[0], dict): + graph_summary.update({key: result[0].get(key) for key in keys}) + else: + graph_summary.update( + { + key: result[i].get(key) if i < len(result) and isinstance(result[i], dict) else None + for i, key in enumerate(keys) + } + ) return graph_summary diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py index c265646fa..c51cb9587 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py @@ -17,6 +17,7 @@ from typing import Any, Dict, Optional from pyhugegraph.client import PyHugeClient +from requests.exceptions import RequestException from hugegraph_llm.config import huge_settings @@ -57,9 +58,12 @@ def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]: if context is None: context = {} - schema = self.schema.getSchema() + try: + schema = self.schema.getSchema() + except RequestException as e: + raise ValueError(f"Failed to connect to HugeGraph to get schema '{self.graph_name}': {e}") from e if not schema["vertexlabels"] and not schema["edgelabels"]: - raise Exception(f"Can not get {self.graph_name}'s schema from HugeGraph!") + raise ValueError(f"Cannot get {self.graph_name}'s schema from HugeGraph!") context.update({"schema": schema}) # TODO: enhance the logic here diff --git a/hugegraph-llm/src/tests/models/embeddings/test_ollama_embedding.py b/hugegraph-llm/src/tests/models/embeddings/test_ollama_embedding.py index c919a2d65..e7a2702a4 100644 --- a/hugegraph-llm/src/tests/models/embeddings/test_ollama_embedding.py +++ b/hugegraph-llm/src/tests/models/embeddings/test_ollama_embedding.py @@ -18,6 +18,7 @@ import os import unittest +from unittest.mock import AsyncMock, MagicMock from hugegraph_llm.models.embeddings.base import SimilarityMode from hugegraph_llm.models.embeddings.ollama import OllamaEmbedding @@ -40,3 +41,64 @@ def test_get_cosine_similarity(self): embedding2 = ollama_embedding.get_text_embedding("bye world") similarity = OllamaEmbedding.similarity(embedding1, embedding2, SimilarityMode.DEFAULT) print(similarity) + + def test_async_get_texts_embeddings_preserves_batch_order(self): + ollama_embedding = OllamaEmbedding(model="test-model") + ollama_embedding.async_client = AsyncMock() + ollama_embedding.async_client.embed.side_effect = [ + {"embeddings": [[1.0], [2.0]]}, + {"embeddings": [[3.0]]}, + ] + + async def run_async_test(): + result = await ollama_embedding.async_get_texts_embeddings(["a", "b", "c"], batch_size=2) + self.assertEqual(result, [[1.0], [2.0], [3.0]]) + self.assertEqual(ollama_embedding.async_client.embed.call_count, 2) + ollama_embedding.async_client.embed.assert_any_call(model="test-model", input=["a", "b"]) + ollama_embedding.async_client.embed.assert_any_call(model="test-model", input=["c"]) + + import asyncio + + asyncio.run(run_async_test()) + + def test_async_get_text_embedding_requires_embeddings_key(self): + ollama_embedding = OllamaEmbedding(model="test-model") + ollama_embedding.async_client = AsyncMock() + ollama_embedding.async_client.embed.return_value = {} + + async def run_async_test(): + with self.assertRaisesRegex(ValueError, "missing 'embeddings'"): + await ollama_embedding.async_get_text_embedding("a") + + import asyncio + + asyncio.run(run_async_test()) + + def test_get_texts_embeddings_requires_embeddings_key(self): + ollama_embedding = OllamaEmbedding(model="test-model") + ollama_embedding.client = MagicMock() + ollama_embedding.client.embed.return_value = {} + + with self.assertRaisesRegex(ValueError, "missing 'embeddings'"): + ollama_embedding.get_texts_embeddings(["a"]) + + def test_get_texts_embeddings_requires_non_empty_embeddings(self): + ollama_embedding = OllamaEmbedding(model="test-model") + ollama_embedding.client = MagicMock() + ollama_embedding.client.embed.return_value = {"embeddings": []} + + with self.assertRaisesRegex(ValueError, "returned no embeddings"): + ollama_embedding.get_texts_embeddings(["a"]) + + def test_async_get_text_embedding_requires_non_empty_embeddings(self): + ollama_embedding = OllamaEmbedding(model="test-model") + ollama_embedding.async_client = AsyncMock() + ollama_embedding.async_client.embed.return_value = {"embeddings": []} + + async def run_async_test(): + with self.assertRaisesRegex(ValueError, "returned no embeddings"): + await ollama_embedding.async_get_text_embedding("a") + + import asyncio + + asyncio.run(run_async_test()) diff --git a/hugegraph-llm/src/tests/models/llms/test_litellm_client.py b/hugegraph-llm/src/tests/models/llms/test_litellm_client.py new file mode 100644 index 000000000..3b27d780b --- /dev/null +++ b/hugegraph-llm/src/tests/models/llms/test_litellm_client.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import unittest +from unittest.mock import AsyncMock, patch + +from litellm.exceptions import APIError, BudgetExceededError + +from hugegraph_llm.models.llms.litellm import LiteLLMClient + + +class TestLiteLLMClient(unittest.TestCase): + def test_budget_exceeded_error_is_not_retried(self): + client = LiteLLMClient(model_name="openai/gpt-4.1-mini") + error = BudgetExceededError(current_cost=2.0, max_budget=1.0) + + with patch("hugegraph_llm.models.llms.litellm.completion", side_effect=error) as mock_completion: + with self.assertRaises(BudgetExceededError): + client.generate(prompt="hello") + + mock_completion.assert_called_once() + + def test_generate_retries_api_error_and_reraises_original_exception(self): + client = LiteLLMClient(model_name="openai/gpt-4.1-mini") + error = APIError(status_code=500, message="upstream failed", llm_provider="openai", model="gpt-4.1-mini") + + with patch("hugegraph_llm.models.llms.litellm.completion", side_effect=error) as mock_completion: + with self.assertRaises(APIError): + client.generate(prompt="hello") + + self.assertEqual(mock_completion.call_count, 2) + + def test_generate_streaming_reraises_api_error(self): + client = LiteLLMClient(model_name="openai/gpt-4.1-mini") + error = APIError(status_code=500, message="upstream failed", llm_provider="openai", model="gpt-4.1-mini") + + with patch("hugegraph_llm.models.llms.litellm.completion", side_effect=error): + with self.assertRaises(APIError): + client.generate_streaming(prompt="hello") + + def test_agenerate_retries_api_error_and_reraises_original_exception(self): + client = LiteLLMClient(model_name="openai/gpt-4.1-mini") + error = APIError(status_code=500, message="upstream failed", llm_provider="openai", model="gpt-4.1-mini") + + async def run_async_test(): + with patch("hugegraph_llm.models.llms.litellm.acompletion", new=AsyncMock(side_effect=error)) as mock_call: + with self.assertRaises(APIError): + await client.agenerate(prompt="hello") + + self.assertEqual(mock_call.call_count, 2) + + asyncio.run(run_async_test()) + + def test_agenerate_streaming_reraises_api_error(self): + client = LiteLLMClient(model_name="openai/gpt-4.1-mini") + error = APIError(status_code=500, message="upstream failed", llm_provider="openai", model="gpt-4.1-mini") + + async def run_async_test(): + with patch("hugegraph_llm.models.llms.litellm.acompletion", new=AsyncMock(side_effect=error)): + with self.assertRaises(APIError): + async for _ in client.agenerate_streaming(prompt="hello"): + pass + + asyncio.run(run_async_test()) + + +if __name__ == "__main__": + unittest.main() diff --git a/hugegraph-llm/src/tests/models/llms/test_openai_client.py b/hugegraph-llm/src/tests/models/llms/test_openai_client.py index b9f8a113d..20e5aaacc 100644 --- a/hugegraph-llm/src/tests/models/llms/test_openai_client.py +++ b/hugegraph-llm/src/tests/models/llms/test_openai_client.py @@ -92,6 +92,23 @@ def test_generate_with_messages(self, mock_openai_class): messages=messages, ) + @patch("hugegraph_llm.models.llms.openai.OpenAI") + def test_generate_raises_runtime_error_with_empty_choices(self, mock_openai_class): + """Test generate method with an empty choices response.""" + # Setup mock client + mock_client = MagicMock() + empty_response = MagicMock() + empty_response.choices = [] + empty_response.usage = None + mock_client.chat.completions.create.return_value = empty_response + mock_openai_class.return_value = mock_client + + # Test the method + openai_client = OpenAIClient(model_name="gpt-3.5-turbo") + + with self.assertRaisesRegex(RuntimeError, "Empty choices in LLM response"): + openai_client.generate(prompt="What is the capital of France?") + @patch("hugegraph_llm.models.llms.openai.AsyncOpenAI") def test_agenerate(self, mock_async_openai_class): """Test agenerate method with mocked async OpenAI client.""" @@ -118,6 +135,26 @@ async def run_async_test(): asyncio.run(run_async_test()) + @patch("hugegraph_llm.models.llms.openai.AsyncOpenAI") + def test_agenerate_raises_runtime_error_with_empty_choices(self, mock_async_openai_class): + """Test agenerate method with an empty choices response.""" + # Setup mock async client + mock_async_client = MagicMock() + empty_response = MagicMock() + empty_response.choices = [] + empty_response.usage = None + mock_async_client.chat.completions.create = AsyncMock(return_value=empty_response) + mock_async_openai_class.return_value = mock_async_client + + # Test the method + openai_client = OpenAIClient(model_name="gpt-3.5-turbo") + + async def run_async_test(): + with self.assertRaisesRegex(RuntimeError, "Empty choices in LLM response"): + await openai_client.agenerate(prompt="What is the capital of France?") + + asyncio.run(run_async_test()) + @patch("hugegraph_llm.models.llms.openai.OpenAI") def test_stream_generate(self, mock_openai_class): """Test generate_streaming method with mocked OpenAI client.""" diff --git a/hugegraph-llm/src/tests/operators/hugegraph_op/test_fetch_graph_data.py b/hugegraph-llm/src/tests/operators/hugegraph_op/test_fetch_graph_data.py index 64c093eda..8527502d0 100644 --- a/hugegraph-llm/src/tests/operators/hugegraph_op/test_fetch_graph_data.py +++ b/hugegraph-llm/src/tests/operators/hugegraph_op/test_fetch_graph_data.py @@ -73,6 +73,21 @@ def test_run_with_none_graph_summary(self): self.assertIn("g.V().id().limit(10000).toList()", groovy_code) self.assertIn("g.E().id().limit(200).toList()", groovy_code) + def test_run_with_legacy_ordered_single_field_dicts_result(self): + """Test run method with legacy ordered single-field dict rows.""" + # Setup mock + self.mock_gremlin.exec.return_value = self.sample_result + + # Call the method + result = self.fetcher.run({}) + + # Verify the result + self.assertEqual(result["vertex_num"], 100) + self.assertEqual(result["edge_num"], 200) + self.assertEqual(result["vertices"], ["v1", "v2", "v3"]) + self.assertEqual(result["edges"], ["e1", "e2"]) + self.assertEqual(result["note"], "Only ≤10000 VIDs and ≤ 200 EIDs for brief overview .") + def test_run_with_existing_graph_summary(self): """Test run method with existing graph_summary.""" # Setup mock @@ -97,6 +112,77 @@ def test_run_with_existing_graph_summary(self): self.assertEqual(result["edges"], ["e1", "e2"]) self.assertIn("note", result) + def test_run_with_single_summary_dict_result(self): + """Test run method with Gremlin map result wrapped as one data row.""" + # Setup mock + self.mock_gremlin.exec.return_value = { + "data": [ + { + "vertex_num": 100, + "edge_num": 200, + "vertices": ["v1", "v2", "v3"], + "edges": ["e1", "e2"], + "note": "Only ≤10000 VIDs and ≤ 200 EIDs for brief overview .", + } + ] + } + + # Call the method + result = self.fetcher.run({}) + + # Verify the result + self.assertEqual(result["vertex_num"], 100) + self.assertEqual(result["edge_num"], 200) + self.assertEqual(result["vertices"], ["v1", "v2", "v3"]) + self.assertEqual(result["edges"], ["e1", "e2"]) + self.assertEqual(result["note"], "Only ≤10000 VIDs and ≤ 200 EIDs for brief overview .") + + def test_run_with_partial_single_summary_dict_result(self): + """Test run method handles a single Gremlin map with missing summary fields.""" + # Setup mock + self.mock_gremlin.exec.return_value = { + "data": [ + { + "vertex_num": 100, + "vertices": ["v1", "v2", "v3"], + } + ] + } + + # Call the method + result = self.fetcher.run({}) + + # Verify the result + self.assertEqual(result["vertex_num"], 100) + self.assertIsNone(result["edge_num"]) + self.assertEqual(result["vertices"], ["v1", "v2", "v3"]) + self.assertIsNone(result["edges"]) + self.assertIsNone(result["note"]) + + def test_run_with_empty_single_summary_dict_result(self): + """Test run method treats one empty dict as a summary row.""" + # Setup mock + self.mock_gremlin.exec.return_value = {"data": [{}]} + + # Call the method + result = self.fetcher.run({}) + + # Verify the result + self.assertIsNone(result["vertex_num"]) + self.assertIsNone(result["edge_num"]) + self.assertIsNone(result["vertices"]) + self.assertIsNone(result["edges"]) + self.assertIsNone(result["note"]) + + def test_run_reraises_gremlin_exec_exception(self): + """Test run method does not hide Gremlin execution failures.""" + # Setup mock + self.mock_gremlin.exec.side_effect = RuntimeError("Gremlin endpoint unavailable") + + # Call the method and verify the original failure is visible + with self.assertRaisesRegex(RuntimeError, "Gremlin endpoint unavailable"): + self.fetcher.run({}) + def test_run_with_empty_result(self): """Test run method with empty result from gremlin.""" # Setup mock diff --git a/hugegraph-llm/src/tests/operators/hugegraph_op/test_schema_manager.py b/hugegraph-llm/src/tests/operators/hugegraph_op/test_schema_manager.py index a20857aec..b77ae01f7 100644 --- a/hugegraph-llm/src/tests/operators/hugegraph_op/test_schema_manager.py +++ b/hugegraph-llm/src/tests/operators/hugegraph_op/test_schema_manager.py @@ -158,7 +158,7 @@ def test_run_with_empty_schema(self): self.schema_manager.run({}) # Verify the exception message - self.assertIn(f"Can not get {self.graph_name}'s schema from HugeGraph!", str(cm.exception)) + self.assertIn(f"Cannot get {self.graph_name}'s schema from HugeGraph!", str(cm.exception)) def test_run_with_existing_context(self): """Test run method with an existing context.""" diff --git a/hugegraph-python-client/src/pyhugegraph/utils/util.py b/hugegraph-python-client/src/pyhugegraph/utils/util.py index d8c833b49..b4a6f6448 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/util.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/util.py @@ -101,9 +101,21 @@ def __call__(self, response: requests.Response, method: str, path: str): log.info("Resource %s not found (404)", path) else: try: - details = response.json().get("exception", "key 'exception' not found") - except (ValueError, KeyError): - details = "key 'exception' not found" + body = response.json() + if isinstance(body, dict): + status = body.get("status") + status_message = status.get("message") if isinstance(status, dict) else None + details = ( + body.get("exception") + or body.get("message") + or status_message + or response.text + or "unknown error" + ) + else: + details = response.text or "unknown error" + except (ValueError, KeyError, AttributeError, TypeError): + details = response.text or "unknown error" req_body = response.request.body if response.request.body else "Empty body" req_body = req_body.encode("utf-8").decode("unicode_escape") diff --git a/hugegraph-python-client/src/tests/api/test_response_validation.py b/hugegraph-python-client/src/tests/api/test_response_validation.py new file mode 100644 index 000000000..759d97184 --- /dev/null +++ b/hugegraph-python-client/src/tests/api/test_response_validation.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from unittest.mock import Mock + +import requests +from pyhugegraph.utils.util import ResponseValidation + + +class TestResponseValidation(unittest.TestCase): + def _mock_error_response(self, body, text): + response = Mock(spec=requests.Response) + response.status_code = 400 + response.text = text + response.content = response.text.encode("utf-8") + response.json.return_value = body + response.request = Mock() + response.request.body = "g.V2()" + response.raise_for_status.side_effect = requests.exceptions.HTTPError("400 Client Error") + return response + + def test_numeric_status_body_raises_server_exception_with_message(self): + response = self._mock_error_response( + {"status": 400, "message": "bad gremlin"}, + '{"status":400,"message":"bad gremlin"}', + ) + validator = ResponseValidation() + + with self.assertRaisesRegex(Exception, "Server Exception: bad gremlin"): + validator(response, "POST", "/gremlin") + + def test_non_dict_json_body_raises_server_exception_with_response_text(self): + response = self._mock_error_response(["bad gremlin"], "bad gremlin") + validator = ResponseValidation() + + with self.assertRaisesRegex(Exception, "Server Exception: bad gremlin"): + validator(response, "POST", "/gremlin") + + +if __name__ == "__main__": + unittest.main()