프로젝트/당일

[AI] Stable Diffusion + LoRA를 통한 이미지 생성 기능 구현

cks._.hong 2024. 5. 29. 08:51

Stable Diffusion + LoRA를 통한 이미지 생성 기능 구현 왜 궁금했을까❓

이번 포스팅에서는 Stable Diffusion과 LoRA를 활용하여 4가지 화풍의 그림을 뽑아내는 기능을 구현할 것이다.

 

 

[AI] Stable Diffusion fine-tuning(LoRA)

Stable Diffusion fine-tuning(LoRA) 왜 궁금했을까❓"당일" 서비스는 사용자가 입력한 일기를 기반으로 4가지 화풍을 가진 대표 이미지를 생성해준다. 이를 위해서는 Fine-Tuning과정이 필요한데 Stable Diffusio

pslog.co.kr

위 포스팅을 통해 LoRA에 대한 개념을 알 수 있다.

1. CheckPoint 및 LoRA 모델 Load

Stable Diffusion을 사용하기 위해서는 초기 모델 CheckPoint 1개와 4가지 화풍을 생성할 LoRA 모델 4개를 로드해야 한다.
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline

def load_sdxl():
    print("=================== Checkpoint loaded start ===================")
    base_dir = app.utils.global_vars.base_dir
    base = StableDiffusionXLPipeline.from_single_file(base_dir + "checkpoint/stable-diffusion-xl-base-1.0.safetensors", torch_dtype=torch.float16, variant="fp16").to("cuda")
    base.scheduler = DPMSolverMultistepScheduler.from_config(base.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")
    print("=================== Checkpoint loaded end ===================")
    
    print("=================== lora loaded start ===================")
    base.load_lora_weights(base_dir + "lora/detail.safetensors", adapter_name="detail")
    base.load_lora_weights(base_dir + "lora/children.safetensors", adapter_name="child")
    base.load_lora_weights(base_dir + "lora/animation.safetensors", adapter_name="animate")
    base.load_lora_weights(base_dir + "lora/pixel.safetensors", adapter_name="pixel")
    
    ...
 

stabilityai/stable-diffusion-xl-base-1.0 · Hugging Face

SD-XL 1.0-base Model Card Model SDXL consists of an ensemble of experts pipeline for latent diffusion: In a first step, the base model is used to generate (noisy) latents, which are then further processed with a refinement model (available here: https://hu

huggingface.co

  • 고품질의 이미지를 뽑기 위해 CheckPoint의 경우 SDXL 1.0 버전을 선택했고 위 링크에서 다운받아 서버에 저장해놨다.
  • FastAPI 서버를 시작할 때 해당 모델을 로드하도록 코드를 작성해놨다.
  • 또한, Scheduler의 경우 Denoise의 방식을 정하는 것인데 낮은 Step에서 높은 품질을 뽑을 수 있도록 DPM++ SDE Karras를 선택했다.
 

Civitai: The Home of Open-Source Generative AI

Explore thousands of high-quality Stable Diffusion models, share your AI-generated art, and engage with a vibrant community of creators

civitai.com

  • LoRA의 경우 위 링크를 통해 구할 수 있었는데, 팀 회의를 통해 상의한 결과 "당일" 서비스는 실사, 아이들이 그린 그림, 애니메이션, 픽셀로 4가지 화풍을 선정했다. (요즘 트렌드와 대중성을 고려해서 선정하였다)

 

2. 4가지 화풍의 이미지 생성

def create_image(emotion, prompt):
    # 모델 설정
    base = app.utils.global_vars.base
    # 이미지 타입
    types = ["childrens_book_illustration, ", "Aardman Animations Style page, ", "pixel art, 64 bit, ", "realistic, "]
    # Lora 타입
    lora_types = ["child", "animate", "pixel", "detail"]
    # 추론 횟수
    n_steps = 40
    
    # 생성된 이미지 저장
    images = []

    # 이미지 생성
    for i in range(4):
        base.set_adapters(lora_types[i])
        image = base(
            prompt=types[i] + emotion + prompt,
            num_inference_steps=n_steps,
        ).images[0]
        images.append(image)
  • 1번 섹션에서 로드한 CheckPoint를 다시 불러와 이미지를 생성할 모델을 설정한다.
  • types 배열을 생성했는데, LoRA를 작동시키기 위해서는 특정 단어가 프롬프트에 포함되어 있어야 한다. 그리고 lora_types 배열의 경우 4가지 화풍을 1번 섹션에서 로드했는데 이를 지칭하는 adpater_name들이 담겨있는 것이다.
  • 추론 횟수의 경우 너무 적어도, 너무 많아도 이미지가 저품질로 출력될 수 있기에 테스트를 통해 40 Step이 가장 적절하다는 것을 알았다.
  • 이미지 생성 부분을 보면 ChekcPoint에 set_adapters를 통하여 LoRA를 변경해주는 것을 알 수 있다. 이를 통해 4가지 화풍을 가진 사진을 만들 수 있는 것이다.
  • 당일 서비스는 일기를 작성하기전에 사용자가 느낀 하루의 감정을 선택할 수 있는데 해당 감정은 emotion 변수에 담길 것이고 prompt는 사용자가 입력한 일기가 Chat GPT에 의해 영어로 번역되어 담겨져 있을 것이다.

3. 이미지 생성 결과

  • 위와 같이 4가지 화풍의 그림을 사용자에게 보여주면 본인의 일기와 가장 근접하거나 선호하는 화풍을 골라 다이어리에 대표 사진으로 저장될 것이다.