Skip to content

feature(whl): add rlhf pipeline.#748

Open
kxzxvbk wants to merge 18 commits intoopendilab:mainfrom
kxzxvbk:rlhf
Open

feature(whl): add rlhf pipeline.#748
kxzxvbk wants to merge 18 commits intoopendilab:mainfrom
kxzxvbk:rlhf

Conversation

@kxzxvbk
Copy link
Copy Markdown
Contributor

@kxzxvbk kxzxvbk commented Nov 6, 2023

Description

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@PaParaZz1 PaParaZz1 added enhancement New feature or request algo Add new algorithm or improve old one labels Nov 6, 2023
from .model import PPOFModel
from .config import get_instance_config, get_instance_env, get_hybrid_shape
from ding.bonus.common import TrainingReturn, EvalReturn
from ..framework.middleware.collector import ChatCollector
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge it into ding.framework

"""
Overview:
The class of the collector running by steps, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why indent here


def top_p_logits(logits, topp=0.9, filter_value=0, min_topk=1):
"""
Filter a distribution of logits using nucleus (top-p) filtering
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

polish comments add add unittest

if topp > 0:
logits_sorted, inds = torch.sort(logits, dim=-1, descending=True)
mask = (logits_sorted.cumsum(dim=-1) - logits_sorted) >= topp
mask[:, :min_topk] = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

..., :min_topk

@@ -1,4 +1,7 @@
from typing import Union, Dict, Optional

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move these modifications to a new single file: lm_vac.py


def __init__(self, config, opt, tokenizer):
super().__init__(config)
self.opt = opt
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why define opt here

else:
logits = self.reward_head(output.last_hidden_state).squeeze(-1)

return (logits, )
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why return a tuple here

self._init_flag = False

def reset(self):
self.last_batch = next(self.generator)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to restrat generatore here?


class LlamaRewardModel(LlamaForCausalLM):

def __init__(self, config, opt, tokenizer):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move the creation of tokenizer insides the constructor of RM?

@@ -0,0 +1,50 @@
from easydict import EasyDict
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move it to dizoo/chat/entry

@codecov
Copy link
Copy Markdown

codecov bot commented Jan 3, 2024

Codecov Report

❌ Patch coverage is 20.50473% with 252 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.83%. Comparing base (d7a61c2) to head (f3a8245).
⚠️ Report is 94 commits behind head on main.

Files with missing lines Patch % Lines
ding/model/template/lm_vac.py 20.00% 92 Missing ⚠️
ding/policy/ppof.py 5.74% 82 Missing ⚠️
ding/framework/middleware/collector.py 15.62% 27 Missing ⚠️
ding/rl_utils/gae.py 11.11% 16 Missing ⚠️
ding/reward_model/language_reward_model.py 31.57% 13 Missing ⚠️
ding/bonus/ppof.py 0.00% 12 Missing ⚠️
ding/bonus/config.py 0.00% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #748      +/-   ##
==========================================
+ Coverage   76.78%   76.83%   +0.04%     
==========================================
  Files         671      674       +3     
  Lines       53196    53935     +739     
==========================================
+ Hits        40847    41440     +593     
- Misses      12349    12495     +146     
Flag Coverage Δ
unittests 76.83% <20.50%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

algo Add new algorithm or improve old one enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants