Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to process in Batches #299

Open
charliemday opened this issue Apr 29, 2024 · 5 comments
Open

Ability to process in Batches #299

charliemday opened this issue Apr 29, 2024 · 5 comments

Comments

@charliemday
Copy link

Is there any ability to batch process similar to how OpenAI do it: https://help.openai.com/en/articles/9197833-batch-api-faq?

@mattt
Copy link
Member

mattt commented Apr 29, 2024

Hi @charliemday. No, Replicate doesn't currently implement a batch processing API like OpenAI. It's something we're considering, though. Can you share more about your intended use case?

@charliemday
Copy link
Author

Hi @mattt, sure.

My intended use case is that I have a CSV file of ~1k rows and I want to to send a request for each row. Given that that the response takes ~1-2 seconds this is going to take ~20 minutes give or take. I would like to batch the rows into groups of 10 and send them all at once taking the time to completion down considerably.

@mattt
Copy link
Member

mattt commented Apr 29, 2024

@charliemday Replicate does support creating up to 6000 concurrent predictions per minute. Depending on how much the model is scaled out, you could process all of them more quickly using our async API. Here's an example from the README that you can adapt (I'd recommend processing rows in async batches of 100 or so, and keeping track of successful and failing rows:

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]

results = await asyncio.gather(*tasks)
print(results)

@charliemday
Copy link
Author

@charliemday Replicate does support creating up to 6000 concurrent predictions per minute. Depending on how much the model is scaled out, you could process all of them more quickly using our async API. Here's an example from the README that you can adapt (I'd recommend processing rows in async batches of 100 or so, and keeping track of successful and failing rows:

import asyncio
import replicate

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [
    f"A chariot pulled by a team of {count} rainbow unicorns"
    for count in ["two", "four", "six", "eight"]
]

async with asyncio.TaskGroup() as tg:
    tasks = [
        tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
        for prompt in prompts
    ]

results = await asyncio.gather(*tasks)
print(results)

Thanks @mattt , this should work for my use case.

Maybe I'm missing something but I don't see 6000 RPM on the link (only 600)?

@mattt
Copy link
Member

mattt commented Apr 29, 2024

@charliemday Apologies, yes — the rate limit is 600 / minute. That was a typo on my part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants