[tinker][SkyRL] Add sample and renderer API to new inference client#1287
[tinker][SkyRL] Add sample and renderer API to new inference client#1287nithinvc wants to merge 38 commits intoNovaSky-AI:mainfrom
Conversation
| return s.getsockname()[1] | ||
|
|
||
|
|
||
| def find_and_reserve_port(start_port: int) -> Tuple[int, socket.socket]: |
There was a problem hiding this comment.
I think this would be best done in a separate PR (making separate PRs is e.g. very useful if one of the changes needs to be reverted, and also helps reviewing).
The same goes for the tests associated to this change of course.
There was a problem hiding this comment.
Sounds good! I can split this into two? One for the sample / render and one for the port collision.
| session_id: Optional[Union[str, int]] = None, | ||
| ) -> Dict[str, Any]: | ||
| """ | ||
| Render chat messages into a tokenized prompt via /v1/chat/completions/render. |
There was a problem hiding this comment.
This is very nice! I hadn't know about this endpoint, looks very useful. For client side rendering, I have found https://github.com/thinking-machines-lab/tinker-cookbook/tree/main/tinker_cookbook/renderers useful, but you are right to use the vllm endpoint for this PR, it will make it much easier to keep everything consistent :)
There was a problem hiding this comment.
I saw this earlier! It's very useful - I think it might make sense to model the training side changes in a similar way so there's little drift between the client renderer <> training backend
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.skipif(not _SKYRL_USE_NEW_INFERENCE, reason="Render API only supported with new inference client") | ||
| async def test_render_chat_completion(self, client): |
There was a problem hiding this comment.
For this we should also test if the result is actually correct / is what we expect.
There was a problem hiding this comment.
Added a check for the prompt and prompt token ids returned by the mock
| client = vllm_server.client | ||
| messages = [{"role": "user", "content": "Hello"}] | ||
| result = asyncio.run(client.render_chat_completion(messages=messages)) | ||
| # vLLM returns [conversation, engine_prompts] |
There was a problem hiding this comment.
Again it would be good to have stricter checks here
There was a problem hiding this comment.
Added a local tokenizer check for the prompt ids
|
It looks good to me overall, maybe @kouroshHakha has some more feedback :) One thing that seems a little strange is that we need all of {sample, render_chat_completion, chat_completion, completion, tokenize, detokenize}. It seems to me we would always go through cc @CharlieFRuan since he has been thinking a lot about token in token out :) |
There was a problem hiding this comment.
Code Review
This pull request introduces sample() and render_chat_completion() methods to the RemoteInferenceClient and integrates the new HTTP-based inference pathway into SkyRLTrainBackend, controlled by the _SKYRL_USE_NEW_INFERENCE environment variable. The changes are accompanied by new tests and updates to existing ones. My review found a couple of areas for improvement: an unused parameter in the new sample method and a configuration issue in a new GPU test that prevents it from testing the intended scenario.
tests/backends/skyrl_train/gpu/gpu_ci/inference_servers/test_backend_weight_sync.py
Outdated
Show resolved
Hide resolved
skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py
Outdated
Show resolved
Hide resolved
Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
kouroshHakha
left a comment
There was a problem hiding this comment.
Just one big comment:
| async def sample( | ||
| self, |
There was a problem hiding this comment.
I would make the interface of this consistent with how chat/completion are implemented. i.e. request_body goes in and then it is parsed out. The request_body should have the same API as the tinker sample api.
Re lora_id / model_id I think we should do this:
InferenceClient has a field called model_name that should be used for the apis. I looked at sample API spec from tinker and they don't need to define that there because the client carries the model name information. We should do a similar thing here.
| @pytest.fixture(scope="class") | ||
| def ray_env_with_new_inference(): |
There was a problem hiding this comment.
don't we already have a conftest for this?
There was a problem hiding this comment.
Yes, in fact this test was redundant with some the existing tests so I removed it
There was a problem hiding this comment.
Code Review
This pull request introduces a new HTTP-based inference pathway by adding sample() and render_chat_completion() methods to RemoteInferenceClient and integrating it into SkyRLTrainBackend. The changes are gated by the _SKYRL_USE_NEW_INFERENCE environment variable, which is a good practice for introducing significant new functionality. The implementation is well-structured, and the addition of comprehensive unit and integration tests for the new APIs is commendable. I've identified a couple of areas for improvement in the RemoteInferenceClient, including a potential runtime type error and a best-practice violation regarding function side effects.
|
@kouroshHakha @pcmoritz I modified the sample method to follow the |
…client.py Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
| return RemoteInferenceClient( | ||
| proxy_url=proxy_url, | ||
| server_urls=server_urls, | ||
| model_name=self._cfg.trainer.policy.model.path, | ||
| ) |
There was a problem hiding this comment.
🔴 _create_remote_inference_client ignores served_model_name, causing request rejection when configured
_create_remote_inference_client always uses self._cfg.trainer.policy.model.path as model_name, but when served_model_name is configured in the inference engine config, the vLLM server only accepts that name (not the model path). This causes all data plane requests (sample, generate, chat_completion, etc.) to fail with a "model not found" error.
The old InferenceEngineClient correctly handles this at skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py:68-70, and the test utility at tests/backends/skyrl_train/gpu/utils.py:512 also correctly uses served_model_name if served_model_name else cfg.trainer.policy.model.path. The production code here omits this logic.
Note: main_base.py:377 has the same pre-existing issue, which this code mirrors — but it should be fixed here nonetheless.
| return RemoteInferenceClient( | |
| proxy_url=proxy_url, | |
| server_urls=server_urls, | |
| model_name=self._cfg.trainer.policy.model.path, | |
| ) | |
| ie_served_name = self._cfg.generator.inference_engine.served_model_name | |
| return RemoteInferenceClient( | |
| proxy_url=proxy_url, | |
| server_urls=server_urls, | |
| model_name=ie_served_name if ie_served_name else self._cfg.trainer.policy.model.path, | |
| ) |
Was this helpful? React with 👍 or 👎 to provide feedback.
Summary
This PR adds the
sample()andrender_chat_completion()methods to RemoteInferenceClient and wires the new HTTP-based inference pathway into SkyRLTrainBackend, gated behind the_SKYRL_USE_NEW_INFERENCEenv var.Issues addressed: #1286 #1288
Changes
RemoteInferenceClient— Addedsample()(firesnum_samplesparallel_generate_singlecalls and aggregates into a singleInferenceEngineOutput) andrender_chat_completion()(calls/v1/chat/completions/renderto tokenize chat messages without generating). Tested against vllm0.16.0.SkyRLTrainBackend— Added_create_remote_inference_client()with the same 4-way branching logic asmain_base.py(external proxy + servers, proxy only, servers only, fully internalServerGroup+InferenceRouter)._ensure_inference_engines()now branches on_SKYRL_USE_NEW_INFERENCEto use the new HTTP client path.test_engine_generation.py— Removed the guard on the sample API test. The new sample API passes.test_save_weights_for_sampler.py— Fixed GPU test to pass tokenizer torun_inferenceand addedgpu_memory_utilization=0.5for colocated placement to avoid OOM when running on L4 GPUs./v1/chat/completions/render endpointandtest_render_chat_completiontotest_remote_inference_client.py.test_client_render_chat_completiontotest_new_inference_generation.py.Testing
test_remote_inference_client.py- no regressions + render API teststest_engine_generation.py- Sample API testtest_save_weights_for_sampler.py- GPU inference + weight syncing testsLimitations
lora_rank=0is required. LoRA support will be added once we move to the native VLLM weight sync API.