From c44891a39fb41b75e53dcff77dd55a8911d5aee6 Mon Sep 17 00:00:00 2001 From: kagami Date: Wed, 13 Mar 2024 11:28:49 +0800 Subject: [PATCH] Handle token budget when recommend test context --- .../assistants/recommend_test_context.py | 52 +++++++++++++------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/unit_tests/assistants/recommend_test_context.py b/unit_tests/assistants/recommend_test_context.py index e173958..f7d849f 100644 --- a/unit_tests/assistants/recommend_test_context.py +++ b/unit_tests/assistants/recommend_test_context.py @@ -19,7 +19,7 @@ ENCODING = ( if USE_USER_MODEL else get_encoding("cl100k_base") ) -# TODO: handle token budget + TOKEN_BUDGET = int(CONTEXT_SIZE.get(MODEL, DEFAULT_CONTEXT_SIZE) * 0.9) @@ -58,29 +58,49 @@ JSON Format Example: """ +def _mk_user_msg(func_to_test: FuncToTest, contexts: List) -> str: + """ + Create a user message to be sent to the model within the token budget. + """ + msg = None + while msg is None: + context_content = "\n\n".join([str(c) for c in contexts]) + + msg = recommend_symbol_context_prompt.format( + function_content=func_to_test.func_content, + context_content=context_content, + function_name=func_to_test.func_name, + file_path=func_to_test.file_path, + ) + + token_count = len(ENCODING.encode(msg, disallowed_special=())) + if contexts and token_count > TOKEN_BUDGET: + # Remove the last context and try again + contexts.pop() + msg = None + + return msg + + def get_recommended_symbols( func_to_test: FuncToTest, known_context: Optional[List] = None ) -> List[str]: known_context = known_context or [] - context_content = "\n\n".join([str(c) for c in known_context]) - - msg = recommend_symbol_context_prompt.format( - function_content=func_to_test.func_content, - context_content=context_content, - function_name=func_to_test.func_name, - file_path=func_to_test.file_path, - ) + msg = _mk_user_msg(func_to_test, known_context) json_res = {} if USE_USER_MODEL: # Use the wrapped api parameters - json_res = chat_completion_no_stream_return_json( - messages=[{"role": "user", "content": msg}], - llm_config={ - "model": MODEL, - "temperature": 0.1, - }, - ) or {} + json_res = ( + chat_completion_no_stream_return_json( + messages=[{"role": "user", "content": msg}], + llm_config={ + "model": MODEL, + "temperature": 0.1, + }, + ) + or {} + ) else: response = create_chat_completion_content(