update apply diff logic

This commit is contained in:
bobo.yang 2023-05-23 08:59:42 +08:00
parent 1ea353cfd9
commit 25915cec29
3 changed files with 164 additions and 95 deletions

View File

@ -71,7 +71,12 @@ async function getNewCode(message: any) : Promise<string | undefined> {
let newCode = message.content;
if (isValidActionString(message.content)) {
newCode = applyCodeChanges(codeTextObj.text, message.content);
if (codeTextObj.select) {
const diffResult = applyCodeChanges(codeTextObj.select, message.content);
newCode = codeTextObj.beforSelect + diffResult + codeTextObj.afterSelect;
} else {
newCode = applyCodeChanges(codeTextObj.text, message.content);
}
} else {
// if select some text, then reconstruct the code
if (codeTextObj.select) {

View File

@ -57,7 +57,7 @@ class DevChat {
this.commandRun.stop();
}
async chat(content: string, options: ChatOptions = {}, onData: (data: ChatResponse) => void): Promise<ChatResponse> {
async buildArgs(options: ChatOptions): Promise<string[]> {
let args = ["prompt"];
if (options.reference) {
@ -76,10 +76,14 @@ class DevChat {
}
}
args.push(content)
if (options.parent) {
args.push("-p", options.parent);
}
const workspaceDir = vscode.workspace.workspaceFolders?.[0].uri.fsPath;
return args;
}
async getOpenAiApiKey(): Promise<string | undefined> {
const secretStorage: vscode.SecretStorage = ExtensionContextHolder.context!.secrets;
let openaiApiKey = await secretStorage.get("devchat_OPENAI_API_KEY");
if (!openaiApiKey) {
@ -88,11 +92,71 @@ class DevChat {
if (!openaiApiKey) {
openaiApiKey = process.env.OPENAI_API_KEY;
}
return openaiApiKey;
}
private parseOutData(stdout: string, isPartial: boolean): ChatResponse {
const responseLines = stdout.trim().split("\n");
if (responseLines.length < 2) {
return {
"prompt-hash": "",
user: "",
date: "",
response: "",
isError: isPartial ? false : true,
};
}
const userLine = responseLines.shift()!;
const user = (userLine.match(/User: (.+)/)?.[1]) ?? "";
const dateLine = responseLines.shift()!;
const date = (dateLine.match(/Date: (.+)/)?.[1]) ?? "";
let promptHashLine = "";
for (let i = responseLines.length - 1; i >= 0; i--) {
if (responseLines[i].startsWith("prompt")) {
promptHashLine = responseLines[i];
responseLines.splice(i, 1);
break;
}
}
if (!promptHashLine) {
return {
"prompt-hash": "",
user: user,
date: date,
response: responseLines.join("\n"),
isError: isPartial ? false : true,
};
}
const promptHash = promptHashLine.split(" ")[1];
const response = responseLines.join("\n");
return {
"prompt-hash": promptHash,
user,
date,
response,
isError: false,
};
}
async chat(content: string, options: ChatOptions = {}, onData: (data: ChatResponse) => void): Promise<ChatResponse> {
const args = await this.buildArgs(options);
args.push(content);
const workspaceDir = vscode.workspace.workspaceFolders?.[0].uri.fsPath;
let openaiApiKey = await this.getOpenAiApiKey();
if (!openaiApiKey) {
logger.channel()?.error('openAI key is invalid!');
logger.channel()?.show();
}
const openAiApiBase = vscode.workspace.getConfiguration('DevChat').get('OpenAI.EndPoint');
const openAiApiBaseObject = openAiApiBase ? { OPENAI_API_BASE: openAiApiBase } : {};
@ -107,10 +171,6 @@ class DevChat {
devChat = 'devchat';
}
if (options.parent) {
args.push("-p", options.parent);
}
const devchatConfig = {
model: openaiModel,
provider: llmModel,
@ -127,61 +187,11 @@ class DevChat {
fs.writeFileSync(configPath, configJson);
try {
const parseOutData = (stdout: string, isPartial: boolean) => {
const responseLines = stdout.trim().split("\n");
if (responseLines.length < 2) {
return {
"prompt-hash": "",
user: "",
date: "",
response: "",
isError: isPartial ? false : true,
};
}
const userLine = responseLines.shift()!;
const user = (userLine.match(/User: (.+)/)?.[1]) ?? "";
const dateLine = responseLines.shift()!;
const date = (dateLine.match(/Date: (.+)/)?.[1]) ?? "";
let promptHashLine = "";
for (let i = responseLines.length - 1; i >= 0; i--) {
if (responseLines[i].startsWith("prompt")) {
promptHashLine = responseLines[i];
responseLines.splice(i, 1);
break;
}
}
if (!promptHashLine) {
return {
"prompt-hash": "",
user: user,
date: date,
response: responseLines.join("\n"),
isError: isPartial ? false : true,
};
}
const promptHash = promptHashLine.split(" ")[1];
const response = responseLines.join("\n");
return {
"prompt-hash": promptHash,
user,
date,
response,
isError: false,
};
};
let receviedStdout = "";
const onStdoutPartial = (stdout: string) => {
receviedStdout += stdout;
const data = parseOutData(receviedStdout, true);
const data = this.parseOutData(receviedStdout, true);
onData(data);
};
@ -207,7 +217,7 @@ class DevChat {
};
}
const response = parseOutData(stdout, false);
const response = this.parseOutData(stdout, false);
return response;
} catch (error: any) {
return {
@ -226,30 +236,25 @@ class DevChat {
const workspaceDir = vscode.workspace.workspaceFolders?.[0].uri.fsPath;
const openaiApiKey = process.env.OPENAI_API_KEY;
try {
logger.channel()?.info(`Running devchat with args: ${args.join(" ")}`);
const { exitCode: code, stdout, stderr } = await this.commandRun.spawnAsync(devChat, args, {
maxBuffer: 10 * 1024 * 1024, // Set maxBuffer to 10 MB
cwd: workspaceDir,
env: {
...process.env,
OPENAI_API_KEY: openaiApiKey,
},
}, undefined, undefined, undefined, undefined);
logger.channel()?.info(`Running devchat with args: ${args.join(" ")}`);
const spawnOptions = {
maxBuffer: 10 * 1024 * 1024, // Set maxBuffer to 10 MB
cwd: workspaceDir,
env: {
...process.env,
OPENAI_API_KEY: openaiApiKey,
},
};
const { exitCode: code, stdout, stderr } = await this.commandRun.spawnAsync(devChat, args, spawnOptions, undefined, undefined, undefined, undefined);
logger.channel()?.info(`Finish devchat with args: ${args.join(" ")}`);
if (stderr) {
logger.channel()?.error(`Error getting log: ${stderr}`);
logger.channel()?.show();
return [];
}
return JSON.parse(stdout.trim()).reverse();
} catch (error) {
logger.channel()?.error(`Error getting log: ${error}`);
logger.channel()?.info(`Finish devchat with args: ${args.join(" ")}`);
if (stderr) {
logger.channel()?.error(`Error getting log: ${stderr}`);
logger.channel()?.show();
return [];
}
return JSON.parse(stdout.trim()).reverse();
}
private buildLogArgs(options: LogOptions): string[] {

View File

@ -5,58 +5,117 @@ type Action = {
action: "delete" | "insert" | "modify";
content?: string;
insert_after?: string;
insert_before?: string;
original_content?: string;
new_content?: string;
};
function findMatchingIndex(list1: string[], list2: string[]): number {
function findMatchingIndex(list1: string[], list2: string[]): number[] {
logger.channel()?.info(`findMatchingIndex start: ${list2.join('\n')}`);
const matchingIndexes: number[] = [];
for (let i = 0; i <= list1.length - list2.length; i++) {
let isMatch = true;
for (let j = 0; j < list2.length; j++) {
if (list1[i + j].trim() !== list2[j].trim()) {
if (j > 0) {
logger.channel()?.info(`findMatchingIndex end at ${j} ${list1[i + j].trim()} != ${list2[j].trim()}`);
}
isMatch = false;
break;
}
}
if (isMatch) {
return i;
matchingIndexes.push(i);
}
}
return -1;
logger.channel()?.info(`findMatchingIndex result: ${matchingIndexes.join(' ')}`);
return matchingIndexes;
}
export function applyCodeChanges(originalCode: string, actionsString: string): string {
const actions = JSON.parse(actionsString) as Array<Action>;
const lines = originalCode.split('\n');
// 构建与lines等长的数组用于记录哪些行被修改过
const modifiedIndexes: number[] = new Array(lines.length).fill(0);
// 构建子函数,用于在多个匹配索引中找出最优的索引
const findOptimalMatchingIndex = (matchingIndexList: number[]) => {
// 优先找出未被修改过的索引
const optimalMatchingIndex = matchingIndexList.find(index => modifiedIndexes[index] === 0);
// 如果所有索引都被修改过,则找出第一个索引
if (optimalMatchingIndex === undefined) {
if (matchingIndexList.length > 0) {
return matchingIndexList[0];
} else {
return undefined;
}
}
return optimalMatchingIndex;
};
for (const action of actions) {
const contentLines = action.content?.split('\n') || [];
const insertAfterLines = action.insert_after?.split('\n') || [];
const insertBeforeLines = action.insert_before?.split('\n') || [];
const originalContentLines = action.original_content?.split('\n') || [];
switch (action.action) {
case 'delete':
// find the matching index
const matchingIndex = findMatchingIndex(lines, contentLines);
if (matchingIndex !== -1) {
lines.splice(matchingIndex, contentLines.length);
const matchingIndexList = findMatchingIndex(lines, contentLines);
const optimalMatchingIndex = findOptimalMatchingIndex(matchingIndexList);
if (matchingIndexList.length > 0) {
if (optimalMatchingIndex !== undefined) {
lines.splice(optimalMatchingIndex, contentLines.length);
// 同步删除modifiedIndexes中记录
modifiedIndexes.splice(optimalMatchingIndex, contentLines.length);
}
}
break;
case 'insert':
// find the matching index
const matchingIndex2 = findMatchingIndex(lines, insertAfterLines);
if (matchingIndex2 !== -1) {
lines.splice(matchingIndex2 + 1, 0, ...contentLines);
if (insertBeforeLines.length > 0) {
const matchingIndexList1 = findMatchingIndex(lines, insertBeforeLines);
const optimalMatchingIndex1 = findOptimalMatchingIndex(matchingIndexList1);
if (matchingIndexList1.length > 0) {
if (optimalMatchingIndex1 !== undefined) {
lines.splice(optimalMatchingIndex1, 0, ...contentLines);
// 同步modifiedIndexes添加记录
modifiedIndexes.splice(optimalMatchingIndex1, 0, ...new Array(contentLines.length).fill(1));
}
}
}
if (insertAfterLines.length > 0) {
const matchingIndexList2 = findMatchingIndex(lines, insertAfterLines);
const optimalMatchingIndex2 = findOptimalMatchingIndex(matchingIndexList2);
if (matchingIndexList2.length > 0) {
if (optimalMatchingIndex2 !== undefined) {
lines.splice(optimalMatchingIndex2 + insertAfterLines.length, 0, ...contentLines);
// 同步modifiedIndexes添加记录
modifiedIndexes.splice(optimalMatchingIndex2 + insertAfterLines.length, 0, ...new Array(contentLines.length).fill(1));
}
}
}
break;
case 'modify':
// find the matching index
const matchingIndex3 = findMatchingIndex(lines, originalContentLines);
if (matchingIndex3 !== -1) {
lines.splice(matchingIndex3, originalContentLines.length, ...action.new_content!.split('\n'));
const matchingIndexList3 = findMatchingIndex(lines, originalContentLines);
const optimalMatchingIndex3 = findOptimalMatchingIndex(matchingIndexList3);
if (matchingIndexList3.length > 0) {
if (optimalMatchingIndex3 !== undefined) {
lines.splice(optimalMatchingIndex3, originalContentLines.length, ...action.new_content!.split('\n'));
// 同步modifiedIndexes添加记录
modifiedIndexes.splice(optimalMatchingIndex3, originalContentLines.length, ...new Array(action.new_content!.split('\n').length).fill(1));
}
}
break;
}
@ -80,7 +139,7 @@ export function isValidActionString(actionString: string): boolean {
return false;
}
if (action.action === "insert" && (!action.content || !action.insert_after)) {
if (action.action === "insert" && (!action.content || (!action.insert_after && !action.insert_before))) {
logger.channel()?.error(`Invalid action string: ${action}`);
return false;
}