[Mlir-commits] [mlir] [MLIR][AMDGPU] Implement reifyResultShapes for FatRawBufferCastOp (PR #171839)
Zhewen Yu
llvmlistbot at llvm.org
Thu Dec 11 13:03:38 PST 2025
https://github.com/Yu-Zhewen updated https://github.com/llvm/llvm-project/pull/171839
>From c4a442f82e4dfad3943f45e05f8614573c5c13a3 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Thu, 11 Dec 2025 06:35:24 -0800
Subject: [PATCH 1/3] first commit
Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 ++
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 18 ++++++++++
.../resolve-shaped-type-result-dims.mlir | 33 +++++++++++++++++++
3 files changed, 53 insertions(+)
create mode 100644 mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 6fbc90ded5824..7c866e32e46b5 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -377,6 +377,8 @@ def AMDGPU_FatRawBufferCastOp :
AMDGPU_Op<"fat_raw_buffer_cast",
[Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>,
ViewLikeOpInterface, AttrSizedOperandSegments]>,
Arguments<(ins AnyMemRef:$source,
Optional<I64>:$validBytes,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b7a665b0f5367..d3c3087da8b33 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
@@ -146,6 +147,23 @@ LogicalResult FatRawBufferCastOp::inferReturnTypes(
return success();
}
+LogicalResult FatRawBufferCastOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ Value source = getSource();
+ auto sourceType = cast<MemRefType>(source.getType());
+ Location loc = getLoc();
+ SmallVector<OpFoldResult> shapes;
+ for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ if (sourceType.isDynamicDim(i)) {
+ shapes.push_back(builder.createOrFold<memref::DimOp>(loc, source, i));
+ } else {
+ shapes.push_back(builder.getIndexAttr(sourceType.getDimSize(i)));
+ }
+ }
+ reifiedReturnShapes.push_back(std::move(shapes));
+ return success();
+}
+
LogicalResult FatRawBufferCastOp::verify() {
FailureOr<MemRefType> expectedResultType =
getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
diff --git a/mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir
new file mode 100644
index 0000000000000..6c7351a8923f3
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
+
+func.func @fat_raw_buffer_cast_static_dim(%arg0: memref<2x3xf32>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 : memref<2x3xf32>
+ to memref<2x3xf32, #amdgpu.address_space<fat_raw_buffer>>
+ %d0 = memref.dim %cast, %c0 : memref<2x3xf32, #amdgpu.address_space<fat_raw_buffer>>
+ %d1 = memref.dim %cast, %c1 : memref<2x3xf32, #amdgpu.address_space<fat_raw_buffer>>
+ return %d0, %d1 : index, index
+}
+// CHECK: func @fat_raw_buffer_cast_static_dim
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: return %[[C2]], %[[C3]]
+
+// -----
+
+func.func @fat_raw_buffer_cast_dynamic_dim(%arg0: memref<4x?xf32>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 : memref<4x?xf32>
+ to memref<4x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+ %d0 = memref.dim %cast, %c0 : memref<4x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+ %d1 = memref.dim %cast, %c1 : memref<4x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+ return %d0, %d1 : index, index
+}
+// CHECK: func @fat_raw_buffer_cast_dynamic_dim
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x?xf32>
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: return %[[C4]], %[[D1]]
>From f7a1dd63d7f3602fee52d5874ac90a9204babd05 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Thu, 11 Dec 2025 11:06:24 -0800
Subject: [PATCH 2/3] fix format
Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index d3c3087da8b33..ddc59197fa673 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -13,9 +13,9 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
>From 7c017e1611f234ef34573002321cdd7d26999cf6 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Thu, 11 Dec 2025 12:57:34 -0800
Subject: [PATCH 3/3] use reifyDimOfResult
Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 +-
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 21 +++++++------------
2 files changed, 9 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7c866e32e46b5..cd98b0a1f940d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -378,7 +378,7 @@ def AMDGPU_FatRawBufferCastOp :
[Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
- ["reifyResultShapes"]>,
+ ["reifyDimOfResult"]>,
ViewLikeOpInterface, AttrSizedOperandSegments]>,
Arguments<(ins AnyMemRef:$source,
Optional<I64>:$validBytes,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index ddc59197fa673..e943046255665 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -147,21 +147,16 @@ LogicalResult FatRawBufferCastOp::inferReturnTypes(
return success();
}
-LogicalResult FatRawBufferCastOp::reifyResultShapes(
- OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder,
+ int resultIndex,
+ int dim) {
+ assert(resultIndex == 0 && "FatRawBufferCastOp has a single result");
Value source = getSource();
auto sourceType = cast<MemRefType>(source.getType());
- Location loc = getLoc();
- SmallVector<OpFoldResult> shapes;
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
- if (sourceType.isDynamicDim(i)) {
- shapes.push_back(builder.createOrFold<memref::DimOp>(loc, source, i));
- } else {
- shapes.push_back(builder.getIndexAttr(sourceType.getDimSize(i)));
- }
- }
- reifiedReturnShapes.push_back(std::move(shapes));
- return success();
+ if (sourceType.isDynamicDim(dim))
+ return OpFoldResult(
+ builder.createOrFold<memref::DimOp>(getLoc(), source, dim));
+ return OpFoldResult(builder.getIndexAttr(sourceType.getDimSize(dim)));
}
LogicalResult FatRawBufferCastOp::verify() {
More information about the Mlir-commits
mailing list