Handle token budget when recommend test context
This commit is contained in:
parent
3b1b24e1b2
commit
c44891a39f
@ -19,7 +19,7 @@ ENCODING = (
|
|||||||
if USE_USER_MODEL
|
if USE_USER_MODEL
|
||||||
else get_encoding("cl100k_base")
|
else get_encoding("cl100k_base")
|
||||||
)
|
)
|
||||||
# TODO: handle token budget
|
|
||||||
TOKEN_BUDGET = int(CONTEXT_SIZE.get(MODEL, DEFAULT_CONTEXT_SIZE) * 0.9)
|
TOKEN_BUDGET = int(CONTEXT_SIZE.get(MODEL, DEFAULT_CONTEXT_SIZE) * 0.9)
|
||||||
|
|
||||||
|
|
||||||
@ -58,29 +58,49 @@ JSON Format Example:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_user_msg(func_to_test: FuncToTest, contexts: List) -> str:
|
||||||
|
"""
|
||||||
|
Create a user message to be sent to the model within the token budget.
|
||||||
|
"""
|
||||||
|
msg = None
|
||||||
|
while msg is None:
|
||||||
|
context_content = "\n\n".join([str(c) for c in contexts])
|
||||||
|
|
||||||
|
msg = recommend_symbol_context_prompt.format(
|
||||||
|
function_content=func_to_test.func_content,
|
||||||
|
context_content=context_content,
|
||||||
|
function_name=func_to_test.func_name,
|
||||||
|
file_path=func_to_test.file_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_count = len(ENCODING.encode(msg, disallowed_special=()))
|
||||||
|
if contexts and token_count > TOKEN_BUDGET:
|
||||||
|
# Remove the last context and try again
|
||||||
|
contexts.pop()
|
||||||
|
msg = None
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
def get_recommended_symbols(
|
def get_recommended_symbols(
|
||||||
func_to_test: FuncToTest, known_context: Optional[List] = None
|
func_to_test: FuncToTest, known_context: Optional[List] = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
known_context = known_context or []
|
known_context = known_context or []
|
||||||
context_content = "\n\n".join([str(c) for c in known_context])
|
msg = _mk_user_msg(func_to_test, known_context)
|
||||||
|
|
||||||
msg = recommend_symbol_context_prompt.format(
|
|
||||||
function_content=func_to_test.func_content,
|
|
||||||
context_content=context_content,
|
|
||||||
function_name=func_to_test.func_name,
|
|
||||||
file_path=func_to_test.file_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
json_res = {}
|
json_res = {}
|
||||||
if USE_USER_MODEL:
|
if USE_USER_MODEL:
|
||||||
# Use the wrapped api parameters
|
# Use the wrapped api parameters
|
||||||
json_res = chat_completion_no_stream_return_json(
|
json_res = (
|
||||||
messages=[{"role": "user", "content": msg}],
|
chat_completion_no_stream_return_json(
|
||||||
llm_config={
|
messages=[{"role": "user", "content": msg}],
|
||||||
"model": MODEL,
|
llm_config={
|
||||||
"temperature": 0.1,
|
"model": MODEL,
|
||||||
},
|
"temperature": 0.1,
|
||||||
) or {}
|
},
|
||||||
|
)
|
||||||
|
or {}
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = create_chat_completion_content(
|
response = create_chat_completion_content(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user