diff --git a/app.py b/app.py new file mode 100644 index 0000000..48e76c4 --- /dev/null +++ b/app.py @@ -0,0 +1,417 @@ +import os +from datetime import datetime +from fastapi import FastAPI, HTTPException, Depends +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from pydantic import BaseModel +from typing import Optional +from runware import Runware, IImageInference, IVideoInference +from dotenv import load_dotenv +from jose import jwt +from motor.motor_asyncio import AsyncIOMotorClient +from models import UserUsage, UserData + +# Load environment variables +load_dotenv(override=True) + +app = FastAPI() + +# Global Runware instance +RUNWARE_API_KEY = os.getenv("RUNWARE_API_KEY") +runware = Runware(api_key=RUNWARE_API_KEY) + +# MongoDB and Clerk Configuration +MONGO_URL = os.getenv("MONGO_URL") +DAILY_LIMIT = 20 + +# Load Clerk Public Key from file (safest) +CLERK_JWT_PUBLIC_KEY = None +if os.path.exists("clerk_public.pem"): + with open("clerk_public.pem", "r") as f: + CLERK_JWT_PUBLIC_KEY = f.read().strip() +else: + # Fallback to env + CLERK_JWT_PUBLIC_KEY = os.getenv("CLERK_JWT_PUBLIC_KEY") + +@app.on_event("startup") +async def startup_event(): + print("\n--- Backend Startup Diagnostics ---") + print(f"RUNWARE_API_KEY: {'✅ Found' if RUNWARE_API_KEY else '❌ Missing'}") + print(f"MONGO_URL: {'✅ Found' if MONGO_URL else '❌ Missing'}") + print(f"CLERK_JWT_PUBLIC_KEY: {'✅ Found' if CLERK_JWT_PUBLIC_KEY else '❌ Missing'}") + if CLERK_JWT_PUBLIC_KEY: + print(f"Clerk Key Length: {len(CLERK_JWT_PUBLIC_KEY)} chars") + print("-----------------------------------\n") + + await runware.connect() + +# Initialize MongoDB client (global) +if MONGO_URL: + client = AsyncIOMotorClient(MONGO_URL) + db = client.get_database() +else: + print("❌ Critical: MONGO_URL not found!") + db = None + +security = HTTPBearer() + +async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): + if not CLERK_JWT_PUBLIC_KEY: + # For development/safety, if public key is not set, we might want to warn or skip + # but in production this should be a hard requirement. + raise HTTPException(status_code=500, detail="CLERK_JWT_PUBLIC_KEY not configured") + + token = credentials.credentials + try: + # Clerk uses RS256 for JWTs + payload = jwt.decode(token, CLERK_JWT_PUBLIC_KEY, algorithms=["RS256"]) + return payload + except Exception as e: + raise HTTPException(status_code=401, detail=f"Unauthorized: {str(e)}") + +async def check_limit_and_credits(user_payload: dict = Depends(verify_token)): + try: + user_id = user_payload["sub"] + today = datetime.utcnow().date().isoformat() + + # 1. Check Daily Limit + usage_data = await db.usage.find_one({"user_id": user_id, "date": today}) + if usage_data: + usage = UserUsage(**usage_data) + if usage.count >= DAILY_LIMIT: + raise HTTPException(status_code=429, detail="Daily generation limit reached") + + # 2. Check Credits & Update Email + email = user_payload.get("email") # Get email from Clerk JWT + + user_doc = await db.users.find_one({"user_id": user_id}) + if not user_doc: + # Auto-create user with 3 free credits and email if not exists + new_user = UserData(user_id=user_id, email=email, credits=3) + await db.users.insert_one(new_user.model_dump()) + user_doc = new_user.model_dump() + else: + # Update email if it changed or was missing + if email and user_doc.get("email") != email: + await db.users.update_one({"user_id": user_id}, {"$set": {"email": email}}) + + user_data = UserData(**user_doc) + if user_data.credits <= 0: + raise HTTPException(status_code=402, detail="No credits left") + + return user_payload + except HTTPException: + raise + except Exception as e: + print(f"ERROR in check_limit_and_credits: {str(e)}") + import traceback + traceback.print_exc() + raise HTTPException(status_code=500, detail=str(e)) + +async def track_usage(user_id: str): + today = datetime.utcnow().date().isoformat() + # Update Daily Limit + await db.usage.update_one( + {"user_id": user_id, "date": today}, + {"$inc": {"count": 1}}, + upsert=True + ) + # Deduct Credit and get new count + result = await db.users.find_one_and_update( + {"user_id": user_id}, + {"$inc": {"credits": -1}}, + return_document=True + ) + return result.get("credits", 0) if result else 0 + +@app.get("/profile") +async def get_profile(user: dict = Depends(verify_token)): + user_id = user["sub"] + user_doc = await db.users.find_one({"user_id": user_id}) + today = datetime.utcnow().date().isoformat() + + if not user_doc: + # Create user if they don't exist yet + email = user.get("email") + new_user = UserData(user_id=user_id, email=email, credits=10) + await db.users.insert_one(new_user.model_dump()) + user_doc = new_user.model_dump() + else: + # Check for date reset + if user_doc.get("last_ad_date") != today: + await db.users.update_one( + {"user_id": user_id}, + {"$set": { + "ads_watched_today": 0, + "ad_credits_earned_today": 0, + "last_ad_date": today + }} + ) + user_doc["ads_watched_today"] = 0 + user_doc["ad_credits_earned_today"] = 0 + user_doc["last_ad_date"] = today + + user_data = UserData(**user_doc) + return { + "user_id": user_data.user_id, + "email": user_data.email, + "credits": user_data.credits, + "plan": user_data.plan, + "ads_watched_today": user_data.ads_watched_today, + "ad_credits_earned_today": user_data.ad_credits_earned_today + } + +@app.post("/add-reward-credits") +async def add_reward_credits(user: dict = Depends(verify_token)): + user_id = user["sub"] + today = datetime.utcnow().date().isoformat() + + user_doc = await db.users.find_one({"user_id": user_id}) + if not user_doc: + raise HTTPException(status_code=404, detail="User not found") + + # Reset logic if it's a new day + if user_doc.get("last_ad_date") != today: + user_doc["ads_watched_today"] = 0 + user_doc["ad_credits_earned_today"] = 0 + user_doc["last_ad_date"] = today + + if user_doc.get("ads_watched_today", 0) >= 5: + raise HTTPException(status_code=429, detail="Daily ad limit reached (5/5)") + + # Increment ads watched + new_ads_watched = user_doc.get("ads_watched_today", 0) + 1 + + # Check if we should grant credit (limit 3 per day) + credit_granted = False + update_data = { + "ads_watched_today": new_ads_watched, + "last_ad_date": today + } + + inc_data = {} + if user_doc.get("ad_credits_earned_today", 0) < 3: + inc_data["credits"] = 1 + inc_data["ad_credits_earned_today"] = 1 + credit_granted = True + + update_op = {"$set": update_data} + if inc_data: + update_op["$inc"] = inc_data + + result = await db.users.find_one_and_update( + {"user_id": user_id}, + update_op, + return_document=True + ) + + if not result: + raise HTTPException(status_code=500, detail="Failed to update credits") + + return { + "success": True, + "credits": result.get("credits", 0), + "credit_granted": credit_granted, + "ads_watched_today": new_ads_watched, + "ad_credits_earned_today": result.get("ad_credits_earned_today", 0), + "message": "Reward credit added successfully" if credit_granted else "Ad watched, but daily credit limit reached" + } + +@app.on_event("startup") +async def startup_event(): + await runware.connect() + +@app.on_event("shutdown") +async def shutdown_event(): + await runware.disconnect() + if 'client' in globals() and client: + client.close() + +class T2IRequest(BaseModel): + prompt: str + aspect_ratio: Optional[str] = "1:1" # 1:1, 16:9, 9:16, 4:3, 3:4 + style: Optional[str] = None + quality: Optional[str] = "medium" # low, medium, high + +class I2IRequest(BaseModel): + prompt: str + source_url: str + aspect_ratio: Optional[str] = "1:1" + style: Optional[str] = None + quality: Optional[str] = "medium" + +class VideoRequest(BaseModel): + prompt: str + source_url: Optional[str] = None + aspect_ratio: Optional[str] = "16:9" + duration: Optional[int] = 2 + audio: Optional[bool] = False + +def get_dimensions(aspect_ratio: str): + ratios = { + "1:1": (1024, 1024), + "16:9": (1344, 768), + "9:16": (768, 1344), + "4:3": (1152, 864), + "3:4": (864, 1152) + } + return ratios.get(aspect_ratio, (1024, 1024)) + +def get_steps(quality: str): + quality_map = { + "low": 20, + "medium": 30, + "high": 50 + } + return quality_map.get(quality, 30) + +@app.post("/text-to-image") +async def text_to_image(request: T2IRequest, user: dict = Depends(check_limit_and_credits)): + try: + user_id = user["sub"] + width, height = get_dimensions(request.aspect_ratio) + steps = get_steps(request.quality) + + # Combine prompt with style + final_prompt = request.prompt + if request.style: + final_prompt = f"{request.prompt}, in the style of {request.style}" + + request_image = IImageInference( + positivePrompt=final_prompt, + negativePrompt="blurry, low quality, distorted face, bad anatomy", + model="runware:400@6", # FLUX.2 [klein] 9B KV + numberResults=1, + height=height, + width=width, + steps=steps, + includeCost=True + ) + images = await runware.imageInference(requestImage=request_image) + if not images: + raise HTTPException(status_code=500, detail="No image generated") + + # Track usage after successful generation + remaining_credits = await track_usage(user_id) + + return { + "url": images[0].imageURL, + "cost": images[0].cost, + "remaining_credits": remaining_credits + } + except Exception as e: + print(f"ERROR in text_to_image: {str(e)}") + import traceback + traceback.print_exc() + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/image-to-image") +async def image_to_image(request: I2IRequest, user: dict = Depends(check_limit_and_credits)): + try: + user_id = user["sub"] + width, height = get_dimensions(request.aspect_ratio) + + final_prompt = request.prompt + if request.style: + final_prompt = f"{request.prompt}, in the style of {request.style}" + + # FLUX.2 uses referenceImages for image-to-image instead of seedImage + # Note: FLUX models do not support the 'strength' parameter. + request_image = IImageInference( + positivePrompt=final_prompt, + negativePrompt="blurry, low quality, distorted face, bad anatomy", + model="runware:400@6", # FLUX.2 [klein] 9B KV + referenceImages=[request.source_url], + numberResults=1, + height=height, + width=width, + steps=4, # FLUX turbo variants use 4 steps + CFGScale=4.0, # FLUX standard CFG is around 4 + outputQuality=85, + includeCost=True + ) + images = await runware.imageInference(requestImage=request_image) + if not images: + raise HTTPException(status_code=500, detail="No image generated") + + # Track usage after successful generation + remaining_credits = await track_usage(user_id) + + return { + "url": images[0].imageURL, + "cost": images[0].cost, + "remaining_credits": remaining_credits + } + except Exception as e: + print(f"ERROR in image_to_image: {str(e)}") + import traceback + traceback.print_exc() + raise HTTPException(status_code=500, detail=str(e)) + +# async def track_video_usage(user_id: str): +# today = datetime.utcnow().date().isoformat() +# # Update Daily Limit +# await db.usage.update_one( +# {"user_id": user_id, "date": today}, +# {"$inc": {"count": 1}}, +# upsert=True +# ) +# # Deduct 5 Credits and get new count +# result = await db.users.find_one_and_update( +# {"user_id": user_id}, +# {"$inc": {"credits": -5}}, +# return_document=True +# ) +# return result.get("credits", 0) if result else 0 +# +# @app.post("/generate-video") +# async def generate_video(request: VideoRequest, user: dict = Depends(check_limit_and_credits)): +# try: +# user_id = user["sub"] +# +# # Check if user has at least 5 credits +# user_doc = await db.users.find_one({"user_id": user_id}) +# if not user_doc or user_doc.get("credits", 0) < 5: +# raise HTTPException(status_code=402, detail="Insufficient credits for video generation (Requires 5 credits)") +# +# width, height = get_dimensions(request.aspect_ratio) +# +# # Note: "runware:bytedance-seedance-1-5-pro" or just "bytedance-seedance-1-5-pro" +# # Often Runware models might need "runware:xx" but "bytedance-seedance-1-5-pro" works as model string. +# +# video_args = { +# "positivePrompt": request.prompt, +# "negativePrompt": "blurry, low quality, distorted", +# "model": "runware:bytedance-seedance-1-5-pro", +# "height": height, +# "width": width, +# "duration": request.duration, +# "includeCost": True +# } +# +# if request.source_url: +# video_args["referenceImages"] = [request.source_url] +# +# request_video = IVideoInference(**video_args) +# +# # videoInference returns a list of videos +# videos = await runware.videoInference(requestVideo=request_video) +# if not videos: +# raise HTTPException(status_code=500, detail="No video generated") +# +# # Track usage after successful generation (cost 5 credits) +# remaining_credits = await track_video_usage(user_id) +# +# return { +# "url": videos[0].videoURL, +# "cost": videos[0].cost, +# "remaining_credits": remaining_credits +# } +# except Exception as e: +# print(f"ERROR in generate_video: {str(e)}") +# import traceback +# traceback.print_exc() +# raise HTTPException(status_code=500, detail=str(e)) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/clerk_public.pem b/clerk_public.pem new file mode 100644 index 0000000..556dd68 --- /dev/null +++ b/clerk_public.pem @@ -0,0 +1,9 @@ +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzspx8CSwc9+7DZL2CaVR +ksT6DU+Zx0E1itDSjRE6DtJiIFCODHST6rBSsR8T2Xeub7p6Zb+0fH4lWFkIJbud +E3uhlnXSRNAit6l81KZgz2PMCpwYtAhnYFnJ9XgRxrDLzLq+XXTkN9lFTef+8cg6 +CQWwiOUfcVAQsY/ZSCz83CrxdKe9lMYERBNQh/C/mxVo6u2UfccE8YikErr5Y8jD +JWGL+1pufeT1R9byNfPJM1nUvGhb+gqdxmhdHztc0YR1rUghoeJbZCO0VYWkgYZH +xb2pxiN0J08env3TXe1dD6bh+4BHRoh6WnExBKZD9mkYwyaZ4/RZCna2d+N7VGsv +nQIDAQAB +-----END PUBLIC KEY----- diff --git a/models.py b/models.py new file mode 100644 index 0000000..2eb7236 --- /dev/null +++ b/models.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel, Field, ConfigDict +from typing import Optional +from datetime import datetime + +class UserUsage(BaseModel): + model_config = ConfigDict(extra='ignore', populate_by_name=True) + + user_id: str + date: str # ISO format YYYY-MM-DD + count: int = Field(default=0, ge=0) + +class UserData(BaseModel): + model_config = ConfigDict(extra='ignore', populate_by_name=True) + + user_id: str + email: Optional[str] = None + credits: int = Field(default=0, ge=0) + plan: str = "free" # "free", "pro", etc. + created_at: datetime = Field(default_factory=datetime.utcnow) + + # Ad Tracking Limits + ads_watched_today: int = Field(default=0, ge=0) + ad_credits_earned_today: int = Field(default=0, ge=0) + last_ad_date: Optional[str] = None # YYYY-MM-DD diff --git a/requirements.txt b/requirements.txt index 4d38ca9..a083245 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,8 @@ aiofiles>=23.2.1 httpx>=0.27.0 python-dotenv>=1.0.1 websockets>=12.0 +fastapi>=0.115.0 +uvicorn>=0.30.0 +motor>=3.6.0 +python-jose[cryptography]>=3.3.0 +pydantic-settings>=2.6.0 diff --git a/runware_demo.py b/runware_demo.py new file mode 100644 index 0000000..348088d --- /dev/null +++ b/runware_demo.py @@ -0,0 +1,66 @@ +import asyncio +import os +from runware import Runware, IImageInference +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +async def text_to_image(runware: Runware): + print("\n--- Running Text-to-Image ---") + request_image = IImageInference( + positivePrompt="a futuristic city with neon lights, highly detailed, digital art", + model="runware:400@6", # FLUX.2 [klein] 9B KV + numberResults=1, + height=1024, + width=1024, + ) + + images = await runware.imageInference(requestImage=request_image) + for image in images: + print(f"Generated Image URL: {image.imageURL}") + return images[0].imageURL if images else None + +async def image_to_image(runware: Runware, source_image_url: str): + print("\n--- Running Image-to-Image ---") + # Image-to-Image uses seedImage and strength + # strength: 0.0 to 1.0 (lower means closer to source image, higher means closer to prompt) + request_image = IImageInference( + positivePrompt="same futuristic city but in daytime with bright sunlight", + model="runware:400@6", # FLUX.2 [klein] 9B KV + seedImage=source_image_url, + strength=0.6, + numberResults=1, + height=1024, + width=1024, + ) + + images = await runware.imageInference(requestImage=request_image) + for image in images: + print(f"Img2Img Result URL: {image.imageURL}") + +async def main(): + api_key = os.getenv("RUNWARE_API_KEY") + if not api_key: + print("Error: RUNWARE_API_KEY not found in .env file.") + return + + runware = Runware(api_key=api_key) + await runware.connect() + + try: + # 1. Text to Image + generated_url = await text_to_image(runware) + + # 2. Image to Image (using the result of the first generation as source) + if generated_url: + await image_to_image(runware, generated_url) + else: + # Fallback if text-to-image failed to return a URL + print("Skipping Img2Img as no source image was generated.") + + finally: + await runware.disconnect() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scratch/check_clerk.py b/scratch/check_clerk.py new file mode 100644 index 0000000..35897ca --- /dev/null +++ b/scratch/check_clerk.py @@ -0,0 +1,22 @@ +import os +import asyncio +from motor.motor_asyncio import AsyncIOMotorClient +from dotenv import load_dotenv + +load_dotenv() + +async def check_clerk_users(): + mongo_url = os.getenv("MONGO_URL") + client = AsyncIOMotorClient(mongo_url) + db = client.get_database() + + count = await db.users.count_documents({"user_id": {"$exists": True}}) + print(f"Users with user_id: {count}") + + sample = await db.users.find_one({"user_id": {"$exists": True}}) + print(f"Sample Clerk User: {sample}") + + client.close() + +if __name__ == "__main__": + asyncio.run(check_clerk_users()) diff --git a/scratch/check_db_credits.py b/scratch/check_db_credits.py new file mode 100644 index 0000000..35c7ba8 --- /dev/null +++ b/scratch/check_db_credits.py @@ -0,0 +1,23 @@ +import os +import asyncio +from motor.motor_asyncio import AsyncIOMotorClient +from dotenv import load_dotenv + +load_dotenv() + +async def check_credits(): + mongo_url = os.getenv("MONGO_URL") + client = AsyncIOMotorClient(mongo_url) + db = client.get_database() + + # Get all users to see their credits + users = await db.users.find().to_list(length=10) + print("--- User Credits in DB ---") + for user in users: + print(f"User ID: {user.get('user_id')} | Email: {user.get('email')} | Credits: {user.get('credits')}") + print("--------------------------") + + client.close() + +if __name__ == "__main__": + asyncio.run(check_credits()) diff --git a/scratch/diagnose.py b/scratch/diagnose.py new file mode 100644 index 0000000..212199b --- /dev/null +++ b/scratch/diagnose.py @@ -0,0 +1,57 @@ +import os +import asyncio +from motor.motor_asyncio import AsyncIOMotorClient +from runware import Runware +from dotenv import load_dotenv + +load_dotenv() + +async def diagnose(): + print("--- Diagnostics Started ---") + + # 1. Check Runware + api_key = os.getenv("RUNWARE_API_KEY") + print(f"Runware API Key present: {bool(api_key)}") + if api_key: + try: + runware = Runware(api_key=api_key) + await runware.connect() + print("✅ Runware connection successful") + await runware.disconnect() + except Exception as e: + print(f"❌ Runware connection failed: {e}") + + # 2. Check MongoDB + mongo_url = os.getenv("MONGO_URL") + print(f"MongoDB URL present: {bool(mongo_url)}") + if mongo_url: + try: + client = AsyncIOMotorClient(mongo_url) + # Try to ping the database + await client.admin.command('ping') + print("✅ MongoDB connection successful") + + db = client.get_database() + print(f"Using database: {db.name}") + + # Check collections + collections = await db.list_collection_names() + print(f"Collections found: {collections}") + + client.close() + except Exception as e: + print(f"❌ MongoDB connection failed: {e}") + + # 3. Check Clerk Key + clerk_key = os.getenv("CLERK_JWT_PUBLIC_KEY") + print(f"Clerk Public Key present: {bool(clerk_key)}") + if clerk_key: + if "BEGIN PUBLIC KEY" in clerk_key and "END PUBLIC KEY" in clerk_key: + print("✅ Clerk Public Key format looks valid (PEM)") + else: + print("❌ Clerk Public Key format invalid (missing PEM headers)") + + print("--- Diagnostics Finished ---") + +if __name__ == "__main__": + asyncio.run(diagnose()) diff --git a/scratch/find_clerk_user.py b/scratch/find_clerk_user.py new file mode 100644 index 0000000..5f1094e --- /dev/null +++ b/scratch/find_clerk_user.py @@ -0,0 +1,28 @@ +import os +import asyncio +from motor.motor_asyncio import AsyncIOMotorClient +from dotenv import load_dotenv + +load_dotenv() + +async def find_clerk_user(): + mongo_url = os.getenv("MONGO_URL") + client = AsyncIOMotorClient(mongo_url) + db = client.get_database() + + # Find any user that has a non-null user_id + clerk_users = await db.users.find({"user_id": {"$exists": True, "$ne": None}}).to_list(length=5) + print(f"Found {len(clerk_users)} Clerk users.") + for u in clerk_users: + print(f"User: {u.get('email')} | Credits: {u.get('credits')} | ID: {u.get('user_id')}") + + # Also check the 'usage' collection + usage = await db.usage.find().to_list(length=5) + print(f"\nFound {len(usage)} usage records.") + for res in usage: + print(f"Usage: {res}") + + client.close() + +if __name__ == "__main__": + asyncio.run(find_clerk_user()) diff --git a/scratch/find_model.py b/scratch/find_model.py new file mode 100644 index 0000000..89d0df7 --- /dev/null +++ b/scratch/find_model.py @@ -0,0 +1,30 @@ +import asyncio +import os +from runware import Runware, IModelSearch +from dotenv import load_dotenv + +load_dotenv() + +async def main(): + api_key = os.getenv("RUNWARE_API_KEY") + if not api_key: + print("RUNWARE_API_KEY not found in .env") + return + + runware = Runware(api_key=api_key) + await runware.connect() + + print("Searching for 'flux klein' models...") + search_results = await runware.modelSearch( + payload=IModelSearch(search="flux 2") + ) + + for model in search_results.results: + print(f"Name: {model.name}") + print(f"AIR: {model.air}") + print("-" * 20) + + await runware.disconnect() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scratch/inspect_db.py b/scratch/inspect_db.py new file mode 100644 index 0000000..54c64ac --- /dev/null +++ b/scratch/inspect_db.py @@ -0,0 +1,19 @@ +import os +import asyncio +from motor.motor_asyncio import AsyncIOMotorClient +from dotenv import load_dotenv + +load_dotenv() + +async def inspect(): + mongo_url = os.getenv("MONGO_URL") + client = AsyncIOMotorClient(mongo_url) + db = client.get_database() + + user = await db.users.find_one() + print(f"Sample User Document: {user}") + + client.close() + +if __name__ == "__main__": + asyncio.run(inspect())