diff --git a/.gitignore b/.gitignore index 8b595b5..3a0ba57 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ node_modules .vscode/settings.json +workflows/ diff --git a/src/command/commandManager.ts b/src/command/commandManager.ts index ef168c6..83b0164 100644 --- a/src/command/commandManager.ts +++ b/src/command/commandManager.ts @@ -22,17 +22,6 @@ class CommandManager { public static getInstance(): CommandManager { if (!CommandManager.instance) { CommandManager.instance = new CommandManager(); - if (FT("ask-code")) { - CommandManager.instance.registerCommand({ - name: 'ask-code', - pattern: 'ask-code', - description: 'Ask anything about your codebase and get answers from our AI agent', - args: 0, - handler: async (commandName: string, userInput: string) => { - return ''; - } - }); - } } return CommandManager.instance; diff --git a/src/context/contextDefRefs.ts b/src/context/contextDefRefs.ts index 9459482..2d27c7c 100644 --- a/src/context/contextDefRefs.ts +++ b/src/context/contextDefRefs.ts @@ -1,9 +1,8 @@ import * as path from 'path'; -import * as vscode from 'vscode' +import * as vscode from 'vscode'; import { ChatContext } from './contextManager'; -import { createTempSubdirectory, git_ls_tree, runCommandStringAndWriteOutput } from '../util/commonUtil'; import { logger } from '../util/logger'; import { handleCodeSelected } from './contextCodeSelected'; diff --git a/src/handler/loadHandlers.ts b/src/handler/loadHandlers.ts index 0ca8a65..0f2aa90 100644 --- a/src/handler/loadHandlers.ts +++ b/src/handler/loadHandlers.ts @@ -6,7 +6,7 @@ import { doCommit } from './doCommit'; import { historyMessages } from './historyMessages'; import { regCommandList, regCommandListByDevChatRun } from './regCommandList'; import { regContextList } from './regContextList'; -import { sendMessage, stopDevChat, regeneration, deleteChatMessage, askCode } from './sendMessage'; +import { sendMessage, stopDevChat, regeneration, deleteChatMessage, userInput } from './sendMessage'; import { blockApply } from './showDiff'; import { showDiff } from './showDiff'; import { addConext } from './addContext'; @@ -83,7 +83,6 @@ messageHandler.registerHandler('applyAction', applyAction); // Response: { command: 'deletedChatMessage', result: } messageHandler.registerHandler('deleteChatMessage', deleteChatMessage); -messageHandler.registerHandler('askCode', askCode); // Execute vscode command // Response: none messageHandler.registerHandler('doCommand', doCommand); @@ -96,4 +95,6 @@ messageHandler.registerHandler('featureToggles', featureToggles); messageHandler.registerHandler('getUserAccessKey', getUserAccessKey); messageHandler.registerHandler('regModelList', regModelList); -messageHandler.registerHandler('isDevChatInstalled', isDevChatInstalled); \ No newline at end of file +messageHandler.registerHandler('isDevChatInstalled', isDevChatInstalled); + +messageHandler.registerHandler('userInput', userInput); diff --git a/src/handler/messageHandler.ts b/src/handler/messageHandler.ts index 9ba5abf..e1c83f8 100644 --- a/src/handler/messageHandler.ts +++ b/src/handler/messageHandler.ts @@ -36,6 +36,9 @@ export class MessageHandler { if (messageObject && messageObject.user && messageObject.user === 'merico-devchat') { message = messageObject; isNeedSendResponse = true; + if (messageObject.hasResponse) { + isNeedSendResponse = false; + } } } catch (e) { } @@ -53,13 +56,6 @@ export class MessageHandler { if (message.text.indexOf('/autox') !== -1) { autox = true; } - // if "/ask-code" in message.text, then call devchat-ask to get result - if (FT("ask-code")) { - if (message.text.indexOf('/ask-code') !== -1) { - message.command = 'askCode'; - message.text = message.text.replace('/ask-code', ''); - } - } } const handler = this.handlers[message.command]; diff --git a/src/handler/sendMessage.ts b/src/handler/sendMessage.ts index e30d106..c5534a4 100644 --- a/src/handler/sendMessage.ts +++ b/src/handler/sendMessage.ts @@ -11,15 +11,18 @@ import { ApiKeyManager } from '../util/apiKey'; import { logger } from '../util/logger'; import { exec as execCb } from 'child_process'; import { promisify } from 'util'; -import { CommandRun, createTempSubdirectory } from '../util/commonUtil'; +import { CommandResult, CommandRun, createTempSubdirectory } from '../util/commonUtil'; +import { WorkflowRunner } from './workflowExecutor'; +import DevChat from '../toolwrapper/devchat'; const exec = promisify(execCb); -let askcode_stop = true; -let askcode_runner : CommandRun | null = null; +let askcodeRunner : CommandRun | null = null; +let commandRunner : WorkflowRunner | null = null; let _lastMessage: any = undefined; + export function createTempFile(content: string): string { // Generate a unique file name const fileName = path.join(os.tmpdir(), `temp_${Date.now()}.txt`); @@ -35,121 +38,19 @@ export function deleteTempFiles(fileName: string): void { fs.unlinkSync(fileName); } -regInMessage({command: 'askCode', text: '', parent_hash: undefined}); -regOutMessage({ command: 'receiveMessage', text: 'xxxx', hash: 'xxx', user: 'xxx', date: 'xxx'}); -export async function askCode(message: any, panel: vscode.WebviewPanel|vscode.WebviewView): Promise { - try { - askcode_stop = false; - askcode_runner = null; - - _lastMessage = [message]; - _lastMessage[0]['askCode'] = true; - - const port = await UiUtilWrapper.getLSPBrigePort(); - - const pythonVirtualEnv: string | undefined = vscode.workspace.getConfiguration('DevChat').get('PythonVirtualEnv'); - if (!pythonVirtualEnv) { - MessageHandler.sendMessage(panel, { command: 'receiveMessage', text: "Index code fail.", hash: "", user: "", date: 0, isError: true }); - return ; - } - - let envs = { - PYTHONUTF8:1, - ...process.env, - }; - - const llmModelData = await ApiKeyManager.llmModel(); - if (!llmModelData) { - logger.channel()?.error('No valid llm model is selected!'); - logger.channel()?.show(); - return; - } - - let openaiApiKey = llmModelData.api_key; - if (!openaiApiKey) { - logger.channel()?.error('The OpenAI key is invalid!'); - logger.channel()?.show(); - return; - } - envs['OPENAI_API_KEY'] = openaiApiKey; - - const openAiApiBase = llmModelData.api_base; - if (openAiApiBase) { - envs['OPENAI_API_BASE'] = openAiApiBase; - } - - const workspaceDir = UiUtilWrapper.workspaceFoldersFirstPath(); - if (askcode_stop) { - return; - } - - try { - let outputResult = ""; - askcode_runner = new CommandRun(); - const command = pythonVirtualEnv.trim(); - const args = [UiUtilWrapper.extensionPath() + "/tools/askcode_index_query.py", "query", message.text, `${port}`]; - const result = await askcode_runner.spawnAsync(command, args, { env: envs, cwd: workspaceDir }, (data) => { - outputResult += data; - MessageHandler.sendMessage(panel, { command: 'receiveMessagePartial', text: outputResult, hash:"", user:"", isError: false }); - logger.channel()?.info(data); - }, (data) => { - logger.channel()?.error(data); - }, undefined, undefined); - - if (result.exitCode === 0) { - // save askcode result to devchat - const stepIndex = result.stdout.lastIndexOf("```Step"); - const stepEndIndex = result.stdout.lastIndexOf("```"); - let resultOut = result.stdout; - if (stepIndex > 0 && stepEndIndex > 0) { - resultOut = result.stdout.substring(stepEndIndex+3, result.stdout.length); - } - let logHash = await insertDevChatLog(message, "/ask-code " + message.text, resultOut); - if (!logHash) { - logHash = ""; - logger.channel()?.error(`Failed to insert devchat log.`); - logger.channel()?.show(); - } - - MessageHandler.sendMessage(panel, { command: 'receiveMessagePartial', text: result.stdout, hash:logHash, user:"", isError: false }); - MessageHandler.sendMessage(panel, { command: 'receiveMessage', text: result.stdout, hash:logHash, user:"", date:0, isError: false }); - - const dateStr = Math.floor(Date.now()/1000).toString(); - await handleTopic( - message.parent_hash, - {"text": "/ask-code " + message.text}, - { response: result.stdout, "prompt-hash": logHash, user: "", "date": dateStr, finish_reason: "", isError: false }); - } else { - logger.channel()?.info(`${result.stdout}`); - if (askcode_stop == false) { - MessageHandler.sendMessage(panel, { command: 'receiveMessage', text: result.stderr, hash: "", user: "", date: 0, isError: true }); - } - } - } catch (error) { - if (error instanceof Error) { - logger.channel()?.error(`error: ${error.message}`); - } else { - logger.channel()?.error(`An unknown error occurred: ${error}`); - } - logger.channel()?.show(); - MessageHandler.sendMessage(panel, { command: 'receiveMessage', text: "Did not get relevant context from AskCode.", hash: "", user: "", date: 0, isError: true }); - } - } finally { - askcode_stop = true; - askcode_runner = null; - } +regInMessage({command: 'userInput', text: '{"field": "value", "field2": "value2"}'});; +export async function userInput(message: any, panel: vscode.WebviewPanel|vscode.WebviewView): Promise { + commandRunner?.input(message.text); } +// eslint-disable-next-line @typescript-eslint/naming-convention regInMessage({command: 'sendMessage', text: '', parent_hash: undefined}); regOutMessage({ command: 'receiveMessage', text: 'xxxx', hash: 'xxx', user: 'xxx', date: 'xxx'}); regOutMessage({ command: 'receiveMessagePartial', text: 'xxxx', user: 'xxx', date: 'xxx'}); -// message: { command: 'sendMessage', text: 'xxx', hash: 'xxx'} -// return message: -// { command: 'receiveMessage', text: 'xxxx', hash: 'xxx', user: 'xxx', date: 'xxx'} -// { command: 'receiveMessagePartial', text: 'xxxx', user: 'xxx', date: 'xxx'} -export async function sendMessage(message: any, panel: vscode.WebviewPanel|vscode.WebviewView, function_name: string|undefined = undefined): Promise { - if (function_name !== undefined && function_name !== "") { +export async function sendMessage(message: any, panel: vscode.WebviewPanel|vscode.WebviewView, functionName: string|undefined = undefined): Promise { + // check whether the message is a command + if (functionName !== undefined && functionName !== "") { const messageText = _lastMessage[0].text.trim(); if (messageText[0] === '/' && message.text[0] !== '/') { const indexS = messageText.indexOf(' '); @@ -161,7 +62,39 @@ export async function sendMessage(message: any, panel: vscode.WebviewPanel|vscod message.text = preCommand + ' ' + message.text; } } - _lastMessage = [message, function_name]; + _lastMessage = [message, functionName]; + + const messageText = message.text.trim(); + if (messageText[0] === '/') { + // split messageText by ' ' or '\n' or '\t' + const messageTextArr = messageText.split(/ |\n|\t/); + // get command name from messageTextArr + const commandName = messageTextArr[0].substring(1); + // test whether the command is a execute command + const devChat = new DevChat(); + const stdout = await devChat.commandPrompt(commandName); + // try parse stdout by json + let stdoutJson: any = null; + try { + stdoutJson = JSON.parse(stdout); + } catch (error) { + // do nothing + } + + if (stdoutJson) { + // run command + try { + commandRunner = null; + + commandRunner = new WorkflowRunner(); + await commandRunner.run(commandName, stdoutJson, message, panel); + } finally { + commandRunner = null; + } + + return ; + } + } // Add a new field to store the names of temporary files let tempFiles: string[] = []; @@ -192,7 +125,7 @@ export async function sendMessage(message: any, panel: vscode.WebviewPanel|vscod const responseMessage = await sendMessageBase(message, (data: { command: string, text: string, user: string, date: string}) => { MessageHandler.sendMessage(panel, data, false); - }, function_name); + }, functionName); if (responseMessage) { MessageHandler.sendMessage(panel, responseMessage); } @@ -209,11 +142,7 @@ regInMessage({command: 'regeneration'}); export async function regeneration(message: any, panel: vscode.WebviewPanel|vscode.WebviewView): Promise { // call sendMessage to send last message again if (_lastMessage) { - if (_lastMessage[0]['askCode']) { - await askCode(_lastMessage[0], panel); - } else { - await sendMessage(_lastMessage[0], panel, _lastMessage[1]); - } + await sendMessage(_lastMessage[0], panel, _lastMessage[1]); } } @@ -221,13 +150,9 @@ regInMessage({command: 'stopDevChat'}); export async function stopDevChat(message: any, panel: vscode.WebviewPanel|vscode.WebviewView): Promise { stopDevChatBase(message); - if (askcode_stop === false) { - askcode_stop = true; - if (askcode_runner) { - askcode_runner.stop(); - askcode_runner = null; - } - await vscode.commands.executeCommand('DevChat.AskCodeIndexStop'); + if (commandRunner) { + commandRunner.stop(); + commandRunner = null; } } diff --git a/src/handler/workflowExecutor.ts b/src/handler/workflowExecutor.ts new file mode 100644 index 0000000..9f0ae77 --- /dev/null +++ b/src/handler/workflowExecutor.ts @@ -0,0 +1,281 @@ + + + +// TODO +// 临时解决方案,后续需要修改 + +import * as vscode from 'vscode'; +import { UiUtilWrapper } from "../util/uiUtil"; +import { MessageHandler } from "../handler/messageHandler"; +import { ApiKeyManager } from "../util/apiKey"; +import { logger } from "../util/logger"; +import { CommandResult, CommandRun, saveModelSettings } from "../util/commonUtil"; +import { handleTopic, insertDevChatLog } from "./sendMessageBase"; +import { regInMessage } from "@/util/reg_messages"; +import parseArgsStringToArgv from 'string-argv'; + + +async function handleWorkflowRequest(request): Promise { + /* + request: { + "command": "some command", + "args": { + "arg1": "value1", + "arg2": "value2" + } + } + response: { + "status": "success", + "result": "success", + "detail": "some detail" + } + */ + if (!request || !request.command) { + return undefined; + } + + if (request.command === "get_lsp_brige_port") { + return JSON.stringify({ + "status": "success", + "result": await UiUtilWrapper.getLSPBrigePort() + }); + } else { + return JSON.stringify({ + "status": "fail", + "result": "fail", + "detail": "command is not supported" + }); + } +} + + +// TODO +// 临时解决方案,后续需要修改 +// 执行workflow + +// workflow执行时,都是通过启动一个进程的方式来执行。 +// 与一般进程不同的是: +// 1. 通过UI交互可以停止该进程; +// 2. 需要在进程启动前初始化相关的环境变量 +// 3. 需要处理进程的通信 + + +export class WorkflowRunner { + private _commandRunner: CommandRun | null = null; + private _stop: boolean = false; + private _cacheOut: string = ""; + private _panel: vscode.WebviewPanel|vscode.WebviewView | null = null; + + constructor() {} + + private async _getApiKeyAndApiBase(): Promise<[string | undefined, string | undefined]> { + const llmModelData = await ApiKeyManager.llmModel(); + if (!llmModelData) { + logger.channel()?.error('No valid llm model is selected!'); + logger.channel()?.show(); + return [undefined, undefined]; + } + + let openaiApiKey = llmModelData.api_key; + if (!openaiApiKey) { + logger.channel()?.error('The OpenAI key is invalid!'); + logger.channel()?.show(); + return [undefined, undefined]; + } + + const openAiApiBase = llmModelData.api_base; + return [openaiApiKey, openAiApiBase]; + } + + private _parseCommandOutput(outputStr: string): string { + /* + output is format as: + <> + {"content": "data"} + <> + */ + const outputWitchCache = this._cacheOut + outputStr; + this._cacheOut = ""; + + let outputResult = ""; + let curPos = 0; + while (true) { + const startPos = outputWitchCache.indexOf('<>', curPos); + const startPos2 = outputWitchCache.indexOf('```', curPos); + if (startPos === -1 && startPos2 === -1) { + break; + } + + const isStart = (startPos2 === -1) || (startPos > -1 && startPos < startPos2); + + let endPos = -1; + if (isStart) { + endPos = outputWitchCache.indexOf('<>', startPos+9); + } else { + endPos = outputWitchCache.indexOf('```', startPos2+3); + } + + if (endPos === -1) { + this._cacheOut = outputWitchCache.substring(startPos, outputWitchCache.length); + break; + } + + let contentStr = ""; + if (isStart) { + contentStr = outputWitchCache.substring(startPos+9, endPos); + curPos = endPos+7; + } else { + contentStr = outputWitchCache.substring(startPos2, endPos+3); + curPos = endPos+3; + } + + outputResult += contentStr.trim() + "\n\n"; + } + + return outputResult; + } + + private async _runCommand(commandWithArgs: string, commandEnvs: any): Promise<[CommandResult | undefined, string]> { + const workspaceDir = UiUtilWrapper.workspaceFoldersFirstPath() || ""; + let commandOutput = ""; + let commandAnswer = ""; + + try { + const commandAndArgsList = parseArgsStringToArgv(commandWithArgs); + this._commandRunner = new CommandRun(); + await saveModelSettings(); + const result = await this._commandRunner.spawnAsync(commandAndArgsList[0], commandAndArgsList.slice(1), { env: commandEnvs, cwd: workspaceDir }, async (data) => { + // handle command stdout + const newData = this._parseCommandOutput(data); + // if newData is json string, then process it by handleWorkflowRequest + let newDataObj: any = undefined; + try { + newDataObj = JSON.parse(newData); + const result = await handleWorkflowRequest(newDataObj); + if (result) { + + this.input(result); + } else if (newDataObj!.result) { + commandAnswer = newDataObj!.result; + commandOutput += newDataObj!.result; + logger.channel()?.info(newDataObj!.result); + MessageHandler.sendMessage(this._panel!, { command: 'receiveMessagePartial', text: commandOutput, hash:"", user:"", isError: false }); + } + } catch (e) { + if (newData.length > 0){ + commandOutput += newData; + logger.channel()?.info(newData); + MessageHandler.sendMessage(this._panel!, { command: 'receiveMessagePartial', text: commandOutput, hash:"", user:"", isError: false }); + } + } + }, (data) => { + // handle command stderr + logger.channel()?.error(data); + logger.channel()?.show(); + }, undefined, undefined); + + + return [result, commandAnswer]; + } catch (error) { + if (error instanceof Error) { + logger.channel()?.error(`error: ${error.message}`); + } else { + logger.channel()?.error(`An unknown error occurred: ${error}`); + } + logger.channel()?.show(); + } + return [undefined, ""]; + } + + public stop(): void { + this._stop = true; + if (this._commandRunner) { + this._commandRunner.stop(); + this._commandRunner = null; + } + } + + public input(data): void { + const userInputWithFlag = `\n<>\n${data}\n<>\n`; + this._commandRunner?.write(userInputWithFlag); + } + + public async run(workflow: string, commandDefines: any, message: any, panel: vscode.WebviewPanel|vscode.WebviewView): Promise { + /* + 1. 判断workflow是否有输入存在 + 2. 获取workflow的环境变量信息 + 3. 执行workflow command + 4. 处理workflow command输出 + */ + + this._panel = panel; + + // 获取workflow的python命令 + const pythonVirtualEnv: string | undefined = vscode.workspace.getConfiguration('DevChat').get('PythonVirtualEnv'); + if (!pythonVirtualEnv) { + MessageHandler.sendMessage(panel, { command: 'receiveMessage', text: "Index code fail.", hash: "", user: "", date: 0, isError: true }); + return ; + } + + // 获取扩展路径 + const extensionPath = UiUtilWrapper.extensionPath(); + + // 获取api_key 和 api_base + const [apiKey, aipBase] = await this._getApiKeyAndApiBase(); + if (!apiKey) { + MessageHandler.sendMessage(panel, { command: 'receiveMessage', text: "The OpenAI key is invalid!", hash: "", user: "", date: 0, isError: true }); + return ; + } + + // 构建子进程环境变量 + const workflowEnvs = { + // eslint-disable-next-line @typescript-eslint/naming-convention + "PYTHONUTF8":1, + "DEVCHATPYTHON": UiUtilWrapper.getConfiguration("DevChat", "PythonPath") || "python3", + "PYTHONLIBPATH": `${extensionPath}/tools/site-packages`, + "PARENT_HASH": message.parent_hash, + ...process.env, + // eslint-disable-next-line @typescript-eslint/naming-convention + OPENAI_API_KEY: apiKey, + // eslint-disable-next-line @typescript-eslint/naming-convention + ...(aipBase ? { 'OPENAI_API_BASE': aipBase } : {}) + }; + + const requireInput = commandDefines.input === "required"; + if (requireInput && message.text.replace("/" + workflow, "").trim() === "") { + MessageHandler.sendMessage(panel, { command: 'receiveMessage', text: `The workflow ${workflow} need input!`, hash: "", user: "", date: 0, isError: true }); + return ; + } + + const workflowCommand = commandDefines.steps[0].run.replace( + '$command_python', `${pythonVirtualEnv}`).replace( + '$input', `${message.text.replace("/" + workflow, "").trim()}`); + + const [commandResult, commandAnswer] = await this._runCommand(workflowCommand, workflowEnvs); + + if (commandResult && commandResult.exitCode === 0) { + const resultOut = commandAnswer === "" ? "success" : commandAnswer; + let logHash = await insertDevChatLog(message, message.text, resultOut); + if (!logHash) { + logHash = ""; + logger.channel()?.error(`Failed to insert devchat log.`); + logger.channel()?.show(); + } + + //MessageHandler.sendMessage(panel, { command: 'receiveMessagePartial', text: resultOut, hash:logHash, user:"", isError: false }); + MessageHandler.sendMessage(panel, { command: 'receiveMessage', text: resultOut, hash:logHash, user:"", date:0, isError: false }); + + const dateStr = Math.floor(Date.now()/1000).toString(); + await handleTopic( + message.parent_hash, + {"text": message.text}, + // eslint-disable-next-line @typescript-eslint/naming-convention + { response: resultOut, "prompt-hash": logHash, user: "", "date": dateStr, finish_reason: "", isError: false }); + } else if (commandResult) { + logger.channel()?.info(`${commandResult.stdout}`); + if (this._stop === false) { + MessageHandler.sendMessage(panel, { command: 'receiveMessage', text: commandResult.stderr, hash: "", user: "", date: 0, isError: true }); + } + } + } +} diff --git a/src/toolwrapper/devchat.ts b/src/toolwrapper/devchat.ts index 2216141..aefc754 100644 --- a/src/toolwrapper/devchat.ts +++ b/src/toolwrapper/devchat.ts @@ -4,12 +4,11 @@ import * as path from 'path'; import * as fs from 'fs'; import { logger } from '../util/logger'; -import { CommandRun } from "../util/commonUtil"; +import { CommandRun, saveModelSettings } from "../util/commonUtil"; import ExtensionContextHolder from '../util/extensionContext'; import { UiUtilWrapper } from '../util/uiUtil'; import { ApiKeyManager } from '../util/apiKey'; import { exitCode } from 'process'; -import * as yaml from 'yaml'; const envPath = path.join(__dirname, '..', '.env'); @@ -217,8 +216,7 @@ class DevChat { logger.channel()?.show(); } - const openaiStream = UiUtilWrapper.getConfiguration('DevChat', 'OpenAI.stream'); - + // eslint-disable-next-line @typescript-eslint/naming-convention const openAiApiBaseObject = llmModelData.api_base? { OPENAI_API_BASE: llmModelData.api_base } : {}; const activeLlmModelKey = llmModelData.api_key; @@ -227,31 +225,7 @@ class DevChat { devChat = 'devchat'; } - const reduceModelData = Object.keys(llmModelData) - .filter(key => key !== 'api_key' && key !== 'provider' && key !== 'model' && key !== 'api_base') - .reduce((obj, key) => { - obj[key] = llmModelData[key]; - return obj; - }, {}); - let devchatConfig = {}; - devchatConfig[llmModelData.model] = { - "provider": llmModelData.provider, - "stream": openaiStream, - ...reduceModelData - }; - - let devchatModels = { - "default_model": llmModelData.model, - "models": devchatConfig}; - - // write to config file - const os = process.platform; - const userHome = os === 'win32' ? fs.realpathSync(process.env.USERPROFILE || '') : process.env.HOME; - - const configPath = path.join(userHome!, '.chat', 'config.yml'); - // write devchatConfig to configPath - const yamlString = yaml.stringify(devchatModels); - fs.writeFileSync(configPath, yamlString); + await saveModelSettings(); try { diff --git a/src/util/apiKey.ts b/src/util/apiKey.ts index 9874abf..c8f9aeb 100644 --- a/src/util/apiKey.ts +++ b/src/util/apiKey.ts @@ -6,19 +6,7 @@ export class ApiKeyManager { static toProviderKey(provider: string) : string | undefined { let providerNameMap = { "openai": "OpenAI", - "devchat": "DevChat", - "cohere": "Cohere", - "anthropic": "Anthropic", - "replicate": "Replicate", - "huggingface": "HuggingFace", - "together_ai": "TogetherAI", - "openrouter": "OpenRouter", - "vertex_ai": "VertexAI", - "ai21": "AI21", - "baseten": "Baseten", - "azure": "Azure", - "sagemaker": "SageMaker", - "bedrock": "Bedrock" + "devchat": "DevChat" }; return providerNameMap[provider]; } @@ -152,6 +140,7 @@ export class ApiKeyManager { return undefined; } } + if (apiBase) { modelProperties["api_base"] = apiBase; } else if (!apiKey) { diff --git a/src/util/commonUtil.ts b/src/util/commonUtil.ts index bd95ac7..8684feb 100644 --- a/src/util/commonUtil.ts +++ b/src/util/commonUtil.ts @@ -1,6 +1,9 @@ +/* eslint-disable @typescript-eslint/naming-convention */ import * as fs from 'fs'; import * as os from 'os'; import * as path from 'path'; +import * as yaml from 'yaml'; +import * as vscode from 'vscode'; import * as childProcess from 'child_process'; import { parseArgsStringToArgv } from 'string-argv'; @@ -10,6 +13,56 @@ import { spawn, exec } from 'child_process'; import { UiUtilWrapper } from './uiUtil'; import { ApiKeyManager } from './apiKey'; + +export async function saveModelSettings(): Promise { + // support models + const supportModels = { + "Model.gpt-3-5": "gpt-3.5-turbo", + "Model.gpt-3-5-1106": "gpt-3.5-turbo-1106", + "Model.gpt-3-5-16k": "gpt-3.5-turbo-16k", + "Model.gpt-4": "gpt-4", + "Model.gpt-4-turbo": "gpt-4-1106-preview", + "Model.claude-2": "claude-2", + "Model.xinghuo-2": "xinghuo-2", + "Model.chatglm_pro": "chatglm_pro", + "Model.ERNIE-Bot": "ERNIE-Bot", + "Model.CodeLlama-34b-Instruct": "CodeLlama-34b-Instruct", + "Model.llama-2-70b-chat": "llama-2-70b-chat" + }; + + // is enable stream + const openaiStream = UiUtilWrapper.getConfiguration('DevChat', 'OpenAI.stream'); + + let devchatConfig = {}; + for (const model of Object.keys(supportModels)) { + const modelConfig = UiUtilWrapper.getConfiguration('devchat', model); + if (modelConfig) { + devchatConfig[supportModels[model]] = { + "stream": openaiStream + }; + for (const key of Object.keys(modelConfig || {})) { + const property = modelConfig![key]; + devchatConfig[supportModels[model]][key] = property; + } + } + } + + let devchatModels = { + // eslint-disable-next-line @typescript-eslint/naming-convention + "default_model": "gpt-3.5-turbo-16k", + "models": devchatConfig + }; + + // write to config file + const os = process.platform; + const userHome = os === 'win32' ? fs.realpathSync(process.env.USERPROFILE || '') : process.env.HOME; + + const configPath = path.join(userHome!, '.chat', 'config.yml'); + // write devchatConfig to configPath + const yamlString = yaml.stringify(devchatModels); + fs.writeFileSync(configPath, yamlString); +} + async function createOpenAiKeyEnv() { let envs = {...process.env}; const llmModelData = await ApiKeyManager.llmModel(); @@ -106,6 +159,7 @@ export class CommandRun { logger.channel()?.show(); } + this.childProcess = null; if (code === 0) { resolve({ exitCode: code, stdout, stderr }); } else { @@ -115,6 +169,7 @@ export class CommandRun { // Add error event listener to handle command not found exception this.childProcess.on('error', (error: any) => { + this.childProcess = null; let errorMessage = error.message; if (error.code === 'ENOENT') { errorMessage = `Command not found: ${command}`; @@ -129,6 +184,12 @@ export class CommandRun { }); }; + public write(input: string) { + if (this.childProcess) { + this.childProcess.stdin.write(input); + } + } + public stop() { if (this.childProcess) { this.childProcess.kill(); @@ -236,7 +297,7 @@ export function runCommandStringAndWriteOutputSync(command: string, outputFile: } } -export function git_ls_tree(withAbsolutePath: boolean = false): string[] { +export function gitLsTree(withAbsolutePath: boolean = false): string[] { // Run the git ls-tree command const workspacePath = UiUtilWrapper.workspaceFoldersFirstPath() || '.'; const result = childProcess.execSync('git ls-tree -r --name-only HEAD', { diff --git a/tools b/tools index 9049c4a..d1f8662 160000 --- a/tools +++ b/tools @@ -1 +1 @@ -Subproject commit 9049c4aee8ed8329d1ad799134d98200cf149f8e +Subproject commit d1f8662061e9b857ac78db362077c1d76868377e