[Mlir-commits] [mlir] [MemRef] Add dim reification for AssumeAlignmentOp (PR #174477)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 5 12:43:34 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Jorn Tuyls (jtuyls)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/174477.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+3-1) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+11) 
- (modified) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+35) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0bf22928f6900..45122788bd2d4 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -149,7 +149,9 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       Pure,
       ViewLikeOpInterface,
       SameOperandsAndResultType,
-      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+                                ["reifyDimOfResult"]>
     ]> {
   let summary =
       "assumption that gives alignment information to the input memref";
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 7bc6ae5f21e8b..24089f4370c8a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -606,6 +606,17 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
 }
 
+FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
+                                                            int resultIndex,
+                                                            int dim) {
+  assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
+  Value source = getMemref();
+  auto sourceType = cast<MemRefType>(source.getType());
+  if (sourceType.isDynamicDim(dim))
+    return OpFoldResult(builder.createOrFold<DimOp>(getLoc(), source, dim));
+  return OpFoldResult(builder.getIndexAttr(sourceType.getDimSize(dim)));
+}
+
 //===----------------------------------------------------------------------===//
 // DistinctObjectsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index e354eb91d7557..374e47fb34b48 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -97,3 +97,38 @@ func.func @iter_to_init_arg_loop_like(
   }
   return %result : tensor<?x?xf32>
 }
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with static dims
+// CHECK-LABEL: func @dim_of_assume_alignment_static(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<2x3xf32>
+//  CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+//  CHECK-DAG:    %[[C3:.*]] = arith.constant 3 : index
+//       CHECK:   return %[[C2]], %[[C3]] : index, index
+func.func @dim_of_assume_alignment_static(%arg0: memref<2x3xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.assume_alignment %arg0, 64 : memref<2x3xf32>
+  %d0 = memref.dim %0, %c0 : memref<2x3xf32>
+  %d1 = memref.dim %0, %c1 : memref<2x3xf32>
+  return %d0, %d1 : index, index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with dynamic dims
+// CHECK-LABEL: func @dim_of_assume_alignment_dynamic(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<4x?xf32>
+//  CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[D1:.*]] = memref.dim %[[MEM]], %[[C1]]
+//       CHECK:   return %[[C4]], %[[D1]] : index, index
+func.func @dim_of_assume_alignment_dynamic(%arg0: memref<4x?xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.assume_alignment %arg0, 64 : memref<4x?xf32>
+  %d0 = memref.dim %0, %c0 : memref<4x?xf32>
+  %d1 = memref.dim %0, %c1 : memref<4x?xf32>
+  return %d0, %d1 : index, index
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/174477


More information about the Mlir-commits mailing list