Tasks: - image_generate: Generate image from prompt - image_variant: Generate variant of existing image - image_upscale: Increase resolution Models: SDXL, Flux, SDXL-Turbo RunPod Serverless Handler
229 lines
6.6 KiB
Python
229 lines
6.6 KiB
Python
"""
|
|
THE FACTORY - Trabajo Iterativo Generativo
|
|
RunPod Serverless Handler
|
|
|
|
Tareas:
|
|
- image_generate: Genera imagen desde prompt
|
|
- image_variant: Genera variante de imagen existente
|
|
- image_upscale: Aumenta resolución
|
|
"""
|
|
|
|
import runpod
|
|
import base64
|
|
import os
|
|
from datetime import datetime
|
|
from io import BytesIO
|
|
|
|
# Force CUDA device
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
# Modelos disponibles
|
|
MODELS = {
|
|
"sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
|
|
"flux": "black-forest-labs/FLUX.1-schnell",
|
|
"sdxl-turbo": "stabilityai/sdxl-turbo"
|
|
}
|
|
|
|
# Lazy loading de modelos
|
|
_loaded_models = {}
|
|
|
|
|
|
def get_model(model_name: str):
|
|
"""Carga modelo bajo demanda."""
|
|
global _loaded_models
|
|
|
|
if model_name not in _loaded_models:
|
|
try:
|
|
import torch
|
|
from diffusers import AutoPipelineForText2Image
|
|
|
|
# Force CUDA
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
model_id = MODELS.get(model_name, MODELS["sdxl-turbo"])
|
|
|
|
pipe = AutoPipelineForText2Image.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.float16,
|
|
variant="fp16",
|
|
use_safetensors=True
|
|
)
|
|
pipe = pipe.to(device)
|
|
|
|
_loaded_models[model_name] = pipe
|
|
except Exception as e:
|
|
return None, str(e)
|
|
|
|
return _loaded_models[model_name], None
|
|
|
|
|
|
def generate_image(prompt: str, model: str = "sdxl-turbo",
|
|
width: int = 1024, height: int = 1024,
|
|
steps: int = 4, guidance: float = 0.0) -> dict:
|
|
"""Genera imagen desde prompt."""
|
|
pipe, error = get_model(model)
|
|
if error:
|
|
return {"error": f"Model load failed: {error}"}
|
|
|
|
try:
|
|
image = pipe(
|
|
prompt=prompt,
|
|
width=width,
|
|
height=height,
|
|
num_inference_steps=steps,
|
|
guidance_scale=guidance
|
|
).images[0]
|
|
|
|
# Convertir a base64
|
|
buffer = BytesIO()
|
|
image.save(buffer, format="PNG")
|
|
img_base64 = base64.b64encode(buffer.getvalue()).decode()
|
|
|
|
return {
|
|
"image_base64": img_base64,
|
|
"width": width,
|
|
"height": height,
|
|
"model": model
|
|
}
|
|
except Exception as e:
|
|
return {"error": str(e)}
|
|
|
|
|
|
def generate_variant(image_base64: str, prompt: str,
|
|
strength: float = 0.5, model: str = "sdxl-turbo") -> dict:
|
|
"""Genera variante de imagen existente."""
|
|
try:
|
|
import torch
|
|
from diffusers import AutoPipelineForImage2Image
|
|
from PIL import Image
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Decodificar imagen
|
|
img_data = base64.b64decode(image_base64)
|
|
init_image = Image.open(BytesIO(img_data)).convert("RGB")
|
|
|
|
model_id = MODELS.get(model, MODELS["sdxl-turbo"])
|
|
pipe = AutoPipelineForImage2Image.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.float16,
|
|
variant="fp16"
|
|
).to(device)
|
|
|
|
image = pipe(
|
|
prompt=prompt,
|
|
image=init_image,
|
|
strength=strength,
|
|
num_inference_steps=4
|
|
).images[0]
|
|
|
|
buffer = BytesIO()
|
|
image.save(buffer, format="PNG")
|
|
img_base64 = base64.b64encode(buffer.getvalue()).decode()
|
|
|
|
return {"image_base64": img_base64}
|
|
except Exception as e:
|
|
return {"error": str(e)}
|
|
|
|
|
|
def upscale_image(image_base64: str, scale: int = 2) -> dict:
|
|
"""Upscale imagen usando PIL LANCZOS."""
|
|
try:
|
|
from PIL import Image
|
|
|
|
img_data = base64.b64decode(image_base64)
|
|
image = Image.open(BytesIO(img_data))
|
|
|
|
new_size = (image.width * scale, image.height * scale)
|
|
upscaled = image.resize(new_size, Image.LANCZOS)
|
|
|
|
buffer = BytesIO()
|
|
upscaled.save(buffer, format="PNG")
|
|
img_base64 = base64.b64encode(buffer.getvalue()).decode()
|
|
|
|
return {
|
|
"image_base64": img_base64,
|
|
"width": new_size[0],
|
|
"height": new_size[1],
|
|
"scale": scale
|
|
}
|
|
except Exception as e:
|
|
return {"error": str(e)}
|
|
|
|
|
|
def handler(job):
|
|
"""
|
|
Handler principal de THE FACTORY.
|
|
|
|
Input esperado:
|
|
{
|
|
"task": "image_generate", # Tarea a ejecutar
|
|
"prompt": "...", # Prompt para generación
|
|
"model": "sdxl-turbo", # Modelo a usar
|
|
"width": 1024, # Ancho (opcional)
|
|
"height": 1024, # Alto (opcional)
|
|
"image_base64": "...", # Para variant/upscale
|
|
"strength": 0.5, # Para variant
|
|
"scale": 2 # Para upscale
|
|
}
|
|
|
|
Tasks disponibles:
|
|
- image_generate: Genera imagen desde prompt
|
|
- image_variant: Genera variante
|
|
- image_upscale: Aumenta resolución
|
|
"""
|
|
job_input = job.get("input", {})
|
|
trace_id = job_input.get("trace_id", str(datetime.utcnow().timestamp()))
|
|
task = job_input.get("task", "image_generate")
|
|
|
|
result = {"trace_id": trace_id, "task": task}
|
|
|
|
if task == "image_generate":
|
|
prompt = job_input.get("prompt")
|
|
if not prompt:
|
|
return {"error": "prompt es requerido para image_generate"}
|
|
|
|
gen_result = generate_image(
|
|
prompt=prompt,
|
|
model=job_input.get("model", "sdxl-turbo"),
|
|
width=job_input.get("width", 1024),
|
|
height=job_input.get("height", 1024),
|
|
steps=job_input.get("steps", 4),
|
|
guidance=job_input.get("guidance", 0.0)
|
|
)
|
|
result.update(gen_result)
|
|
|
|
elif task == "image_variant":
|
|
image_base64 = job_input.get("image_base64")
|
|
prompt = job_input.get("prompt", "")
|
|
if not image_base64:
|
|
return {"error": "image_base64 es requerido para image_variant"}
|
|
|
|
var_result = generate_variant(
|
|
image_base64=image_base64,
|
|
prompt=prompt,
|
|
strength=job_input.get("strength", 0.5),
|
|
model=job_input.get("model", "sdxl-turbo")
|
|
)
|
|
result.update(var_result)
|
|
|
|
elif task == "image_upscale":
|
|
image_base64 = job_input.get("image_base64")
|
|
if not image_base64:
|
|
return {"error": "image_base64 es requerido para image_upscale"}
|
|
|
|
up_result = upscale_image(
|
|
image_base64=image_base64,
|
|
scale=job_input.get("scale", 2)
|
|
)
|
|
result.update(up_result)
|
|
|
|
else:
|
|
return {"error": f"Task '{task}' no reconocida. Disponibles: image_generate, image_variant, image_upscale"}
|
|
|
|
return result
|
|
|
|
|
|
# RunPod serverless
|
|
runpod.serverless.start({"handler": handler})
|