Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A small trick for memory efficiency #10

Open
Hannibal046 opened this issue Jun 28, 2022 · 1 comment
Open

A small trick for memory efficiency #10

Hannibal046 opened this issue Jun 28, 2022 · 1 comment

Comments

@Hannibal046
Copy link

Hi, in this part,

BRIO/modeling_bart.py

Lines 1863 to 1869 in 135f0e5

if self.is_scoring_mode:
cand_num = decoder_input_ids.size(1)
encoder_hidden_states = encoder_outputs[0]
encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, cand_num, dim=0)
attention_mask = torch.repeat_interleave(attention_mask, cand_num, dim=0)
decoder_input_ids = decoder_input_ids.view(-1, decoder_input_ids.size(-1))
decoder_attention_mask = decoder_attention_mask.view(-1, decoder_attention_mask.size(-1))

since the encoder_hidden_states and attention_mask won't be changed in the decoder, a new view for them is more memory efficient than repeat_interleave. Because repeat operation in pytorch would copy the data storage as illustrated
image
using index_select with proper index will be better:
draft ipynb — graph_sum_reranker  SSH: 45a3159k71 zicp vip  2022-06-28 14-30-22
so a simple modification is :

if self.is_scoring_mode:
        batch_size,cand_num,_ = decoder_input_ids.shape
        encoder_hidden_states = encoder_outputs[0]
        expanded_return_idx = torch.arange(batch_size).view(-1,1).repeat(1,cand_num).view(-1).to(encoder_hidden_states.device)
        encoder_hidden_states = encoder_hidden_states.index_select(0,expanded_return_idx)
        attention_mask = attention_mask.index_select(0,expanded_return_idx)
        decoder_input_ids = decoder_input_ids.view(-1, decoder_input_ids.size(-1))
        decoder_attention_mask = decoder_attention_mask.view(-1, decoder_attention_mask.size(-1))
@yixinL7
Copy link
Owner

yixinL7 commented Jul 20, 2022

Thanks a lot for the suggestion! I was wondering if you have tried this modification and observed the difference in memory usage? It would be really great if you could provide some statistics about the change in memory usage and create a pull request for your suggested changes :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants