generated from society-ai/simple-image-gen
Hezi Aharon
bf326688c5
Some checks failed
society-ai-hub-container-cache Actions Demo / build (push) Has been cancelled
155 lines
5.5 KiB
Python
155 lines
5.5 KiB
Python
import gradio as gr
|
||
import requests
|
||
import base64
|
||
from PIL import Image
|
||
from io import BytesIO
|
||
import numpy as np
|
||
import random
|
||
import os
|
||
|
||
API_URL = 'https://hub.societyai.com/models/flux-1-schnell/infer'
|
||
API_TOKEN = os.environ.get("SAI_API_TOKEN", "")
|
||
MAX_SEED = np.iinfo(np.int32).max
|
||
MAX_IMAGE_SIZE = 2048
|
||
CYBERPUNK_STYLE = (
|
||
"Create a highly detailed and visually stunning image in a Cyberpunk theme, "
|
||
"regardless of the subject or concept described. Use neon lights, futuristic technology, "
|
||
"gritty urban environments, and cybernetic enhancements as core elements of the design. "
|
||
"Blend a mix of dystopian aesthetics, vibrant holographic displays, and advanced robotics "
|
||
"to capture the essence of Cyberpunk, while adapting to the user’s specific prompt or description seamlessly. "
|
||
"Maintain a bold and immersive atmosphere that conveys a high-tech, neon-lit future."
|
||
)
|
||
|
||
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
||
gr.Markdown("## FLUX.1-schnell Cyberpunk Image Generation")
|
||
|
||
with gr.Row():
|
||
prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Enter your prompt here")
|
||
with gr.Accordion("Advanced Settings", open=False):
|
||
with gr.Row():
|
||
width = gr.Number(value=512, label="Width", maximum=1920)
|
||
height = gr.Number(value=512, label="Height", maximum=1080)
|
||
|
||
with gr.Row():
|
||
num_steps = gr.Number(value=4, label="Number of Steps (1-4)", minimum=1, maximum=4)
|
||
guidance_scale = gr.Slider(0.0, 10.0, 0.0, value=7.5, label="Guidance Scale")
|
||
|
||
with gr.Row():
|
||
seed = gr.Slider(
|
||
label="Seed",
|
||
minimum=0,
|
||
maximum=MAX_SEED,
|
||
step=1,
|
||
value=0,
|
||
)
|
||
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
||
|
||
generate_button = gr.Button("Generate Image")
|
||
output_image = gr.Image(type="numpy", label="Generated Image")
|
||
message = gr.Textbox(label="Status", interactive=False)
|
||
|
||
def generate_image(prompt, width, height, num_steps, guidance_scale, seed, randomize_seed):
|
||
try:
|
||
# Validation: Ensure width and height are divisible by 16
|
||
if width % 8 != 0 or height % 8 != 0:
|
||
return None, "Error: Both width and height must be divisible by 8."
|
||
if randomize_seed:
|
||
seed = random.randint(0, MAX_SEED)
|
||
|
||
# Combine the Cyberpunk style with the user prompt
|
||
full_prompt = f"{CYBERPUNK_STYLE} {prompt.strip()}"
|
||
|
||
# Prepare the data payload
|
||
inputs = [
|
||
{
|
||
"name": "PROMPT",
|
||
"shape": [1],
|
||
"datatype": "BYTES",
|
||
"data": [full_prompt]
|
||
},
|
||
{
|
||
"name": "INIT_IMAGE",
|
||
"shape": [1],
|
||
"datatype": "BYTES",
|
||
"data": [""] # not supported
|
||
},
|
||
{
|
||
"name": "WIDTH",
|
||
"shape": [1],
|
||
"datatype": "INT32",
|
||
"data": [int(width)]
|
||
},
|
||
{
|
||
"name": "HEIGHT",
|
||
"shape": [1],
|
||
"datatype": "INT32",
|
||
"data": [int(height)]
|
||
},
|
||
{
|
||
"name": "NUM_STEPS",
|
||
"shape": [1],
|
||
"datatype": "INT32",
|
||
"data": [int(num_steps)]
|
||
},
|
||
{
|
||
"name": "GUIDANCE_SCALE",
|
||
"shape": [1],
|
||
"datatype": "FP32",
|
||
"data": [float(guidance_scale)]
|
||
},
|
||
{
|
||
"name": "SEED",
|
||
"shape": [1],
|
||
"datatype": "INT32",
|
||
"data": [int(seed)]
|
||
},
|
||
{
|
||
"name": "IMAGE_STRENGTH",
|
||
"shape": [1],
|
||
"datatype": "FP32",
|
||
"data": [0.0] # not supported
|
||
}
|
||
]
|
||
|
||
payload = {
|
||
"inputs": inputs,
|
||
"outputs": [
|
||
{
|
||
"name": "IMAGE"
|
||
}
|
||
]
|
||
}
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {API_TOKEN}"
|
||
}
|
||
|
||
# Send the POST request
|
||
response = requests.post(API_URL, headers=headers, json=payload)
|
||
|
||
if response.status_code == 200:
|
||
# Parse the response
|
||
result = response.json()
|
||
image_base64 = result['outputs'][0]['data'][0]
|
||
# Decode the base64 image data
|
||
image_data = base64.b64decode(image_base64)
|
||
# Convert to numpy array
|
||
image = Image.open(BytesIO(image_data))
|
||
image_np = np.array(image)
|
||
return image_np, "Cyberpunk Image generated successfully.", seed
|
||
else:
|
||
# Handle error
|
||
return None, f"Error: {response.status_code} - {response.text}", seed
|
||
except Exception as e:
|
||
return None, f"Error: {str(e)}", seed
|
||
|
||
generate_button.click(
|
||
generate_image,
|
||
inputs=[prompt, width, height, num_steps, guidance_scale, seed, randomize_seed],
|
||
outputs=[output_image, message, seed]
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
demo.launch()
|