diff --git a/app.py b/app.py index b034bf1..6dbb651 100644 --- a/app.py +++ b/app.py @@ -4,115 +4,108 @@ import base64 from PIL import Image from io import BytesIO import numpy as np -import random import os API_URL = 'https://hub.societyai.com/models/flux-1-schnell/infer' API_TOKEN = os.environ.get("SAI_API_TOKEN", "") MAX_SEED = np.iinfo(np.int32).max -MAX_IMAGE_SIZE = 2048 with gr.Blocks(css="footer {visibility: hidden}") as demo: gr.Markdown("## FLUX.1-schnell Image Generation") - with gr.Row(): - prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here") - with gr.Accordion("Advanced Settings", open=False): - with gr.Row(): - width = gr.Number(value=512, label="Width", maximum=1920) - height = gr.Number(value=512, label="Height", maximum=1080) + # Dropdown menus + avatar = gr.Dropdown(label="Avatar", choices=["Wizard", "Cyborg", "Clown", "Samurai"], value="Wizard") + hair = gr.Dropdown(label="Hair", choices=["Long", "Short", "Mohawk", "Ponytail"], value="Long") + theme = gr.Dropdown(label="Theme", choices=["Cyberpunk", "Fantasy", "Anime", "Dreamscape"], value="Cyberpunk") + color = gr.Dropdown(label="Color", choices=["Pink", "Green", "Blue", "Red"], value="Pink") - with gr.Row(): - num_steps = gr.Number(value=4, label="Number of Steps (1-4)", minimum=1, maximum=4) - guidance_scale = gr.Slider(0.0, 10.0, 0.0, value=7.5, label="Guidance Scale") - - with gr.Row(): - seed = gr.Slider( - label="Seed", - minimum=0, - maximum=MAX_SEED, - step=1, - value=0, - ) - randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + # Fixed parameters + width = 256 + height = 256 + num_steps = 4 + guidance_scale = 7.5 + seed = 123 + randomize_seed = False generate_button = gr.Button("Generate Image") output_image = gr.Image(type="numpy", label="Generated Image") message = gr.Textbox(label="Status", interactive=False) - def generate_image(prompt, width, height, num_steps, guidance_scale, seed, randomize_seed): - try: - # Validation: Ensure width and height are divisible by 16 - if width % 8 != 0 or height % 8 != 0: - return None, "Error: Both width and height must be divisible by 8." - if randomize_seed: - seed = random.randint(0, MAX_SEED) - # Prepare the data payload - inputs = [ + def generate_image(avatar, hair, theme, color): + # Construct the prompt + prompt = f"image of a {avatar} with {hair} hair, in a {theme} style with {color} as the main color" + + # Validation: Ensure width and height are divisible by 8 + if width % 8 != 0 or height % 8 != 0: + return None, "Error: Both width and height must be divisible by 8." + + # Prepare the data payload + inputs = [ + { + "name": "PROMPT", + "shape": [1], + "datatype": "BYTES", + "data": [prompt] + }, + { + "name": "INIT_IMAGE", + "shape": [1], + "datatype": "BYTES", + "data": [""] # not supported + }, + { + "name": "WIDTH", + "shape": [1], + "datatype": "INT32", + "data": [int(width)] + }, + { + "name": "HEIGHT", + "shape": [1], + "datatype": "INT32", + "data": [int(height)] + }, + { + "name": "NUM_STEPS", + "shape": [1], + "datatype": "INT32", + "data": [int(num_steps)] + }, + { + "name": "GUIDANCE_SCALE", + "shape": [1], + "datatype": "FP32", + "data": [float(guidance_scale)] + }, + { + "name": "SEED", + "shape": [1], + "datatype": "INT32", + "data": [int(seed)] + }, + { + "name": "IMAGE_STRENGTH", + "shape": [1], + "datatype": "FP32", + "data": [0.0] # not supported + } + ] + + payload = { + "inputs": inputs, + "outputs": [ { - "name": "PROMPT", - "shape": [1], - "datatype": "BYTES", - "data": [prompt] - }, - { - "name": "INIT_IMAGE", - "shape": [1], - "datatype": "BYTES", - "data": [""] # not supported - }, - { - "name": "WIDTH", - "shape": [1], - "datatype": "INT32", - "data": [int(width)] - }, - { - "name": "HEIGHT", - "shape": [1], - "datatype": "INT32", - "data": [int(height)] - }, - { - "name": "NUM_STEPS", - "shape": [1], - "datatype": "INT32", - "data": [int(num_steps)] - }, - { - "name": "GUIDANCE_SCALE", - "shape": [1], - "datatype": "FP32", - "data": [float(guidance_scale)] - }, - { - "name": "SEED", - "shape": [1], - "datatype": "INT32", - "data": [int(seed)] - }, - { - "name": "IMAGE_STRENGTH", - "shape": [1], - "datatype": "FP32", - "data": [0.0] # not supported + "name": "IMAGE" } ] + } - payload = { - "inputs": inputs, - "outputs": [ - { - "name": "IMAGE" - } - ] - } - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {API_TOKEN}" - } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {API_TOKEN}" + } + try: # Send the POST request response = requests.post(API_URL, headers=headers, json=payload) @@ -125,17 +118,17 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo: # Convert to numpy array image = Image.open(BytesIO(image_data)) image_np = np.array(image) - return image_np, "Image generated successfully.", seed + return image_np, "Image generated successfully." else: # Handle error - return None, f"Error: {response.status_code} - {response.text}", seed + return None, f"Error: {response.status_code} - {response.text}" except Exception as e: - return None, f"Error: {str(e)}", seed + return None, f"Error: {str(e)}" generate_button.click( generate_image, - inputs=[prompt, width, height, num_steps, guidance_scale, seed, randomize_seed], - outputs=[output_image, message, seed] + inputs=[avatar, hair, theme, color], + outputs=[output_image, message] ) if __name__ == "__main__":