diff --git a/packages/opencode/src/util/rpc.ts b/packages/opencode/src/util/rpc.ts index 02586ebcfc60..b4c564378331 100644 --- a/packages/opencode/src/util/rpc.ts +++ b/packages/opencode/src/util/rpc.ts @@ -2,12 +2,21 @@ type Definition = { [method: string]: (input: any) => any } +function errorMessage(error: unknown) { + if (error instanceof Error) return error.message + return String(error) +} + export function listen(rpc: Definition) { onmessage = async (evt) => { const parsed = JSON.parse(evt.data) if (parsed.type === "rpc.request") { - const result = await rpc[parsed.method](parsed.input) - postMessage(JSON.stringify({ type: "rpc.result", result, id: parsed.id })) + try { + const result = await rpc[parsed.method](parsed.input) + postMessage(JSON.stringify({ type: "rpc.result", result, id: parsed.id })) + } catch (error) { + postMessage(JSON.stringify({ type: "rpc.error", error: errorMessage(error), id: parsed.id })) + } } } } @@ -20,15 +29,22 @@ export function client(target: { postMessage: (data: string) => void | null onmessage: ((this: Worker, ev: MessageEvent) => any) | null }) { - const pending = new Map void>() + const pending = new Map void; reject: (error: Error) => void }>() const listeners = new Map void>>() let id = 0 target.onmessage = async (evt) => { const parsed = JSON.parse(evt.data) if (parsed.type === "rpc.result") { - const resolve = pending.get(parsed.id) - if (resolve) { - resolve(parsed.result) + const request = pending.get(parsed.id) + if (request) { + request.resolve(parsed.result) + pending.delete(parsed.id) + } + } + if (parsed.type === "rpc.error") { + const request = pending.get(parsed.id) + if (request) { + request.reject(new Error(parsed.error)) pending.delete(parsed.id) } } @@ -44,9 +60,14 @@ export function client(target: { return { call(method: Method, input: Parameters[0]): Promise> { const requestId = id++ - return new Promise((resolve) => { - pending.set(requestId, resolve) - target.postMessage(JSON.stringify({ type: "rpc.request", method, input, id: requestId })) + return new Promise((resolve, reject) => { + pending.set(requestId, { resolve, reject }) + try { + target.postMessage(JSON.stringify({ type: "rpc.request", method, input, id: requestId })) + } catch (error) { + pending.delete(requestId) + reject(error) + } }) }, on(event: string, handler: (data: Data) => void) { diff --git a/packages/opencode/test/util/rpc.test.ts b/packages/opencode/test/util/rpc.test.ts new file mode 100644 index 000000000000..c246381eb736 --- /dev/null +++ b/packages/opencode/test/util/rpc.test.ts @@ -0,0 +1,66 @@ +import { expect, test } from "bun:test" +import { Rpc } from "../../src/util/rpc" + +test("rpc listener returns handler errors without rejecting the message handler", async () => { + const global = globalThis as typeof globalThis & { + onmessage?: (evt: { data: string }) => Promise + postMessage?: (data: string) => void + } + const previousOnMessage = global.onmessage + const previousPostMessage = global.postMessage + const messages: string[] = [] + + global.postMessage = (data) => messages.push(data) + + try { + Rpc.listen({ + fail: () => { + throw new Error("boom") + }, + }) + + await expect( + global.onmessage?.({ + data: JSON.stringify({ type: "rpc.request", method: "fail", id: 1 }), + }), + ).resolves.toBeUndefined() + + expect(JSON.parse(messages[0] ?? "{}")).toMatchObject({ + type: "rpc.error", + id: 1, + error: "boom", + }) + } finally { + global.onmessage = previousOnMessage + global.postMessage = previousPostMessage + } +}) + +test("rpc client rejects rpc errors and can continue handling later results", async () => { + const target: { + postMessage: (data: string) => void + onmessage: ((evt: { data: string }) => void) | null + } = { + postMessage(data) { + const parsed = JSON.parse(data) + target.onmessage?.({ + data: JSON.stringify( + parsed.method === "fail" + ? { type: "rpc.error", id: parsed.id, error: "boom" } + : { type: "rpc.result", id: parsed.id, result: "ok" }, + ), + }) + }, + onmessage: null, + } + const client = Rpc.client<{ fail: () => string; ok: () => string }>(target) + let failed: string | undefined + + void client.call("fail", undefined).catch((error) => { + failed = error instanceof Error ? error.message : String(error) + }) + await Promise.resolve() + + expect(failed).toBe("boom") + await expect(client.call("ok", undefined)).resolves.toBe("ok") +})