261 lines
9.5 KiB
Python
261 lines
9.5 KiB
Python
"""
|
|
Run Command with a input text.
|
|
"""
|
|
import os
|
|
import sys
|
|
import json
|
|
import threading
|
|
import subprocess
|
|
from typing import List, Dict
|
|
import shlex
|
|
|
|
from devchat.utils import get_logger
|
|
from .command_parser import Command
|
|
from .util import ToolUtil
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
DEVCHAT_COMMAND_MISS_ERROR_MESSAGE = (
|
|
'devchat-commands environment is not installed yet. '
|
|
'Please install it before using the current command.'
|
|
'The devchat-command environment is automatically '
|
|
'installed after the plugin starts,'
|
|
' and details can be viewed in the output window.'
|
|
)
|
|
|
|
def pipe_reader(pipe, out_data, out_flag):
|
|
while pipe:
|
|
data = pipe.read(1)
|
|
if data == '':
|
|
break
|
|
out_data['out'] += data
|
|
print(data, end='', file=out_flag, flush=True)
|
|
|
|
|
|
# Equivalent of CommandRun in Python\which executes subprocesses
|
|
class CommandRunner:
|
|
def __init__(self, model_name: str):
|
|
self.process = None
|
|
self._model_name = model_name
|
|
|
|
def run_command(self,
|
|
command_name: str,
|
|
command: Command,
|
|
history_messages: List[Dict],
|
|
input_text: str,
|
|
parent_hash: str):
|
|
"""
|
|
if command has parameters, then generate command parameters from input by LLM
|
|
if command.input is "required", and input is null, then return error
|
|
"""
|
|
input_text = input_text.strip()\
|
|
.replace(f'/{command_name}', '')\
|
|
.replace('\"', '\\"')\
|
|
.replace('\'', '\\\'')\
|
|
.replace('\n', '\\n')
|
|
|
|
arguments = {}
|
|
if command.parameters and len(command.parameters) > 0:
|
|
if not self._model_name.startswith("gpt-"):
|
|
return None
|
|
|
|
arguments = self._call_function_by_llm(command_name, command, history_messages)
|
|
if not arguments:
|
|
print("No valid parameters generated by LLM", file=sys.stderr, flush=True)
|
|
return (-1, "")
|
|
|
|
return self.run_command_with_parameters(
|
|
command_name=command_name,
|
|
command=command,
|
|
parameters={
|
|
"input": input_text,
|
|
**arguments
|
|
},
|
|
parent_hash=parent_hash,
|
|
history_messages=history_messages
|
|
)
|
|
|
|
def run_command_with_parameters(self,
|
|
command_name: str,
|
|
command: Command,
|
|
parameters: Dict[str, str],
|
|
parent_hash: str,
|
|
history_messages: List[Dict]):
|
|
"""
|
|
replace $xxx in command.steps[0].run with parameters[xxx]
|
|
then run command.steps[0].run
|
|
"""
|
|
try:
|
|
env = os.environ.copy()
|
|
env.update(parameters)
|
|
env.update(
|
|
self.__load_command_runtime(command)
|
|
)
|
|
env.update(
|
|
self.__load_chat_data(self._model_name, parent_hash, history_messages)
|
|
)
|
|
self.__update_devchat_python_path(env, command.steps[0]["run"])
|
|
|
|
command_run = command.steps[0]["run"]
|
|
for parameter in env:
|
|
command_run = command_run.replace('$' + parameter, str(env[parameter]))
|
|
|
|
if self.__check_command_python_error(command_run, env):
|
|
return (-1, "")
|
|
if self.__check_input_miss_error(command, command_name, env):
|
|
return (-1, "")
|
|
if self.__check_parameters_miss_error(command, command_run):
|
|
return (-1, "")
|
|
|
|
return self.__run_command_with_thread_output(command_run, env)
|
|
except Exception as err:
|
|
print("Exception:", type(err), err, file=sys.stderr, flush=True)
|
|
logger.exception("Run command error: %s", err)
|
|
return (-1, "")
|
|
|
|
def __run_command_with_thread_output(self, command_str: str, env: Dict[str, str]):
|
|
"""
|
|
run command string
|
|
"""
|
|
def handle_output(process):
|
|
stdout_data, stderr_data = {'out': ''}, {'out': ''}
|
|
stdout_thread = threading.Thread(
|
|
target=pipe_reader,
|
|
args=(process.stdout, stdout_data, sys.stdout))
|
|
stderr_thread = threading.Thread(
|
|
target=pipe_reader,
|
|
args=(process.stderr, stderr_data, sys.stderr))
|
|
stdout_thread.start()
|
|
stderr_thread.start()
|
|
stdout_thread.join()
|
|
stderr_thread.join()
|
|
return (process.wait(), stdout_data["out"])
|
|
|
|
for key in env:
|
|
if isinstance(env[key], (List, Dict)):
|
|
env[key] = json.dumps(env[key])
|
|
with subprocess.Popen(
|
|
shlex.split(command_str),
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
env=env,
|
|
text=True
|
|
) as process:
|
|
return handle_output(process)
|
|
|
|
def __check_command_python_error(self, command_run: str, parameters: Dict[str, str]):
|
|
need_command_python = command_run.find('$command_python ') != -1
|
|
has_command_python = parameters.get('command_python', None)
|
|
|
|
if need_command_python and not has_command_python:
|
|
print(DEVCHAT_COMMAND_MISS_ERROR_MESSAGE, file=sys.stderr, flush=True)
|
|
return True
|
|
return False
|
|
|
|
def __get_readme(self, command: Command):
|
|
try:
|
|
command_dir = os.path.dirname(command.path)
|
|
readme_file = os.path.join(command_dir, 'README.md')
|
|
if os.path.exists(readme_file):
|
|
with open(readme_file, 'r', encoding='utf8') as file:
|
|
readme = file.read()
|
|
return readme
|
|
return None
|
|
except Exception:
|
|
return None
|
|
|
|
def __check_input_miss_error(
|
|
self, command: Command, command_name: str, parameters: Dict[str, str]
|
|
):
|
|
is_input_required = command.input == "required"
|
|
if not (is_input_required and parameters["input"] == ""):
|
|
return False
|
|
|
|
input_miss_error = (
|
|
f"{command_name} workflow is missing input. Example usage: "
|
|
f"'/{command_name} user input'\n"
|
|
)
|
|
readme_content = self.__get_readme(command)
|
|
if readme_content:
|
|
print(readme_content, file=sys.stderr, flush=True)
|
|
else:
|
|
print(input_miss_error, file=sys.stderr, flush=True)
|
|
return True
|
|
|
|
def __check_parameters_miss_error(self, command: Command, command_run: str):
|
|
# visit parameters in command
|
|
parameter_names = command.parameters.keys() if command.parameters else []
|
|
if len(parameter_names) == 0:
|
|
return False
|
|
|
|
missed_parameters = []
|
|
for parameter_name in parameter_names:
|
|
if command_run.find('$' + parameter_name) != -1:
|
|
missed_parameters.append(parameter_name)
|
|
|
|
if len(missed_parameters) == 0:
|
|
return False
|
|
|
|
readme_content = self.__get_readme(command)
|
|
if readme_content:
|
|
print(readme_content, file=sys.stderr, flush=True)
|
|
else:
|
|
print("Missing parameters:", missed_parameters, file=sys.stderr, flush=True)
|
|
return True
|
|
|
|
def __load_command_runtime(self, command: Command):
|
|
command_path = os.path.dirname(command.path)
|
|
runtime_config = {}
|
|
|
|
# visit each path in command_path, for example: /usr/x1/x2/x3
|
|
# then load visit: /usr, /usr/x1, /usr/x1/x2, /usr/x1/x2/x3
|
|
paths = command_path.split('/')
|
|
for index in range(1, len(paths)+1):
|
|
try:
|
|
path = '/'.join(paths[:index])
|
|
runtime_file = os.path.join(path, 'runtime.json')
|
|
if os.path.exists(runtime_file):
|
|
with open(runtime_file, 'r', encoding='utf8') as file:
|
|
command_runtime_config = json.loads(file.read())
|
|
runtime_config.update(command_runtime_config)
|
|
except Exception:
|
|
pass
|
|
|
|
# for windows
|
|
if runtime_config.get('command_python', None):
|
|
runtime_config['command_python'] = \
|
|
runtime_config['command_python'].replace('\\', '/')
|
|
return runtime_config
|
|
|
|
def __load_chat_data(self, model_name: str, parent_hash: str, history_messages: List[Dict]):
|
|
return {
|
|
"LLM_MODEL": model_name if model_name else "",
|
|
"PARENT_HASH": parent_hash if parent_hash else "",
|
|
"CONTEXT_CONTENTS": history_messages if history_messages else [],
|
|
}
|
|
|
|
def __update_devchat_python_path(self, env: Dict[str, str], command_run: str):
|
|
python_path = os.environ.get('PYTHONPATH', '')
|
|
env['DEVCHAT_PYTHONPATH'] = os.environ.get('DEVCHAT_PYTHONPATH', python_path)
|
|
if command_run.find('$devchat_python ') == -1:
|
|
del env['PYTHONPATH']
|
|
env["devchat_python"] = sys.executable.replace('\\', '/')
|
|
|
|
def _call_function_by_llm(self,
|
|
command_name: str,
|
|
command: Command,
|
|
history_messages: List[Dict]):
|
|
"""
|
|
command needs multi parameters, so we need parse each
|
|
parameter by LLM from input_text
|
|
"""
|
|
tools = [ToolUtil.make_function(command, command_name)]
|
|
|
|
function_call = ToolUtil.select_function_by_llm(history_messages, tools)
|
|
if not function_call:
|
|
return None
|
|
|
|
return function_call["arguments"]
|