feat(canvas): add tier-aware model selector to prompt node
This commit is contained in:
@@ -18,7 +18,11 @@ import BaseNodeWrapper from "./base-node-wrapper";
|
||||
import { useCanvasPlacement } from "@/components/canvas/canvas-placement-context";
|
||||
import { useCanvasSync } from "@/components/canvas/canvas-sync-context";
|
||||
import { useDebouncedCallback } from "@/hooks/use-debounced-callback";
|
||||
import { DEFAULT_MODEL_ID, getModel } from "@/lib/ai-models";
|
||||
import {
|
||||
DEFAULT_MODEL_ID,
|
||||
getAvailableImageModels,
|
||||
getModel,
|
||||
} from "@/lib/ai-models";
|
||||
import {
|
||||
DEFAULT_ASPECT_RATIO,
|
||||
getAiImageNodeOuterSize,
|
||||
@@ -40,6 +44,7 @@ import { Sparkles, Loader2, Coins } from "lucide-react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { toast } from "@/lib/toast";
|
||||
import { classifyError } from "@/lib/ai-errors";
|
||||
import { normalizePublicTier } from "@/lib/tier-credits";
|
||||
|
||||
type PromptNodeData = {
|
||||
prompt?: string;
|
||||
@@ -63,6 +68,7 @@ export default function PromptNode({
|
||||
const { getEdges, getNode } = useReactFlow();
|
||||
|
||||
const [prompt, setPrompt] = useState(nodeData.prompt ?? "");
|
||||
const [modelId, setModelId] = useState(nodeData.model ?? DEFAULT_MODEL_ID);
|
||||
const [aspectRatio, setAspectRatio] = useState(
|
||||
nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO
|
||||
);
|
||||
@@ -72,14 +78,20 @@ export default function PromptNode({
|
||||
const nodes = useStore((store) => store.nodes);
|
||||
|
||||
const promptRef = useRef(prompt);
|
||||
const modelIdRef = useRef(modelId);
|
||||
const aspectRatioRef = useRef(aspectRatio);
|
||||
promptRef.current = prompt;
|
||||
modelIdRef.current = modelId;
|
||||
aspectRatioRef.current = aspectRatio;
|
||||
|
||||
useEffect(() => {
|
||||
setPrompt(nodeData.prompt ?? "");
|
||||
}, [nodeData.prompt]);
|
||||
|
||||
useEffect(() => {
|
||||
setModelId(nodeData.model ?? DEFAULT_MODEL_ID);
|
||||
}, [nodeData.model]);
|
||||
|
||||
useEffect(() => {
|
||||
setAspectRatio(nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO);
|
||||
}, [nodeData.aspectRatio]);
|
||||
@@ -113,7 +125,29 @@ export default function PromptNode({
|
||||
dataRef.current = data;
|
||||
|
||||
const balance = useAuthQuery(api.credits.getBalance);
|
||||
const creditCost = getModel(DEFAULT_MODEL_ID)?.creditCost ?? 4;
|
||||
const subscription = useAuthQuery(api.credits.getSubscription);
|
||||
const userTier = normalizePublicTier(subscription?.tier ?? "free");
|
||||
const availableModels = useMemo(
|
||||
() => getAvailableImageModels(userTier),
|
||||
[userTier],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (availableModels.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!availableModels.some((model) => model.id === modelId)) {
|
||||
setModelId(availableModels[0]!.id);
|
||||
}
|
||||
}, [availableModels, modelId]);
|
||||
|
||||
const selectedModel =
|
||||
getModel(modelId) ??
|
||||
availableModels[0] ??
|
||||
getModel(DEFAULT_MODEL_ID);
|
||||
const resolvedModelId = selectedModel?.id ?? DEFAULT_MODEL_ID;
|
||||
const creditCost = selectedModel?.creditCost ?? 4;
|
||||
|
||||
const availableCredits =
|
||||
balance !== undefined ? balance.balance - balance.reserved : null;
|
||||
@@ -131,12 +165,13 @@ export default function PromptNode({
|
||||
void _statusMessage;
|
||||
void queueNodeDataUpdate({
|
||||
nodeId: id as Id<"nodes">,
|
||||
data: {
|
||||
...rest,
|
||||
prompt: promptRef.current,
|
||||
aspectRatio: aspectRatioRef.current,
|
||||
},
|
||||
});
|
||||
data: {
|
||||
...rest,
|
||||
prompt: promptRef.current,
|
||||
model: modelIdRef.current,
|
||||
aspectRatio: aspectRatioRef.current,
|
||||
},
|
||||
});
|
||||
}, 500);
|
||||
|
||||
const handlePromptChange = useCallback(
|
||||
@@ -156,6 +191,14 @@ export default function PromptNode({
|
||||
[debouncedSave]
|
||||
);
|
||||
|
||||
const handleModelChange = useCallback(
|
||||
(value: string) => {
|
||||
setModelId(value);
|
||||
debouncedSave();
|
||||
},
|
||||
[debouncedSave],
|
||||
);
|
||||
|
||||
const handleGenerate = useCallback(async () => {
|
||||
if (!effectivePrompt.trim() || isGenerating) return;
|
||||
if (status.isOffline) {
|
||||
@@ -229,8 +272,8 @@ export default function PromptNode({
|
||||
height: outer.height,
|
||||
data: {
|
||||
prompt: promptToUse,
|
||||
model: DEFAULT_MODEL_ID,
|
||||
modelTier: "standard",
|
||||
model: resolvedModelId,
|
||||
modelTier: selectedModel?.tier ?? "standard",
|
||||
canvasId,
|
||||
aspectRatio,
|
||||
outputWidth: viewport.width,
|
||||
@@ -249,7 +292,7 @@ export default function PromptNode({
|
||||
prompt: promptToUse,
|
||||
referenceStorageId,
|
||||
referenceImageUrl,
|
||||
model: DEFAULT_MODEL_ID,
|
||||
model: resolvedModelId,
|
||||
aspectRatio,
|
||||
}),
|
||||
{
|
||||
@@ -285,6 +328,7 @@ export default function PromptNode({
|
||||
prompt,
|
||||
effectivePrompt,
|
||||
aspectRatio,
|
||||
resolvedModelId,
|
||||
isGenerating,
|
||||
nodeData.canvasId,
|
||||
id,
|
||||
@@ -292,6 +336,7 @@ export default function PromptNode({
|
||||
getNode,
|
||||
createNodeConnectedFromSource,
|
||||
generateImage,
|
||||
selectedModel?.tier,
|
||||
creditCost,
|
||||
availableCredits,
|
||||
hasEnoughCredits,
|
||||
@@ -338,6 +383,31 @@ export default function PromptNode({
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<Label
|
||||
htmlFor={`prompt-model-${id}`}
|
||||
className="text-[11px] font-medium text-muted-foreground"
|
||||
>
|
||||
Modell
|
||||
</Label>
|
||||
<Select value={resolvedModelId} onValueChange={handleModelChange}>
|
||||
<SelectTrigger
|
||||
id={`prompt-model-${id}`}
|
||||
className="nodrag nowheel w-full"
|
||||
size="sm"
|
||||
>
|
||||
<SelectValue placeholder="Modell" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="nodrag">
|
||||
{availableModels.map((model) => (
|
||||
<SelectItem key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<Label
|
||||
htmlFor={`prompt-format-${id}`}
|
||||
|
||||
Reference in New Issue
Block a user