Skip to content

[Codegen][GPU] Fix shared memory estimation for multi-buffering #23736

Merged
Yu-Zhewen merged 2 commits intoiree-org:mainfrom
Yu-Zhewen:num_stages_shared_mem
Mar 11, 2026
Merged

[Codegen][GPU] Fix shared memory estimation for multi-buffering #23736
Yu-Zhewen merged 2 commits intoiree-org:mainfrom
Yu-Zhewen:num_stages_shared_mem

Conversation

@Yu-Zhewen
Copy link
Contributor

  • Pass useDirectLoad and prefetchNumStages through calculateOperandsSharedMemoryUsedInBytes, so that it can account
    for multi-buffering when using direct loads with prefetching.
  • Properly guard direct load flag for scaled matmuls as forced off (Scaled matmul fails to compile when using global load DMA promotion #22119) to avoid inconsistency, by overriding the flag with a
    warning emitted.
  • As such, for regular matmul the shared memory usage scales with prefetchNumStages when useDirectLoad is enabled. For scaled matmul, shared memory usage is unchanged since direct load is forced off.

Signed-off-by: Yu-Zhewen <[email protected]>
@Yu-Zhewen Yu-Zhewen marked this pull request as ready for review March 11, 2026 14:59
@Yu-Zhewen Yu-Zhewen requested review from jerryyin and lialan March 11, 2026 14:59
@kuhar
Copy link
Member

kuhar commented Mar 11, 2026

@bangtianliu @RattataKing we should mirror this in the tuner (eventually)


// CHECK-REMARKS-DIRECT-LOAD-3: [Analysis] SharedMemoryUsage
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Category:deduceMMASchedule
// CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=34816
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, find if we only need 1, or something is not right.

Comment on lines +87 to +97
// Account for multi-buffering when using direct loads.
int64_t numBuffers =
(useDirectLoad && prefetchNumStages > 0) ? prefetchNumStages : 1;

int64_t lhsSharedMemoryUsed = numBuffers * tileM * tileK * lhsBitwidth;
int64_t rhsSharedMemoryUsed =
numBuffers * numRhs * tileN * tileK * rhsBitwidth;
int64_t aScaleSharedMemoryUsed =
numBuffers * tileM * tileKo * lhsScaleBitwidth;
int64_t bScaleSharedMemoryUsed =
numBuffers * numRhs * tileN * tileKo * rhsScaleBitwidth;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be cleaner if we just multiply numBuffers at the very end before return if we are scaling numBuffers * uniformly.

Also, add a explicit comment about the ROCDLPrefetchSharedMemoryPass that the current decisions of multiple buffering is the num of multi-buffer equals to numStages in direct load mode.

Help me inspect another scenario too: At this point is the useDirectLoad a done deal or will it be overridden in some way later? I don't want us to be overly aggressive in our estimation if that's the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only place useDirectLoad gets overridden is the scaled matmul case discussed above (@lialan is looking into this as an orthogonal issue). By the time we reach calculateOperandsSharedMemoryUsedInBytes, the flag reflects the final decision, so there's no risk of over-estimation.

Signed-off-by: Yu-Zhewen <[email protected]>
@jerryyin jerryyin self-requested a review March 11, 2026 18:11
Copy link
Member

@jerryyin jerryyin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@lialan lialan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@Yu-Zhewen Yu-Zhewen merged commit 8206e32 into iree-org:main Mar 11, 2026
56 checks passed
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just missed the merge, but could you address my comment as a followup?

Comment on lines +159 to +160
bool canUpcastAcc = false, bool useDirectLoad = false,
int64_t prefetchNumStages = 0, bool mustBeAligned = true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the function docs to explain what the new params do.

Also, I think the naming here is a bit confusing. The decision to use direct load is independent to the decision to use multi-buffering AFAIU. It just happens that the decision is linked today. Perhaps instead of these params we could just have a single int64_t param called numMultiBufferingStages (defaulting to 1), and then we just multiply by that value in the shared mem computation.

Copy link
Contributor Author

@Yu-Zhewen Yu-Zhewen Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the function docs to explain what the new params do.

Thanks, will do.

Also, I think the naming here is a bit confusing. The decision to use direct load is independent to the decision to use multi-buffering AFAIU. It just happens that the decision is linked today. Perhaps instead of these params we could just have a single int64_t param called numMultiBufferingStages (defaulting to 1), and then we just multiply by that value in the shared mem computation.

I'd prefer to keep the two separate parameters, as they correspond to two user-facing flags. Also, prefetchNumStages has different resource implications (multi-buffering via LDS or VGPRs) depending on the load mode.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants