Skip to content

Commit 59ed41e

Browse files
authored
Fix tp (#42368)
* up * oups you need renamed key for merge moduleoist * oups there was something I forgot * update
1 parent 5169c23 commit 59ed41e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/core_model_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def convert_and_load_state_dict_in_model(
745745
mapping.distributed_operation = tp_layer(
746746
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
747747
)
748-
shard_index = len(mapping.collected_tensors)
748+
shard_index = len(mapping.collected_tensors.get(original_key, []))
749749
future = spawn_tp_materialize(
750750
thread_pool,
751751
tensor,

0 commit comments

Comments
 (0)