[Mlir-commits] [mlir] [MLIR][AMDGPU] Implement reifyResultShapes for FatRawBufferCastOp (PR #171839)

Zhewen Yu llvmlistbot at llvm.org
Thu Dec 11 13:03:38 PST 2025


https://github.com/Yu-Zhewen updated https://github.com/llvm/llvm-project/pull/171839

>From c4a442f82e4dfad3943f45e05f8614573c5c13a3 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Thu, 11 Dec 2025 06:35:24 -0800
Subject: [PATCH 1/3] first commit

Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  2 ++
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 18 ++++++++++
 .../resolve-shaped-type-result-dims.mlir      | 33 +++++++++++++++++++
 3 files changed, 53 insertions(+)
 create mode 100644 mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 6fbc90ded5824..7c866e32e46b5 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -377,6 +377,8 @@ def AMDGPU_FatRawBufferCastOp :
     AMDGPU_Op<"fat_raw_buffer_cast",
       [Pure,
        DeclareOpInterfaceMethods<InferTypeOpInterface>,
+       DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+                                 ["reifyResultShapes"]>,
        ViewLikeOpInterface, AttrSizedOperandSegments]>,
     Arguments<(ins AnyMemRef:$source,
       Optional<I64>:$validBytes,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b7a665b0f5367..d3c3087da8b33 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
@@ -146,6 +147,23 @@ LogicalResult FatRawBufferCastOp::inferReturnTypes(
   return success();
 }
 
+LogicalResult FatRawBufferCastOp::reifyResultShapes(
+    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  Value source = getSource();
+  auto sourceType = cast<MemRefType>(source.getType());
+  Location loc = getLoc();
+  SmallVector<OpFoldResult> shapes;
+  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+    if (sourceType.isDynamicDim(i)) {
+      shapes.push_back(builder.createOrFold<memref::DimOp>(loc, source, i));
+    } else {
+      shapes.push_back(builder.getIndexAttr(sourceType.getDimSize(i)));
+    }
+  }
+  reifiedReturnShapes.push_back(std::move(shapes));
+  return success();
+}
+
 LogicalResult FatRawBufferCastOp::verify() {
   FailureOr<MemRefType> expectedResultType =
       getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
diff --git a/mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir
new file mode 100644
index 0000000000000..6c7351a8923f3
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
+
+func.func @fat_raw_buffer_cast_static_dim(%arg0: memref<2x3xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cast = amdgpu.fat_raw_buffer_cast %arg0 : memref<2x3xf32>
+      to memref<2x3xf32, #amdgpu.address_space<fat_raw_buffer>>
+  %d0 = memref.dim %cast, %c0 : memref<2x3xf32, #amdgpu.address_space<fat_raw_buffer>>
+  %d1 = memref.dim %cast, %c1 : memref<2x3xf32, #amdgpu.address_space<fat_raw_buffer>>
+  return %d0, %d1 : index, index
+}
+//      CHECK: func @fat_raw_buffer_cast_static_dim
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[C3:.+]] = arith.constant 3 : index
+//      CHECK:   return %[[C2]], %[[C3]]
+
+// -----
+
+func.func @fat_raw_buffer_cast_dynamic_dim(%arg0: memref<4x?xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cast = amdgpu.fat_raw_buffer_cast %arg0 : memref<4x?xf32>
+      to memref<4x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+  %d0 = memref.dim %cast, %c0 : memref<4x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+  %d1 = memref.dim %cast, %c1 : memref<4x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+  return %d0, %d1 : index, index
+}
+//      CHECK: func @fat_raw_buffer_cast_dynamic_dim
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x?xf32>
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
+//      CHECK:   %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//      CHECK:   return %[[C4]], %[[D1]]

>From f7a1dd63d7f3602fee52d5874ac90a9204babd05 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Thu, 11 Dec 2025 11:06:24 -0800
Subject: [PATCH 2/3] fix format

Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index d3c3087da8b33..ddc59197fa673 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -13,9 +13,9 @@
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"

>From 7c017e1611f234ef34573002321cdd7d26999cf6 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Thu, 11 Dec 2025 12:57:34 -0800
Subject: [PATCH 3/3] use reifyDimOfResult

Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  2 +-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 21 +++++++------------
 2 files changed, 9 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7c866e32e46b5..cd98b0a1f940d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -378,7 +378,7 @@ def AMDGPU_FatRawBufferCastOp :
       [Pure,
        DeclareOpInterfaceMethods<InferTypeOpInterface>,
        DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
-                                 ["reifyResultShapes"]>,
+                                 ["reifyDimOfResult"]>,
        ViewLikeOpInterface, AttrSizedOperandSegments]>,
     Arguments<(ins AnyMemRef:$source,
       Optional<I64>:$validBytes,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index ddc59197fa673..e943046255665 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -147,21 +147,16 @@ LogicalResult FatRawBufferCastOp::inferReturnTypes(
   return success();
 }
 
-LogicalResult FatRawBufferCastOp::reifyResultShapes(
-    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(OpBuilder &builder,
+                                                             int resultIndex,
+                                                             int dim) {
+  assert(resultIndex == 0 && "FatRawBufferCastOp has a single result");
   Value source = getSource();
   auto sourceType = cast<MemRefType>(source.getType());
-  Location loc = getLoc();
-  SmallVector<OpFoldResult> shapes;
-  for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
-    if (sourceType.isDynamicDim(i)) {
-      shapes.push_back(builder.createOrFold<memref::DimOp>(loc, source, i));
-    } else {
-      shapes.push_back(builder.getIndexAttr(sourceType.getDimSize(i)));
-    }
-  }
-  reifiedReturnShapes.push_back(std::move(shapes));
-  return success();
+  if (sourceType.isDynamicDim(dim))
+    return OpFoldResult(
+        builder.createOrFold<memref::DimOp>(getLoc(), source, dim));
+  return OpFoldResult(builder.getIndexAttr(sourceType.getDimSize(dim)));
 }
 
 LogicalResult FatRawBufferCastOp::verify() {



More information about the Mlir-commits mailing list