[Mlir-commits] [mlir] [MLIR] normalize-memrefs: Normalize memref.alloca (PR #123293)

Matthias Gehre llvmlistbot at llvm.org
Thu Jan 16 23:29:49 PST 2025


https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/123293

The pass was only handling `memref.alloc` and this extends it to also handle `memref.alloca`.

>From 3f89bdf0f5c05927862cd3280578f3070844d8e9 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Tue, 1 Oct 2024 21:37:04 +0200
Subject: [PATCH] normalize-memrefs: Normalize memref.alloca

The pass was only handling memref.alloc,
and this extends it to also handle memref.alloca.
---
 mlir/include/mlir/Dialect/Affine/Utils.h      |  8 ++++-
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       | 21 ++++++++----
 .../MemRef/Transforms/NormalizeMemRefs.cpp    | 34 ++++++++++++++-----
 .../Dialect/MemRef/normalize-memrefs.mlir     | 14 ++++++++
 4 files changed, 60 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 0f801ebb6f5898..ff1900bc8f2ebc 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -31,6 +31,7 @@ class FuncOp;
 
 namespace memref {
 class AllocOp;
+class AllocaOp;
 } // namespace memref
 
 namespace affine {
@@ -245,7 +246,12 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef,
 /// Rewrites the memref defined by this alloc op to have an identity layout map
 /// and updates all its indexing uses. Returns failure if any of its uses
 /// escape (while leaving the IR in a valid state).
-LogicalResult normalizeMemRef(memref::AllocOp *op);
+template <typename AllocLikeOp>
+LogicalResult normalizeMemRef(AllocLikeOp *op);
+extern template LogicalResult
+normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
+extern template LogicalResult
+normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
 
 /// Normalizes `memrefType` so that the affine layout map of the memref is
 /// transformed to an identity map with a new shape being computed for the
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 9e3257a62b12fb..7ef016f88be375 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1737,9 +1737,10 @@ static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput,
 ///   %c4 = arith.constant 4 : index
 ///   %1 = affine.apply #map1(%c4, %0)
 ///   %2 = affine.apply #map2(%c4, %0)
+template <typename AllocLikeOp>
 static void createNewDynamicSizes(MemRefType oldMemRefType,
                                   MemRefType newMemRefType, AffineMap map,
-                                  memref::AllocOp *allocOp, OpBuilder b,
+                                  AllocLikeOp *allocOp, OpBuilder b,
                                   SmallVectorImpl<Value> &newDynamicSizes) {
   // Create new input for AffineApplyOp.
   SmallVector<Value, 4> inAffineApply;
@@ -1786,7 +1787,8 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
 }
 
 // TODO: Currently works for static memrefs with a single layout map.
-LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
+template <typename AllocLikeOp>
+LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) {
   MemRefType memrefType = allocOp->getType();
   OpBuilder b(*allocOp);
 
@@ -1802,7 +1804,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
 
   SmallVector<Value, 4> symbolOperands(allocOp->getSymbolOperands());
   AffineMap layoutMap = memrefType.getLayout().getAffineMap();
-  memref::AllocOp newAlloc;
+  AllocLikeOp newAlloc;
   // Check if `layoutMap` is a tiled layout. Only single layout map is
   // supported for normalizing dynamic memrefs.
   SmallVector<std::tuple<AffineExpr, unsigned, unsigned>> tileSizePos;
@@ -1814,11 +1816,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
                           newDynamicSizes);
     // Add the new dynamic sizes in new AllocOp.
     newAlloc =
-        b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
-                                  newDynamicSizes, allocOp->getAlignmentAttr());
+        b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType, newDynamicSizes,
+                              allocOp->getAlignmentAttr());
   } else {
-    newAlloc = b.create<memref::AllocOp>(allocOp->getLoc(), newMemRefType,
-                                         allocOp->getAlignmentAttr());
+    newAlloc = b.create<AllocLikeOp>(allocOp->getLoc(), newMemRefType,
+                                     allocOp->getAlignmentAttr());
   }
   // Replace all uses of the old memref.
   if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc,
@@ -1843,6 +1845,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) {
   return success();
 }
 
+template LogicalResult
+mlir::affine::normalizeMemRef<memref::AllocaOp>(memref::AllocaOp *op);
+template LogicalResult
+mlir::affine::normalizeMemRef<memref::AllocOp>(memref::AllocOp *op);
+
 MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
   unsigned rank = memrefType.getRank();
   if (rank == 0)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
index 33772ccb7dd9d3..08b853fe65b857 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp
@@ -151,11 +151,11 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
   });
 }
 
-/// Check whether all the uses of AllocOps, CallOps and function arguments of a
-/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
-/// or ReturnOp. Only if these constraints are satisfied will the function
-/// become a candidate for normalization. When the uses of a memref are
-/// non-normalizable and the memref map layout is trivial (identity), we can
+/// Check whether all the uses of AllocOps, AllocaOps, CallOps and function
+/// arguments of a function are either of dereferencing type or are uses in:
+/// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will
+/// the function become a candidate for normalization. When the uses of a memref
+/// are non-normalizable and the memref map layout is trivial (identity), we can
 /// still label the entire function as normalizable. We assume external
 /// functions to be normalizable.
 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
@@ -174,6 +174,17 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
           .wasInterrupted())
     return false;
 
+  if (funcOp
+          .walk([&](memref::AllocaOp allocaOp) -> WalkResult {
+            Value oldMemRef = allocaOp.getResult();
+            if (!allocaOp.getType().getLayout().isIdentity() &&
+                !isMemRefNormalizable(oldMemRef.getUsers()))
+              return WalkResult::interrupt();
+            return WalkResult::advance();
+          })
+          .wasInterrupted())
+    return false;
+
   if (funcOp
           .walk([&](func::CallOp callOp) -> WalkResult {
             for (unsigned resIndex :
@@ -335,18 +346,23 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
 }
 
 /// Normalizes the memrefs within a function which includes those arising as a
-/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
-/// is used to help update function's signature after normalization.
+/// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp
+/// argument is used to help update function's signature after normalization.
 void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
                                               ModuleOp moduleOp) {
   // Turn memrefs' non-identity layouts maps into ones with identity. Collect
-  // alloc ops first and then process since normalizeMemRef replaces/erases ops
-  // during memref rewriting.
+  // alloc/alloca ops first and then process since normalizeMemRef
+  // replaces/erases ops during memref rewriting.
   SmallVector<memref::AllocOp, 4> allocOps;
   funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
   for (memref::AllocOp allocOp : allocOps)
     (void)normalizeMemRef(&allocOp);
 
+  SmallVector<memref::AllocaOp> allocaOps;
+  funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); });
+  for (memref::AllocaOp allocaOp : allocaOps)
+    (void)normalizeMemRef(&allocaOp);
+
   // We use this OpBuilder to create new memref layout later.
   OpBuilder b(funcOp);
 
diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
index 6d20ccbf2ca055..e93a1a4ebae532 100644
--- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
+++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir
@@ -31,6 +31,20 @@ func.func @permute() {
 // CHECK-NEXT: memref.dealloc [[MEM]]
 // CHECK-NEXT: return
 
+// CHECK-LABEL: func @alloca
+func.func @alloca(%idx : index) {
+  // CHECK-NEXT: memref.alloca() : memref<65xf32>
+  %A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
+  // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
+  affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
+  affine.for %i = 0 to 64 {
+    %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>>
+    "prevent.dce"(%1) : (f32) -> ()
+    // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
+  }
+  return
+}
+
 // CHECK-LABEL: func @shift
 func.func @shift(%idx : index) {
   // CHECK-NEXT: memref.alloc() : memref<65xf32>



More information about the Mlir-commits mailing list