Distributed SaveLoad implementation for semi-auto strategy#59659
Distributed SaveLoad implementation for semi-auto strategy#59659zhiqiu merged 45 commits intoPaddlePaddle:developfrom
Conversation
… dist_save_load
… dist_save_load
… dist_save_load
… dist_save_load
|
你的PR提交成功,感谢你对开源项目的贡献! |
| paddle.distributed.get_world_size() > 1 or coordinator_rank != 0 | ||
| ): | ||
| raise ValueError( | ||
| f"use_dist is False, please set coordinator_rank to 0 and paddle.distributed.get_world_size() to 1, world_size:{paddle.distributed.get_world_size()}, coordinator_rank:{coordinator_rank}" |
There was a problem hiding this comment.
Why not allow use_dist=false and world_size > 1?
There was a problem hiding this comment.
use_dist是针对单卡的情况的,但貌似不需要用户指定,在内部通过use_dist=True if world_size>1 else False来确定就行。save_state_dict的设计是导出当前训练时候分布式策略下的模型,如果当前是分布式的就导出分布式的,如果是单卡的就导出单卡的,不支持直接在分布式的情况下导出单卡模型,如果需要导出单卡模型,需要先定义单卡模型,用load_state_dict加载再用save_state_dict导出即可
| return tuple(local_shape), tuple(global_offset) | ||
|
|
||
|
|
||
| def flatten_state_dict(state_dict): |
There was a problem hiding this comment.
是个TODO,为了支持state_dict={"model":model.state_dict(), "optimizer":optimizer.state_dict()}这种情况,但目前还未实现,先不对传入的state_dict进行操作
| if coordinator_rank == paddle.distributed.get_rank(): | ||
| logger.debug(f"metadata:{metadata}") | ||
| paddle.save(metadata, os.path.join(path, f"{unique_id}.metadata")) |
There was a problem hiding this comment.
why not save meta on all ranks?
There was a problem hiding this comment.
meta是global的,每个rank上是一样的,只需要保存一份
There was a problem hiding this comment.
我明白,每个rank都save是不是方便调试,不必都找rank 0?meta 也不占很多空间。
There was a problem hiding this comment.
这个可能不行,因为每个机器都有多个卡,多个卡同时写一个文件可能会出问题,导致写入的内容不符合预期
| The identifier of a local tensor. | ||
| """ | ||
|
|
||
| tensor_id: str |
There was a problem hiding this comment.
tensor_name貌似不太合适,这个是个标识,在动半中是structure_name,在静半中是tensor的名字。叫tensor_key与tensor_id的意思类似,也是可以的,如果觉得tensor_key更合适,可更改
| local_tensor_index not in tensor_id_list | ||
| ), f"Duplicate tensor_id:{local_tensor_index} found. Check whether the metadata_file:{metadata_file} contains the same tensor metadata." | ||
| tensor_id_list.append(local_tensor_index.tensor_id) | ||
| if local_tensor_index.tensor_id in state_dict: |
There was a problem hiding this comment.
The state_dict is local_state_dict?
There was a problem hiding this comment.
这个state_dict是每个rank自己维护的那个,是local的
| for rank, local_files in enumerate(global_data_files): | ||
| if len(local_files) > 0: | ||
| local_files = [ | ||
| f for f in local_files if f in necessary_data_files_set |
There was a problem hiding this comment.
When does local_files differ from necessary_data_files_set?
There was a problem hiding this comment.
necessary_data_files_set是指当前state_dict的key命中的所有需要的文件,这些文件可能分布在其他rank上,local_files这里是个list,确实包含了所有rank可以读到的文件总和,但是不排除这些可以读到的文件总和是大于state_dict所需要读到的数据文件的,所以这里做了一个过滤的逻辑,只处理需要用到的文件
There was a problem hiding this comment.
大于的话没有关系,不需要warning,因为不影响当前参数的加载
| @@ -0,0 +1,21 @@ | |||
| # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | |||
| @@ -0,0 +1,497 @@ | |||
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | |||
| if f not in file_to_ranks: | ||
| file_to_ranks[f] = [] | ||
| file_to_ranks[f].append(r) | ||
| logger.info(f"file_to_ranks:{file_to_ranks}") |
There was a problem hiding this comment.
logger系列调试信息后续会清理吗?如果不清理建议规范化一下
There was a problem hiding this comment.
会打算在最后合入前统一清理,如果规范化的话,是有指定格式吗
| @@ -0,0 +1,42 @@ | |||
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | |||
| v._local_value().add_(paddle.ones_like(v._local_value())) | ||
| paddle.distributed.load_state_dict(state_dict, ckpt_path()) | ||
| for k, v in state_dict.items(): | ||
| assert k in local_state_dict, k |
There was a problem hiding this comment.
最后那个k是打印内容,assert用法是assert condition, error_message
| assert k in local_state_dict, k | ||
| if v._is_initialized(): | ||
| self.check_tensor_eq(v._local_value(), local_state_dict[k]) | ||
| os.system(f"rm -rf {ckpt_path()}") |
There was a problem hiding this comment.
use tempfile.TemporaryDirectory(), you can find examples in other ut.
|
中文api文档PR: PaddlePaddle/docs#6355 |
| __all__ = [ | ||
| "save_state_dict", | ||
| "load_state_dict", | ||
| ] |
There was a problem hiding this comment.
Only add API in list of __ all__ at recommended user path, as we recommend using paddle.distributed.save_state_dict and paddle.distributed.load_state_dict, there is no need to add them to this list. import above can be retained.
| def load_state_dict( | ||
| state_dict, | ||
| path, | ||
| process_group=None, | ||
| coordinator_rank=0, | ||
| ) -> None: |
There was a problem hiding this comment.
I saw in the design document that there is parameter of use_dist. Shall we need to implement use_dist which is not implemented here? If not, please explain the reason and modify the design document.
sunzhongkai588
left a comment
There was a problem hiding this comment.
API 文档请参考 英文模板,务必注意空行和缩进
| coordinator_rank(int): The rank used to save non distributed values. Rank0 is used by default. | ||
|
|
||
| Examples: | ||
| .. code-block:: python |
|
|
||
| Examples: | ||
| .. code-block:: python | ||
| >>> # doctest: +SKIP('Save state dict.') |
There was a problem hiding this comment.
| >>> # doctest: +SKIP('Save state dict.') | |
| >>> # doctest: +SKIP('state dict not exist'') |
跳过检查的原因写清晰一点叭,保证可读性
| ) -> None: | ||
| """ | ||
| Load the state_dict inplace from a checkpoint path. | ||
| Args: |
There was a problem hiding this comment.
| Args: | |
| Args: |
声明、参数..等各部分之间加空行,否则可能会导致官网渲染出错
| Example: | ||
| .. code-block:: python |
There was a problem hiding this comment.
| Example: | |
| .. code-block:: python | |
| Example: | |
| .. code-block:: python | |
同理
| coordinator_rank(int): The rank used to coordinate the checkpoint. Rank0 is used by default. | ||
| Example: | ||
| .. code-block:: python | ||
| >>> # doctest: +SKIP('Load state dict.') |
There was a problem hiding this comment.
| >>> # doctest: +SKIP('Load state dict.') | |
| >>> # doctest: +SKIP('state dict not exist') |
理由写清晰一点,保证可读性
sunzhongkai588
left a comment
There was a problem hiding this comment.
LGTM,先合入,后续进行相关修改

PR types
Others
PR changes
Others
Description
card-78318
Design the save_state_dict and load_state_dict api to support save and load checkpoint of dynamic and static graph semi-auto distributed training.