Initial commit: V1

This commit is contained in:
theliu
2026-04-25 14:10:09 +08:00
parent 76b5751518
commit 3fe9b00de7
9 changed files with 305 additions and 279 deletions
+70 -77
View File
@@ -1,28 +1,24 @@
"""
image_gen.py - 统一文生图接口
支持两个模型:
- Kolors(便宜快速)→ SiliconFlow API(同步)
- Qwen-Image(高质量)→ ModelScope API(异步轮询)
image_gen.py - Unified text-to-image interface.
Providers:
- SiliconFlow (Kolors) — sync API
- ModelScope (Qwen-Image) — async polling API
"""
import requests
import os
import time
from datetime import datetime
from config import (
SILICONFLOW_API_KEY,
SILICONFLOW_API_BASE,
MODELSCOPE_API_KEY,
MODELSCOPE_API_BASE,
MODELSCOPE_POLL_INTERVAL,
MODELSCOPE_MAX_WAIT,
IMAGE_MODELS,
NEGATIVE_PROMPT,
)
from config import IMAGE_MODELS, NEGATIVE_PROMPT
def _generate_siliconflow(prompt, model_id, size, guidance, neg, save_dir, filename):
"""SiliconFlow 同步 APIKolors"""
def _generate_siliconflow(prompt, model_id, size, guidance, neg, save_dir, filename, api_key, api_base):
"""SiliconFlow sync API"""
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
payload = {
"model": model_id,
"prompt": prompt,
@@ -33,37 +29,32 @@ def _generate_siliconflow(prompt, model_id, size, guidance, neg, save_dir, filen
"negative_prompt": neg,
}
headers = {
"Authorization": f"Bearer {SILICONFLOW_API_KEY}",
"Content-Type": "application/json",
}
print(f" [SiliconFlow] {prompt[:60]}{'...' if len(prompt) > 60 else ''}")
print(f" [SiliconFlow] 提交: {prompt[:60]}{'...' if len(prompt) > 60 else ''}")
for attempt in range(6): # 最多重试 5 次
resp = requests.post(SILICONFLOW_API_BASE, headers=headers, json=payload, timeout=120)
for attempt in range(6):
resp = requests.post(api_base, headers=headers, json=payload, timeout=120)
print(f" HTTP {resp.status_code}: {resp.text[:300]}")
if resp.status_code == 429:
wait = 15 * (attempt + 1) # 15s, 30s, 45s, 60s, 75s
print(f" [!] 限频,等待 {wait}s 后重试 ({attempt+1}/5)...")
wait = 15 * (attempt + 1)
print(f" [!] Rate limited, waiting {wait}s ({attempt+1}/5)...")
time.sleep(wait)
continue
if resp.status_code != 200:
raise Exception(f"SiliconFlow 生成失败 ({resp.status_code}): {resp.text[:300]}")
raise Exception(f"SiliconFlow error ({resp.status_code}): {resp.text[:300]}")
break
else:
raise Exception("SiliconFlow 持续限频,已重试 5 次,请稍后再试或切换模型")
raise Exception("SiliconFlow rate limit, retried 5 times.")
result = resp.json()
images = result.get("images", [])
if not images:
raise Exception(f"SiliconFlow 返回无图片: {result}")
raise Exception(f"SiliconFlow returned no images: {result}")
img_url = images[0].get("url")
if not img_url:
raise Exception(f"返回图片 URL 为空: {result}")
raise Exception(f"Empty image URL: {result}")
img_data = requests.get(img_url, timeout=60).content
@@ -78,12 +69,12 @@ def _generate_siliconflow(prompt, model_id, size, guidance, neg, save_dir, filen
return {"url": img_url, "filepath": filepath}
def _generate_modelscope(prompt, model_id, size, guidance, neg, save_dir, filename):
"""ModelScope 异步轮询 APIQwen-Image"""
def _generate_modelscope(prompt, model_id, size, guidance, neg, save_dir, filename, api_key, api_base):
"""ModelScope async polling API"""
submit_headers = {
"Authorization": f"Bearer {MODELSCOPE_API_KEY}",
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"X-ModelScope-Async-Mode": "true"
"X-ModelScope-Async-Mode": "true",
}
payload = {
"model": model_id,
@@ -94,31 +85,32 @@ def _generate_modelscope(prompt, model_id, size, guidance, neg, save_dir, filena
"negative_prompt": neg,
}
print(f" [ModelScope] 提交: {prompt[:60]}{'...' if len(prompt) > 60 else ''}")
resp = requests.post(MODELSCOPE_API_BASE, headers=submit_headers, json=payload, timeout=60)
print(f" [ModelScope] {prompt[:60]}{'...' if len(prompt) > 60 else ''}")
resp = requests.post(api_base, headers=submit_headers, json=payload, timeout=60)
if resp.status_code != 200:
raise Exception(f"ModelScope 提交失败 ({resp.status_code}): {resp.text[:300]}")
raise Exception(f"ModelScope submit failed ({resp.status_code}): {resp.text[:300]}")
result = resp.json()
task_id = result.get("task_id")
if not task_id:
raise Exception(f"未找到 task_id: {result}")
raise Exception(f"No task_id: {result}")
print(f" task_id: {task_id}")
# 轮询结果
query_headers = {
"Authorization": f"Bearer {MODELSCOPE_API_KEY}",
"X-ModelScope-Task-Type": "image_generation"
"Authorization": f"Bearer {api_key}",
"X-ModelScope-Task-Type": "image_generation",
}
status_url = f"https://api-inference.modelscope.cn/v1/tasks/{task_id}"
poll_interval = IMAGE_MODELS["Qwen-Image (ModelScope)"].get("poll_interval", 3)
max_wait = IMAGE_MODELS["Qwen-Image (ModelScope)"].get("max_wait", 180)
start = time.time()
for attempt in range(100):
if attempt > 0:
time.sleep(MODELSCOPE_POLL_INTERVAL)
time.sleep(poll_interval)
elapsed = int(time.time() - start)
if elapsed > MODELSCOPE_MAX_WAIT:
raise Exception(f"ModelScope 超时({MODELSCOPE_MAX_WAIT}s")
if elapsed > max_wait:
raise Exception(f"ModelScope timeout ({max_wait}s)")
qresp = requests.get(status_url, headers=query_headers, timeout=30)
if qresp.status_code != 200:
@@ -130,11 +122,13 @@ def _generate_modelscope(prompt, model_id, size, guidance, neg, save_dir, filena
print(f" [{elapsed}s] {task_status}")
if task_status == "SUCCEED":
output_images = (qresult.get("output_images")
or qresult.get("outputs", {}).get("output_images")
or [])
output_images = (
qresult.get("output_images")
or qresult.get("outputs", {}).get("output_images")
or []
)
if not output_images:
raise Exception(f"SUCCEED 但无图片: {qresult}")
raise Exception(f"SUCCEED but no images: {qresult}")
url = output_images[0]
img_data = requests.get(url, timeout=180).content
@@ -149,34 +143,30 @@ def _generate_modelscope(prompt, model_id, size, guidance, neg, save_dir, filena
return {"url": url, "filepath": filepath}
elif task_status == "FAILED":
raise Exception(f"ModelScope 任务失败: {qresult.get('errors', qresult)}")
raise Exception(f"ModelScope task failed: {qresult.get('errors', qresult)}")
raise Exception(f"ModelScope 超时({MODELSCOPE_MAX_WAIT}s")
raise Exception(f"ModelScope timeout ({max_wait}s)")
def image_generate(
prompt: str,
save_dir: str = "./generated_images",
model_name: str = None,
n: int = 1,
seed: int = None,
num_inference_steps: int = 20,
guidance_scale: float = None,
negative_prompt: str = None,
filename: str = None,
image_size: str = None,
guidance_scale: float = None,
negative_prompt: str = None,
) -> dict:
"""
统一文生图接口
"""Unified text-to-image interface.
Args:
prompt: 生成提示词
save_dir: 保存目录
model_name: 模型名称(IMAGE_MODELS 的 key),默认用 config 中的 DEFAULT_IMAGE_MODEL
image_size: 图片尺寸,默认 1280x72016:9
prompt: generation prompt
save_dir: output directory
model_name: model name (key in IMAGE_MODELS), None = default
filename: output filename, None = auto
image_size: image size, None = model default
Returns:
dict: {"url": str, "filepath": str}
{"url": str, "filepath": str}
"""
from config import DEFAULT_IMAGE_MODEL
@@ -185,31 +175,34 @@ def image_generate(
model_config = IMAGE_MODELS.get(model_name)
if not model_config:
raise ValueError(f"未知模型: {model_name},可选: {list(IMAGE_MODELS.keys())}")
raise ValueError(f"Unknown model: {model_name}, available: {list(IMAGE_MODELS.keys())}")
api_key = model_config.get("api_key", "")
if not api_key:
raise ValueError(
f"API key not configured for '{model_name}'. "
f"Edit config.py and fill in the api_key field."
)
model_id = model_config["model"]
size = image_size or model_config["default_size"]
guidance = guidance_scale if guidance_scale is not None else model_config["guidance_scale"]
neg = negative_prompt or NEGATIVE_PROMPT
provider = model_config["provider"]
api_base = model_config.get("api_base", "")
os.makedirs(save_dir, exist_ok=True)
provider = model_config["provider"]
if provider == "siliconflow":
return _generate_siliconflow(prompt, model_id, size, guidance, neg, save_dir, filename)
return _generate_siliconflow(prompt, model_id, size, guidance, neg, save_dir, filename, api_key, api_base)
elif provider == "modelscope":
return _generate_modelscope(prompt, model_id, size, guidance, neg, save_dir, filename)
return _generate_modelscope(prompt, model_id, size, guidance, neg, save_dir, filename, api_key, api_base)
else:
raise ValueError(f"未知 provider: {provider}")
def get_available_models() -> list[str]:
"""返回可用的文生图模型名称列表"""
return list(IMAGE_MODELS.keys())
raise ValueError(f"Unknown provider: {provider}")
if __name__ == "__main__":
for name in get_available_models():
print(f"\n测试模型: {name}")
result = image_generate("A cute cat sitting on a desk, 16:9 aspect ratio", model_name=name)
print(f" 路径: {result['filepath']}")
for name in list(IMAGE_MODELS.keys()):
print(f"\nTesting: {name}")
result = image_generate("A cute cat sitting on a desk, 16:9", model_name=name)
print(f" Path: {result['filepath']}")