diff --git a/components/canvas/__tests__/use-canvas-edge-insertions.test.tsx b/components/canvas/__tests__/use-canvas-edge-insertions.test.tsx index 23de626..2c6cde9 100644 --- a/components/canvas/__tests__/use-canvas-edge-insertions.test.tsx +++ b/components/canvas/__tests__/use-canvas-edge-insertions.test.tsx @@ -223,7 +223,7 @@ describe("useCanvasEdgeInsertions", () => { label: "Prompt", width: 320, height: 220, - defaultData: { prompt: "", model: "", aspectRatio: "1:1" }, + defaultData: { prompt: "", aspectRatio: "1:1" }, } as CanvasNodeTemplate); }); @@ -268,7 +268,7 @@ describe("useCanvasEdgeInsertions", () => { label: "Prompt", width: 320, height: 220, - defaultData: { prompt: "", model: "", aspectRatio: "1:1" }, + defaultData: { prompt: "", aspectRatio: "1:1" }, } as CanvasNodeTemplate); }); @@ -319,7 +319,7 @@ describe("useCanvasEdgeInsertions", () => { label: "Prompt", width: 320, height: 220, - defaultData: { prompt: "", model: "", aspectRatio: "1:1" }, + defaultData: { prompt: "", aspectRatio: "1:1" }, } as CanvasNodeTemplate); }); @@ -332,7 +332,7 @@ describe("useCanvasEdgeInsertions", () => { height: 220, data: { prompt: "", - model: "", + model: "google/gemini-2.5-flash-image", aspectRatio: "1:1", canvasId: "canvas-1", }, @@ -650,7 +650,7 @@ describe("useCanvasEdgeInsertions", () => { label: "Prompt", width: 320, height: 220, - defaultData: { prompt: "", model: "", aspectRatio: "1:1" }, + defaultData: { prompt: "", model: "google/gemini-2.5-flash-image", aspectRatio: "1:1" }, } as CanvasNodeTemplate); }); diff --git a/components/canvas/canvas-graph-context.tsx b/components/canvas/canvas-graph-context.tsx index 23e2267..907d42c 100644 --- a/components/canvas/canvas-graph-context.tsx +++ b/components/canvas/canvas-graph-context.tsx @@ -4,7 +4,6 @@ import { createContext, useCallback, useContext, - useEffect, useMemo, useState, type ReactNode, @@ -84,12 +83,6 @@ export function CanvasGraphProvider({ [nodes, previewNodeDataOverrides], ); - useEffect(() => { - if (prunedPreviewNodeDataOverrides !== previewNodeDataOverrides) { - setPreviewNodeDataOverrides(prunedPreviewNodeDataOverrides); - } - }, [previewNodeDataOverrides, prunedPreviewNodeDataOverrides]); - const graph = useMemo( () => buildGraphSnapshot(nodes, edges, { diff --git a/components/canvas/nodes/prompt-node.tsx b/components/canvas/nodes/prompt-node.tsx index 2de5512..2b9ce04 100644 --- a/components/canvas/nodes/prompt-node.tsx +++ b/components/canvas/nodes/prompt-node.tsx @@ -18,7 +18,11 @@ import BaseNodeWrapper from "./base-node-wrapper"; import { useCanvasPlacement } from "@/components/canvas/canvas-placement-context"; import { useCanvasSync } from "@/components/canvas/canvas-sync-context"; import { useDebouncedCallback } from "@/hooks/use-debounced-callback"; -import { DEFAULT_MODEL_ID, getModel } from "@/lib/ai-models"; +import { + DEFAULT_MODEL_ID, + getAvailableImageModels, + getModel, +} from "@/lib/ai-models"; import { DEFAULT_ASPECT_RATIO, getAiImageNodeOuterSize, @@ -40,6 +44,7 @@ import { Sparkles, Loader2, Coins } from "lucide-react"; import { useRouter } from "next/navigation"; import { toast } from "@/lib/toast"; import { classifyError } from "@/lib/ai-errors"; +import { normalizePublicTier } from "@/lib/tier-credits"; type PromptNodeData = { prompt?: string; @@ -63,6 +68,7 @@ export default function PromptNode({ const { getEdges, getNode } = useReactFlow(); const [prompt, setPrompt] = useState(nodeData.prompt ?? ""); + const [modelId, setModelId] = useState(nodeData.model ?? DEFAULT_MODEL_ID); const [aspectRatio, setAspectRatio] = useState( nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO ); @@ -72,14 +78,20 @@ export default function PromptNode({ const nodes = useStore((store) => store.nodes); const promptRef = useRef(prompt); + const modelIdRef = useRef(modelId); const aspectRatioRef = useRef(aspectRatio); promptRef.current = prompt; + modelIdRef.current = modelId; aspectRatioRef.current = aspectRatio; useEffect(() => { setPrompt(nodeData.prompt ?? ""); }, [nodeData.prompt]); + useEffect(() => { + setModelId(nodeData.model ?? DEFAULT_MODEL_ID); + }, [nodeData.model]); + useEffect(() => { setAspectRatio(nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO); }, [nodeData.aspectRatio]); @@ -113,7 +125,29 @@ export default function PromptNode({ dataRef.current = data; const balance = useAuthQuery(api.credits.getBalance); - const creditCost = getModel(DEFAULT_MODEL_ID)?.creditCost ?? 4; + const subscription = useAuthQuery(api.credits.getSubscription); + const userTier = normalizePublicTier(subscription?.tier ?? "free"); + const availableModels = useMemo( + () => getAvailableImageModels(userTier), + [userTier], + ); + + useEffect(() => { + if (availableModels.length === 0) { + return; + } + + if (!availableModels.some((model) => model.id === modelId)) { + setModelId(availableModels[0]!.id); + } + }, [availableModels, modelId]); + + const selectedModel = + getModel(modelId) ?? + availableModels[0] ?? + getModel(DEFAULT_MODEL_ID); + const resolvedModelId = selectedModel?.id ?? DEFAULT_MODEL_ID; + const creditCost = selectedModel?.creditCost ?? 4; const availableCredits = balance !== undefined ? balance.balance - balance.reserved : null; @@ -131,12 +165,13 @@ export default function PromptNode({ void _statusMessage; void queueNodeDataUpdate({ nodeId: id as Id<"nodes">, - data: { - ...rest, - prompt: promptRef.current, - aspectRatio: aspectRatioRef.current, - }, - }); + data: { + ...rest, + prompt: promptRef.current, + model: modelIdRef.current, + aspectRatio: aspectRatioRef.current, + }, + }); }, 500); const handlePromptChange = useCallback( @@ -156,6 +191,14 @@ export default function PromptNode({ [debouncedSave] ); + const handleModelChange = useCallback( + (value: string) => { + setModelId(value); + debouncedSave(); + }, + [debouncedSave], + ); + const handleGenerate = useCallback(async () => { if (!effectivePrompt.trim() || isGenerating) return; if (status.isOffline) { @@ -229,8 +272,8 @@ export default function PromptNode({ height: outer.height, data: { prompt: promptToUse, - model: DEFAULT_MODEL_ID, - modelTier: "standard", + model: resolvedModelId, + modelTier: selectedModel?.tier ?? "standard", canvasId, aspectRatio, outputWidth: viewport.width, @@ -249,7 +292,7 @@ export default function PromptNode({ prompt: promptToUse, referenceStorageId, referenceImageUrl, - model: DEFAULT_MODEL_ID, + model: resolvedModelId, aspectRatio, }), { @@ -285,6 +328,7 @@ export default function PromptNode({ prompt, effectivePrompt, aspectRatio, + resolvedModelId, isGenerating, nodeData.canvasId, id, @@ -292,6 +336,7 @@ export default function PromptNode({ getNode, createNodeConnectedFromSource, generateImage, + selectedModel?.tier, creditCost, availableCredits, hasEnoughCredits, @@ -338,6 +383,31 @@ export default function PromptNode({ /> )} +
+ + +
+