diff --git a/lib/image-pipeline/backend/backend-router.ts b/lib/image-pipeline/backend/backend-router.ts index e249371..c285197 100644 --- a/lib/image-pipeline/backend/backend-router.ts +++ b/lib/image-pipeline/backend/backend-router.ts @@ -196,31 +196,60 @@ export function createBackendRouter(options?: { }); } - try { - args.runBackend(selection.backend); - return; - } catch (error: unknown) { - const shouldAbort = args.executionOptions?.shouldAbort; - if (shouldAbort?.()) { - throw error; - } + const shouldAbort = args.executionOptions?.shouldAbort; + let backend = selection.backend; - if (selection.backend.id.toLowerCase() === cpuFallbackBackend.id.toLowerCase()) { - throw error; - } + while (true) { + try { + args.runBackend(backend); + return; + } catch (error: unknown) { + if (shouldAbort?.()) { + throw error; + } - const normalizedError = - error instanceof Error ? error : new Error("Image pipeline backend execution failed."); - emitFallback({ - reason: "runtime_error", - requestedBackend: selection.backend.id.toLowerCase(), - fallbackBackend: cpuFallbackBackend.id, - error: normalizedError, - }); - args.runBackend(cpuFallbackBackend); + if (backend.id.toLowerCase() === cpuFallbackBackend.id.toLowerCase()) { + throw error; + } + + const fallbackBackend = resolveRuntimeFallbackBackend(backend.id); + if (!fallbackBackend || fallbackBackend.id.toLowerCase() === backend.id.toLowerCase()) { + throw error; + } + + const normalizedError = + error instanceof Error ? error : new Error("Image pipeline backend execution failed."); + emitFallback({ + reason: "runtime_error", + requestedBackend: backend.id.toLowerCase(), + fallbackBackend: fallbackBackend.id, + error: normalizedError, + }); + + backend = fallbackBackend; + } } } + function resolveRuntimeFallbackBackend(failedBackendId: string): ImagePipelineBackend | null { + const normalizedFailedBackendId = failedBackendId.toLowerCase(); + + if (normalizedFailedBackendId === "webgl") { + const wasmBackend = byId.get("wasm"); + const wasmAvailability = readAvailability("wasm"); + if ( + wasmBackend && + isBackendEnabledByFlags("wasm") && + wasmAvailability?.enabled !== false && + wasmAvailability?.supported !== false + ) { + return wasmBackend; + } + } + + return cpuFallbackBackend; + } + return { resolveBackend(backendHint) { return resolveBackendWithFallbackReason(backendHint).backend; @@ -307,7 +336,15 @@ export function getPreviewBackendHintForSteps(steps: readonly PreviewBackendRequ const rolloutState = getRolloutRouterState(); if (rolloutState.webglEnabled && rolloutState.webglAvailable) { - return isWebglPreviewPipelineSupported(steps) ? "webgl" : CPU_BACKEND_ID; + if (isWebglPreviewPipelineSupported(steps)) { + return "webgl"; + } + + if (rolloutState.wasmEnabled && rolloutState.wasmAvailable) { + return "wasm"; + } + + return CPU_BACKEND_ID; } if (rolloutState.wasmEnabled && rolloutState.wasmAvailable) { diff --git a/tests/image-pipeline/wasm-backend.test.ts b/tests/image-pipeline/wasm-backend.test.ts index 6fecb9a..fbe3afd 100644 --- a/tests/image-pipeline/wasm-backend.test.ts +++ b/tests/image-pipeline/wasm-backend.test.ts @@ -70,9 +70,203 @@ describe("wasm backend rollout selection", () => { expect(backendRouter.getPreviewBackendHintForSteps([createStep()])).toBe("wasm"); }); + + it("prefers wasm when webgl is enabled+available but unsupported for the step set", async () => { + vi.doMock("@/lib/image-pipeline/backend/feature-flags", async () => { + const actual = await vi.importActual("@/lib/image-pipeline/backend/feature-flags"); + return { + ...actual, + getBackendFeatureFlags: () => ({ + forceCpu: false, + webglEnabled: true, + wasmEnabled: true, + }), + }; + }); + + vi.doMock("@/lib/image-pipeline/backend/capabilities", async () => { + const actual = await vi.importActual("@/lib/image-pipeline/backend/capabilities"); + return { + ...actual, + detectBackendCapabilities: () => ({ + webgl: true, + wasmSimd: true, + offscreenCanvas: true, + }), + }; + }); + + vi.doMock("@/lib/image-pipeline/backend/webgl/webgl-backend", async () => { + const actual = await vi.importActual("@/lib/image-pipeline/backend/webgl/webgl-backend"); + return { + ...actual, + createWebglPreviewBackend: () => ({ + id: "webgl", + runPreviewStep: vi.fn(), + runFullPipeline: vi.fn(), + }), + isWebglPreviewPipelineSupported: () => false, + }; + }); + + const backendRouter = await import("@/lib/image-pipeline/backend/backend-router"); + + expect(backendRouter.getPreviewBackendHintForSteps([createStep()])).toBe("wasm"); + }); }); describe("wasm backend fallback behavior", () => { + it("uses wasm as runtime fallback before cpu when webgl fails", () => { + const fallbackEvents: Array<{ + reason: string; + requestedBackend: string; + fallbackBackend: string; + }> = []; + const webglPreview = vi.fn(() => { + throw new Error("webgl failed"); + }); + const wasmPreview = vi.fn(); + const cpuPreview = vi.fn(); + const router = createBackendRouter({ + backends: [ + { + id: "cpu", + runPreviewStep: cpuPreview, + runFullPipeline: vi.fn(), + }, + { + id: "wasm", + runPreviewStep: wasmPreview, + runFullPipeline: vi.fn(), + }, + { + id: "webgl", + runPreviewStep: webglPreview, + runFullPipeline: vi.fn(), + }, + ], + defaultBackendId: "webgl", + backendAvailability: { + webgl: { + supported: true, + enabled: true, + }, + wasm: { + supported: true, + enabled: true, + }, + }, + featureFlags: { + forceCpu: false, + webglEnabled: true, + wasmEnabled: true, + }, + onFallback: (event) => { + fallbackEvents.push({ + reason: event.reason, + requestedBackend: event.requestedBackend, + fallbackBackend: event.fallbackBackend, + }); + }, + }); + + router.runPreviewStep({ + pixels: new Uint8ClampedArray(4), + step: createStep(), + width: 1, + height: 1, + backendHint: "webgl", + }); + + expect(webglPreview).toHaveBeenCalledTimes(1); + expect(wasmPreview).toHaveBeenCalledTimes(1); + expect(cpuPreview).not.toHaveBeenCalled(); + expect(fallbackEvents).toEqual([ + { + reason: "runtime_error", + requestedBackend: "webgl", + fallbackBackend: "wasm", + }, + ]); + }); + + it("falls through to cpu when both webgl and wasm fail at runtime", () => { + const fallbackEvents: Array<{ + reason: string; + requestedBackend: string; + fallbackBackend: string; + }> = []; + const cpuPreview = vi.fn(); + const router = createBackendRouter({ + backends: [ + { + id: "cpu", + runPreviewStep: cpuPreview, + runFullPipeline: vi.fn(), + }, + { + id: "wasm", + runPreviewStep: () => { + throw new Error("wasm failed"); + }, + runFullPipeline: vi.fn(), + }, + { + id: "webgl", + runPreviewStep: () => { + throw new Error("webgl failed"); + }, + runFullPipeline: vi.fn(), + }, + ], + defaultBackendId: "webgl", + backendAvailability: { + webgl: { + supported: true, + enabled: true, + }, + wasm: { + supported: true, + enabled: true, + }, + }, + featureFlags: { + forceCpu: false, + webglEnabled: true, + wasmEnabled: true, + }, + onFallback: (event) => { + fallbackEvents.push({ + reason: event.reason, + requestedBackend: event.requestedBackend, + fallbackBackend: event.fallbackBackend, + }); + }, + }); + + router.runPreviewStep({ + pixels: new Uint8ClampedArray(4), + step: createStep(), + width: 1, + height: 1, + backendHint: "webgl", + }); + + expect(cpuPreview).toHaveBeenCalledTimes(1); + expect(fallbackEvents).toEqual([ + { + reason: "runtime_error", + requestedBackend: "webgl", + fallbackBackend: "wasm", + }, + { + reason: "runtime_error", + requestedBackend: "wasm", + fallbackBackend: "cpu", + }, + ]); + }); + it("downgrades to cpu with runtime_error when wasm initialization fails", () => { const fallbackEvents: Array<{ reason: string;