Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions graphgen/common/init_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(self, backend: str, config: Dict[str, Any]):
from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper

self.llm_instance = VLLMWrapper(**config)
elif backend == "ray_serve":
from graphgen.models.llm.api.ray_serve_client import RayServeClient

self.llm_instance = RayServeClient(**config)
else:
raise NotImplementedError(f"Backend {backend} is not implemented yet.")

Expand Down
88 changes: 88 additions & 0 deletions graphgen/models/llm/api/ray_serve_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Any, List, Optional

from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token


class RayServeClient(BaseLLMWrapper):
"""
A client to interact with a Ray Serve deployment.
"""

def __init__(
self,
*,
app_name: Optional[str] = None,
deployment_name: Optional[str] = None,
serve_backend: Optional[str] = None,
**kwargs: Any,
):
try:
from ray import serve
except ImportError as e:
raise ImportError(
"Ray is not installed. Please install it with `pip install ray[serve]`."
) from e

super().__init__(**kwargs)

# Try to get existing handle first
self.handle = None
if app_name:
try:
self.handle = serve.get_app_handle(app_name)
except Exception:
pass
Comment on lines +34 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching a generic Exception can hide unexpected errors and make debugging difficult. It's better to catch more specific exceptions that you anticipate, or at least log the full traceback if a generic Exception must be caught.

elif deployment_name:
try:
self.handle = serve.get_deployment(deployment_name).get_handle()
except Exception:
Comment on lines +38 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, catching a generic Exception here can obscure the root cause of issues. Consider catching more specific exceptions related to serve.get_deployment or get_handle failures.

pass

# If no handle found, try to deploy if serve_backend is provided
if self.handle is None:
if serve_backend:
if not app_name:
import uuid

app_name = f"llm_app_{serve_backend}_{uuid.uuid4().hex[:8]}"
Comment on lines +46 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Imports should generally be placed at the top of the file (PEP 8 guideline). Moving import uuid to the top of the file improves readability and ensures consistency, even if Python allows conditional imports.

import uuid

                    app_name = f"llm_app_{serve_backend}_{uuid.uuid4().hex[:8]}"


print(
f"Deploying Ray Serve app '{app_name}' with backend '{serve_backend}'..."
)
from graphgen.models.llm.local.ray_serve_deployment import LLMDeployment

# Filter kwargs to avoid passing unrelated args if necessary,
# but LLMDeployment config accepts everything for now.
# Note: We need to pass kwargs as the config dict.
Comment on lines +55 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment indicates a potential issue with kwargs being passed directly. It's safer and clearer to explicitly construct a config dictionary with only the parameters relevant to LLMDeployment to avoid passing unintended arguments. This also makes the code more robust to future changes in LLMDeployment's constructor.

                # Explicitly construct config for LLMDeployment
                deployment_config = {"model": kwargs.get("model"), "tokenizer": kwargs.get("tokenizer")} # Add other relevant config parameters
                deployment = LLMDeployment.bind(backend=serve_backend, config=deployment_config)

deployment = LLMDeployment.bind(backend=serve_backend, config=kwargs)
serve.run(deployment, name=app_name, route_prefix=f"/{app_name}")
Comment on lines +58 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

serve.run() is a blocking call by default. If RayServeClient is intended to be used within an asynchronous application, this call will block the event loop, potentially causing performance issues or deadlocks. Consider if the Ray Serve application should be deployed out-of-band (e.g., as a separate script) or if serve.start() and serve.run_app() should be used in a non-blocking manner if the client is responsible for its lifecycle within an async context.

self.handle = serve.get_app_handle(app_name)
elif app_name or deployment_name:
raise ValueError(
f"Ray Serve app/deployment '{app_name or deployment_name}' "
"not found and 'serve_backend' not provided to deploy it."
)
else:
raise ValueError(
"Either 'app_name', 'deployment_name' or 'serve_backend' "
"must be provided for RayServeClient."
)

async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
"""Generate answer from the model."""
return await self.handle.generate_answer.remote(text, history, **extra)

async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
"""Generate top-k tokens for the next token prediction."""
return await self.handle.generate_topk_per_token.remote(text, history, **extra)

async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
"""Generate probabilities for each token in the input."""
return await self.handle.generate_inputs_prob.remote(text, history, **extra)
84 changes: 84 additions & 0 deletions graphgen/models/llm/local/ray_serve_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
from typing import Any, Dict, List, Optional

from ray import serve
from starlette.requests import Request

from graphgen.bases.datatypes import Token
from graphgen.models.tokenizer import Tokenizer


@serve.deployment
class LLMDeployment:
def __init__(self, backend: str, config: Dict[str, Any]):
self.backend = backend

# Initialize tokenizer if needed
tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base")
if "tokenizer" not in config:
tokenizer = Tokenizer(model_name=tokenizer_model)
config["tokenizer"] = tokenizer
Comment on lines +17 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tokenizer initialization logic is duplicated here and in LLMServiceActor. To improve maintainability and avoid redundancy, consider extracting this logic into a shared utility function or a common base class if applicable. This ensures consistent tokenizer handling across different LLM wrappers.


if backend == "vllm":
from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper

self.llm_instance = VLLMWrapper(**config)
elif backend == "huggingface":
from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper

self.llm_instance = HuggingFaceWrapper(**config)
elif backend == "sglang":
from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper

self.llm_instance = SGLangWrapper(**config)
else:
raise NotImplementedError(
f"Backend {backend} is not implemented for Ray Serve yet."
)

async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
return await self.llm_instance.generate_answer(text, history, **extra)

async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
return await self.llm_instance.generate_topk_per_token(text, history, **extra)

async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
return await self.llm_instance.generate_inputs_prob(text, history, **extra)

async def __call__(self, request: Request) -> Dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The __call__ method, which serves as the HTTP entry point for the Ray Serve deployment, lacks any authentication or authorization checks. This allows any user with network access to the Ray Serve port to execute LLM queries, potentially leading to unauthorized resource consumption and abuse of the service.

try:
data = await request.json()
text = data.get("text")
history = data.get("history")
method = data.get("method", "generate_answer")
kwargs = data.get("kwargs", {})

if method == "generate_answer":
result = await self.generate_answer(text, history, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The __call__ method passes untrusted user input (text and history) directly to the LLM's generate_answer method. Several LLM backends in this repository (e.g., HuggingFaceWrapper, SGLangWrapper) use manual string concatenation to build prompts, making them highly susceptible to prompt injection attacks. An attacker could provide crafted input to manipulate the LLM's behavior or spoof conversation history.

elif method == "generate_topk_per_token":
result = await self.generate_topk_per_token(text, history, **kwargs)
elif method == "generate_inputs_prob":
result = await self.generate_inputs_prob(text, history, **kwargs)
else:
return {"error": f"Method {method} not supported"}

return {"result": result}
except Exception as e:
return {"error": str(e)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The current implementation exposes sensitive internal information by returning raw exception details, which could be exploited by attackers to gain deeper insights into the system's architecture and vulnerabilities. It's crucial to prevent the leakage of stack traces, file paths, or configuration details. Instead, catch specific exceptions and return generalized error messages, logging full details internally without exposing them to the client.

        except Exception as e:
            # Log the error internally for debugging
            # import logging
            # logging.exception("Error in LLMDeployment")
            return {"error": "An internal error occurred while processing your request."}



def app_builder(args: Dict[str, str]) -> Any:
"""
Builder function for 'serve run'.
Usage: serve run graphgen.models.llm.local.ray_serve_deployment:app_builder backend=vllm model=...
"""
# args comes from the command line key=value pairs
backend = args.pop("backend", "vllm")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The app_builder function defaults backend to "vllm" if not provided. This default might not align with the user's expectation or the LLMDeployment's intended behavior if other backends are more common or desired as a default. Consider making the backend explicit or ensuring the default is well-documented and consistent with the overall system design.

# remaining args are treated as config
return LLMDeployment.bind(backend=backend, config=args)
Loading