feat(canvas): add tier-aware model selector to prompt node

This commit is contained in:
2026-04-07 23:27:21 +02:00
parent 39d435d58e
commit 91fdd6c143
2 changed files with 336 additions and 11 deletions

View File

@@ -18,7 +18,11 @@ import BaseNodeWrapper from "./base-node-wrapper";
import { useCanvasPlacement } from "@/components/canvas/canvas-placement-context";
import { useCanvasSync } from "@/components/canvas/canvas-sync-context";
import { useDebouncedCallback } from "@/hooks/use-debounced-callback";
import { DEFAULT_MODEL_ID, getModel } from "@/lib/ai-models";
import {
DEFAULT_MODEL_ID,
getAvailableImageModels,
getModel,
} from "@/lib/ai-models";
import {
DEFAULT_ASPECT_RATIO,
getAiImageNodeOuterSize,
@@ -40,6 +44,7 @@ import { Sparkles, Loader2, Coins } from "lucide-react";
import { useRouter } from "next/navigation";
import { toast } from "@/lib/toast";
import { classifyError } from "@/lib/ai-errors";
import { normalizePublicTier } from "@/lib/tier-credits";
type PromptNodeData = {
prompt?: string;
@@ -63,6 +68,7 @@ export default function PromptNode({
const { getEdges, getNode } = useReactFlow();
const [prompt, setPrompt] = useState(nodeData.prompt ?? "");
const [modelId, setModelId] = useState(nodeData.model ?? DEFAULT_MODEL_ID);
const [aspectRatio, setAspectRatio] = useState(
nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO
);
@@ -72,14 +78,20 @@ export default function PromptNode({
const nodes = useStore((store) => store.nodes);
const promptRef = useRef(prompt);
const modelIdRef = useRef(modelId);
const aspectRatioRef = useRef(aspectRatio);
promptRef.current = prompt;
modelIdRef.current = modelId;
aspectRatioRef.current = aspectRatio;
useEffect(() => {
setPrompt(nodeData.prompt ?? "");
}, [nodeData.prompt]);
useEffect(() => {
setModelId(nodeData.model ?? DEFAULT_MODEL_ID);
}, [nodeData.model]);
useEffect(() => {
setAspectRatio(nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO);
}, [nodeData.aspectRatio]);
@@ -113,7 +125,29 @@ export default function PromptNode({
dataRef.current = data;
const balance = useAuthQuery(api.credits.getBalance);
const creditCost = getModel(DEFAULT_MODEL_ID)?.creditCost ?? 4;
const subscription = useAuthQuery(api.credits.getSubscription);
const userTier = normalizePublicTier(subscription?.tier ?? "free");
const availableModels = useMemo(
() => getAvailableImageModels(userTier),
[userTier],
);
useEffect(() => {
if (availableModels.length === 0) {
return;
}
if (!availableModels.some((model) => model.id === modelId)) {
setModelId(availableModels[0]!.id);
}
}, [availableModels, modelId]);
const selectedModel =
getModel(modelId) ??
availableModels[0] ??
getModel(DEFAULT_MODEL_ID);
const resolvedModelId = selectedModel?.id ?? DEFAULT_MODEL_ID;
const creditCost = selectedModel?.creditCost ?? 4;
const availableCredits =
balance !== undefined ? balance.balance - balance.reserved : null;
@@ -131,12 +165,13 @@ export default function PromptNode({
void _statusMessage;
void queueNodeDataUpdate({
nodeId: id as Id<"nodes">,
data: {
...rest,
prompt: promptRef.current,
aspectRatio: aspectRatioRef.current,
},
});
data: {
...rest,
prompt: promptRef.current,
model: modelIdRef.current,
aspectRatio: aspectRatioRef.current,
},
});
}, 500);
const handlePromptChange = useCallback(
@@ -156,6 +191,14 @@ export default function PromptNode({
[debouncedSave]
);
const handleModelChange = useCallback(
(value: string) => {
setModelId(value);
debouncedSave();
},
[debouncedSave],
);
const handleGenerate = useCallback(async () => {
if (!effectivePrompt.trim() || isGenerating) return;
if (status.isOffline) {
@@ -229,8 +272,8 @@ export default function PromptNode({
height: outer.height,
data: {
prompt: promptToUse,
model: DEFAULT_MODEL_ID,
modelTier: "standard",
model: resolvedModelId,
modelTier: selectedModel?.tier ?? "standard",
canvasId,
aspectRatio,
outputWidth: viewport.width,
@@ -249,7 +292,7 @@ export default function PromptNode({
prompt: promptToUse,
referenceStorageId,
referenceImageUrl,
model: DEFAULT_MODEL_ID,
model: resolvedModelId,
aspectRatio,
}),
{
@@ -285,6 +328,7 @@ export default function PromptNode({
prompt,
effectivePrompt,
aspectRatio,
resolvedModelId,
isGenerating,
nodeData.canvasId,
id,
@@ -292,6 +336,7 @@ export default function PromptNode({
getNode,
createNodeConnectedFromSource,
generateImage,
selectedModel?.tier,
creditCost,
availableCredits,
hasEnoughCredits,
@@ -338,6 +383,31 @@ export default function PromptNode({
/>
)}
<div className="flex flex-col gap-1.5">
<Label
htmlFor={`prompt-model-${id}`}
className="text-[11px] font-medium text-muted-foreground"
>
Modell
</Label>
<Select value={resolvedModelId} onValueChange={handleModelChange}>
<SelectTrigger
id={`prompt-model-${id}`}
className="nodrag nowheel w-full"
size="sm"
>
<SelectValue placeholder="Modell" />
</SelectTrigger>
<SelectContent className="nodrag">
{availableModels.map((model) => (
<SelectItem key={model.id} value={model.id}>
{model.name}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="flex flex-col gap-1.5">
<Label
htmlFor={`prompt-format-${id}`}