diff --git a/components/canvas/nodes/prompt-node.tsx b/components/canvas/nodes/prompt-node.tsx
index 2de5512..2b9ce04 100644
--- a/components/canvas/nodes/prompt-node.tsx
+++ b/components/canvas/nodes/prompt-node.tsx
@@ -18,7 +18,11 @@ import BaseNodeWrapper from "./base-node-wrapper";
import { useCanvasPlacement } from "@/components/canvas/canvas-placement-context";
import { useCanvasSync } from "@/components/canvas/canvas-sync-context";
import { useDebouncedCallback } from "@/hooks/use-debounced-callback";
-import { DEFAULT_MODEL_ID, getModel } from "@/lib/ai-models";
+import {
+ DEFAULT_MODEL_ID,
+ getAvailableImageModels,
+ getModel,
+} from "@/lib/ai-models";
import {
DEFAULT_ASPECT_RATIO,
getAiImageNodeOuterSize,
@@ -40,6 +44,7 @@ import { Sparkles, Loader2, Coins } from "lucide-react";
import { useRouter } from "next/navigation";
import { toast } from "@/lib/toast";
import { classifyError } from "@/lib/ai-errors";
+import { normalizePublicTier } from "@/lib/tier-credits";
type PromptNodeData = {
prompt?: string;
@@ -63,6 +68,7 @@ export default function PromptNode({
const { getEdges, getNode } = useReactFlow();
const [prompt, setPrompt] = useState(nodeData.prompt ?? "");
+ const [modelId, setModelId] = useState(nodeData.model ?? DEFAULT_MODEL_ID);
const [aspectRatio, setAspectRatio] = useState(
nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO
);
@@ -72,14 +78,20 @@ export default function PromptNode({
const nodes = useStore((store) => store.nodes);
const promptRef = useRef(prompt);
+ const modelIdRef = useRef(modelId);
const aspectRatioRef = useRef(aspectRatio);
promptRef.current = prompt;
+ modelIdRef.current = modelId;
aspectRatioRef.current = aspectRatio;
useEffect(() => {
setPrompt(nodeData.prompt ?? "");
}, [nodeData.prompt]);
+ useEffect(() => {
+ setModelId(nodeData.model ?? DEFAULT_MODEL_ID);
+ }, [nodeData.model]);
+
useEffect(() => {
setAspectRatio(nodeData.aspectRatio ?? DEFAULT_ASPECT_RATIO);
}, [nodeData.aspectRatio]);
@@ -113,7 +125,29 @@ export default function PromptNode({
dataRef.current = data;
const balance = useAuthQuery(api.credits.getBalance);
- const creditCost = getModel(DEFAULT_MODEL_ID)?.creditCost ?? 4;
+ const subscription = useAuthQuery(api.credits.getSubscription);
+ const userTier = normalizePublicTier(subscription?.tier ?? "free");
+ const availableModels = useMemo(
+ () => getAvailableImageModels(userTier),
+ [userTier],
+ );
+
+ useEffect(() => {
+ if (availableModels.length === 0) {
+ return;
+ }
+
+ if (!availableModels.some((model) => model.id === modelId)) {
+ setModelId(availableModels[0]!.id);
+ }
+ }, [availableModels, modelId]);
+
+ const selectedModel =
+ getModel(modelId) ??
+ availableModels[0] ??
+ getModel(DEFAULT_MODEL_ID);
+ const resolvedModelId = selectedModel?.id ?? DEFAULT_MODEL_ID;
+ const creditCost = selectedModel?.creditCost ?? 4;
const availableCredits =
balance !== undefined ? balance.balance - balance.reserved : null;
@@ -131,12 +165,13 @@ export default function PromptNode({
void _statusMessage;
void queueNodeDataUpdate({
nodeId: id as Id<"nodes">,
- data: {
- ...rest,
- prompt: promptRef.current,
- aspectRatio: aspectRatioRef.current,
- },
- });
+ data: {
+ ...rest,
+ prompt: promptRef.current,
+ model: modelIdRef.current,
+ aspectRatio: aspectRatioRef.current,
+ },
+ });
}, 500);
const handlePromptChange = useCallback(
@@ -156,6 +191,14 @@ export default function PromptNode({
[debouncedSave]
);
+ const handleModelChange = useCallback(
+ (value: string) => {
+ setModelId(value);
+ debouncedSave();
+ },
+ [debouncedSave],
+ );
+
const handleGenerate = useCallback(async () => {
if (!effectivePrompt.trim() || isGenerating) return;
if (status.isOffline) {
@@ -229,8 +272,8 @@ export default function PromptNode({
height: outer.height,
data: {
prompt: promptToUse,
- model: DEFAULT_MODEL_ID,
- modelTier: "standard",
+ model: resolvedModelId,
+ modelTier: selectedModel?.tier ?? "standard",
canvasId,
aspectRatio,
outputWidth: viewport.width,
@@ -249,7 +292,7 @@ export default function PromptNode({
prompt: promptToUse,
referenceStorageId,
referenceImageUrl,
- model: DEFAULT_MODEL_ID,
+ model: resolvedModelId,
aspectRatio,
}),
{
@@ -285,6 +328,7 @@ export default function PromptNode({
prompt,
effectivePrompt,
aspectRatio,
+ resolvedModelId,
isGenerating,
nodeData.canvasId,
id,
@@ -292,6 +336,7 @@ export default function PromptNode({
getNode,
createNodeConnectedFromSource,
generateImage,
+ selectedModel?.tier,
creditCost,
availableCredits,
hasEnoughCredits,
@@ -338,6 +383,31 @@ export default function PromptNode({
/>
)}
+
+
+
+
+