2024-12-13 17:02:40 +00:00
|
|
|
import gradio as gr
|
|
|
|
import requests
|
|
|
|
import base64
|
|
|
|
from PIL import Image
|
|
|
|
from io import BytesIO
|
|
|
|
import numpy as np
|
|
|
|
import os
|
2024-12-13 17:14:13 +00:00
|
|
|
import random
|
2024-12-13 17:02:40 +00:00
|
|
|
|
|
|
|
API_URL = 'https://hub.societyai.com/models/flux-1-schnell/infer'
|
|
|
|
API_TOKEN = os.environ.get("SAI_API_TOKEN", "")
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
|
|
|
|
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
|
|
|
gr.Markdown("## FLUX.1-schnell Image Generation")
|
|
|
|
|
2024-12-13 17:06:09 +00:00
|
|
|
# Dropdown menus
|
2024-12-13 17:17:08 +00:00
|
|
|
gender = gr.Dropdown(label="Gender", choices=["Male", "Female"], value="Male")
|
2024-12-13 17:06:09 +00:00
|
|
|
avatar = gr.Dropdown(label="Avatar", choices=["Wizard", "Cyborg", "Clown", "Samurai"], value="Wizard")
|
|
|
|
hair = gr.Dropdown(label="Hair", choices=["Long", "Short", "Mohawk", "Ponytail"], value="Long")
|
|
|
|
theme = gr.Dropdown(label="Theme", choices=["Cyberpunk", "Fantasy", "Anime", "Dreamscape"], value="Cyberpunk")
|
|
|
|
color = gr.Dropdown(label="Color", choices=["Pink", "Green", "Blue", "Red"], value="Pink")
|
2024-12-13 17:02:40 +00:00
|
|
|
|
2024-12-13 17:14:13 +00:00
|
|
|
# Checkbox for randomize seed
|
|
|
|
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
|
|
|
|
|
2024-12-13 17:17:08 +00:00
|
|
|
# Adjusted parameters
|
|
|
|
width = 512
|
|
|
|
height = 512
|
2024-12-13 17:06:09 +00:00
|
|
|
num_steps = 4
|
|
|
|
guidance_scale = 7.5
|
2024-12-13 17:14:13 +00:00
|
|
|
fixed_seed = 123
|
2024-12-13 17:02:40 +00:00
|
|
|
|
|
|
|
generate_button = gr.Button("Generate Image")
|
|
|
|
output_image = gr.Image(type="numpy", label="Generated Image")
|
|
|
|
message = gr.Textbox(label="Status", interactive=False)
|
|
|
|
|
2024-12-13 17:17:08 +00:00
|
|
|
def generate_image(gender, avatar, hair, theme, color, randomize_seed):
|
2024-12-13 17:06:09 +00:00
|
|
|
# Construct the prompt
|
2024-12-13 17:17:08 +00:00
|
|
|
prompt = f"image of a {gender} {avatar} with {hair} hair, in a {theme} style with {color} as the main color"
|
2024-12-13 17:06:09 +00:00
|
|
|
|
2024-12-13 17:14:13 +00:00
|
|
|
# Seed logic
|
|
|
|
seed = random.randint(0, MAX_SEED) if randomize_seed else fixed_seed
|
|
|
|
|
2024-12-13 17:06:09 +00:00
|
|
|
# Validation: Ensure width and height are divisible by 8
|
|
|
|
if width % 8 != 0 or height % 8 != 0:
|
|
|
|
return None, "Error: Both width and height must be divisible by 8."
|
|
|
|
|
|
|
|
# Prepare the data payload
|
|
|
|
inputs = [
|
|
|
|
{
|
|
|
|
"name": "PROMPT",
|
|
|
|
"shape": [1],
|
|
|
|
"datatype": "BYTES",
|
|
|
|
"data": [prompt]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "INIT_IMAGE",
|
|
|
|
"shape": [1],
|
|
|
|
"datatype": "BYTES",
|
|
|
|
"data": [""] # not supported
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "WIDTH",
|
|
|
|
"shape": [1],
|
|
|
|
"datatype": "INT32",
|
|
|
|
"data": [int(width)]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "HEIGHT",
|
|
|
|
"shape": [1],
|
|
|
|
"datatype": "INT32",
|
|
|
|
"data": [int(height)]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "NUM_STEPS",
|
|
|
|
"shape": [1],
|
|
|
|
"datatype": "INT32",
|
|
|
|
"data": [int(num_steps)]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "GUIDANCE_SCALE",
|
|
|
|
"shape": [1],
|
|
|
|
"datatype": "FP32",
|
|
|
|
"data": [float(guidance_scale)]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "SEED",
|
|
|
|
"shape": [1],
|
|
|
|
"datatype": "INT32",
|
|
|
|
"data": [int(seed)]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"name": "IMAGE_STRENGTH",
|
|
|
|
"shape": [1],
|
|
|
|
"datatype": "FP32",
|
|
|
|
"data": [0.0] # not supported
|
|
|
|
}
|
|
|
|
]
|
|
|
|
|
|
|
|
payload = {
|
|
|
|
"inputs": inputs,
|
|
|
|
"outputs": [
|
2024-12-13 17:02:40 +00:00
|
|
|
{
|
2024-12-13 17:06:09 +00:00
|
|
|
"name": "IMAGE"
|
2024-12-13 17:02:40 +00:00
|
|
|
}
|
|
|
|
]
|
2024-12-13 17:06:09 +00:00
|
|
|
}
|
2024-12-13 17:02:40 +00:00
|
|
|
|
2024-12-13 17:06:09 +00:00
|
|
|
headers = {
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
"Authorization": f"Bearer {API_TOKEN}"
|
|
|
|
}
|
2024-12-13 17:02:40 +00:00
|
|
|
|
2024-12-13 17:06:09 +00:00
|
|
|
try:
|
2024-12-13 17:02:40 +00:00
|
|
|
# Send the POST request
|
|
|
|
response = requests.post(API_URL, headers=headers, json=payload)
|
|
|
|
|
|
|
|
if response.status_code == 200:
|
|
|
|
# Parse the response
|
|
|
|
result = response.json()
|
|
|
|
image_base64 = result['outputs'][0]['data'][0]
|
|
|
|
# Decode the base64 image data
|
|
|
|
image_data = base64.b64decode(image_base64)
|
|
|
|
# Convert to numpy array
|
|
|
|
image = Image.open(BytesIO(image_data))
|
|
|
|
image_np = np.array(image)
|
2024-12-13 17:06:09 +00:00
|
|
|
return image_np, "Image generated successfully."
|
2024-12-13 17:02:40 +00:00
|
|
|
else:
|
|
|
|
# Handle error
|
2024-12-13 17:06:09 +00:00
|
|
|
return None, f"Error: {response.status_code} - {response.text}"
|
2024-12-13 17:02:40 +00:00
|
|
|
except Exception as e:
|
2024-12-13 17:06:09 +00:00
|
|
|
return None, f"Error: {str(e)}"
|
2024-12-13 17:02:40 +00:00
|
|
|
|
|
|
|
generate_button.click(
|
|
|
|
generate_image,
|
2024-12-13 17:17:08 +00:00
|
|
|
inputs=[gender, avatar, hair, theme, color, randomize_seed],
|
2024-12-13 17:06:09 +00:00
|
|
|
outputs=[output_image, message]
|
2024-12-13 17:02:40 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
demo.launch()
|