fix(image-pipeline): wire webgl preview shader execution

This commit is contained in:
Matthias
2026-04-04 22:04:24 +02:00
parent 423eb76581
commit 80f12739f9
6 changed files with 444 additions and 117 deletions

View File

@@ -245,25 +245,55 @@ export function createBackendRouter(options?: {
}; };
} }
const rolloutFeatureFlags = getBackendFeatureFlags(); type RolloutRouterState = {
const rolloutCapabilities = detectBackendCapabilities(); router: BackendRouter;
const rolloutWebglAvailable = rolloutCapabilities.webgl; webglAvailable: boolean;
const rolloutWebglEnabled = rolloutFeatureFlags.webglEnabled && !rolloutFeatureFlags.forceCpu; webglEnabled: boolean;
};
const rolloutRouter = createBackendRouter({ let cachedRolloutState: RolloutRouterState | null = null;
backends: [cpuBackend, createWebglPreviewBackend()], let cachedRolloutKey: string | null = null;
defaultBackendId: "webgl",
backendAvailability: { function getRolloutRouterState(): RolloutRouterState {
webgl: { const featureFlags = getBackendFeatureFlags();
supported: rolloutWebglAvailable, const capabilities = detectBackendCapabilities();
enabled: rolloutWebglEnabled, const webglAvailable = capabilities.webgl;
}, const webglEnabled = featureFlags.webglEnabled && !featureFlags.forceCpu;
}, const rolloutKey = JSON.stringify({
featureFlags: rolloutFeatureFlags, forceCpu: featureFlags.forceCpu,
}); webglEnabled: featureFlags.webglEnabled,
wasmEnabled: featureFlags.wasmEnabled,
webglAvailable,
});
if (cachedRolloutState && cachedRolloutKey === rolloutKey) {
return cachedRolloutState;
}
cachedRolloutState = {
router: createBackendRouter({
backends: [cpuBackend, createWebglPreviewBackend()],
defaultBackendId: "webgl",
backendAvailability: {
webgl: {
supported: webglAvailable,
enabled: webglEnabled,
},
},
featureFlags,
}),
webglAvailable,
webglEnabled,
};
cachedRolloutKey = rolloutKey;
return cachedRolloutState;
}
export function getPreviewBackendHintForSteps(steps: readonly PreviewBackendRequest["step"][]): BackendHint { export function getPreviewBackendHintForSteps(steps: readonly PreviewBackendRequest["step"][]): BackendHint {
if (!rolloutWebglEnabled || !rolloutWebglAvailable) { const rolloutState = getRolloutRouterState();
if (!rolloutState.webglEnabled || !rolloutState.webglAvailable) {
return CPU_BACKEND_ID; return CPU_BACKEND_ID;
} }
@@ -271,9 +301,9 @@ export function getPreviewBackendHintForSteps(steps: readonly PreviewBackendRequ
} }
export function runPreviewStepWithBackendRouter(request: PreviewBackendRequest): void { export function runPreviewStepWithBackendRouter(request: PreviewBackendRequest): void {
rolloutRouter.runPreviewStep(request); getRolloutRouterState().router.runPreviewStep(request);
} }
export function runFullPipelineWithBackendRouter(request: FullBackendRequest): void { export function runFullPipelineWithBackendRouter(request: FullBackendRequest): void {
rolloutRouter.runFullPipeline(request); getRolloutRouterState().router.runFullPipeline(request);
} }

View File

@@ -1,4 +1,3 @@
#version 100
precision mediump float; precision mediump float;
varying vec2 vUv; varying vec2 vUv;

View File

@@ -1,4 +1,3 @@
#version 100
precision mediump float; precision mediump float;
varying vec2 vUv; varying vec2 vUv;

View File

@@ -0,0 +1,4 @@
declare module "*.glsl?raw" {
const source: string;
export default source;
}

View File

@@ -1,40 +1,13 @@
import { applyPipelineStep, applyPipelineSteps } from "@/lib/image-pipeline/render-core";
import type { import type {
BackendPipelineRequest, BackendPipelineRequest,
BackendStepRequest, BackendStepRequest,
ImagePipelineBackend, ImagePipelineBackend,
} from "@/lib/image-pipeline/backend/backend-types"; } from "@/lib/image-pipeline/backend/backend-types";
import type { PipelineStep } from "@/lib/image-pipeline/contracts"; import type { PipelineStep } from "@/lib/image-pipeline/contracts";
import colorAdjustFragmentShaderSource from "@/lib/image-pipeline/backend/webgl/shaders/color-adjust.frag.glsl?raw";
import curvesFragmentShaderSource from "@/lib/image-pipeline/backend/webgl/shaders/curves.frag.glsl?raw";
const CURVES_FRAGMENT_SHADER_SOURCE = `#version 100 const VERTEX_SHADER_SOURCE = `
precision mediump float;
varying vec2 vUv;
uniform sampler2D uSource;
uniform float uGamma;
void main() {
vec4 color = texture2D(uSource, vUv);
color.rgb = pow(max(color.rgb, vec3(0.0)), vec3(max(uGamma, 0.001)));
gl_FragColor = color;
}
`;
const COLOR_ADJUST_FRAGMENT_SHADER_SOURCE = `#version 100
precision mediump float;
varying vec2 vUv;
uniform sampler2D uSource;
uniform vec3 uColorShift;
void main() {
vec4 color = texture2D(uSource, vUv);
color.rgb = clamp(color.rgb + uColorShift, 0.0, 1.0);
gl_FragColor = color;
}
`;
const VERTEX_SHADER_SOURCE = `#version 100
attribute vec2 aPosition; attribute vec2 aPosition;
varying vec2 vUv; varying vec2 vUv;
@@ -46,6 +19,13 @@ void main() {
type SupportedPreviewStepType = "curves" | "color-adjust"; type SupportedPreviewStepType = "curves" | "color-adjust";
type WebglBackendContext = {
gl: WebGLRenderingContext;
curvesProgram: WebGLProgram;
colorAdjustProgram: WebGLProgram;
quadBuffer: WebGLBuffer;
};
const SUPPORTED_PREVIEW_STEP_TYPES = new Set<SupportedPreviewStepType>([ const SUPPORTED_PREVIEW_STEP_TYPES = new Set<SupportedPreviewStepType>([
"curves", "curves",
"color-adjust", "color-adjust",
@@ -59,34 +39,33 @@ function assertSupportedStep(step: PipelineStep): void {
throw new Error(`WebGL backend does not support step type '${step.type}'.`); throw new Error(`WebGL backend does not support step type '${step.type}'.`);
} }
function createGlContext(): WebGLRenderingContext | WebGL2RenderingContext { function createGlContext(): WebGLRenderingContext {
if (typeof document !== "undefined") { if (typeof document !== "undefined") {
const canvas = document.createElement("canvas"); const canvas = document.createElement("canvas");
return ( const context = canvas.getContext("webgl", {
canvas.getContext("webgl2") ?? alpha: true,
canvas.getContext("webgl") ?? antialias: false,
(() => { premultipliedAlpha: false,
throw new Error("WebGL context is unavailable."); preserveDrawingBuffer: true,
})() });
); if (context) {
return context;
}
} }
if (typeof OffscreenCanvas !== "undefined") { if (typeof OffscreenCanvas !== "undefined") {
const canvas = new OffscreenCanvas(1, 1); const canvas = new OffscreenCanvas(1, 1);
return ( const context = canvas.getContext("webgl");
canvas.getContext("webgl2") ?? if (context) {
canvas.getContext("webgl") ?? return context;
(() => { }
throw new Error("WebGL context is unavailable.");
})()
);
} }
throw new Error("WebGL context is unavailable."); throw new Error("WebGL context is unavailable.");
} }
function compileShader( function compileShader(
gl: WebGLRenderingContext | WebGL2RenderingContext, gl: WebGLRenderingContext,
source: string, source: string,
shaderType: number, shaderType: number,
): WebGLShader { ): WebGLShader {
@@ -108,9 +87,9 @@ function compileShader(
} }
function compileProgram( function compileProgram(
gl: WebGLRenderingContext | WebGL2RenderingContext, gl: WebGLRenderingContext,
fragmentShaderSource: string, fragmentShaderSource: string,
): void { ): WebGLProgram {
const vertexShader = compileShader(gl, VERTEX_SHADER_SOURCE, gl.VERTEX_SHADER); const vertexShader = compileShader(gl, VERTEX_SHADER_SOURCE, gl.VERTEX_SHADER);
const fragmentShader = compileShader(gl, fragmentShaderSource, gl.FRAGMENT_SHADER); const fragmentShader = compileShader(gl, fragmentShaderSource, gl.FRAGMENT_SHADER);
const program = gl.createProgram(); const program = gl.createProgram();
@@ -129,8 +108,7 @@ function compileProgram(
gl.deleteShader(fragmentShader); gl.deleteShader(fragmentShader);
if (gl.getProgramParameter(program, gl.LINK_STATUS)) { if (gl.getProgramParameter(program, gl.LINK_STATUS)) {
gl.deleteProgram(program); return program;
return;
} }
const info = gl.getProgramInfoLog(program) ?? "Unknown program link error."; const info = gl.getProgramInfoLog(program) ?? "Unknown program link error.";
@@ -138,6 +116,150 @@ function compileProgram(
throw new Error(`WebGL program link failed: ${info}`); throw new Error(`WebGL program link failed: ${info}`);
} }
function createQuadBuffer(gl: WebGLRenderingContext): WebGLBuffer {
const quadBuffer = gl.createBuffer();
if (!quadBuffer) {
throw new Error("WebGL quad buffer allocation failed.");
}
gl.bindBuffer(gl.ARRAY_BUFFER, quadBuffer);
gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([-1, -1, 1, -1, -1, 1, 1, 1]), gl.STATIC_DRAW);
return quadBuffer;
}
function mapCurvesGamma(step: PipelineStep): number {
const gamma = (step.params as { levels?: { gamma?: unknown } })?.levels?.gamma;
if (typeof gamma === "number" && Number.isFinite(gamma)) {
return Math.max(gamma, 0.001);
}
return 1;
}
function mapColorShift(step: PipelineStep): [number, number, number] {
const params = step.params as {
hsl?: { luminance?: unknown };
temperature?: unknown;
tint?: unknown;
};
const luminance = typeof params?.hsl?.luminance === "number" ? params.hsl.luminance : 0;
const temperature = typeof params?.temperature === "number" ? params.temperature : 0;
const tint = typeof params?.tint === "number" ? params.tint : 0;
return [
(luminance + temperature) / 255,
(luminance + tint) / 255,
(luminance - temperature) / 255,
];
}
function runStepOnGpu(context: WebglBackendContext, request: BackendStepRequest): void {
const { gl } = context;
const shaderProgram = request.step.type === "curves" ? context.curvesProgram : context.colorAdjustProgram;
gl.useProgram(shaderProgram);
gl.bindBuffer(gl.ARRAY_BUFFER, context.quadBuffer);
const positionLocation = gl.getAttribLocation(shaderProgram, "aPosition");
if (positionLocation >= 0) {
gl.enableVertexAttribArray(positionLocation);
gl.vertexAttribPointer(positionLocation, 2, gl.FLOAT, false, 0, 0);
}
const sourceTexture = gl.createTexture();
if (!sourceTexture) {
throw new Error("WebGL source texture allocation failed.");
}
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, sourceTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.RGBA,
request.width,
request.height,
0,
gl.RGBA,
gl.UNSIGNED_BYTE,
request.pixels,
);
const outputTexture = gl.createTexture();
if (!outputTexture) {
gl.deleteTexture(sourceTexture);
throw new Error("WebGL output texture allocation failed.");
}
gl.bindTexture(gl.TEXTURE_2D, outputTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.RGBA,
request.width,
request.height,
0,
gl.RGBA,
gl.UNSIGNED_BYTE,
null,
);
const framebuffer = gl.createFramebuffer();
if (!framebuffer) {
gl.deleteTexture(sourceTexture);
gl.deleteTexture(outputTexture);
throw new Error("WebGL framebuffer allocation failed.");
}
gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, outputTexture, 0);
if (gl.checkFramebufferStatus(gl.FRAMEBUFFER) !== gl.FRAMEBUFFER_COMPLETE) {
gl.deleteFramebuffer(framebuffer);
gl.deleteTexture(sourceTexture);
gl.deleteTexture(outputTexture);
throw new Error("WebGL framebuffer is incomplete.");
}
const sourceLocation = gl.getUniformLocation(shaderProgram, "uSource");
if (sourceLocation) {
gl.uniform1i(sourceLocation, 0);
}
if (request.step.type === "curves") {
const gammaLocation = gl.getUniformLocation(shaderProgram, "uGamma");
if (gammaLocation) {
gl.uniform1f(gammaLocation, mapCurvesGamma(request.step));
}
} else {
const colorShiftLocation = gl.getUniformLocation(shaderProgram, "uColorShift");
if (colorShiftLocation) {
const [r, g, b] = mapColorShift(request.step);
gl.uniform3f(colorShiftLocation, r, g, b);
}
}
gl.viewport(0, 0, request.width, request.height);
gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);
const readback = new Uint8Array(request.pixels.length);
gl.readPixels(0, 0, request.width, request.height, gl.RGBA, gl.UNSIGNED_BYTE, readback);
request.pixels.set(readback);
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
gl.deleteFramebuffer(framebuffer);
gl.deleteTexture(sourceTexture);
gl.deleteTexture(outputTexture);
}
export function isWebglPreviewStepSupported(step: PipelineStep): boolean { export function isWebglPreviewStepSupported(step: PipelineStep): boolean {
return SUPPORTED_PREVIEW_STEP_TYPES.has(step.type as SupportedPreviewStepType); return SUPPORTED_PREVIEW_STEP_TYPES.has(step.type as SupportedPreviewStepType);
} }
@@ -147,45 +269,45 @@ export function isWebglPreviewPipelineSupported(steps: readonly PipelineStep[]):
} }
export function createWebglPreviewBackend(): ImagePipelineBackend { export function createWebglPreviewBackend(): ImagePipelineBackend {
let initialized = false; let context: WebglBackendContext | null = null;
function ensureInitialized(): void { function ensureInitialized(): WebglBackendContext {
if (initialized) { if (context) {
return; return context;
} }
const gl = createGlContext(); const gl = createGlContext();
compileProgram(gl, CURVES_FRAGMENT_SHADER_SOURCE); context = {
compileProgram(gl, COLOR_ADJUST_FRAGMENT_SHADER_SOURCE); gl,
initialized = true; curvesProgram: compileProgram(gl, curvesFragmentShaderSource),
colorAdjustProgram: compileProgram(gl, colorAdjustFragmentShaderSource),
quadBuffer: createQuadBuffer(gl),
};
return context;
} }
return { return {
id: "webgl", id: "webgl",
runPreviewStep(request: BackendStepRequest): void { runPreviewStep(request: BackendStepRequest): void {
assertSupportedStep(request.step); assertSupportedStep(request.step);
ensureInitialized(); runStepOnGpu(ensureInitialized(), request);
applyPipelineStep(
request.pixels,
request.step,
request.width,
request.height,
request.executionOptions,
);
}, },
runFullPipeline(request: BackendPipelineRequest): void { runFullPipeline(request: BackendPipelineRequest): void {
if (!isWebglPreviewPipelineSupported(request.steps)) { if (!isWebglPreviewPipelineSupported(request.steps)) {
throw new Error("WebGL backend does not support all pipeline steps."); throw new Error("WebGL backend does not support all pipeline steps.");
} }
ensureInitialized(); const initializedContext = ensureInitialized();
applyPipelineSteps( for (const step of request.steps) {
request.pixels, runStepOnGpu(initializedContext, {
request.steps, pixels: request.pixels,
request.width, step,
request.height, width: request.width,
request.executionOptions, height: request.height,
); executionOptions: request.executionOptions,
});
}
}, },
}; };
} }

View File

@@ -3,7 +3,6 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import type { PipelineStep } from "@/lib/image-pipeline/contracts"; import type { PipelineStep } from "@/lib/image-pipeline/contracts";
import type { ImagePipelineBackend } from "@/lib/image-pipeline/backend/backend-types";
function createCurvesStep(): PipelineStep { function createCurvesStep(): PipelineStep {
return { return {
@@ -65,24 +64,6 @@ function createUnsupportedStep(): PipelineStep {
}; };
} }
function createCpuAndWebglBackends(args: {
webglPreview?: ImagePipelineBackend["runPreviewStep"];
cpuPreview?: ImagePipelineBackend["runPreviewStep"];
}): readonly ImagePipelineBackend[] {
return [
{
id: "cpu",
runPreviewStep: args.cpuPreview ?? vi.fn(),
runFullPipeline: vi.fn(),
},
{
id: "webgl",
runPreviewStep: args.webglPreview ?? vi.fn(),
runFullPipeline: vi.fn(),
},
];
}
describe("webgl backend poc", () => { describe("webgl backend poc", () => {
beforeEach(() => { beforeEach(() => {
vi.resetModules(); vi.resetModules();
@@ -95,8 +76,102 @@ describe("webgl backend poc", () => {
vi.unmock("@/lib/image-pipeline/backend/webgl/webgl-backend"); vi.unmock("@/lib/image-pipeline/backend/webgl/webgl-backend");
vi.unmock("@/lib/image-pipeline/backend/backend-router"); vi.unmock("@/lib/image-pipeline/backend/backend-router");
vi.unmock("@/lib/image-pipeline/source-loader"); vi.unmock("@/lib/image-pipeline/source-loader");
vi.unmock("@/lib/image-pipeline/render-core");
}); });
function createFakeWebglContext(options?: {
compileSuccess?: boolean;
linkSuccess?: boolean;
readbackPixels?: Uint8Array;
}): WebGLRenderingContext {
const compileSuccess = options?.compileSuccess ?? true;
const linkSuccess = options?.linkSuccess ?? true;
const readbackPixels = options?.readbackPixels ?? new Uint8Array([0, 0, 0, 255]);
return {
VERTEX_SHADER: 0x8b31,
FRAGMENT_SHADER: 0x8b30,
COMPILE_STATUS: 0x8b81,
LINK_STATUS: 0x8b82,
ARRAY_BUFFER: 0x8892,
STATIC_DRAW: 0x88e4,
TRIANGLE_STRIP: 0x0005,
FLOAT: 0x1406,
TEXTURE_2D: 0x0de1,
RGBA: 0x1908,
UNSIGNED_BYTE: 0x1401,
TEXTURE0: 0x84c0,
TEXTURE_MIN_FILTER: 0x2801,
TEXTURE_MAG_FILTER: 0x2800,
TEXTURE_WRAP_S: 0x2802,
TEXTURE_WRAP_T: 0x2803,
CLAMP_TO_EDGE: 0x812f,
NEAREST: 0x2600,
FRAMEBUFFER: 0x8d40,
COLOR_ATTACHMENT0: 0x8ce0,
FRAMEBUFFER_COMPLETE: 0x8cd5,
createShader: vi.fn(() => ({ shader: true })),
shaderSource: vi.fn(),
compileShader: vi.fn(),
getShaderParameter: vi.fn((_shader: unknown, pname: number) => {
if (pname === 0x8b81) {
return compileSuccess;
}
return true;
}),
getShaderInfoLog: vi.fn(() => "compile error"),
deleteShader: vi.fn(),
createProgram: vi.fn(() => ({ program: true })),
attachShader: vi.fn(),
linkProgram: vi.fn(),
getProgramParameter: vi.fn((_program: unknown, pname: number) => {
if (pname === 0x8b82) {
return linkSuccess;
}
return true;
}),
getProgramInfoLog: vi.fn(() => "link error"),
deleteProgram: vi.fn(),
useProgram: vi.fn(),
createBuffer: vi.fn(() => ({ buffer: true })),
bindBuffer: vi.fn(),
bufferData: vi.fn(),
getAttribLocation: vi.fn(() => 0),
enableVertexAttribArray: vi.fn(),
vertexAttribPointer: vi.fn(),
createTexture: vi.fn(() => ({ texture: true })),
bindTexture: vi.fn(),
texParameteri: vi.fn(),
texImage2D: vi.fn(),
activeTexture: vi.fn(),
getUniformLocation: vi.fn(() => ({ uniform: true })),
uniform1i: vi.fn(),
uniform1f: vi.fn(),
uniform3f: vi.fn(),
createFramebuffer: vi.fn(() => ({ framebuffer: true })),
bindFramebuffer: vi.fn(),
framebufferTexture2D: vi.fn(),
checkFramebufferStatus: vi.fn(() => 0x8cd5),
deleteFramebuffer: vi.fn(),
viewport: vi.fn(),
drawArrays: vi.fn(),
deleteTexture: vi.fn(),
readPixels: vi.fn(
(
_x: number,
_y: number,
_width: number,
_height: number,
_format: number,
_type: number,
pixels: Uint8Array,
) => {
pixels.set(readbackPixels);
},
),
} as unknown as WebGLRenderingContext;
}
it("selects webgl for preview when webgl is available and enabled", async () => { it("selects webgl for preview when webgl is available and enabled", async () => {
const webglPreview = vi.fn(); const webglPreview = vi.fn();
@@ -224,8 +299,49 @@ describe("webgl backend poc", () => {
} }
}); });
it("runs a supported preview step through gpu shader path with readback", async () => {
const cpuPreview = vi.fn();
vi.doMock("@/lib/image-pipeline/render-core", async () => {
const actual = await vi.importActual<typeof import("@/lib/image-pipeline/render-core")>(
"@/lib/image-pipeline/render-core",
);
return {
...actual,
applyPipelineStep: cpuPreview,
};
});
const fakeGl = createFakeWebglContext({
readbackPixels: new Uint8Array([10, 20, 30, 255]),
});
vi.spyOn(HTMLCanvasElement.prototype, "getContext").mockImplementation((contextId) => {
if (contextId === "webgl") {
return fakeGl;
}
return null;
});
const { createWebglPreviewBackend } = await import("@/lib/image-pipeline/backend/webgl/webgl-backend");
const pixels = new Uint8ClampedArray([200, 100, 50, 255]);
const backend = createWebglPreviewBackend();
backend.runPreviewStep({
pixels,
step: createCurvesStep(),
width: 1,
height: 1,
});
expect(Array.from(pixels)).toEqual([10, 20, 30, 255]);
expect(cpuPreview).not.toHaveBeenCalled();
expect(fakeGl.readPixels).toHaveBeenCalledTimes(1);
});
it("downgrades compile/link failures to cpu with runtime_error reason", async () => { it("downgrades compile/link failures to cpu with runtime_error reason", async () => {
const { createBackendRouter } = await import("@/lib/image-pipeline/backend/backend-router"); const { createBackendRouter } = await import("@/lib/image-pipeline/backend/backend-router");
const { createWebglPreviewBackend } = await import("@/lib/image-pipeline/backend/webgl/webgl-backend");
const cpuPreview = vi.fn(); const cpuPreview = vi.fn();
const fallbackEvents: Array<{ const fallbackEvents: Array<{
reason: string; reason: string;
@@ -233,13 +349,25 @@ describe("webgl backend poc", () => {
fallbackBackend: string; fallbackBackend: string;
}> = []; }> = [];
const fakeGl = createFakeWebglContext({
compileSuccess: false,
});
vi.spyOn(HTMLCanvasElement.prototype, "getContext").mockImplementation((contextId) => {
if (contextId === "webgl") {
return fakeGl;
}
return null;
});
const router = createBackendRouter({ const router = createBackendRouter({
backends: createCpuAndWebglBackends({ backends: [
cpuPreview, {
webglPreview: () => { id: "cpu",
throw new Error("WebGL shader compile failed"); runPreviewStep: cpuPreview,
runFullPipeline: vi.fn(),
}, },
}), createWebglPreviewBackend(),
],
defaultBackendId: "webgl", defaultBackendId: "webgl",
backendAvailability: { backendAvailability: {
webgl: { webgl: {
@@ -277,4 +405,49 @@ describe("webgl backend poc", () => {
}, },
]); ]);
}); });
it("re-evaluates rollout flags and capabilities at runtime", async () => {
const runtimeState = {
flags: {
forceCpu: false,
webglEnabled: false,
wasmEnabled: false,
},
capabilities: {
webgl: false,
wasmSimd: false,
offscreenCanvas: false,
},
};
vi.doMock("@/lib/image-pipeline/backend/feature-flags", async () => {
const actual = await vi.importActual<typeof import("@/lib/image-pipeline/backend/feature-flags")>(
"@/lib/image-pipeline/backend/feature-flags",
);
return {
...actual,
getBackendFeatureFlags: () => runtimeState.flags,
};
});
vi.doMock("@/lib/image-pipeline/backend/capabilities", async () => {
const actual = await vi.importActual<typeof import("@/lib/image-pipeline/backend/capabilities")>(
"@/lib/image-pipeline/backend/capabilities",
);
return {
...actual,
detectBackendCapabilities: () => runtimeState.capabilities,
};
});
const { getPreviewBackendHintForSteps } = await import("@/lib/image-pipeline/backend/backend-router");
const steps = [createCurvesStep()] as const;
expect(getPreviewBackendHintForSteps(steps)).toBe("cpu");
runtimeState.flags.webglEnabled = true;
runtimeState.capabilities.webgl = true;
expect(getPreviewBackendHintForSteps(steps)).toBe("webgl");
});
}); });