Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import argparse
import os

import gradio as gr
import uvicorn
Expand Down Expand Up @@ -202,5 +203,5 @@ def create_app():
host=args.host,
port=args.port,
factory=True,
reload=True,
reload=os.getenv("HG_DEV_RELOAD") == "1",
)
39 changes: 30 additions & 9 deletions hugegraph-llm/src/hugegraph_llm/models/embeddings/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,41 @@ def get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[L
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
response = self.client.embed(model=self.model, input=batch)["embeddings"]
all_embeddings.extend([list(inner_sequence) for inner_sequence in response])
response = self.client.embed(model=self.model, input=batch)
all_embeddings.extend(self._get_embeddings_from_response(response))
return all_embeddings

def _get_embeddings_from_response(self, response) -> List[List[float]]:
Comment thread
zxuexingzhijie marked this conversation as resolved.
if "embeddings" not in response:
raise ValueError("Ollama embedding response missing 'embeddings'.")
embeddings = response["embeddings"]
if not embeddings:
raise ValueError("Ollama embedding response returned no embeddings.")
return [list(inner_sequence) for inner_sequence in embeddings]

async def async_get_text_embedding(self, text: str) -> List[float]:
"""Get embedding for a single text asynchronously."""
response = await self.async_client.embeddings(model=self.model, prompt=text)
return list(response["embedding"])
if not hasattr(self.async_client, "embed"):
error_message = (
"The required 'embed' method was not found on the Ollama async client. "
"Please ensure your ollama library is up-to-date and supports batch embedding. "
)
raise AttributeError(error_message)

response = await self.async_client.embed(model=self.model, input=[text])
Comment thread
zxuexingzhijie marked this conversation as resolved.
return self._get_embeddings_from_response(response)[0]

async def async_get_texts_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
# Ollama python client may not provide batch async embeddings; fallback per item
# batch_size parameter included for consistency with base class signature
if not hasattr(self.async_client, "embed"):
error_message = (
"The required 'embed' method was not found on the Ollama async client. "
"Please ensure your ollama library is up-to-date and supports batch embedding. "
)
raise AttributeError(error_message)

results: List[List[float]] = []
for t in texts:
response = await self.async_client.embeddings(model=self.model, prompt=t)
results.append(list(response["embedding"]))
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
response = await self.async_client.embed(model=self.model, input=batch)
results.extend(self._get_embeddings_from_response(response))
return results
14 changes: 8 additions & 6 deletions hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def __init__(
@retry(
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=2, max=5),
retry=retry_if_exception_type((RateLimitError, BudgetExceededError, APIError)),
retry=retry_if_exception_type((RateLimitError, APIError)),
reraise=True,
)
def generate(
self,
Expand All @@ -75,12 +76,13 @@ def generate(
return response.choices[0].message.content
except (RateLimitError, BudgetExceededError, APIError) as e:
log.error("Error in LiteLLM call: %s", e)
return f"Error: {str(e)}"
raise
Comment thread
zxuexingzhijie marked this conversation as resolved.

@retry(
stop=stop_after_attempt(2),
wait=wait_exponential(multiplier=1, min=2, max=5),
retry=retry_if_exception_type((RateLimitError, BudgetExceededError, APIError)),
retry=retry_if_exception_type((RateLimitError, APIError)),
reraise=True,
)
async def agenerate(
self,
Expand All @@ -104,7 +106,7 @@ async def agenerate(
return response.choices[0].message.content
except (RateLimitError, BudgetExceededError, APIError) as e:
log.error("Error in async LiteLLM call: %s", e)
return f"Error: {str(e)}"
raise

def generate_streaming(
self,
Expand Down Expand Up @@ -138,7 +140,7 @@ def generate_streaming(
return result
except (RateLimitError, BudgetExceededError, APIError) as e:
log.error("Error in streaming LiteLLM call: %s", e)
return f"Error: {str(e)}"
raise

async def agenerate_streaming(
self,
Expand Down Expand Up @@ -170,7 +172,7 @@ async def agenerate_streaming(
yield chunk.choices[0].delta.content
except (RateLimitError, BudgetExceededError, APIError) as e:
log.error("Error in async streaming LiteLLM call: %s", e)
yield f"Error: {str(e)}"
raise

def num_tokens_from_string(self, string: str) -> int:
"""Get token count from string."""
Expand Down
10 changes: 8 additions & 2 deletions hugegraph-llm/src/hugegraph_llm/models/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def generate(
max_tokens=self.max_tokens,
messages=messages,
)
log.info("Token usage: %s", completions.usage.model_dump_json())
if not completions.choices:
raise RuntimeError(f"Empty choices in LLM response: {str(completions)[:200]}")
if completions.usage:
log.info("Token usage: %s", completions.usage.model_dump_json())
return completions.choices[0].message.content
# catch context length / do not retry
except openai.BadRequestError as e:
Expand Down Expand Up @@ -105,7 +108,10 @@ async def agenerate(
max_tokens=self.max_tokens,
messages=messages,
)
log.info("Token usage: %s", completions.usage.model_dump_json())
if not completions.choices:
raise RuntimeError(f"Empty choices in LLM response: {str(completions)[:200]}")
if completions.usage:
log.info("Token usage: %s", completions.usage.model_dump_json())
return completions.choices[0].message.content
# catch context length / do not retry
except openai.BadRequestError as e:
Expand Down
12 changes: 8 additions & 4 deletions hugegraph-llm/src/hugegraph_llm/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import traceback
from typing import Dict, Optional

from pycgraph import CStatus, GNode
Expand All @@ -22,6 +23,11 @@
from hugegraph_llm.utils.log import log


def _format_node_err(node, exc: Exception, prefix: str = "Node failed") -> str:
node_info = f"Node type: {type(node).__name__}, Node object: {node}"
return f"{prefix}: {exc}\n{node_info}\n{traceback.format_exc()}"


class BaseNode(GNode):
"""
Base class for workflow nodes, providing context management and operation scheduling.
Expand Down Expand Up @@ -70,10 +76,8 @@ def run(self):
try:
res = self.operator_schedule(data_json)
except (ValueError, TypeError, KeyError, NotImplementedError) as exc:
import traceback

node_info = f"Node type: {type(self).__name__}, Node object: {self}"
err_msg = f"Node failed: {exc}\n{node_info}\n{traceback.format_exc()}"
err_msg = _format_node_err(self, exc)
log.error(err_msg)
return CStatus(-1, err_msg)
# For unexpected exceptions, re-raise to let them propagate or be caught elsewhere

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def node_init(self):
def operator_schedule(self, data_json):
try:
schema_result = self.schema_builder.run(data_json)

return {"schema": schema_result}
except (ValueError, RuntimeError) as e:
log.error("Failed to generate schema: %s", e)
return {"schema": f"Schema generation failed: {e}"}
raise ValueError(f"Failed to generate schema: {e}") from e
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@ def res = [:];
return res;
"""

result = self.graph.gremlin().exec(groovy_code)["data"]

response = self.graph.gremlin().exec(groovy_code)
result = response.get("data") if isinstance(response, dict) else None
if isinstance(result, list) and len(result) > 0:
graph_summary.update({key: result[i].get(key) for i, key in enumerate(keys)})
if len(result) == 1 and isinstance(result[0], dict):
graph_summary.update({key: result[0].get(key) for key in keys})
Comment thread
zxuexingzhijie marked this conversation as resolved.
else:
graph_summary.update(
{
key: result[i].get(key) if i < len(result) and isinstance(result[i], dict) else None
for i, key in enumerate(keys)
}
)
return graph_summary
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, Optional

from pyhugegraph.client import PyHugeClient
from requests.exceptions import RequestException

from hugegraph_llm.config import huge_settings

Expand Down Expand Up @@ -57,9 +58,12 @@ def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]:
def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if context is None:
context = {}
schema = self.schema.getSchema()
try:
schema = self.schema.getSchema()
except RequestException as e:
raise ValueError(f"Failed to connect to HugeGraph to get schema '{self.graph_name}': {e}") from e
if not schema["vertexlabels"] and not schema["edgelabels"]:
raise Exception(f"Can not get {self.graph_name}'s schema from HugeGraph!")
raise ValueError(f"Cannot get {self.graph_name}'s schema from HugeGraph!")

context.update({"schema": schema})
# TODO: enhance the logic here
Expand Down
62 changes: 62 additions & 0 deletions hugegraph-llm/src/tests/models/embeddings/test_ollama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import os
import unittest
from unittest.mock import AsyncMock, MagicMock

from hugegraph_llm.models.embeddings.base import SimilarityMode
from hugegraph_llm.models.embeddings.ollama import OllamaEmbedding
Expand All @@ -40,3 +41,64 @@ def test_get_cosine_similarity(self):
embedding2 = ollama_embedding.get_text_embedding("bye world")
similarity = OllamaEmbedding.similarity(embedding1, embedding2, SimilarityMode.DEFAULT)
print(similarity)

def test_async_get_texts_embeddings_preserves_batch_order(self):
ollama_embedding = OllamaEmbedding(model="test-model")
ollama_embedding.async_client = AsyncMock()
ollama_embedding.async_client.embed.side_effect = [
{"embeddings": [[1.0], [2.0]]},
{"embeddings": [[3.0]]},
]

async def run_async_test():
result = await ollama_embedding.async_get_texts_embeddings(["a", "b", "c"], batch_size=2)
self.assertEqual(result, [[1.0], [2.0], [3.0]])
self.assertEqual(ollama_embedding.async_client.embed.call_count, 2)
ollama_embedding.async_client.embed.assert_any_call(model="test-model", input=["a", "b"])
ollama_embedding.async_client.embed.assert_any_call(model="test-model", input=["c"])

import asyncio

asyncio.run(run_async_test())

def test_async_get_text_embedding_requires_embeddings_key(self):
ollama_embedding = OllamaEmbedding(model="test-model")
ollama_embedding.async_client = AsyncMock()
ollama_embedding.async_client.embed.return_value = {}

async def run_async_test():
with self.assertRaisesRegex(ValueError, "missing 'embeddings'"):
await ollama_embedding.async_get_text_embedding("a")

import asyncio

asyncio.run(run_async_test())

def test_get_texts_embeddings_requires_embeddings_key(self):
ollama_embedding = OllamaEmbedding(model="test-model")
ollama_embedding.client = MagicMock()
ollama_embedding.client.embed.return_value = {}

with self.assertRaisesRegex(ValueError, "missing 'embeddings'"):
ollama_embedding.get_texts_embeddings(["a"])

def test_get_texts_embeddings_requires_non_empty_embeddings(self):
ollama_embedding = OllamaEmbedding(model="test-model")
ollama_embedding.client = MagicMock()
ollama_embedding.client.embed.return_value = {"embeddings": []}

with self.assertRaisesRegex(ValueError, "returned no embeddings"):
ollama_embedding.get_texts_embeddings(["a"])

def test_async_get_text_embedding_requires_non_empty_embeddings(self):
ollama_embedding = OllamaEmbedding(model="test-model")
ollama_embedding.async_client = AsyncMock()
ollama_embedding.async_client.embed.return_value = {"embeddings": []}

async def run_async_test():
with self.assertRaisesRegex(ValueError, "returned no embeddings"):
await ollama_embedding.async_get_text_embedding("a")

import asyncio

asyncio.run(run_async_test())
83 changes: 83 additions & 0 deletions hugegraph-llm/src/tests/models/llms/test_litellm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import asyncio
import unittest
from unittest.mock import AsyncMock, patch

from litellm.exceptions import APIError, BudgetExceededError

from hugegraph_llm.models.llms.litellm import LiteLLMClient


class TestLiteLLMClient(unittest.TestCase):
def test_budget_exceeded_error_is_not_retried(self):
client = LiteLLMClient(model_name="openai/gpt-4.1-mini")
error = BudgetExceededError(current_cost=2.0, max_budget=1.0)

with patch("hugegraph_llm.models.llms.litellm.completion", side_effect=error) as mock_completion:
with self.assertRaises(BudgetExceededError):
client.generate(prompt="hello")

mock_completion.assert_called_once()

def test_generate_retries_api_error_and_reraises_original_exception(self):
client = LiteLLMClient(model_name="openai/gpt-4.1-mini")
error = APIError(status_code=500, message="upstream failed", llm_provider="openai", model="gpt-4.1-mini")

with patch("hugegraph_llm.models.llms.litellm.completion", side_effect=error) as mock_completion:
with self.assertRaises(APIError):
client.generate(prompt="hello")

self.assertEqual(mock_completion.call_count, 2)

def test_generate_streaming_reraises_api_error(self):
client = LiteLLMClient(model_name="openai/gpt-4.1-mini")
error = APIError(status_code=500, message="upstream failed", llm_provider="openai", model="gpt-4.1-mini")

with patch("hugegraph_llm.models.llms.litellm.completion", side_effect=error):
with self.assertRaises(APIError):
client.generate_streaming(prompt="hello")

def test_agenerate_retries_api_error_and_reraises_original_exception(self):
client = LiteLLMClient(model_name="openai/gpt-4.1-mini")
error = APIError(status_code=500, message="upstream failed", llm_provider="openai", model="gpt-4.1-mini")

async def run_async_test():
with patch("hugegraph_llm.models.llms.litellm.acompletion", new=AsyncMock(side_effect=error)) as mock_call:
with self.assertRaises(APIError):
await client.agenerate(prompt="hello")

self.assertEqual(mock_call.call_count, 2)

asyncio.run(run_async_test())

def test_agenerate_streaming_reraises_api_error(self):
client = LiteLLMClient(model_name="openai/gpt-4.1-mini")
error = APIError(status_code=500, message="upstream failed", llm_provider="openai", model="gpt-4.1-mini")

async def run_async_test():
with patch("hugegraph_llm.models.llms.litellm.acompletion", new=AsyncMock(side_effect=error)):
with self.assertRaises(APIError):
async for _ in client.agenerate_streaming(prompt="hello"):
pass

asyncio.run(run_async_test())


if __name__ == "__main__":
unittest.main()
Loading
Loading