-
Notifications
You must be signed in to change notification settings - Fork 588
[JAX] Fix incorrect calculation of segment pos from segment ids in user-facing API #2523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Fix incorrect calculation of segment pos from segment ids in user-facing API #2523
Conversation
|
/te-ci jax L0 L1 |
|
/te-ci jax L0 L1 |
Greptile SummaryThis PR fixes incorrect segment position calculation in Key Changes:
Critical Issue Found:
Impact:
Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant SequenceDescriptor
participant from_segment_ids_and_pos
participant generate_default_pos
User->>from_segment_ids_and_pos: segment_ids, segment_pos=None, is_thd, is_segment_ids_reordered
alt segment_pos is None
from_segment_ids_and_pos->>from_segment_ids_and_pos: Check is_segment_ids_reordered
alt is_segment_ids_reordered=True
from_segment_ids_and_pos->>from_segment_ids_and_pos: assert not is_thd (line 833)
Note over from_segment_ids_and_pos: Reject THD + reordered<br/>without explicit segment_pos
from_segment_ids_and_pos->>from_segment_ids_and_pos: assert is_thd (line 839) ❌
Note over from_segment_ids_and_pos: CRITICAL BUG:<br/>Contradictory assertion!
end
from_segment_ids_and_pos->>generate_default_pos: seg_ids
alt is_thd=True
generate_default_pos->>generate_default_pos: Calculate segment starts
generate_default_pos->>generate_default_pos: Calculate segment_start_offsets
generate_default_pos->>generate_default_pos: Find last non-zero index
generate_default_pos->>generate_default_pos: Compute positions and mask padding
generate_default_pos-->>from_segment_ids_and_pos: THD segment_pos
else is_thd=False (BSHD)
generate_default_pos->>generate_default_pos: Simple arange broadcast
generate_default_pos-->>from_segment_ids_and_pos: BSHD segment_pos
end
else segment_pos provided
from_segment_ids_and_pos->>from_segment_ids_and_pos: Use provided segment_pos
end
from_segment_ids_and_pos-->>User: SequenceDescriptor object
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/jax/attention.py, line 841-842 (link)logic:
q_seg_idsandkv_seg_idsare used here but not defined until line 847. This will cause aNameErrorat runtime whensegment_posisNone.
2 files reviewed, 1 comment
|
/te-ci jax L0 L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/jax/attention.py, line 829-832 (link)style: f-strings in assert messages won't be evaluated until assertion fails
2 files reviewed, 1 comment
…es and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
499f4de to
217ea58
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/jax/attention.py, line 833-837 (link)logic: warning is shown even for the valid case (BSHD without load balancing)
this will cause users to see warnings every time they correctly use
segment_pos=Nonewith BSHD layout
2 files reviewed, 1 comment
2. Add support for THD in from_segment_ids_and_pos() 3. Assert if load balanced segment_ids is passed to generate a segment_pos Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/jax/attention.py, line 880-881 (link)syntax:
segment_idsis undefined, should beseg_ids(the parameter name on line 836) -
tests/jax/test_fused_attn.py, line 680-681 (link)syntax:
is_thdandis_load_balancedare keyword-only parameters (note the*on line 800 inattention.py), so they must be passed as keyword arguments -
tests/jax/test_fused_attn.py, line 709-710 (link)syntax:
is_thdandis_load_balancedare keyword-only parameters (note the*on line 800 inattention.py), so they must be passed as keyword arguments
2 files reviewed, 3 comments
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
tests/jax/test_fused_attn.py, line 706-711 (link)logic: passing
segment_pos=Nonewithis_load_balanced=Truewill trigger assertion error onattention.py:830the BSHD + load balanced + SegmentIDs combination should either:
- pass explicit segment_pos (like THD does on line 677-678), or
- skip this format when load balanced
currently only works because default
mesh_shape=(1,1,1)meanscp_size=1 -
transformer_engine/jax/attention.py, line 794 (link)style: misleading TODO - THD layout (non load balanced) IS now supported by the implementation below (lines 837-878)
-
transformer_engine/jax/attention.py, line 839-840 (link)style: assumption not enforced - if sequence starts with padding (
seg_ids[0]=0),first_is_segment=Truewill incorrectly mark it as segment startconsider asserting
seg_ids[:, 0] != 0or handle leading padding explicitly
2 files reviewed, 3 comments
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
…actually reordered segment_ids and segment_pos Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
for more information, see https://pre-commit.ci
…ncing Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
transformer_engine/jax/attention.py, line 833-844 (link)logic: contradictory assertions will always fail when
is_segment_ids_reordered=Trueline 833 asserts
not is_thd(meaning is_thd must be False)
line 839 assertsis_thd(meaning is_thd must be True)both assertions cannot be satisfied simultaneously, causing the function to always fail when
segment_pos=Noneandis_segment_ids_reordered=True -
transformer_engine/jax/attention.py, line 850-851 (link)style: assumption that first token is always a valid segment may be incorrect if THD format allows leading padding
if
seg_ids[0, 0] == 0(first token is padding), this assumption causes incorrect segment_pos calculation. verify THD format guarantees no leading padding or add validation
2 files reviewed, 2 comments
for more information, see https://pre-commit.ci Fix lint errors Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
4101df5 to
74a352e
Compare
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/jax/attention.py, line 833-844 (link)logic: contradictory assertions on lines 833 and 839 - both will execute when
is_segment_ids_reordered=True, but line 833 requiresis_thd=Falsewhile line 839 requiresis_thd=Truethis means when
is_segment_ids_reordered=True:- if
is_thd=True: line 833 fails with "not supported for load balanced reordered (Striped) THD" - if
is_thd=False: line 839 fails with "not supported for load balanced reordered (Dual Chunk) BSHD"
based on line 822 comments and PR description, only THD with load balancing should set
is_segment_ids_reordered=True, so line 839's assertion should be removed (it contradicts the intended behavior) - if
-
transformer_engine/jax/attention.py, line 856 (link)style: potential edge case:
(seg_ids[..., 1:] != 0)condition may not handle all segment transitions correctlyconsider sequence:
[1, 1, 0, 2, 2]where 0 is padding in the middle- transition from 1→0 at index 2:
(1 != 0) & (0 != 0)= False (correctly not marked as segment start) - transition from 0→2 at index 3:
(0 != 2) & (2 != 0)= True (marked as segment start)
however, the assumption on line 850 is that "the first token belongs to a segment and is not a padded token", which conflicts with the possibility of having padding in the middle. verify whether THD format allows mid-sequence padding, and if so, this logic needs adjustment. does THD format allow padding (segment_id=0) in the middle of a sequence, or is padding only at the end?
- transition from 1→0 at index 2:
-
tests/jax/test_fused_attn.py, line 687-689 (link)style: simplify boolean expression
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
2 files reviewed, 3 comments
|
/te-ci jax L0 L1 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending CI, thanks!
Description
SequenceDescriptor's
from_segment_ids_and_pos()accepts thesegment_idsand an optionalsegment_posas input. This class is supposed to serve as a convenience method to do two things:segment_idsandsegment_posin a SequenceDescriptor object for TE to use downstreamsegment_posis not passed, then calculate/extrapolate itIn it's current form, the second functionality gives incorrect results for THD + non-reordered and THD + reordered cases as it merely uses an
arangeto calculate thesegment_posnaively. This could result in incorrect masking for these cases.Type of change
Changes
This PR makes few changes:
from_segment_ids_and_pos():is_thdandis_segment_ids_reordered- the only cases that this function can currently guarantee to support is BSHD with and without load balancing and, THD without load balancing.from_segment_ids_and_pos(). However, if thesegment_posare reordered and passed tofrom_segment_ids_and_pos()it will assertfrom_segment_ids_and_pos(), it will assertfrom_segment_ids_and_pos()Impact on user of the API:
is_thdandis_segment_ids_reorderedare not Optional and hence they will cause aTypeErrorfor current users of this API - a breaking change. However, this is needed to ensure correct usage of this APIsegment_idsare reordered or not. It is expected that thesegment_idspassed will be reordered only for THD load balancing. For all other cases thesegment_idsshould not be reorderedChecklist: