replicate-integration
π―Skillfrom guaderrama/cabo-health-ai
replicate-integration skill from guaderrama/cabo-health-ai
Part of
guaderrama/cabo-health-ai(7 items)
Installation
pip install replicate httpx python-dotenv pydanticSkill Details
Integrate Replicate API for AI model deployment. Use when generating images with Flux, SDXL, or custom LoRA models via Replicate.
Overview
# Replicate API Integration
Purpose
Deploy and run AI models in production using Replicate's cloud platform. Specialized for image generation with Flux Dev, SDXL, custom LoRA fine-tuning, and prediction polling patterns.
When to Use
- Generating images with Flux Dev, SDXL, or custom models
- Fine-tuning models with custom LoRA weights
- Running predictions with polling/webhook patterns
- Deploying custom models to production
- Managing long-running AI workloads
Architecture Pattern
Project Structure
```
backend/
βββ services/
β βββ replicate_service.py # Main Replicate client
β βββ model_service.py # Model-specific logic
βββ models/
β βββ replicate_models.py # Pydantic models
βββ config/
β βββ replicate_config.py # Configuration
βββ utils/
βββ polling.py # Polling utilities
```
Installation
```bash
pip install replicate httpx python-dotenv pydantic
```
Environment Setup
```bash
# .env
REPLICATE_API_TOKEN=r8_...
FRONTEND_URL=http://localhost:3000
DEFAULT_MODEL=black-forest-labs/flux-dev
LORA_MODEL_ID=your-username/your-model-id
```
Quick Start
Basic Image Generation
```python
import replicate
import os
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
# Simple prediction
output = client.run(
"black-forest-labs/flux-dev",
input={
"prompt": "A serene mountain landscape at sunset",
"num_outputs": 1,
"aspect_ratio": "16:9"
}
)
print(f"Generated image: {output[0]}")
```
Async Pattern with Polling
```python
import replicate
import asyncio
from typing import List
async def generate_images_async(
prompt: str,
num_images: int = 4
) -> List[str]:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
# Start prediction
prediction = client.predictions.create(
version="black-forest-labs/flux-dev",
input={
"prompt": prompt,
"num_outputs": num_images,
"guidance_scale": 3.5,
"num_inference_steps": 28
}
)
# Poll for completion
while prediction.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(1)
prediction = client.predictions.get(prediction.id)
if prediction.status == "succeeded":
return prediction.output
else:
raise Exception(f"Prediction failed: {prediction.error}")
```
Flux Dev Integration
Complete Flux Dev Configuration
```python
from pydantic import BaseModel, Field
from typing import Literal
class FluxDevInput(BaseModel):
prompt: str = Field(description="Text description of image to generate")
aspect_ratio: Literal["1:1", "16:9", "21:9", "3:2", "2:3", "4:5", "5:4", "9:16", "9:21"] = "1:1"
num_outputs: int = Field(ge=1, le=4, default=1)
num_inference_steps: int = Field(ge=1, le=50, default=28)
guidance_scale: float = Field(ge=0, le=10, default=3.5)
output_format: Literal["webp", "jpg", "png"] = "webp"
output_quality: int = Field(ge=0, le=100, default=80)
seed: int | None = None
disable_safety_checker: bool = False
async def generate_flux_images(input: FluxDevInput) -> List[str]:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
output = await client.async_run(
"black-forest-labs/flux-dev",
input=input.model_dump()
)
return output
```
Flux Dev Best Practices
```python
# Optimal settings for portraits
PORTRAIT_CONFIG = {
"aspect_ratio": "2:3",
"num_inference_steps": 30,
"guidance_scale": 4.0,
"output_format": "webp",
"output_quality": 90
}
# Optimal settings for landscapes
LANDSCAPE_CONFIG = {
"aspect_ratio": "16:9",
"num_inference_steps": 28,
"guidance_scale": 3.5,
"output_format": "webp",
"output_quality": 85
}
# Batch generation pattern
async def batch_generate(prompts: List[str]) -> List[List[str]]:
tasks = [generate_flux_images(FluxDevInput(prompt=p)) for p in prompts]
return await asyncio.gather(*tasks)
```
LoRA Fine-Tuning Integration
Using Custom LoRA Models
```python
class LoRAInput(BaseModel):
prompt: str
lora_scale: float = Field(ge=0, le=1, default=1.0, description="LoRA influence strength")
trigger_word: str | None = Field(default=None, description="Special token for identity")
num_outputs: int = Field(ge=1, le=10, default=1)
async def generate_with_lora(
model_id: str, # e.g., "daniel-carreon/danielcarrong"
input: LoRAInput
) -> List[str]:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
# Ensure trigger word is in prompt
prompt = input.prompt
if input.trigger_word and input.trigger_word not in prompt:
prompt = f"{input.trigger_word} {prompt}"
output = await client.async_run(
model_id,
input={
"prompt": prompt,
"lora_scale": input.lora_scale,
"num_outputs": input.num_outputs,
"num_inference_steps": 28,
"guidance_scale": 3.5
}
)
return output
```
Training Custom LoRA
```python
async def train_lora(
images_zip_url: str,
trigger_word: str,
steps: int = 1000
) -> str:
"""Train custom LoRA model on Replicate"""
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
training = client.trainings.create(
version="ostris/flux-dev-lora-trainer",
input={
"input_images": images_zip_url,
"trigger_word": trigger_word,
"steps": steps,
"lora_rank": 16,
"optimizer": "adamw8bit",
"batch_size": 1,
"learning_rate": 4e-4,
"caption_prefix": f"a photo of {trigger_word}"
},
destination=f"{os.getenv('REPLICATE_USERNAME')}/my-lora-model"
)
# Wait for training completion
while training.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(30)
training = client.trainings.get(training.id)
if training.status == "succeeded":
return training.output # Model URL
else:
raise Exception(f"Training failed: {training.error}")
```
Polling Patterns
Robust Polling with Retry
```python
import asyncio
from typing import Callable, Any
class PollConfig(BaseModel):
max_wait: int = 300 # 5 minutes
poll_interval: float = 1.0 # 1 second
timeout_multiplier: float = 1.5 # Backoff factor
async def poll_prediction(
prediction_id: str,
on_progress: Callable[[float], None] | None = None
) -> Any:
"""Poll Replicate prediction with exponential backoff"""
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
start_time = asyncio.get_event_loop().time()
poll_interval = 1.0
while True:
prediction = client.predictions.get(prediction_id)
# Report progress
if on_progress and hasattr(prediction, 'logs'):
progress = extract_progress(prediction.logs)
on_progress(progress)
# Check status
if prediction.status == "succeeded":
return prediction.output
elif prediction.status in ["failed", "canceled"]:
raise Exception(f"Prediction {prediction.status}: {prediction.error}")
# Timeout check
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > 300: # 5 minutes
raise TimeoutError(f"Prediction timeout after {elapsed}s")
# Exponential backoff
await asyncio.sleep(poll_interval)
poll_interval = min(poll_interval * 1.5, 5.0)
def extract_progress(logs: str) -> float:
"""Extract progress from logs (0.0 to 1.0)"""
# Example: "Progress: 50%"
import re
match = re.search(r"Progress: (\d+)%", logs or "")
return float(match.group(1)) / 100 if match else 0.0
```
Webhook Pattern (Production)
```python
from fastapi import FastAPI, Request
app = FastAPI()
@app.post("/webhooks/replicate")
async def handle_webhook(request: Request):
"""Handle Replicate webhook callback"""
payload = await request.json()
prediction_id = payload["id"]
status = payload["status"]
if status == "succeeded":
output = payload["output"]
# Process completed prediction
await save_results(prediction_id, output)
elif status == "failed":
error = payload["error"]
# Handle error
await log_error(prediction_id, error)
return {"status": "received"}
# Start prediction with webhook
def create_prediction_with_webhook(prompt: str) -> str:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
prediction = client.predictions.create(
version="black-forest-labs/flux-dev",
input={"prompt": prompt},
webhook=f"{os.getenv('BACKEND_URL')}/webhooks/replicate",
webhook_events_filter=["completed"]
)
return prediction.id
```
Error Handling
Comprehensive Error Handling
```python
from enum import Enum
class ReplicateError(Exception):
"""Base Replicate error"""
pass
class RateLimitError(ReplicateError):
"""Rate limit exceeded"""
pass
class ModelNotFoundError(ReplicateError):
"""Model not found"""
pass
async def safe_replicate_call(
model: str,
input: dict,
max_retries: int = 3
) -> Any:
"""Call Replicate with retry logic"""
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
for attempt in range(max_retries):
try:
output = await client.async_run(model, input=input)
return output
except replicate.exceptions.ModelError as e:
if "not found" in str(e).lower():
raise ModelNotFoundError(f"Model {model} not found")
raise
except replicate.exceptions.ReplicateError as e:
if "rate limit" in str(e).lower():
if attempt < max_retries - 1:
await asyncio.sleep(2 ** attempt) # Exponential backoff
continue
raise RateLimitError("Rate limit exceeded")
raise ReplicateError(f"Replicate error: {e}")
except Exception as e:
if attempt < max_retries - 1:
await asyncio.sleep(1)
continue
raise
```
FastAPI Integration
Complete Replicate Endpoint
```python
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
app = FastAPI()
class GenerateRequest(BaseModel):
prompt: str
num_images: int = 4
use_lora: bool = False
lora_model_id: str | None = None
class GenerateResponse(BaseModel):
prediction_id: str
status: str
images: List[str] | None = None
@app.post("/generate", response_model=GenerateResponse)
async def generate_images(request: GenerateRequest):
"""Generate images using Replicate"""
try:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
# Select model
model = request.lora_model_id if request.use_lora else "black-forest-labs/flux-dev"
# Create prediction
prediction = client.predictions.create(
version=model,
input={
"prompt": request.prompt,
"num_outputs": request.num_images
}
)
return GenerateResponse(
prediction_id=prediction.id,
status=prediction.status
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/predictions/{prediction_id}")
async def get_prediction(prediction_id: str):
"""Check prediction status"""
try:
client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
prediction = client.predictions.get(prediction_id)
return GenerateResponse(
prediction_id=prediction.id,
status=prediction.status,
images=prediction.output if prediction.status == "succeeded" else None
)
except Exception as e:
raise HTTPException(status_code=404, detail="Prediction not found")
```
Rate Limiting
Rate Limit Management
```python
from collections import deque
from datetime import datetime, timedelta
class RateLimiter:
def __init__(self, max_requests: int = 50, window: int = 60):
self.max_requests = max_requests
self.window = timedelta(seconds=window)
self.requests: deque[datetime] = deque()
async def acquire(self):
"""Wait if rate limit reached"""
now = datetime.now()
# Remove old requests
while self.requests and self.requests[0] < now - self.window:
self.requests.popleft()
# Check limit
if len(self.requests) >= self.max_requests:
wait_time = (self.requests[0] + self.window - now).total_seconds()
if wait_time > 0:
await asyncio.sleep(wait_time)
self.requests.append(now)
# Usage
limiter = RateLimiter(max_requests=50, window=60)
async def generate_with_limit(prompt: str):
await limiter.acquire()
return await generate_flux_images(FluxDevInput(prompt=prompt))
```
Best Practices
- Always use async/await for non-blocking I/O
- Implement polling with backoff to avoid rate limits
- Use webhooks in production for long-running tasks
- Cache prediction results to avoid redundant API calls
- Monitor costs - log prediction IDs and metrics
- Handle errors gracefully with retry logic
- Use typed inputs with Pydantic models
- Set timeouts for all predictions
- Validate outputs before returning to users
- Store metadata for debugging and analytics
Common Pitfalls
β Don't: Poll too frequently (wastes API calls)
β Do: Use exponential backoff (1s β 1.5s β 2.25s β ...)
β Don't: Forget to handle prediction failures
β Do: Check status and handle errors
β Don't: Hardcode model versions
β Do: Use environment variables for flexibility
β Don't: Block on predictions in API endpoints
β Do: Return prediction ID immediately, poll separately
Complete Example: Production Service
```python
from fastapi import FastAPI, BackgroundTasks
import replicate
from pydantic import BaseModel
from typing import List
import asyncio
app = FastAPI()
class ImageGenerationService:
def __init__(self):
self.client = replicate.Client(api_token=os.getenv("REPLICATE_API_TOKEN"))
self.rate_limiter = RateLimiter(max_requests=50, window=60)
async def generate(
self,
prompt: str,
num_images: int = 4,
model_id: str = "black-forest-labs/flux-dev"
) -> List[str]:
"""Generate images with rate limiting and error handling"""
await self.rate_limiter.acquire()
try:
prediction = self.client.predictions.create(
version=model_id,
input={
"prompt": prompt,
"num_outputs": num_images,
"num_inference_steps": 28,
"guidance_scale": 3.5
}
)
# Poll for completion
output = await poll_prediction(
prediction.id,
on_progress=lambda p: print(f"Progress: {p*100:.0f}%")
)
return output
except Exception as e:
print(f"Generation failed: {e}")
raise
# Global service instance
service = ImageGenerationService()
@app.post("/api/generate")
async def generate_endpoint(request: GenerateRequest):
try:
images = await service.generate(
prompt=request.prompt,
num_images=request.num_images
)
return {"status": "success", "images": images}
except Exception as e:
return {"status": "error", "message": str(e)}
```
Resources
- [Replicate Docs](https://replicate.com/docs)
- [Flux Dev Model](https://replicate.com/black-forest-labs/flux-dev)
- [LoRA Training Guide](https://replicate.com/docs/guides/fine-tune-a-language-model)
- [Python Client](https://github.com/replicate/replicate-python)
More from this repository6
agent-builder-pydantic-ai skill from guaderrama/cabo-health-ai
nano-banana-image-combine skill from guaderrama/cabo-health-ai
Guides developers through Next.js 16's features, migration strategies, and performance optimizations for modern web applications.
Builds streaming AI chat interfaces in Next.js using Vercel AI SDK and OpenRouter, enabling type-safe tool calling and multi-provider support.
Guides developers through creating custom skills with structured templates, executable scripts, and validation for extending Claude's capabilities.
Enables secure user authentication and persistent conversation memory for AI applications using Supabase PostgreSQL, supporting cross-device sync and multi-tenant setups.