workflows/ask-code/ask-code.py

71 lines
2.1 KiB
Python
Raw Normal View History

2023-11-30 07:57:01 +08:00
import os
import sys
from chat.ask_codebase.chains.smart_qa import SmartQA
2023-12-08 18:28:36 +08:00
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "libs"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "libs"))
2023-11-30 07:57:01 +08:00
2023-12-08 18:38:12 +08:00
from ide_services import get_lsp_brige_port # noqa: E402
2023-11-30 07:57:01 +08:00
def query(question, lsp_brige_port):
root_path = os.getcwd()
2023-12-08 18:28:36 +08:00
2023-11-30 07:57:01 +08:00
# Create an instance of SmartQA
smart_qa = SmartQA(root_path)
# Use SmartQA to get the answer
2023-12-08 18:28:36 +08:00
answer = smart_qa.run(
question=question,
verbose=False,
dfs_depth=3,
dfs_max_visit=10,
bridge_url=f"http://localhost:{lsp_brige_port}",
)
2023-11-30 07:57:01 +08:00
# Print the answer
prices = {
"gpt-4": (0.03, 0.06),
2023-12-18 09:14:42 +08:00
"gpt-4-32k": (0.06, 0.12),
"gpt-3.5-turbo": (0.0015, 0.002),
"gpt-3.5-turbo-16k": (0.003, 0.004),
"claude-2": (0.01102, 0.03268),
"starchat-alpha": (0.0004, 0.0004),
"CodeLlama-34b-Instruct": (0.0008, 0.0008),
"llama-2-70b-chat": (0.001, 0.001),
"gpt-3.5-turbo-1106": (0.001, 0.002),
"gpt-4-1106-preview": (0.01, 0.03),
"gpt-4-1106-vision-preview": (0.01, 0.03),
"others": (0.001, 0.002),
}
2023-11-30 07:57:01 +08:00
print(answer[0])
spent_money = 0.0
token_usages = answer[2].get("usages", [])
if len(token_usages) > 0:
for token_usage in token_usages:
2023-12-18 09:14:42 +08:00
price = prices.get(token_usage.model, prices["others"])
spent_money += (price[0] * token_usage.prompt_tokens) / 1000 + (
price[1] * token_usage.completion_tokens
) / 1000
print(f"***/ask-code has costed approximately ${spent_money/0.7} USD for this question.***")
2023-11-30 07:57:01 +08:00
2023-12-18 09:14:42 +08:00
2023-11-30 07:57:01 +08:00
def main():
try:
if len(sys.argv) < 3:
print("Usage: python index_and_query.py query [question] [port]")
sys.exit(1)
2023-12-08 18:28:36 +08:00
2023-11-30 07:57:01 +08:00
port = get_lsp_brige_port()
2023-12-08 18:28:36 +08:00
2023-11-30 07:57:01 +08:00
question = sys.argv[2]
query(question, port)
sys.exit(0)
except Exception as e:
print("Exception: ", e, file=sys.stderr, flush=True)
sys.exit(1)
if __name__ == "__main__":
2023-12-08 18:28:36 +08:00
main()