[Codegen][GPU] Fix shared memory estimation for multi-buffering #23736
[Codegen][GPU] Fix shared memory estimation for multi-buffering #23736Yu-Zhewen merged 2 commits intoiree-org:mainfrom
Conversation
Signed-off-by: Yu-Zhewen <[email protected]>
|
@bangtianliu @RattataKing we should mirror this in the tuner (eventually) |
compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
Show resolved
Hide resolved
|
|
||
| // CHECK-REMARKS-DIRECT-LOAD-3: [Analysis] SharedMemoryUsage | ||
| // CHECK-REMARKS-DIRECT-LOAD-3-SAME: Category:deduceMMASchedule | ||
| // CHECK-REMARKS-DIRECT-LOAD-3-SAME: Remark=34816 |
There was a problem hiding this comment.
same here, find if we only need 1, or something is not right.
compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
Show resolved
Hide resolved
compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp
Show resolved
Hide resolved
| // Account for multi-buffering when using direct loads. | ||
| int64_t numBuffers = | ||
| (useDirectLoad && prefetchNumStages > 0) ? prefetchNumStages : 1; | ||
|
|
||
| int64_t lhsSharedMemoryUsed = numBuffers * tileM * tileK * lhsBitwidth; | ||
| int64_t rhsSharedMemoryUsed = | ||
| numBuffers * numRhs * tileN * tileK * rhsBitwidth; | ||
| int64_t aScaleSharedMemoryUsed = | ||
| numBuffers * tileM * tileKo * lhsScaleBitwidth; | ||
| int64_t bScaleSharedMemoryUsed = | ||
| numBuffers * numRhs * tileN * tileKo * rhsScaleBitwidth; |
There was a problem hiding this comment.
I think it'd be cleaner if we just multiply numBuffers at the very end before return if we are scaling numBuffers * uniformly.
Also, add a explicit comment about the ROCDLPrefetchSharedMemoryPass that the current decisions of multiple buffering is the num of multi-buffer equals to numStages in direct load mode.
Help me inspect another scenario too: At this point is the useDirectLoad a done deal or will it be overridden in some way later? I don't want us to be overly aggressive in our estimation if that's the case.
There was a problem hiding this comment.
The only place useDirectLoad gets overridden is the scaled matmul case discussed above (@lialan is looking into this as an orthogonal issue). By the time we reach calculateOperandsSharedMemoryUsedInBytes, the flag reflects the final decision, so there's no risk of over-estimation.
compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse_gfx950.mlir
Show resolved
Hide resolved
Signed-off-by: Yu-Zhewen <[email protected]>
Max191
left a comment
There was a problem hiding this comment.
Sorry, I just missed the merge, but could you address my comment as a followup?
| bool canUpcastAcc = false, bool useDirectLoad = false, | ||
| int64_t prefetchNumStages = 0, bool mustBeAligned = true, |
There was a problem hiding this comment.
Please update the function docs to explain what the new params do.
Also, I think the naming here is a bit confusing. The decision to use direct load is independent to the decision to use multi-buffering AFAIU. It just happens that the decision is linked today. Perhaps instead of these params we could just have a single int64_t param called numMultiBufferingStages (defaulting to 1), and then we just multiply by that value in the shared mem computation.
There was a problem hiding this comment.
Please update the function docs to explain what the new params do.
Thanks, will do.
Also, I think the naming here is a bit confusing. The decision to use direct load is independent to the decision to use multi-buffering AFAIU. It just happens that the decision is linked today. Perhaps instead of these params we could just have a single int64_t param called numMultiBufferingStages (defaulting to 1), and then we just multiply by that value in the shared mem computation.
I'd prefer to keep the two separate parameters, as they correspond to two user-facing flags. Also, prefetchNumStages has different resource implications (multi-buffering via LDS or VGPRs) depending on the load mode.
useDirectLoadandprefetchNumStagesthroughcalculateOperandsSharedMemoryUsedInBytes, so that it can accountfor multi-buffering when using direct loads with prefetching.
warning emitted.
prefetchNumStageswhenuseDirectLoadis enabled. For scaled matmul, shared memory usage is unchanged since direct load is forced off.