[Mlir-commits] [mlir] [MLIR][AMDGPU] Implement reifyResultShapes for FatRawBufferCastOp (PR #171839)
Zhewen Yu
llvmlistbot at llvm.org
Thu Dec 11 09:57:57 PST 2025
https://github.com/Yu-Zhewen updated https://github.com/llvm/llvm-project/pull/171839
>From fcb6523a1688253308eb463168b3c72fe148e0b3 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Thu, 11 Dec 2025 06:35:24 -0800
Subject: [PATCH] first commit
Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 ++
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 18 ++++++++++
.../resolve-shaped-type-result-dims.mlir | 33 +++++++++++++++++++
3 files changed, 53 insertions(+)
create mode 100644 mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 56160d3e8fe85..572b937ee9317 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -372,6 +372,8 @@ def AMDGPU_FatRawBufferCastOp :
AMDGPU_Op<"fat_raw_buffer_cast",
[Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>,
ViewLikeOpInterface, AttrSizedOperandSegments]>,
Arguments<(ins AnyMemRef:$source,
Optional<I64>:$validBytes,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index b7a665b0f5367..d3c3087da8b33 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
@@ -146,6 +147,23 @@ LogicalResult FatRawBufferCastOp::inferReturnTypes(
return success();
}
+LogicalResult FatRawBufferCastOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ Value source = getSource();
+ auto sourceType = cast<MemRefType>(source.getType());
+ Location loc = getLoc();
+ SmallVector<OpFoldResult> shapes;
+ for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ if (sourceType.isDynamicDim(i)) {
+ shapes.push_back(builder.createOrFold<memref::DimOp>(loc, source, i));
+ } else {
+ shapes.push_back(builder.getIndexAttr(sourceType.getDimSize(i)));
+ }
+ }
+ reifiedReturnShapes.push_back(std::move(shapes));
+ return success();
+}
+
LogicalResult FatRawBufferCastOp::verify() {
FailureOr<MemRefType> expectedResultType =
getFatRawBufferTypeLike(getSource().getType(), getResetOffset());
diff --git a/mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir
new file mode 100644
index 0000000000000..6c7351a8923f3
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/resolve-shaped-type-result-dims.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
+
+func.func @fat_raw_buffer_cast_static_dim(%arg0: memref<2x3xf32>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 : memref<2x3xf32>
+ to memref<2x3xf32, #amdgpu.address_space<fat_raw_buffer>>
+ %d0 = memref.dim %cast, %c0 : memref<2x3xf32, #amdgpu.address_space<fat_raw_buffer>>
+ %d1 = memref.dim %cast, %c1 : memref<2x3xf32, #amdgpu.address_space<fat_raw_buffer>>
+ return %d0, %d1 : index, index
+}
+// CHECK: func @fat_raw_buffer_cast_static_dim
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK: return %[[C2]], %[[C3]]
+
+// -----
+
+func.func @fat_raw_buffer_cast_dynamic_dim(%arg0: memref<4x?xf32>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cast = amdgpu.fat_raw_buffer_cast %arg0 : memref<4x?xf32>
+ to memref<4x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+ %d0 = memref.dim %cast, %c0 : memref<4x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+ %d1 = memref.dim %cast, %c1 : memref<4x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+ return %d0, %d1 : index, index
+}
+// CHECK: func @fat_raw_buffer_cast_dynamic_dim
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x?xf32>
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
+// CHECK: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: return %[[C4]], %[[D1]]
More information about the Mlir-commits
mailing list