Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(integrations): Add support for diffusion pipelines not in the list of supported pipelines #7450

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

soumik12345
Copy link
Contributor

Description

The diffusers integration currently only supports pipelines that are mentioned in the SUPPORTED_MULTIMODAL_PIPELINES. This PR updates the integration to track pipelines that are not in this list, irrespective of whether the pipeline is even part of the diffusers library or not.

Example 1: Tracking the Pixart Sigma pipeline from the official codebase

  1. First clone the library using https://github.com/PixArt-alpha/PixArt-sigma.
  2. Next, rename the repository to PixArt_sigma in order to treat it as a python module.
  3. Install diffusers from source using pip install git+https://github.com/huggingface/diffusers.
  4. Run the code using the autologger:
import torch
from diffusers import Transformer2DModel
from PixArt_sigma.scripts.diffusers_patches import (
    pixart_sigma_init_patched_inputs,
    PixArtSigmaPipeline,
)
from wandb.integration.diffusers import autolog

# We tell the autolog exactly which pipeline to track
pipeline_log_config = (
    dict(
        api_module="PixArt_sigma.scripts.diffusers_patches",
        pipeline=PixArtSigmaPipeline,
        kwarg_logging=["prompt", "negative_prompt"],
    )
    if not autolog.check_pipeline_support(PixArtSigmaPipeline)
    else dict()
)
autolog(init=dict(project="diffusers_logging", job_type="test"), **pipeline_log_config)

assert getattr(
    Transformer2DModel, "_init_patched_inputs", False
), "Need to Upgrade diffusers: pip install git+https://github.com/huggingface/diffusers"
setattr(Transformer2DModel, "_init_patched_inputs", pixart_sigma_init_patched_inputs)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float16

transformer = Transformer2DModel.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    subfolder="transformer",
    torch_dtype=weight_dtype,
    use_safetensors=True,
)
pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
    transformer=transformer,
    torch_dtype=weight_dtype,
    use_safetensors=True,
)
pipe.to(device)

prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]

Sample Run: https://wandb.ai/geekyrakshit/diffusers_logging/runs/31f4mpf1

Example 2: Tracking the StableCascade Pipeline

The StableCascadeCombinedPipeline is currently not part of the SUPPORTED_MULTIMODAL_PIPELINES.

import torch
from diffusers import StableCascadeCombinedPipeline

from wandb.integration.diffusers import autolog


# We tell the autolog exactly which pipeline to track
pipeline_log_config = (
    dict(
        api_module="diffusers",
        pipeline=StableCascadeCombinedPipeline,
        kwarg_logging=["prompt", "negative_prompt"],
    )
    if not autolog.check_pipeline_support(StableCascadeCombinedPipeline)
    else dict()
)
autolog(init=dict(project="diffusers_logging", job_type="test"), **pipeline_log_config)

pipe = StableCascadeCombinedPipeline.from_pretrained(
    "stabilityai/stable-cascade", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
images = pipe(prompt=prompt)

Sample Run: https://wandb.ai/geekyrakshit/diffusers_logging/runs/31f4mpf1

  • I updated CHANGELOG.md, or it's not applicable

Copy link

codecov bot commented Apr 22, 2024

Codecov Report

Attention: Patch coverage is 0% with 152 lines in your changes missing coverage. Please review.

Project coverage is 74.70%. Comparing base (23e3023) to head (a44e045).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7450      +/-   ##
==========================================
+ Coverage   72.86%   74.70%   +1.84%     
==========================================
  Files         492      491       -1     
  Lines       53352    51560    -1792     
==========================================
- Hits        38875    38520     -355     
+ Misses      14017    12580    -1437     
  Partials      460      460              
Flag Coverage Δ
func 44.61% <0.00%> (-0.02%) ⬇️
system 61.98% <0.00%> (-0.01%) ⬇️
unit 57.12% <0.00%> (+1.20%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
wandb/integration/diffusers/utils.py 0.00% <ø> (ø)
wandb/integration/diffusers/__init__.py 0.00% <0.00%> (ø)
wandb/integration/diffusers/autologger.py 0.00% <0.00%> (ø)
...ndb/integration/diffusers/diffusers_autolog_api.py 0.00% <0.00%> (ø)
wandb/integration/diffusers/pipeline_resolver.py 0.00% <0.00%> (ø)

... and 115 files with indirect coverage changes

@soumik12345 soumik12345 requested a review from a team June 10, 2024 13:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant