[Mlir-commits] [mlir] [MemRef] Add dim reification for AssumeAlignmentOp (PR #174477)

Jorn Tuyls llvmlistbot at llvm.org
Tue Jan 6 00:11:58 PST 2026


https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/174477

>From 19b9245e80c07eb4027f8d3de3684cbbb2416e83 Mon Sep 17 00:00:00 2001
From: Jorn Tuyls <jorn.tuyls at gmail.com>
Date: Mon, 5 Jan 2026 14:40:05 -0600
Subject: [PATCH] [MemRef] Add dim reification for AssumeAlignmentOp

---
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  4 ++-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  7 ++++
 mlir/test/Dialect/MemRef/resolve-dim-ops.mlir | 35 +++++++++++++++++++
 3 files changed, 45 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 0bf22928f6900..45122788bd2d4 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -149,7 +149,9 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment", [
       Pure,
       ViewLikeOpInterface,
       SameOperandsAndResultType,
-      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>
+      DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
+      DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+                                ["reifyDimOfResult"]>
     ]> {
   let summary =
       "assumption that gives alignment information to the input memref";
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 7bc6ae5f21e8b..c2e71669362c0 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -606,6 +606,13 @@ AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
   return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
 }
 
+FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
+                                                            int resultIndex,
+                                                            int dim) {
+  assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
+  return getMixedSize(builder, getLoc(), getMemref(), dim);
+}
+
 //===----------------------------------------------------------------------===//
 // DistinctObjectsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
index e354eb91d7557..374e47fb34b48 100644
--- a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -97,3 +97,38 @@ func.func @iter_to_init_arg_loop_like(
   }
   return %result : tensor<?x?xf32>
 }
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with static dims
+// CHECK-LABEL: func @dim_of_assume_alignment_static(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<2x3xf32>
+//  CHECK-DAG:    %[[C2:.*]] = arith.constant 2 : index
+//  CHECK-DAG:    %[[C3:.*]] = arith.constant 3 : index
+//       CHECK:   return %[[C2]], %[[C3]] : index, index
+func.func @dim_of_assume_alignment_static(%arg0: memref<2x3xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.assume_alignment %arg0, 64 : memref<2x3xf32>
+  %d0 = memref.dim %0, %c0 : memref<2x3xf32>
+  %d1 = memref.dim %0, %c1 : memref<2x3xf32>
+  return %d0, %d1 : index, index
+}
+
+// -----
+
+// Test case: Folding of memref.dim(memref.assume_alignment) with dynamic dims
+// CHECK-LABEL: func @dim_of_assume_alignment_dynamic(
+//  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: memref<4x?xf32>
+//  CHECK-DAG:    %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:    %[[C4:.*]] = arith.constant 4 : index
+//       CHECK:   %[[D1:.*]] = memref.dim %[[MEM]], %[[C1]]
+//       CHECK:   return %[[C4]], %[[D1]] : index, index
+func.func @dim_of_assume_alignment_dynamic(%arg0: memref<4x?xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = memref.assume_alignment %arg0, 64 : memref<4x?xf32>
+  %d0 = memref.dim %0, %c0 : memref<4x?xf32>
+  %d1 = memref.dim %0, %c1 : memref<4x?xf32>
+  return %d0, %d1 : index, index
+}



More information about the Mlir-commits mailing list