266 lines
9.2 KiB
Python
Raw Normal View History

2023-11-29 14:07:47 +08:00
"""
Run Command with a input text.
"""
import json
2023-11-29 14:07:47 +08:00
import os
import shlex
import subprocess
2023-11-29 14:07:47 +08:00
import sys
import threading
from typing import Dict, List
2023-11-29 14:07:47 +08:00
from devchat.utils import get_logger
2023-11-29 14:07:47 +08:00
from .command_parser import Command
2024-02-08 13:19:24 +08:00
from .util import ToolUtil
2023-11-29 14:07:47 +08:00
logger = get_logger(__name__)
2024-02-08 13:19:24 +08:00
DEVCHAT_COMMAND_MISS_ERROR_MESSAGE = (
"devchat-commands environment is not installed yet. "
"Please install it before using the current command."
"The devchat-command environment is automatically "
"installed after the plugin starts,"
" and details can be viewed in the output window."
2024-02-08 13:19:24 +08:00
)
2024-02-08 13:19:24 +08:00
def pipe_reader(pipe, out_data, out_flag):
while pipe:
data = pipe.read(1)
if data == "":
2024-02-08 13:19:24 +08:00
break
out_data["out"] += data
print(data, end="", file=out_flag, flush=True)
2024-02-08 13:19:24 +08:00
2023-11-29 14:07:47 +08:00
# Equivalent of CommandRun in Python\which executes subprocesses
class CommandRunner:
def __init__(self, model_name: str):
self.process = None
self._model_name = model_name
def run_command(
self,
command_name: str,
command: Command,
history_messages: List[Dict],
input_text: str,
parent_hash: str,
):
2023-11-29 14:07:47 +08:00
"""
if command has parameters, then generate command parameters from input by LLM
if command.input is "required", and input is null, then return error
"""
input_text = (
input_text.strip()
.replace(f"/{command_name}", "")
.replace('"', '\\"')
.replace("'", "\\'")
.replace("\n", "\\n")
)
2024-02-08 13:19:24 +08:00
arguments = {}
2023-11-29 14:07:47 +08:00
if command.parameters and len(command.parameters) > 0:
if not self._model_name.startswith("gpt-"):
return None
2024-02-08 13:19:24 +08:00
arguments = self._call_function_by_llm(command_name, command, history_messages)
2023-11-29 14:07:47 +08:00
if not arguments:
print("No valid parameters generated by LLM", file=sys.stderr, flush=True)
return (-1, "")
return self.run_command_with_parameters(
2024-02-08 13:19:24 +08:00
command_name=command_name,
command=command,
parameters={"input": input_text, **arguments},
2024-02-08 13:19:24 +08:00
parent_hash=parent_hash,
history_messages=history_messages,
2024-02-08 13:19:24 +08:00
)
2023-11-29 14:07:47 +08:00
def run_command_with_parameters(
self,
command_name: str,
command: Command,
parameters: Dict[str, str],
parent_hash: str,
history_messages: List[Dict],
):
2023-11-29 14:07:47 +08:00
"""
replace $xxx in command.steps[0].run with parameters[xxx]
then run command.steps[0].run
"""
result = (-1, "")
2023-11-29 14:07:47 +08:00
try:
env = os.environ.copy()
2024-02-08 13:19:24 +08:00
env.update(parameters)
env.update(self.__load_command_runtime(command))
env.update(self.__load_chat_data(self._model_name, parent_hash, history_messages))
2024-02-08 13:19:24 +08:00
self.__update_devchat_python_path(env, command.steps[0]["run"])
command_run = command.steps[0]["run"]
for parameter in env:
command_run = command_run.replace("$" + parameter, str(env[parameter]))
2024-02-08 13:19:24 +08:00
if self.__check_command_python_error(command_run, env):
return result
2024-02-08 13:19:24 +08:00
if self.__check_input_miss_error(command, command_name, env):
if self.__get_readme(command):
result = (0, "")
return result
2024-02-08 13:19:24 +08:00
if self.__check_parameters_miss_error(command, command_run):
if self.__get_readme(command):
result = (0, "")
return result
result = self.__run_command_with_thread_output(command_run, env)
2023-11-29 14:07:47 +08:00
except Exception as err:
print("Exception:", type(err), err, file=sys.stderr, flush=True)
2024-02-08 13:19:24 +08:00
logger.exception("Run command error: %s", err)
return result
2024-02-08 13:19:24 +08:00
def __run_command_with_thread_output(self, command_str: str, env: Dict[str, str]):
"""
run command string
"""
2024-02-08 13:19:24 +08:00
def handle_output(process):
stdout_data, stderr_data = {"out": ""}, {"out": ""}
2024-02-08 13:19:24 +08:00
stdout_thread = threading.Thread(
target=pipe_reader, args=(process.stdout, stdout_data, sys.stdout)
)
2024-02-08 13:19:24 +08:00
stderr_thread = threading.Thread(
target=pipe_reader, args=(process.stderr, stderr_data, sys.stderr)
)
2024-02-08 13:19:24 +08:00
stdout_thread.start()
stderr_thread.start()
stdout_thread.join()
stderr_thread.join()
return (process.wait(), stdout_data["out"])
for key in env:
if isinstance(env[key], (List, Dict)):
env[key] = json.dumps(env[key])
with subprocess.Popen(
shlex.split(command_str),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
text=True,
) as process:
2024-02-08 13:19:24 +08:00
return handle_output(process)
def __check_command_python_error(self, command_run: str, parameters: Dict[str, str]):
need_command_python = command_run.find("$command_python ") != -1
has_command_python = parameters.get("command_python", None)
2024-02-08 13:19:24 +08:00
if need_command_python and not has_command_python:
print(DEVCHAT_COMMAND_MISS_ERROR_MESSAGE, file=sys.stderr, flush=True)
return True
return False
def __get_readme(self, command: Command):
try:
command_dir = os.path.dirname(command.path)
readme_file = os.path.join(command_dir, "README.md")
2024-02-08 13:19:24 +08:00
if os.path.exists(readme_file):
with open(readme_file, "r", encoding="utf8") as file:
2024-02-08 13:19:24 +08:00
readme = file.read()
return readme
return None
except Exception:
return None
def __check_input_miss_error(
self, command: Command, command_name: str, parameters: Dict[str, str]
):
2024-02-08 13:19:24 +08:00
is_input_required = command.input == "required"
if not (is_input_required and parameters["input"] == ""):
return False
input_miss_error = (
f"{command_name} workflow is missing input. Example usage: "
f"'/{command_name} user input'\n"
)
readme_content = self.__get_readme(command)
if readme_content:
print(readme_content, flush=True)
2024-02-08 13:19:24 +08:00
else:
print(input_miss_error, file=sys.stderr, flush=True)
return True
def __check_parameters_miss_error(self, command: Command, command_run: str):
# visit parameters in command
parameter_names = command.parameters.keys() if command.parameters else []
if len(parameter_names) == 0:
return False
missed_parameters = []
for parameter_name in parameter_names:
if command_run.find("$" + parameter_name) != -1:
2024-02-08 13:19:24 +08:00
missed_parameters.append(parameter_name)
if len(missed_parameters) == 0:
return False
readme_content = self.__get_readme(command)
if readme_content:
print(readme_content, flush=True)
2024-02-08 13:19:24 +08:00
else:
print("Missing parameters:", missed_parameters, file=sys.stderr, flush=True)
return True
def __load_command_runtime(self, command: Command):
command_path = os.path.dirname(command.path)
runtime_config = {}
# visit each path in command_path, for example: /usr/x1/x2/x3
# then load visit: /usr, /usr/x1, /usr/x1/x2, /usr/x1/x2/x3
paths = command_path.split("/")
for index in range(1, len(paths) + 1):
2024-02-08 13:19:24 +08:00
try:
path = "/".join(paths[:index])
runtime_file = os.path.join(path, "runtime.json")
2024-02-08 13:19:24 +08:00
if os.path.exists(runtime_file):
with open(runtime_file, "r", encoding="utf8") as file:
2024-02-08 13:19:24 +08:00
command_runtime_config = json.loads(file.read())
runtime_config.update(command_runtime_config)
except Exception:
pass
# for windows
if runtime_config.get("command_python", None):
runtime_config["command_python"] = runtime_config["command_python"].replace("\\", "/")
2024-02-08 13:19:24 +08:00
return runtime_config
def __load_chat_data(self, model_name: str, parent_hash: str, history_messages: List[Dict]):
return {
"LLM_MODEL": model_name if model_name else "",
"PARENT_HASH": parent_hash if parent_hash else "",
"CONTEXT_CONTENTS": history_messages if history_messages else [],
}
def __update_devchat_python_path(self, env: Dict[str, str], command_run: str):
python_path = os.environ.get("PYTHONPATH", "")
env["DEVCHAT_PYTHONPATH"] = os.environ.get("DEVCHAT_PYTHONPATH", python_path)
if command_run.find("$devchat_python ") == -1:
del env["PYTHONPATH"]
env["devchat_python"] = sys.executable.replace("\\", "/")
def _call_function_by_llm(
self, command_name: str, command: Command, history_messages: List[Dict]
):
2024-02-08 13:19:24 +08:00
"""
command needs multi parameters, so we need parse each
parameter by LLM from input_text
"""
tools = [ToolUtil.make_function(command, command_name)]
function_call = ToolUtil.select_function_by_llm(history_messages, tools)
if not function_call:
return None
return function_call["arguments"]