Skip to content

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 16, 2025

Description

SequenceDescriptor's from_segment_ids_and_pos() accepts the segment_ids and an optional segment_pos as input. This class is supposed to serve as a convenience method to do two things:

  1. Stuff the segment_ids and segment_pos in a SequenceDescriptor object for TE to use downstream
  2. If the segment_pos is not passed, then calculate/extrapolate it

In it's current form, the second functionality gives incorrect results for THD + non-reordered and THD + reordered cases as it merely uses an arange to calculate the segment_pos naively. This could result in incorrect masking for these cases.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR makes few changes:

  1. Passed two new args to this function from_segment_ids_and_pos() : is_thd and is_segment_ids_reordered- the only cases that this function can currently guarantee to support is BSHD with and without load balancing and, THD without load balancing.
  2. BSHD with load balancing is supported natively because the segment_ids and segment_pos are not reordered before passing onto from_segment_ids_and_pos(). However, if the segment_pos are reordered and passed to from_segment_ids_and_pos() it will assert
  3. If THD + reordered use cases calls the function from_segment_ids_and_pos(), it will assert
  4. The fused attn tests were modified to account for these two new args
    • For THD fused attn non-CP tests, segment_pos=None is passed so as to exercise the newly added THD path in from_segment_ids_and_pos()
    • For THD fused attn CP tests, segment_pos is explicitly passed (as before, not a new change)
    • For BSHD fused attn CP tests, segment_pos=None is passed so as to exercise the default BSHD path to generate segment_pos (as before, not a new change)

Impact on user of the API:

  1. These two new args, is_thd and is_segment_ids_reordered are not Optional and hence they will cause a TypeError for current users of this API - a breaking change. However, this is needed to ensure correct usage of this API
  2. The user is now expected to let the API know whether this is a THD or BSHD layout and whether the segment_ids are reordered or not. It is expected that the segment_ids passed will be reordered only for THD load balancing. For all other cases the segment_ids should not be reordered

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Dec 16, 2025
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

@KshitijLakhani KshitijLakhani added attention jax bug Something isn't working labels Dec 16, 2025
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

@KshitijLakhani KshitijLakhani marked this pull request as ready for review December 17, 2025 02:16
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 17, 2025

Greptile Summary

This PR fixes incorrect segment position calculation in from_segment_ids_and_pos() by adding proper THD (packed) layout support and introducing two new required parameters: is_thd and is_segment_ids_reordered. Previously, the function naively used arange() for all cases, which produced incorrect masking for THD layouts.

Key Changes:

  • Added is_thd and is_segment_ids_reordered parameters (breaking change - now required, not optional)
  • Implemented proper THD segment position calculation that detects segment boundaries and handles padding
  • Added assertions to prevent unsupported combinations (THD + reordered segment_ids without explicit segment_pos)
  • Updated tests to pass new required parameters based on layout and context parallelism settings

Critical Issue Found:

  • Lines 833 and 839 in attention.py contain contradictory assertions that will always fail when is_segment_ids_reordered=True - one requires is_thd=False while the other requires is_thd=True

Impact:

  • Breaking API change requiring users to specify layout type and reordering status
  • Supported: BSHD (with/without load balancing) and THD (without load balancing) when segment_pos=None
  • Unsupported: THD + load balancing requires explicit segment_pos parameter

Confidence Score: 1/5

  • This PR contains a critical logic error that will cause runtime failures
  • Score of 1 reflects the contradictory assertions on lines 833-844 of attention.py that make is_segment_ids_reordered=True impossible to use - the code will always fail assertion checks regardless of is_thd value. This is a blocking issue that needs to be fixed before merge.
  • Pay critical attention to transformer_engine/jax/attention.py lines 833-844 - the contradictory assertion logic must be fixed

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Contains critical logic error in assertion checks (lines 833 and 839) - contradictory assertions will always fail when is_segment_ids_reordered=True. New THD segment position calculation logic added.
tests/jax/test_fused_attn.py Test updates correctly pass new required parameters is_thd and is_segment_ids_reordered to from_segment_ids_and_pos(). Logic for conditionally passing segment_pos appears correct.

Sequence Diagram

sequenceDiagram
    participant User
    participant SequenceDescriptor
    participant from_segment_ids_and_pos
    participant generate_default_pos
    
    User->>from_segment_ids_and_pos: segment_ids, segment_pos=None, is_thd, is_segment_ids_reordered
    
    alt segment_pos is None
        from_segment_ids_and_pos->>from_segment_ids_and_pos: Check is_segment_ids_reordered
        
        alt is_segment_ids_reordered=True
            from_segment_ids_and_pos->>from_segment_ids_and_pos: assert not is_thd (line 833)
            Note over from_segment_ids_and_pos: Reject THD + reordered<br/>without explicit segment_pos
            from_segment_ids_and_pos->>from_segment_ids_and_pos: assert is_thd (line 839) ❌
            Note over from_segment_ids_and_pos: CRITICAL BUG:<br/>Contradictory assertion!
        end
        
        from_segment_ids_and_pos->>generate_default_pos: seg_ids
        
        alt is_thd=True
            generate_default_pos->>generate_default_pos: Calculate segment starts
            generate_default_pos->>generate_default_pos: Calculate segment_start_offsets
            generate_default_pos->>generate_default_pos: Find last non-zero index
            generate_default_pos->>generate_default_pos: Compute positions and mask padding
            generate_default_pos-->>from_segment_ids_and_pos: THD segment_pos
        else is_thd=False (BSHD)
            generate_default_pos->>generate_default_pos: Simple arange broadcast
            generate_default_pos-->>from_segment_ids_and_pos: BSHD segment_pos
        end
        
    else segment_pos provided
        from_segment_ids_and_pos->>from_segment_ids_and_pos: Use provided segment_pos
    end
    
    from_segment_ids_and_pos-->>User: SequenceDescriptor object
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/attention.py, line 841-842 (link)

    logic: q_seg_ids and kv_seg_ids are used here but not defined until line 847. This will cause a NameError at runtime when segment_pos is None.

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/attention.py, line 829-832 (link)

    style: f-strings in assert messages won't be evaluated until assertion fails

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

KshitijLakhani and others added 7 commits December 22, 2025 11:30
…es and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/incorrect-sequence-descr-from-seg-ids-only branch from 499f4de to 217ea58 Compare December 22, 2025 19:31
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/attention.py, line 833-837 (link)

    logic: warning is shown even for the valid case (BSHD without load balancing)

    this will cause users to see warnings every time they correctly use segment_pos=None with BSHD layout

2 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Kshitij Janardan Lakhani and others added 2 commits December 22, 2025 16:42
2. Add support for THD in from_segment_ids_and_pos()
3. Assert if load balanced segment_ids is passed to generate a segment_pos

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/jax/attention.py, line 880-881 (link)

    syntax: segment_ids is undefined, should be seg_ids (the parameter name on line 836)

  2. tests/jax/test_fused_attn.py, line 680-681 (link)

    syntax: is_thd and is_load_balanced are keyword-only parameters (note the * on line 800 in attention.py), so they must be passed as keyword arguments

  3. tests/jax/test_fused_attn.py, line 709-710 (link)

    syntax: is_thd and is_load_balanced are keyword-only parameters (note the * on line 800 in attention.py), so they must be passed as keyword arguments

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Kshitij Janardan Lakhani added 2 commits December 22, 2025 16:55
Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
@KshitijLakhani KshitijLakhani changed the title [JAX] Fix incorrect calculation of segment pos from segment ids [JAX] Fix incorrect calculation of segment pos from segment ids in user-facing API Dec 23, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. tests/jax/test_fused_attn.py, line 706-711 (link)

    logic: passing segment_pos=None with is_load_balanced=True will trigger assertion error on attention.py:830

    the BSHD + load balanced + SegmentIDs combination should either:

    1. pass explicit segment_pos (like THD does on line 677-678), or
    2. skip this format when load balanced

    currently only works because default mesh_shape=(1,1,1) means cp_size=1

  2. transformer_engine/jax/attention.py, line 794 (link)

    style: misleading TODO - THD layout (non load balanced) IS now supported by the implementation below (lines 837-878)

  3. transformer_engine/jax/attention.py, line 839-840 (link)

    style: assumption not enforced - if sequence starts with padding (seg_ids[0]=0), first_is_segment=True will incorrectly mark it as segment start

    consider asserting seg_ids[:, 0] != 0 or handle leading padding explicitly

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
@KshitijLakhani KshitijLakhani marked this pull request as draft December 23, 2025 01:15
…actually reordered segment_ids and segment_pos

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
pre-commit-ci bot and others added 3 commits December 23, 2025 22:18
…ncing

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani marked this pull request as ready for review December 24, 2025 00:23
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/jax/attention.py, line 833-844 (link)

    logic: contradictory assertions will always fail when is_segment_ids_reordered=True

    line 833 asserts not is_thd (meaning is_thd must be False)
    line 839 asserts is_thd (meaning is_thd must be True)

    both assertions cannot be satisfied simultaneously, causing the function to always fail when segment_pos=None and is_segment_ids_reordered=True

  2. transformer_engine/jax/attention.py, line 850-851 (link)

    style: assumption that first token is always a valid segment may be incorrect if THD format allows leading padding

    if seg_ids[0, 0] == 0 (first token is padding), this assumption causes incorrect segment_pos calculation. verify THD format guarantees no leading padding or add validation

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

for more information, see https://pre-commit.ci

Fix lint errors

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/incorrect-sequence-descr-from-seg-ids-only branch from 4101df5 to 74a352e Compare December 24, 2025 01:46
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/jax/attention.py, line 833-844 (link)

    logic: contradictory assertions on lines 833 and 839 - both will execute when is_segment_ids_reordered=True, but line 833 requires is_thd=False while line 839 requires is_thd=True

    this means when is_segment_ids_reordered=True:

    • if is_thd=True: line 833 fails with "not supported for load balanced reordered (Striped) THD"
    • if is_thd=False: line 839 fails with "not supported for load balanced reordered (Dual Chunk) BSHD"

    based on line 822 comments and PR description, only THD with load balancing should set is_segment_ids_reordered=True, so line 839's assertion should be removed (it contradicts the intended behavior)

  2. transformer_engine/jax/attention.py, line 856 (link)

    style: potential edge case: (seg_ids[..., 1:] != 0) condition may not handle all segment transitions correctly

    consider sequence: [1, 1, 0, 2, 2] where 0 is padding in the middle

    • transition from 1→0 at index 2: (1 != 0) & (0 != 0) = False (correctly not marked as segment start)
    • transition from 0→2 at index 3: (0 != 2) & (2 != 0) = True (marked as segment start)

    however, the assumption on line 850 is that "the first token belongs to a segment and is not a padded token", which conflicts with the possibility of having padding in the middle. verify whether THD format allows mid-sequence padding, and if so, this logic needs adjustment. does THD format allow padding (segment_id=0) in the middle of a sequence, or is padding only at the end?

  3. tests/jax/test_fused_attn.py, line 687-689 (link)

    style: simplify boolean expression

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

attention bug Something isn't working jax

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants