Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VQ-BeT #166

Draft
wants to merge 74 commits into
base: main
Choose a base branch
from
Draft

Add VQ-BeT #166

wants to merge 74 commits into from

Conversation

jayLEE0301
Copy link

@jayLEE0301 jayLEE0301 commented May 10, 2024

What this does

Add VQ-BeT for PushT env.

How it was tested

Explain/show how you tested your changes.

Examples:

  • Added configuration_vqbet.py and modeling_vqbet.py in vqbet folder.

How to checkout & try? (for the reviewer)

Examples:

python lerobot/scripts/train.py policy=vqbet env=pusht dataset_repo_id=lerobot/pusht

This change is Reviewable

Copy link
Collaborator

@alexander-soare alexander-soare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jayLEE0301 thanks so much for being the first to PR a model to LeRobot! The paper for VQ-BeT was a really nice read.

So, for the review. I've left a bunch of comments (many of them nits, but some blockers), and actually decided to stop reviewing partway through. That's because I noticed there are some high-level points I can share here. So instead, please consider these high level comments as my primary review, and my inline comments as examples to support.

So, our goal is to make this code highly accessible to the community, meaning it's easy to read and understand, and is easily hackable. A side effect of aiming for these goals is usually that the code is maintainable.

With that overarching goal in mind here are 3 high level points:

  1. Consider the VQBeTPolicy class as the only "public" object in the modeling file. Everything else is there for the sole purpose of VQBeTPolicy. This means:

    • Go minimalist. We should drop any kwargs, conditional branching, or other logic that is unused. The other functions and logic should only be as dynamic as needed to serve VQBeTPolicy. Rule of thumb: if it can't be be activated via the configuration parameters, it can go
    • Use the config instead of many kwargs. Most of the other modules can take a config argument and make a self.config (avoids relisting parameters twice, and makes it that there's one source of truth for what the params mean - no need to repeat documentation or type-hinting).
  2. Consolidate code: We want to avoid too much nesting or duplication of code. Consider for example my inline comment about the MLPs. I think it's reasonable to use one class for MLPs (and it can be simpler and shorter than the 3 existing classes now). This is just an example though, there may be more opportunities for consolidation.

  3. Documentation and naming: We want to make sure that everything is well understood by a first-time reader. Wear the hat of someone who has read through your paper once, and enters the code via the VQBeTPolicy class. They should be able to traverse the submodule hierarchy, understanding what everything is as they go. And they should be able to make sense of what's happening in the forward function.

    • Above all, please make sure the VQBeTConfig documentation is solid.
    • Please add docstrings to classes and methods when it wouldn't be obvious what they are in relation to the main policy and paper.
    • Please separate long methods into logical blocks with comments so that one doesn't get lost along they way. (btw: this doesn't mean separating them into smaller functions)
    • Please make sure it's easy to follow what's happening with tensor dimensions. einops is also helpful for that.
    • Favor full words over abbreviations: embd -> embed and try to match the terminology/naming in your paper.

When in doubt, please take inspiration from LeRobot's ACT and TD-MPC (Diffusion Policy is good too but may need a little more work).

lerobot/common/policies/vqbet/configuration_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/configuration_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/configuration_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/configuration_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/configuration_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/configs/policy/vqbet.yaml Outdated Show resolved Hide resolved
@aliberts aliberts added the 🧠 Policies Something policies-related label May 12, 2024
@jayLEE0301
Copy link
Author

Thank you for the review!

Following these high-level points,

  1. we removed all kwargs as possible, and only remained configs.
  2. consolidated all the similar functions
  3. added comments, and changed names

for all the parts of this PR.

Copy link
Collaborator

@alexander-soare alexander-soare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint. I will continue next week.

lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved


# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: perhaps a call to self.reset() at the bottom of the __init__ would be more appropriate? See

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added self.reset() at the bottom of __init__

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we can also drop self._queues = None please? It doesn't hurt from a logic perspective, but it does potentially confuse someone who will wonder why there's a redundant line of code here.

lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
features = self.policy(observation_feature)
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode

# only extract the output tokens at the position of action query
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious: If this is the case, what function to the other action tokens serve other than to increase compute for a forward pass?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This increases computation, but can help improve overall learning performance, and avoiding overfitting (not always)

You can think of it similar to predicting a longer sequence of actions in a diffusion policy compared to the actual sequence of actions to be performed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Mind adding that in as a comment (if it's not already mentioned in your paper)?

Copy link
Author

@jayLEE0301 jayLEE0301 Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added
Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251). Thus, it predict historical action sequence, in addition to current and future actions (predicting future actions : optional).

lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/configuration_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
spatial_softmax_num_keypoints: int = 32
# VQ-VAE
discretize_step: int = 3000
vqvae_groups: int = 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some places of the code, this is statically handled, meaning changing this number will break things. Can we please either remove it as a parameter or make sure the code can handle it dynamically?

One example is the cbet_loss.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the code can handle vqvae_groups dynamically

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still seeing the use of "primary" and "secondary" in the code. For example VQBeTOptimizer.__init__. Am I misunderstanding something?

Copy link
Collaborator

@alexander-soare alexander-soare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just publishing my responses in a batch. Thanks for resolving these :D

Now moving on with the review.

lerobot/common/policies/vqbet/configuration_vqbet.py Outdated Show resolved Hide resolved


# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we can also drop self._queues = None please? It doesn't hurt from a logic perspective, but it does potentially confuse someone who will wonder why there's a redundant line of code here.

# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.check_discretized():
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think I understand this. Can you let me know if my understanding is correct?

n_different_codes: how many of the total possible VQ codes are being used (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed`.
n_different_combinations: how many different code combinations are being used out of all possible combinations. This can be at most `vqvae_n_embed ^ vqvae_groups` (hint consider the RVQ as a decision tree).

But shouldn't `n_different_codes` max out at `vqvae_n_embed * vqvae_groups`? That's how many codes there are in total. Or are you only referring to the codes of the first RVQ layer?

Btw: I think this is a great metric to track!

features = self.policy(observation_feature)
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode

# only extract the output tokens at the position of action query
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Mind adding that in as a comment (if it's not already mentioned in your paper)?

self.map_to_cbet_preds_bin: outputs probability of each code (for each layer).
The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT,
and the output dimension of `self.map_to_cbet_preds_bin` is `self.config.vqvae_groups * self.config.vqvae_n_embed`, where
`self.config.vqvae_groups` is number of RVQ layers, and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Same as an earlier revision above, can we please remove these duplicated explanations of what these variables mean?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the duplicated parts:) Thank you

lerobot/common/policies/vqbet/configuration_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
spatial_softmax_num_keypoints: int = 32
# VQ-VAE
discretize_step: int = 3000
vqvae_groups: int = 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still seeing the use of "primary" and "secondary" in the code. For example VQBeTOptimizer.__init__. Am I misunderstanding something?

}
return loss_dict

class VQBeTOptimizer(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I please ask that we consolidate and simplify here? Consider where we can get away with using one optimizer instead of many. I count 4 optimizers being initialized here and I'm not sure all of them are needed. I'll let you double check, but I think we might be able to get away with 2 or even just 1 (if you no_grad the quantizer when the discretization is done).

Feel free to let me know if this is not possible. I checked briefly, but not exhaustively.

At a higher level, we have a plan to have some way of the policy code providing the optimizer and scheduler. So I think you have made a good step towards that here. Right now we have train.py handling this logic and that's not nice. Ideally, what I think we want here is one method in the top-level policy class make_optimizer which handles everything. That way train.py can just call make_optimizer without having to know which specific policy it is. Here, this would mean taking Karpathy's configure optimizers logic and consolidating it into that same make_optimizer class. I don't think we want the optimizer creation distributed throughout various modules of the file.

Happy to get your input on all these thoughts.

Copy link
Author

@jayLEE0301 jayLEE0301 Jun 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestions!
I removed all the redundant optimizers, and merged all the optimizers for phase 1 and phase 2 into one, leaving only one optimizer. I also deleted def step, def zero_grad. (and put all the parameters for phase 2 in the same scheduler.)

We haven't done much analysis on how this affects the stability of training at this time, but (after running two seeds) we have found that it can produce similar performance to the uploaded model(https://huggingface.co/JayLee131/vqbet_pusht) based on the best checkpoint.

Perhaps a more diverse hyperparameter search may be needed in the future.

class VQBeTOptimizer(torch.optim.Adam):
    def __init__(self, policy, cfg):
        vqvae_params = (
            list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
            + list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
            + list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
        )
        decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
        decay_params = (
            decay_params
            + list(policy.vqbet.rgb_encoder.parameters())
            + list(policy.vqbet.state_projector.parameters())
            + list(policy.vqbet.rgb_feature_projector.parameters())
            + [policy.vqbet._action_token]
            + list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
        )

        if cfg.policy.sequentially_select:
            decay_params = (
                decay_params
                + list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
                + list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
            )
        else:
            decay_params = (
                decay_params
                + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
            )

        optim_groups = [
            {
                "params": decay_params,
                "weight_decay": cfg.training.adam_weight_decay,
                "lr": cfg.training.lr,
            },
            {
                "params": vqvae_params,
                "weight_decay": 0.0001,
                "lr": cfg.training.vqvae_lr,
            },
            {
                "params": no_decay_params,
                "weight_decay": 0.0,
                "lr": cfg.training.lr,
            },
        ]
        super(VQBeTOptimizer, self).__init__(
            optim_groups,
            cfg.training.lr,
            cfg.training.adam_betas,
            cfg.training.adam_eps,
        )

lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
lerobot/common/policies/vqbet/modeling_vqbet.py Outdated Show resolved Hide resolved
else:
self.eval()

def draw_logits_forward(self, encoding_logits):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add a docstring here or change the function name to something more apparent? I'm not sure what it means to draw logits forward.

Note: I think most of the function names are self-explanatory, so I really do just mean this one and draw_code_forward.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your opinion:) I removed def draw_logits_forward since it is not used, and changed def draw_code_forward to def get_embeddings_from_code

@jayLEE0301
Copy link
Author

Thank you for the review!

I've resolved all the comments. In high-level view,

  • I removed and consolidated all the redundant optimizers and schedulers (merged all the optimizers for phase 1 and phase 2 into one, leaving only one optimizer). I also deleted def step, def zero_grad
  • removed redundant functions, and now using original load_state_dict, train, and eval of nn.module. In fact, the existence of these functions was to prevent the EMA updates after RVQ training has ended. This is now implemented instead via self.vq_layer.freeze_codebook = torch.tensor(True) and torch.no_grad()
  • Added comments and changed some confusing function names, and removed unused parts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🧠 Policies Something policies-related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants