cyberpunk-image-gen/app.py
Hezi Aharon bf326688c5
Some checks failed
society-ai-hub-container-cache Actions Demo / build (push) Has been cancelled
Update app.py
2024-12-11 15:27:40 +00:00

155 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
import requests
import base64
from PIL import Image
from io import BytesIO
import numpy as np
import random
import os
API_URL = 'https://hub.societyai.com/models/flux-1-schnell/infer'
API_TOKEN = os.environ.get("SAI_API_TOKEN", "")
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
CYBERPUNK_STYLE = (
"Create a highly detailed and visually stunning image in a Cyberpunk theme, "
"regardless of the subject or concept described. Use neon lights, futuristic technology, "
"gritty urban environments, and cybernetic enhancements as core elements of the design. "
"Blend a mix of dystopian aesthetics, vibrant holographic displays, and advanced robotics "
"to capture the essence of Cyberpunk, while adapting to the users specific prompt or description seamlessly. "
"Maintain a bold and immersive atmosphere that conveys a high-tech, neon-lit future."
)
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown("## FLUX.1-schnell Cyberpunk Image Generation")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
width = gr.Number(value=512, label="Width", maximum=1920)
height = gr.Number(value=512, label="Height", maximum=1080)
with gr.Row():
num_steps = gr.Number(value=4, label="Number of Steps (1-4)", minimum=1, maximum=4)
guidance_scale = gr.Slider(0.0, 10.0, 0.0, value=7.5, label="Guidance Scale")
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
generate_button = gr.Button("Generate Image")
output_image = gr.Image(type="numpy", label="Generated Image")
message = gr.Textbox(label="Status", interactive=False)
def generate_image(prompt, width, height, num_steps, guidance_scale, seed, randomize_seed):
try:
# Validation: Ensure width and height are divisible by 16
if width % 8 != 0 or height % 8 != 0:
return None, "Error: Both width and height must be divisible by 8."
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Combine the Cyberpunk style with the user prompt
full_prompt = f"{CYBERPUNK_STYLE} {prompt.strip()}"
# Prepare the data payload
inputs = [
{
"name": "PROMPT",
"shape": [1],
"datatype": "BYTES",
"data": [full_prompt]
},
{
"name": "INIT_IMAGE",
"shape": [1],
"datatype": "BYTES",
"data": [""] # not supported
},
{
"name": "WIDTH",
"shape": [1],
"datatype": "INT32",
"data": [int(width)]
},
{
"name": "HEIGHT",
"shape": [1],
"datatype": "INT32",
"data": [int(height)]
},
{
"name": "NUM_STEPS",
"shape": [1],
"datatype": "INT32",
"data": [int(num_steps)]
},
{
"name": "GUIDANCE_SCALE",
"shape": [1],
"datatype": "FP32",
"data": [float(guidance_scale)]
},
{
"name": "SEED",
"shape": [1],
"datatype": "INT32",
"data": [int(seed)]
},
{
"name": "IMAGE_STRENGTH",
"shape": [1],
"datatype": "FP32",
"data": [0.0] # not supported
}
]
payload = {
"inputs": inputs,
"outputs": [
{
"name": "IMAGE"
}
]
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_TOKEN}"
}
# Send the POST request
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == 200:
# Parse the response
result = response.json()
image_base64 = result['outputs'][0]['data'][0]
# Decode the base64 image data
image_data = base64.b64decode(image_base64)
# Convert to numpy array
image = Image.open(BytesIO(image_data))
image_np = np.array(image)
return image_np, "Cyberpunk Image generated successfully.", seed
else:
# Handle error
return None, f"Error: {response.status_code} - {response.text}", seed
except Exception as e:
return None, f"Error: {str(e)}", seed
generate_button.click(
generate_image,
inputs=[prompt, width, height, num_steps, guidance_scale, seed, randomize_seed],
outputs=[output_image, message, seed]
)
if __name__ == "__main__":
demo.launch()