simple-image-gen/app.py

143 lines
4.7 KiB
Python
Raw Normal View History

2024-11-11 16:57:23 +07:00
import gradio as gr
import requests
import base64
from PIL import Image
from io import BytesIO
import numpy as np
import random
2024-12-11 08:34:53 +00:00
import os
2024-11-11 16:57:23 +07:00
2024-12-11 08:54:38 +00:00
API_URL = 'https://hub.societyai.com/models/flux-1-schnell/infer'
2024-12-11 08:34:53 +00:00
API_TOKEN = os.environ.get("SAI_API_TOKEN", "")
2024-11-11 16:57:23 +07:00
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
2024-12-30 14:16:59 +00:00
with gr.Blocks() as demo:
2024-11-11 16:57:23 +07:00
gr.Markdown("## FLUX.1-schnell Image Generation")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width = gr.Number(value=512, label="Width", maximum=1920)
height = gr.Number(value=512, label="Height", maximum=1080)
with gr.Row():
num_steps = gr.Number(value=4, label="Number of Steps (1-4)", minimum=1, maximum=4)
guidance_scale = gr.Slider(0.0, 10.0, 0.0, value=7.5, label="Guidance Scale")
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
generate_button = gr.Button("Generate Image")
output_image = gr.Image(type="numpy", label="Generated Image")
message = gr.Textbox(label="Status", interactive=False)
def generate_image(prompt, width, height, num_steps, guidance_scale, seed, randomize_seed):
try:
# Validation: Ensure width and height are divisible by 16
if width % 8 != 0 or height % 8 != 0:
return None, "Error: Both width and height must be divisible by 8."
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Prepare the data payload
inputs = [
{
"name": "PROMPT",
"shape": [1],
"datatype": "BYTES",
"data": [prompt]
},
{
"name": "INIT_IMAGE",
"shape": [1],
"datatype": "BYTES",
"data": [""] # not supported
},
{
"name": "WIDTH",
"shape": [1],
"datatype": "INT32",
"data": [int(width)]
},
{
"name": "HEIGHT",
"shape": [1],
"datatype": "INT32",
"data": [int(height)]
},
{
"name": "NUM_STEPS",
"shape": [1],
"datatype": "INT32",
"data": [int(num_steps)]
},
{
"name": "GUIDANCE_SCALE",
"shape": [1],
"datatype": "FP32",
"data": [float(guidance_scale)]
},
{
"name": "SEED",
"shape": [1],
"datatype": "INT32",
"data": [int(seed)]
},
{
"name": "IMAGE_STRENGTH",
"shape": [1],
"datatype": "FP32",
"data": [0.0] # not supported
}
]
payload = {
"inputs": inputs,
"outputs": [
{
"name": "IMAGE"
}
]
}
headers = {
2024-12-11 08:34:53 +00:00
"Content-Type": "application/json",
"Authorization": f"Bearer {API_TOKEN}"
2024-11-11 16:57:23 +07:00
}
# Send the POST request
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == 200:
# Parse the response
result = response.json()
image_base64 = result['outputs'][0]['data'][0]
# Decode the base64 image data
image_data = base64.b64decode(image_base64)
# Convert to numpy array
image = Image.open(BytesIO(image_data))
image_np = np.array(image)
return image_np, "Image generated successfully.", seed
else:
# Handle error
return None, f"Error: {response.status_code} - {response.text}", seed
except Exception as e:
return None, f"Error: {str(e)}", seed
generate_button.click(
generate_image,
inputs=[prompt, width, height, num_steps, guidance_scale, seed, randomize_seed],
outputs=[output_image, message, seed]
)
if __name__ == "__main__":
demo.launch()