feat(canvas): add tier-aware model selector to prompt node
This commit is contained in:
@@ -18,7 +18,11 @@ import BaseNodeWrapper from "./base-node-wrapper";
|
|||||||
import { useCanvasPlacement } from "@/components/canvas/canvas-placement-context";
|
import { useCanvasPlacement } from "@/components/canvas/canvas-placement-context";
|
||||||
import { useCanvasSync } from "@/components/canvas/canvas-sync-context";
|
import { useCanvasSync } from "@/components/canvas/canvas-sync-context";
|
||||||
import { useDebouncedCallback } from "@/hooks/use-debounced-callback";
|
import { useDebouncedCallback } from "@/hooks/use-debounced-callback";
|
||||||
import { DEFAULT_MODEL_ID, getModel } from "@/lib/ai-models";
|
import {
|
||||||
|
DEFAULT_MODEL_ID,
|
||||||
|
getAvailableImageModels,
|
||||||
|
getModel,
|
||||||
|
} from "@/lib/ai-models";
|
||||||
import {
|
import {
|
||||||
DEFAULT_ASPECT_RATIO,
|
DEFAULT_ASPECT_RATIO,
|
||||||
getAiImageNodeOuterSize,
|
getAiImageNodeOuterSize,
|
||||||
@@ -40,6 +44,7 @@ import { Sparkles, Loader2, Coins } from "lucide-react";
|
|||||||
import { useRouter } from "next/navigation";
|
import { useRouter } from "next/navigation";
|
||||||
import { toast } from "@/lib/toast";
|
import { toast } from "@/lib/toast";
|
||||||
import { classifyError } from "@/lib/ai-errors";
|
import { classifyError } from "@/lib/ai-errors";
|
||||||
|
import { normalizePublicTier } from "@/lib/tier-credits";
|
||||||
|
|
||||||
type PromptNodeData = {
|
type PromptNodeData = {
|
||||||
prompt?: string;
|
prompt?: string;
|
||||||
@@ -63,6 +68,7 @@ export default function PromptNode({
|
|||||||
const { getEdges, getNode } = useReactFlow();
|
const { getEdges, getNode } = useReactFlow();
|
||||||
|
|
||||||
const [prompt, setPrompt] = useState(nodeData.prompt ?? "");
|
const [prompt, setPrompt] = useState(nodeData.prompt ?? "");
|
||||||
|
const [modelId, setModelId] = useState(nodeData.model ?? DEFAULT_MODEL_ID);
|
||||||
const [aspectRatio, setAspectRatio] = useState(
|
const [aspectRatio, setAspectRatio] = useState(
|
||||||
nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO
|
nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO
|
||||||
);
|
);
|
||||||
@@ -72,14 +78,20 @@ export default function PromptNode({
|
|||||||
const nodes = useStore((store) => store.nodes);
|
const nodes = useStore((store) => store.nodes);
|
||||||
|
|
||||||
const promptRef = useRef(prompt);
|
const promptRef = useRef(prompt);
|
||||||
|
const modelIdRef = useRef(modelId);
|
||||||
const aspectRatioRef = useRef(aspectRatio);
|
const aspectRatioRef = useRef(aspectRatio);
|
||||||
promptRef.current = prompt;
|
promptRef.current = prompt;
|
||||||
|
modelIdRef.current = modelId;
|
||||||
aspectRatioRef.current = aspectRatio;
|
aspectRatioRef.current = aspectRatio;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setPrompt(nodeData.prompt ?? "");
|
setPrompt(nodeData.prompt ?? "");
|
||||||
}, [nodeData.prompt]);
|
}, [nodeData.prompt]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setModelId(nodeData.model ?? DEFAULT_MODEL_ID);
|
||||||
|
}, [nodeData.model]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setAspectRatio(nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO);
|
setAspectRatio(nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO);
|
||||||
}, [nodeData.aspectRatio]);
|
}, [nodeData.aspectRatio]);
|
||||||
@@ -113,7 +125,29 @@ export default function PromptNode({
|
|||||||
dataRef.current = data;
|
dataRef.current = data;
|
||||||
|
|
||||||
const balance = useAuthQuery(api.credits.getBalance);
|
const balance = useAuthQuery(api.credits.getBalance);
|
||||||
const creditCost = getModel(DEFAULT_MODEL_ID)?.creditCost ?? 4;
|
const subscription = useAuthQuery(api.credits.getSubscription);
|
||||||
|
const userTier = normalizePublicTier(subscription?.tier ?? "free");
|
||||||
|
const availableModels = useMemo(
|
||||||
|
() => getAvailableImageModels(userTier),
|
||||||
|
[userTier],
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (availableModels.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!availableModels.some((model) => model.id === modelId)) {
|
||||||
|
setModelId(availableModels[0]!.id);
|
||||||
|
}
|
||||||
|
}, [availableModels, modelId]);
|
||||||
|
|
||||||
|
const selectedModel =
|
||||||
|
getModel(modelId) ??
|
||||||
|
availableModels[0] ??
|
||||||
|
getModel(DEFAULT_MODEL_ID);
|
||||||
|
const resolvedModelId = selectedModel?.id ?? DEFAULT_MODEL_ID;
|
||||||
|
const creditCost = selectedModel?.creditCost ?? 4;
|
||||||
|
|
||||||
const availableCredits =
|
const availableCredits =
|
||||||
balance !== undefined ? balance.balance - balance.reserved : null;
|
balance !== undefined ? balance.balance - balance.reserved : null;
|
||||||
@@ -131,12 +165,13 @@ export default function PromptNode({
|
|||||||
void _statusMessage;
|
void _statusMessage;
|
||||||
void queueNodeDataUpdate({
|
void queueNodeDataUpdate({
|
||||||
nodeId: id as Id<"nodes">,
|
nodeId: id as Id<"nodes">,
|
||||||
data: {
|
data: {
|
||||||
...rest,
|
...rest,
|
||||||
prompt: promptRef.current,
|
prompt: promptRef.current,
|
||||||
aspectRatio: aspectRatioRef.current,
|
model: modelIdRef.current,
|
||||||
},
|
aspectRatio: aspectRatioRef.current,
|
||||||
});
|
},
|
||||||
|
});
|
||||||
}, 500);
|
}, 500);
|
||||||
|
|
||||||
const handlePromptChange = useCallback(
|
const handlePromptChange = useCallback(
|
||||||
@@ -156,6 +191,14 @@ export default function PromptNode({
|
|||||||
[debouncedSave]
|
[debouncedSave]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleModelChange = useCallback(
|
||||||
|
(value: string) => {
|
||||||
|
setModelId(value);
|
||||||
|
debouncedSave();
|
||||||
|
},
|
||||||
|
[debouncedSave],
|
||||||
|
);
|
||||||
|
|
||||||
const handleGenerate = useCallback(async () => {
|
const handleGenerate = useCallback(async () => {
|
||||||
if (!effectivePrompt.trim() || isGenerating) return;
|
if (!effectivePrompt.trim() || isGenerating) return;
|
||||||
if (status.isOffline) {
|
if (status.isOffline) {
|
||||||
@@ -229,8 +272,8 @@ export default function PromptNode({
|
|||||||
height: outer.height,
|
height: outer.height,
|
||||||
data: {
|
data: {
|
||||||
prompt: promptToUse,
|
prompt: promptToUse,
|
||||||
model: DEFAULT_MODEL_ID,
|
model: resolvedModelId,
|
||||||
modelTier: "standard",
|
modelTier: selectedModel?.tier ?? "standard",
|
||||||
canvasId,
|
canvasId,
|
||||||
aspectRatio,
|
aspectRatio,
|
||||||
outputWidth: viewport.width,
|
outputWidth: viewport.width,
|
||||||
@@ -249,7 +292,7 @@ export default function PromptNode({
|
|||||||
prompt: promptToUse,
|
prompt: promptToUse,
|
||||||
referenceStorageId,
|
referenceStorageId,
|
||||||
referenceImageUrl,
|
referenceImageUrl,
|
||||||
model: DEFAULT_MODEL_ID,
|
model: resolvedModelId,
|
||||||
aspectRatio,
|
aspectRatio,
|
||||||
}),
|
}),
|
||||||
{
|
{
|
||||||
@@ -285,6 +328,7 @@ export default function PromptNode({
|
|||||||
prompt,
|
prompt,
|
||||||
effectivePrompt,
|
effectivePrompt,
|
||||||
aspectRatio,
|
aspectRatio,
|
||||||
|
resolvedModelId,
|
||||||
isGenerating,
|
isGenerating,
|
||||||
nodeData.canvasId,
|
nodeData.canvasId,
|
||||||
id,
|
id,
|
||||||
@@ -292,6 +336,7 @@ export default function PromptNode({
|
|||||||
getNode,
|
getNode,
|
||||||
createNodeConnectedFromSource,
|
createNodeConnectedFromSource,
|
||||||
generateImage,
|
generateImage,
|
||||||
|
selectedModel?.tier,
|
||||||
creditCost,
|
creditCost,
|
||||||
availableCredits,
|
availableCredits,
|
||||||
hasEnoughCredits,
|
hasEnoughCredits,
|
||||||
@@ -338,6 +383,31 @@ export default function PromptNode({
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
<div className="flex flex-col gap-1.5">
|
||||||
|
<Label
|
||||||
|
htmlFor={`prompt-model-${id}`}
|
||||||
|
className="text-[11px] font-medium text-muted-foreground"
|
||||||
|
>
|
||||||
|
Modell
|
||||||
|
</Label>
|
||||||
|
<Select value={resolvedModelId} onValueChange={handleModelChange}>
|
||||||
|
<SelectTrigger
|
||||||
|
id={`prompt-model-${id}`}
|
||||||
|
className="nodrag nowheel w-full"
|
||||||
|
size="sm"
|
||||||
|
>
|
||||||
|
<SelectValue placeholder="Modell" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent className="nodrag">
|
||||||
|
{availableModels.map((model) => (
|
||||||
|
<SelectItem key={model.id} value={model.id}>
|
||||||
|
{model.name}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div className="flex flex-col gap-1.5">
|
<div className="flex flex-col gap-1.5">
|
||||||
<Label
|
<Label
|
||||||
htmlFor={`prompt-format-${id}`}
|
htmlFor={`prompt-format-${id}`}
|
||||||
|
|||||||
255
tests/prompt-node.test.ts
Normal file
255
tests/prompt-node.test.ts
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
// @vitest-environment jsdom
|
||||||
|
|
||||||
|
import React from "react";
|
||||||
|
import { act } from "react";
|
||||||
|
import { createRoot, type Root } from "react-dom/client";
|
||||||
|
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||||
|
|
||||||
|
import type { Id } from "@/convex/_generated/dataModel";
|
||||||
|
|
||||||
|
const mocks = vi.hoisted(() => ({
|
||||||
|
edges: [] as Array<{ source: string; target: string }>,
|
||||||
|
nodes: [] as Array<{ id: string; type: string; data: Record<string, unknown> }>,
|
||||||
|
balance: { balance: 100, reserved: 0 } as { balance: number; reserved: number } | undefined,
|
||||||
|
subscription: { tier: "starter" as const },
|
||||||
|
queueNodeDataUpdate: vi.fn(async () => undefined),
|
||||||
|
createNodeConnectedFromSource: vi.fn(async () => "ai-image-node-1" as Id<"nodes">),
|
||||||
|
generateImage: vi.fn(async () => ({ queued: true, nodeId: "ai-image-node-1" })),
|
||||||
|
getEdges: vi.fn(() => [] as Array<{ source: string; target: string }>),
|
||||||
|
getNode: vi.fn((id: string) =>
|
||||||
|
id === "prompt-1"
|
||||||
|
? { id, position: { x: 100, y: 50 }, measured: { width: 280, height: 220 } }
|
||||||
|
: null,
|
||||||
|
),
|
||||||
|
push: vi.fn(),
|
||||||
|
toastPromise: vi.fn(async <T,>(promise: Promise<T>) => await promise),
|
||||||
|
toastWarning: vi.fn(),
|
||||||
|
toastAction: vi.fn(),
|
||||||
|
toastError: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("next-intl", () => ({
|
||||||
|
useTranslations: () => (key: string) => key,
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("next/navigation", () => ({
|
||||||
|
useRouter: () => ({ push: mocks.push }),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("convex/react", () => ({
|
||||||
|
useAction: () => mocks.generateImage,
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/convex/_generated/api", () => ({
|
||||||
|
api: {
|
||||||
|
ai: {
|
||||||
|
generateImage: "ai.generateImage",
|
||||||
|
},
|
||||||
|
credits: {
|
||||||
|
getBalance: "credits.getBalance",
|
||||||
|
getSubscription: "credits.getSubscription",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/hooks/use-auth-query", () => ({
|
||||||
|
useAuthQuery: (query: string) => {
|
||||||
|
if (query === "credits.getSubscription") return mocks.subscription;
|
||||||
|
return mocks.balance;
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/hooks/use-debounced-callback", () => ({
|
||||||
|
useDebouncedCallback: (callback: (...args: Array<unknown>) => void) => callback,
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/components/canvas/canvas-sync-context", () => ({
|
||||||
|
useCanvasSync: () => ({
|
||||||
|
queueNodeDataUpdate: mocks.queueNodeDataUpdate,
|
||||||
|
status: { isOffline: false, isSyncing: false, pendingCount: 0 },
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/components/canvas/canvas-placement-context", () => ({
|
||||||
|
useCanvasPlacement: () => ({
|
||||||
|
createNodeConnectedFromSource: mocks.createNodeConnectedFromSource,
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/lib/toast", () => ({
|
||||||
|
toast: {
|
||||||
|
promise: mocks.toastPromise,
|
||||||
|
warning: mocks.toastWarning,
|
||||||
|
action: mocks.toastAction,
|
||||||
|
error: mocks.toastError,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/lib/ai-errors", () => ({
|
||||||
|
classifyError: (error: unknown) => ({
|
||||||
|
type: "generic",
|
||||||
|
rawMessage: error instanceof Error ? error.message : String(error),
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/components/ui/label", () => ({
|
||||||
|
Label: ({ children, htmlFor }: { children: React.ReactNode; htmlFor?: string }) =>
|
||||||
|
React.createElement("label", { htmlFor }, children),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/components/ui/select", () => ({
|
||||||
|
Select: ({
|
||||||
|
value,
|
||||||
|
onValueChange,
|
||||||
|
children,
|
||||||
|
}: {
|
||||||
|
value: string;
|
||||||
|
onValueChange: (value: string) => void;
|
||||||
|
children: React.ReactNode;
|
||||||
|
}) =>
|
||||||
|
React.createElement(
|
||||||
|
"select",
|
||||||
|
{
|
||||||
|
"data-testid": value.includes("/") ? "model-select" : "format-select",
|
||||||
|
value,
|
||||||
|
onChange: (event: Event) => {
|
||||||
|
onValueChange((event.target as HTMLSelectElement).value);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
children,
|
||||||
|
),
|
||||||
|
SelectTrigger: ({ children }: { children: React.ReactNode }) => children,
|
||||||
|
SelectValue: () => null,
|
||||||
|
SelectContent: ({ children }: { children: React.ReactNode }) => children,
|
||||||
|
SelectItem: ({ children, value }: { children: React.ReactNode; value: string }) =>
|
||||||
|
React.createElement("option", { value }, children),
|
||||||
|
SelectGroup: ({ children }: { children: React.ReactNode }) => children,
|
||||||
|
SelectLabel: ({ children }: { children: React.ReactNode }) =>
|
||||||
|
React.createElement("optgroup", { label: String(children) }),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@/components/canvas/nodes/base-node-wrapper", () => ({
|
||||||
|
default: ({ children }: { children: React.ReactNode }) => React.createElement("div", null, children),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("@xyflow/react", () => ({
|
||||||
|
Handle: () => null,
|
||||||
|
Position: { Left: "left", Right: "right" },
|
||||||
|
useStore: (selector: (state: { edges: typeof mocks.edges; nodes: typeof mocks.nodes }) => unknown) =>
|
||||||
|
selector({ edges: mocks.edges, nodes: mocks.nodes }),
|
||||||
|
useReactFlow: () => ({
|
||||||
|
getEdges: mocks.getEdges,
|
||||||
|
getNode: mocks.getNode,
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
import PromptNode from "@/components/canvas/nodes/prompt-node";
|
||||||
|
|
||||||
|
(globalThis as typeof globalThis & { IS_REACT_ACT_ENVIRONMENT?: boolean }).IS_REACT_ACT_ENVIRONMENT =
|
||||||
|
true;
|
||||||
|
|
||||||
|
describe("PromptNode", () => {
|
||||||
|
let container: HTMLDivElement | null = null;
|
||||||
|
let root: Root | null = null;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mocks.edges = [];
|
||||||
|
mocks.nodes = [];
|
||||||
|
mocks.balance = { balance: 100, reserved: 0 };
|
||||||
|
mocks.subscription = { tier: "starter" };
|
||||||
|
mocks.queueNodeDataUpdate.mockClear();
|
||||||
|
mocks.createNodeConnectedFromSource.mockClear();
|
||||||
|
mocks.generateImage.mockClear();
|
||||||
|
mocks.getEdges.mockClear();
|
||||||
|
mocks.getNode.mockClear();
|
||||||
|
mocks.push.mockClear();
|
||||||
|
mocks.toastPromise.mockClear();
|
||||||
|
mocks.toastWarning.mockClear();
|
||||||
|
mocks.toastAction.mockClear();
|
||||||
|
mocks.toastError.mockClear();
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
if (root) {
|
||||||
|
act(() => {
|
||||||
|
root?.unmount();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
container?.remove();
|
||||||
|
container = null;
|
||||||
|
root = null;
|
||||||
|
});
|
||||||
|
|
||||||
|
it("propagates selected image model into node creation and generation action", async () => {
|
||||||
|
container = document.createElement("div");
|
||||||
|
document.body.appendChild(container);
|
||||||
|
root = createRoot(container);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
root?.render(
|
||||||
|
React.createElement(PromptNode, {
|
||||||
|
id: "prompt-1",
|
||||||
|
selected: false,
|
||||||
|
dragging: false,
|
||||||
|
draggable: true,
|
||||||
|
selectable: true,
|
||||||
|
deletable: true,
|
||||||
|
zIndex: 1,
|
||||||
|
isConnectable: true,
|
||||||
|
type: "prompt",
|
||||||
|
data: {
|
||||||
|
prompt: "ein neugieriger hund im regen",
|
||||||
|
aspectRatio: "1:1",
|
||||||
|
canvasId: "canvas-1",
|
||||||
|
},
|
||||||
|
positionAbsoluteX: 0,
|
||||||
|
positionAbsoluteY: 0,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
const modelSelect = container.querySelector('select[data-testid="model-select"]');
|
||||||
|
if (!(modelSelect instanceof HTMLSelectElement)) {
|
||||||
|
throw new Error("Model select not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
modelSelect.value = "openai/gpt-5-image-mini";
|
||||||
|
modelSelect.dispatchEvent(new Event("change", { bubbles: true }));
|
||||||
|
});
|
||||||
|
|
||||||
|
const button = Array.from(container.querySelectorAll("button")).find((element) =>
|
||||||
|
element.textContent?.includes("Bild generieren"),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!(button instanceof HTMLButtonElement)) {
|
||||||
|
throw new Error("Generate button not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
button.click();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mocks.createNodeConnectedFromSource).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mocks.createNodeConnectedFromSource).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
type: "ai-image",
|
||||||
|
sourceNodeId: "prompt-1",
|
||||||
|
data: expect.objectContaining({
|
||||||
|
model: "openai/gpt-5-image-mini",
|
||||||
|
modelTier: "premium",
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mocks.generateImage).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mocks.generateImage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
canvasId: "canvas-1",
|
||||||
|
nodeId: "ai-image-node-1",
|
||||||
|
prompt: "ein neugieriger hund im regen",
|
||||||
|
model: "openai/gpt-5-image-mini",
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user