[Mlir-commits] [mlir] [MemRef] Add dim reification for AssumeAlignmentOp (PR #174477)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 5 12:43:34 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Jorn Tuyls (jtuyls)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/174477.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+3-1)
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+11)
- (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+35)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0bf22928f6900..45122788bd2d4 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -149,7 +149,9 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
Pure,
ViewLikeOpInterface,
SameOperandsAndResultType,
- DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
+ DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyDimOfResult"]>
]> {
let summary =
"assumption that gives alignment information to the input memref";
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 7bc6ae5f21e8b..24089f4370c8a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -606,6 +606,17 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
}
+FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
+ int resultIndex,
+ int dim) {
+ assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
+ Value source = getMemref();
+ auto sourceType = cast<MemRefType>(source.getType());
+ if (sourceType.isDynamicDim(dim))
+ return OpFoldResult(builder.createOrFold<DimOp>(getLoc(), source, dim));
+ return OpFoldResult(builder.getIndexAttr(sourceType.getDimSize(dim)));
+}
+
//===----------------------------------------------------------------------===//
// DistinctObjectsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index e354eb91d7557..374e47fb34b48 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -97,3 +97,38 @@ func.func @iter_to_init_arg_loop_like(
}
return %result : tensor<?x?xf32>
}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with static dims
+// CHECK-LABEL: func @dim_of_assume_alignment_static(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<2x3xf32>
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK: return %[[C2]], %[[C3]] : index, index
+func.func @dim_of_assume_alignment_static(%arg0: memref<2x3xf32>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = memref.assume_alignment %arg0, 64 : memref<2x3xf32>
+ %d0 = memref.dim %0, %c0 : memref<2x3xf32>
+ %d1 = memref.dim %0, %c1 : memref<2x3xf32>
+ return %d0, %d1 : index, index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with dynamic dims
+// CHECK-LABEL: func @dim_of_assume_alignment_dynamic(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<4x?xf32>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[D1:.*]] = memref.dim %[[MEM]], %[[C1]]
+// CHECK: return %[[C4]], %[[D1]] : index, index
+func.func @dim_of_assume_alignment_dynamic(%arg0: memref<4x?xf32>) -> (index, index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = memref.assume_alignment %arg0, 64 : memref<4x?xf32>
+ %d0 = memref.dim %0, %c0 : memref<4x?xf32>
+ %d1 = memref.dim %0, %c1 : memref<4x?xf32>
+ return %d0, %d1 : index, index
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/174477
More information about the Mlir-commits
mailing list