Update app.py
All checks were successful
society-ai-hub-container-cache Actions Demo / build (push) Successful in 26s

This commit is contained in:
Hezi Aharon 2024-12-13 17:06:09 +00:00
parent d6982b0d9c
commit a828c1f0cc

181
app.py

@ -4,115 +4,108 @@ import base64
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
import random
import os import os
API_URL = 'https://hub.societyai.com/models/flux-1-schnell/infer' API_URL = 'https://hub.societyai.com/models/flux-1-schnell/infer'
API_TOKEN = os.environ.get("SAI_API_TOKEN", "") API_TOKEN = os.environ.get("SAI_API_TOKEN", "")
MAX_SEED = np.iinfo(np.int32).max MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
with gr.Blocks(css="footer {visibility: hidden}") as demo: with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown("## FLUX.1-schnell Image Generation") gr.Markdown("## FLUX.1-schnell Image Generation")
with gr.Row(): # Dropdown menus
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here") avatar = gr.Dropdown(label="Avatar", choices=["Wizard", "Cyborg", "Clown", "Samurai"], value="Wizard")
with gr.Accordion("Advanced Settings", open=False): hair = gr.Dropdown(label="Hair", choices=["Long", "Short", "Mohawk", "Ponytail"], value="Long")
with gr.Row(): theme = gr.Dropdown(label="Theme", choices=["Cyberpunk", "Fantasy", "Anime", "Dreamscape"], value="Cyberpunk")
width = gr.Number(value=512, label="Width", maximum=1920) color = gr.Dropdown(label="Color", choices=["Pink", "Green", "Blue", "Red"], value="Pink")
height = gr.Number(value=512, label="Height", maximum=1080)
with gr.Row(): # Fixed parameters
num_steps = gr.Number(value=4, label="Number of Steps (1-4)", minimum=1, maximum=4) width = 256
guidance_scale = gr.Slider(0.0, 10.0, 0.0, value=7.5, label="Guidance Scale") height = 256
num_steps = 4
with gr.Row(): guidance_scale = 7.5
seed = gr.Slider( seed = 123
label="Seed", randomize_seed = False
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
generate_button = gr.Button("Generate Image") generate_button = gr.Button("Generate Image")
output_image = gr.Image(type="numpy", label="Generated Image") output_image = gr.Image(type="numpy", label="Generated Image")
message = gr.Textbox(label="Status", interactive=False) message = gr.Textbox(label="Status", interactive=False)
def generate_image(prompt, width, height, num_steps, guidance_scale, seed, randomize_seed): def generate_image(avatar, hair, theme, color):
try: # Construct the prompt
# Validation: Ensure width and height are divisible by 16 prompt = f"image of a {avatar} with {hair} hair, in a {theme} style with {color} as the main color"
if width % 8 != 0 or height % 8 != 0:
return None, "Error: Both width and height must be divisible by 8." # Validation: Ensure width and height are divisible by 8
if randomize_seed: if width % 8 != 0 or height % 8 != 0:
seed = random.randint(0, MAX_SEED) return None, "Error: Both width and height must be divisible by 8."
# Prepare the data payload
inputs = [ # Prepare the data payload
inputs = [
{
"name": "PROMPT",
"shape": [1],
"datatype": "BYTES",
"data": [prompt]
},
{
"name": "INIT_IMAGE",
"shape": [1],
"datatype": "BYTES",
"data": [""] # not supported
},
{
"name": "WIDTH",
"shape": [1],
"datatype": "INT32",
"data": [int(width)]
},
{
"name": "HEIGHT",
"shape": [1],
"datatype": "INT32",
"data": [int(height)]
},
{
"name": "NUM_STEPS",
"shape": [1],
"datatype": "INT32",
"data": [int(num_steps)]
},
{
"name": "GUIDANCE_SCALE",
"shape": [1],
"datatype": "FP32",
"data": [float(guidance_scale)]
},
{
"name": "SEED",
"shape": [1],
"datatype": "INT32",
"data": [int(seed)]
},
{
"name": "IMAGE_STRENGTH",
"shape": [1],
"datatype": "FP32",
"data": [0.0] # not supported
}
]
payload = {
"inputs": inputs,
"outputs": [
{ {
"name": "PROMPT", "name": "IMAGE"
"shape": [1],
"datatype": "BYTES",
"data": [prompt]
},
{
"name": "INIT_IMAGE",
"shape": [1],
"datatype": "BYTES",
"data": [""] # not supported
},
{
"name": "WIDTH",
"shape": [1],
"datatype": "INT32",
"data": [int(width)]
},
{
"name": "HEIGHT",
"shape": [1],
"datatype": "INT32",
"data": [int(height)]
},
{
"name": "NUM_STEPS",
"shape": [1],
"datatype": "INT32",
"data": [int(num_steps)]
},
{
"name": "GUIDANCE_SCALE",
"shape": [1],
"datatype": "FP32",
"data": [float(guidance_scale)]
},
{
"name": "SEED",
"shape": [1],
"datatype": "INT32",
"data": [int(seed)]
},
{
"name": "IMAGE_STRENGTH",
"shape": [1],
"datatype": "FP32",
"data": [0.0] # not supported
} }
] ]
}
payload = { headers = {
"inputs": inputs, "Content-Type": "application/json",
"outputs": [ "Authorization": f"Bearer {API_TOKEN}"
{ }
"name": "IMAGE"
}
]
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_TOKEN}"
}
try:
# Send the POST request # Send the POST request
response = requests.post(API_URL, headers=headers, json=payload) response = requests.post(API_URL, headers=headers, json=payload)
@ -125,17 +118,17 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
# Convert to numpy array # Convert to numpy array
image = Image.open(BytesIO(image_data)) image = Image.open(BytesIO(image_data))
image_np = np.array(image) image_np = np.array(image)
return image_np, "Image generated successfully.", seed return image_np, "Image generated successfully."
else: else:
# Handle error # Handle error
return None, f"Error: {response.status_code} - {response.text}", seed return None, f"Error: {response.status_code} - {response.text}"
except Exception as e: except Exception as e:
return None, f"Error: {str(e)}", seed return None, f"Error: {str(e)}"
generate_button.click( generate_button.click(
generate_image, generate_image,
inputs=[prompt, width, height, num_steps, guidance_scale, seed, randomize_seed], inputs=[avatar, hair, theme, color],
outputs=[output_image, message, seed] outputs=[output_image, message]
) )
if __name__ == "__main__": if __name__ == "__main__":