Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_from_name 加入 flash-attn 支持 #312

Open
ZechengLi19 opened this issue May 11, 2024 · 4 comments
Open

load_from_name 加入 flash-attn 支持 #312

ZechengLi19 opened this issue May 11, 2024 · 4 comments

Comments

@ZechengLi19
Copy link

感谢你如此好的代码实现,他对我的帮助很大,但是我在使用load_from_name 函数时,我发现并不支持flash-attn ,因此我自己实现了这一块的代码,但是我不确定实现是否正确,尽管它可以正常运行。

以下是代码片段

###### ------- ps: add use_flash_attention keyword ------- ######
def load_from_name(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
                   download_root: str = None, vision_model_name: str = None, text_model_name: str = None, 
                   input_resolution: int = None, use_flash_attention: bool = False):
    if name in _MODELS:
        model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
        model_name, model_input_resolution = _MODEL_INFO[name]['struct'], _MODEL_INFO[name]['input_resolution']
    elif os.path.isfile(name):
        assert vision_model_name and text_model_name and input_resolution, "Please specify specific 'vision_model_name', 'text_model_name', and 'input_resolution'"
        model_path = name
        model_name, model_input_resolution = f'{vision_model_name}@{text_model_name}', input_resolution
    else:
        raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

    with open(model_path, 'rb') as opened_file:
        # loading saved checkpoint
        checkpoint = torch.load(opened_file, map_location="cpu")

    model = create_model(model_name, checkpoint, use_flash_attention=use_flash_attention)
    if str(device) == "cpu":
        model.float()
    else:
        model.to(device)
    return model, image_transform(model_input_resolution)
###### ------- ps: convert flash_attention weight ------- ######
def create_model(model_name, checkpoint=None, use_flash_attention=False):
    vision_model, text_model = model_name.split('@')
    # Initialize the model.
    vision_model_config_file = Path(
        __file__).parent / f"model_configs/{vision_model.replace('/', '-')}.json"
    print('Loading vision model config from', vision_model_config_file)
    assert os.path.exists(vision_model_config_file)

    text_model_config_file = Path(
        __file__).parent / f"model_configs/{text_model.replace('/', '-')}.json"
    print('Loading text model config from', text_model_config_file)
    assert os.path.exists(text_model_config_file)

    with open(vision_model_config_file, 'r') as fv, open(text_model_config_file, 'r') as ft:
        model_info = json.load(fv)
        for k, v in json.load(ft).items():
            model_info[k] = v
    if isinstance(model_info['vision_layers'], str):
        model_info['vision_layers'] = eval(model_info['vision_layers'])
    print('Model info', model_info)
    if use_flash_attention:
        model_info['use_flash_attention'] = use_flash_attention
    model = CLIP(**model_info)
    convert_weights(model)
            
    if checkpoint:
        if use_flash_attention:
            sd = checkpoint["state_dict"]
            sd = {k: v for k, v in sd.items() if "bert.pooler" not in k}
            if next(iter(sd.items()))[0].startswith('module'):
                sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
            # Resize the positional embedding by interpolation, if needed
            resize_pos_embed(sd, model, prefix="module.")
            # Adapt flash attention
            sd = convert_state_dict(sd)
            # Load the state dict
        else:
            sd = checkpoint["state_dict"]
            if next(iter(sd.items()))[0].startswith('module'):
                sd = {k[len('module.'):]: v for k, v in sd.items() if "bert.pooler" not in k}
        model.load_state_dict(sd)
    return model

如果作者有空能帮我检查一下,这一实现是否正确就好了~

如果是正确的,作者可以将我的implement加入到仓库中~

不甚感谢

@DtYXs
Copy link
Collaborator

DtYXs commented May 22, 2024

您好,目前在启动flash-attn训练时,保存的ckpt格式与不启动是完全一致的。因此用flash-attn训练得到的ckpt应该是直接可以load进来的,您可以先尝试一下。

"state_dict": model.state_dict() if not args.use_flash_attention else convert_state_dict(model.state_dict()),

@ZechengLi19
Copy link
Author

@DtYXs 感谢您的回复,但是您好像误解了我的意思。

我想做的事情是,在我自己写的代码段中,直接调用load_from_name函数得到模型,并且该模型具有直接切换为flash-attn模式的功能。但是目前的load_from_name这个方法并没有提供flash-attn的选项~

@DtYXs
Copy link
Collaborator

DtYXs commented May 23, 2024

@ZechengLi19 我明白你的意思~我理解目前代码中定义的flash-attn格式只适用Chinese-CLIP这一个项目,而Chinese-CLIP训练得到的模型会自动将flash-attn模型转化为正常模式,所以我想知道目前是在什么情况下需要load一个flash-attn格式的模型呢。

@ZechengLi19
Copy link
Author

ZechengLi19 commented May 23, 2024

@DtYXs 比如说,我想把你训练好的chinese-clip用到其他下游任务中。

那我可能会有一个该下游任务的一个baseline代码,那我想换一个backbone的话,就希望调用load_from_name函数创建一个clip的backbone,如果我进一步的想微调clip的话,我觉得加上一个flash-attn可以更加好的帮助我代码的加速,这样~

也就是说,我把你的仓库当作一个包来用,那我其实就只需要看到load_from_name这一个函数,如果有flash-attn的支持可能会帮助到更多人用到下游中?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants