[Mlir-commits] [mlir] 7aabf47 - [mlir][affine] Add pass --affine-raise-from-memref (#138004)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 6 00:17:00 PDT 2025


Author: Clément Fournier
Date: 2025-05-06T09:16:57+02:00
New Revision: 7aabf47522625e227433cc9603e0b6858c5dd66d

URL: https://github.com/llvm/llvm-project/commit/7aabf47522625e227433cc9603e0b6858c5dd66d
DIFF: https://github.com/llvm/llvm-project/commit/7aabf47522625e227433cc9603e0b6858c5dd66d.diff

LOG: [mlir][affine] Add pass --affine-raise-from-memref (#138004)

This adds a pass that converts memref.load/store into affine.load/store.
This is useful as those memref operators are ignored by passes like
--affine-scalrep as they don't implement the
Affine[Read/Write]OpInterface. Doing this allows you to put as much of
your program in affine form before you apply affine optimization passes.

This also slightly changes the implementation of affine::isValidDim. The
previous implementation allowed values from the iter_args of affine
loops to be used as valid dims. I think this doesn't make sense and what
was meant is just the induction vars. In the real world, there is little
reason to find an index in the iter_args, but I wrote that in my tests
and found out it was treated as an affine dim, so corrected that.

Co-authored-by: Oleksandr "Alex" Zinenko <git at ozinenko.com>

Rebased from #114032.

Added: 
    mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
    mlir/test/Dialect/Affine/raise-memref.mlir

Modified: 
    mlir/include/mlir/Dialect/Affine/Passes.h
    mlir/include/mlir/Dialect/Affine/Passes.td
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index 96bd3c6a9a7bc..2f70f24dd3ef2 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -23,6 +23,9 @@ namespace mlir {
 namespace func {
 class FuncOp;
 } // namespace func
+namespace memref {
+class MemRefDialect;
+} // namespace memref
 
 namespace affine {
 class AffineForOp;
@@ -45,6 +48,13 @@ createSimplifyAffineStructuresPass();
 std::unique_ptr<OperationPass<func::FuncOp>>
 createAffineLoopInvariantCodeMotionPass();
 
+/// Creates a pass to convert all parallel affine.for's into 1-d affine.parallel
+/// ops.
+std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();
+
+/// Creates a pass that converts some memref operators to affine operators.
+std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();
+
 /// Apply normalization transformations to affine loop-like ops. If
 /// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
 /// loop is replaced by its loop body).

diff  --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 0b8d5b7d94861..67f9138589c47 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -396,6 +396,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
   let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
 }
 
+def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
+  let summary = "Turn some memref operators to affine operators where supported";
+  let description = [{
+    Raise memref.load and memref.store to affine.store and affine.load, inferring
+    the affine map of those operators if needed. This allows passes like --affine-scalrep
+    to optimize those loads and stores (forwarding them or eliminating them).
+    They can be turned back to memref dialect ops with --lower-affine.
+  }];
+  let constructor = "mlir::affine::createRaiseMemrefToAffine()";
+  let dependentDialects = ["affine::AffineDialect"];
+}
+
 def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
   let summary = "Simplify affine expressions in maps/sets and normalize "
                 "memrefs";

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index eb23403a68813..65f85444e70db 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -294,10 +294,12 @@ bool mlir::affine::isValidDim(Value value) {
     return isValidDim(value, getAffineScope(defOp));
 
   // This value has to be a block argument for an op that has the
-  // `AffineScope` trait or for an affine.for or affine.parallel.
+  // `AffineScope` trait or an induction var of an affine.for or
+  // affine.parallel.
+  if (isAffineInductionVar(value))
+    return true;
   auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
-  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
-                      isa<AffineForOp, AffineParallelOp>(parentOp));
+  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
 }
 
 // Value can be used as a dimension id iff it meets one of the following
@@ -316,10 +318,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
 
   auto *op = value.getDefiningOp();
   if (!op) {
-    // This value has to be a block argument for an affine.for or an
+    // This value has to be an induction var for an affine.for or an
     // affine.parallel.
-    auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
-    return isa<AffineForOp, AffineParallelOp>(parentOp);
+    return isAffineInductionVar(value);
   }
 
   // Affine apply operation is ok if all of its operands are ok.

diff  --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index c42789b01bc9f..1c82822b2bd7f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
   LoopUnroll.cpp
   LoopUnrollAndJam.cpp
   PipelineDataTransfer.cpp
+  RaiseMemrefDialect.cpp
   ReifyValueBounds.cpp
   SuperVectorize.cpp
   SimplifyAffineStructures.cpp

diff  --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
new file mode 100644
index 0000000000000..491d2e03c36bc
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
@@ -0,0 +1,187 @@
+//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements functionality to convert memref load and store ops to
+// the corresponding affine ops, inferring the affine map as needed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Analysis/Utils.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_RAISEMEMREFDIALECT
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+#define DEBUG_TYPE "raise-memref-to-affine"
+
+using namespace mlir;
+using namespace mlir::affine;
+
+namespace {
+
+/// Find the index of the given value in the `dims` list,
+/// and append it if it was not already in the list. The
+/// dims list is a list of symbols or dimensions of the
+/// affine map. Within the results of an affine map, they
+/// are identified by their index, which is why we need
+/// this function.
+static std::optional<size_t>
+findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
+                function_ref<bool(Value)> isValidElement) {
+
+  Value *loopIV = std::find(dims.begin(), dims.end(), value);
+  if (loopIV != dims.end()) {
+    // We found an IV that already has an index, return that index.
+    return {std::distance(dims.begin(), loopIV)};
+  }
+  if (isValidElement(value)) {
+    // This is a valid element for the dim/symbol list, push this as a
+    // parameter.
+    size_t idx = dims.size();
+    dims.push_back(value);
+    return idx;
+  }
+  return std::nullopt;
+}
+
+/// Convert a value to an affine expr if possible. Adds dims and symbols
+/// if needed.
+static AffineExpr toAffineExpr(Value value,
+                               llvm::SmallVectorImpl<Value> &affineDims,
+                               llvm::SmallVectorImpl<Value> &affineSymbols) {
+  using namespace matchers;
+  IntegerAttr::ValueType cst;
+  if (matchPattern(value, m_ConstantInt(&cst))) {
+    return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
+  }
+
+  Operation *definingOp = value.getDefiningOp();
+  if (llvm::isa_and_nonnull<arith::AddIOp>(definingOp) ||
+      llvm::isa_and_nonnull<arith::MulIOp>(definingOp)) {
+    // TODO: replace recursion with explicit stack.
+    // For the moment this can be tolerated as we only recurse on
+    // arith.addi and arith.muli, so there cannot be any infinite
+    // recursion. The depth of these expressions should be in most
+    // cases very manageable, as affine expressions should be as
+    // simple as `a + b * c`.
+    AffineExpr lhsE =
+        toAffineExpr(definingOp->getOperand(0), affineDims, affineSymbols);
+    AffineExpr rhsE =
+        toAffineExpr(definingOp->getOperand(1), affineDims, affineSymbols);
+
+    if (lhsE && rhsE) {
+      AffineExprKind kind;
+      if (isa<arith::AddIOp>(definingOp)) {
+        kind = mlir::AffineExprKind::Add;
+      } else {
+        kind = mlir::AffineExprKind::Mul;
+
+        if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) {
+          // This is not an affine expression, give up.
+          return {};
+        }
+      }
+      return getAffineBinaryOpExpr(kind, lhsE, rhsE);
+    }
+    return {};
+  }
+
+  if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
+        return affine::isValidSymbol(v);
+      })) {
+    return getAffineSymbolExpr(*dimIx, value.getContext());
+  }
+
+  if (auto dimIx = findInListOrAdd(
+          value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
+
+    return getAffineDimExpr(*dimIx, value.getContext());
+  }
+
+  return {};
+}
+
+static LogicalResult
+computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
+                        llvm::SmallVectorImpl<Value> &mapArgs) {
+  SmallVector<AffineExpr> results;
+  SmallVector<Value> symbols;
+  SmallVector<Value> dims;
+
+  for (Value indexExpr : indices) {
+    AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
+    if (!res) {
+      return failure();
+    }
+    results.push_back(res);
+  }
+
+  map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
+
+  dims.append(symbols);
+  mapArgs.swap(dims);
+  return success();
+}
+
+struct RaiseMemrefDialect
+    : public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
+
+  void runOnOperation() override {
+    auto *ctx = &getContext();
+    Operation *op = getOperation();
+    IRRewriter rewriter(ctx);
+    AffineMap map;
+    SmallVector<Value> mapArgs;
+    op->walk([&](Operation *op) {
+      rewriter.setInsertionPoint(op);
+      if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
+
+        if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
+                                              mapArgs))) {
+          rewriter.replaceOpWithNewOp<AffineStoreOp>(
+              op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
+          return;
+        }
+
+        LLVM_DEBUG(llvm::dbgs()
+                   << "[affine] Cannot raise memref op: " << op << "\n");
+
+      } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
+        if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
+                                              mapArgs))) {
+          rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
+                                                    mapArgs);
+          return;
+        }
+        LLVM_DEBUG(llvm::dbgs()
+                   << "[affine] Cannot raise memref op: " << op << "\n");
+      }
+    });
+  }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::affine::createRaiseMemrefToAffine() {
+  return std::make_unique<RaiseMemrefDialect>();
+}

diff  --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir
new file mode 100644
index 0000000000000..8dc24d99db3e2
--- /dev/null
+++ b/mlir/test/Dialect/Affine/raise-memref.mlir
@@ -0,0 +1,138 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s
+
+// CHECK-LABEL:    func @reduce_window_max(
+func.func @reduce_window_max() {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = memref.alloc() : memref<1x8x8x64xf32>
+  %1 = memref.alloc() : memref<1x18x18x64xf32>
+  affine.for %arg0 = 0 to 1 {
+    affine.for %arg1 = 0 to 8 {
+      affine.for %arg2 = 0 to 8 {
+        affine.for %arg3 = 0 to 64 {
+          memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+        }
+      }
+    }
+  }
+  affine.for %arg0 = 0 to 1 {
+    affine.for %arg1 = 0 to 8 {
+      affine.for %arg2 = 0 to 8 {
+        affine.for %arg3 = 0 to 64 {
+          affine.for %arg4 = 0 to 1 {
+            affine.for %arg5 = 0 to 3 {
+              affine.for %arg6 = 0 to 3 {
+                affine.for %arg7 = 0 to 1 {
+                  %2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+                  %21 = arith.addi %arg0, %arg4 : index
+                  %22 = arith.constant 2 : index
+                  %23 = arith.muli %arg1, %22 : index
+                  %24 = arith.addi %23, %arg5 : index
+                  %25 = arith.muli %arg2, %22 : index
+                  %26 = arith.addi %25, %arg6 : index
+                  %27 = arith.addi %arg3, %arg7 : index
+                  %3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32>
+                  %4 = arith.cmpf ogt, %2, %3 : f32
+                  %5 = arith.select %4, %2, %3 : f32
+                  memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return
+}
+
+// CHECK:        %[[cst:.*]] = arith.constant 0
+// CHECK:        %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
+// CHECK:        %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
+// CHECK:        affine.for %[[arg0:.*]] =
+// CHECK:          affine.for %[[arg1:.*]] =
+// CHECK:            affine.for %[[arg2:.*]] =
+// CHECK:              affine.for %[[arg3:.*]] =
+// CHECK:                affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] :
+// CHECK:        affine.for %[[a0:.*]] =
+// CHECK:          affine.for %[[a1:.*]] =
+// CHECK:            affine.for %[[a2:.*]] =
+// CHECK:              affine.for %[[a3:.*]] =
+// CHECK:                affine.for %[[a4:.*]] =
+// CHECK:                  affine.for %[[a5:.*]] =
+// CHECK:                    affine.for %[[a6:.*]] =
+// CHECK:                      affine.for %[[a7:.*]] =
+// CHECK:                        %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
+// CHECK:                        %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] :
+// CHECK:                        %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
+// CHECK:                        %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
+// CHECK:                        affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
+
+// CHECK-LABEL:    func @symbols(
+func.func @symbols(%N : index) {
+  %0 = memref.alloc() : memref<1024x1024xf32>
+  %1 = memref.alloc() : memref<1024x1024xf32>
+  %2 = memref.alloc() : memref<1024x1024xf32>
+  %cst1 = arith.constant 1 : index
+  %cst2 = arith.constant 2 : index
+  affine.for %i = 0 to %N {
+    affine.for %j = 0 to %N {
+      %7 = memref.load %2[%i, %j] : memref<1024x1024xf32>
+      %10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index {
+        %12 = arith.muli %N, %cst2 : index
+        %13 = arith.addi %12, %cst1 : index
+        %14 = arith.addi %13, %j : index
+        %5 = memref.load %0[%i, %12] : memref<1024x1024xf32>
+        %6 = memref.load %1[%14, %j] : memref<1024x1024xf32>
+        %8 = arith.mulf %5, %6 : f32
+        %9 = arith.addf %7, %8 : f32
+        %4 = arith.addi %N, %cst1 : index
+        %11 = arith.addi %ax, %cst1 : index
+        memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
+        memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised
+        %something = "ab.v"() : () -> index
+        memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be raised
+        affine.yield %11 : index
+      }
+    }
+  }
+  return
+}
+
+// CHECK:          %[[cst1:.*]] = arith.constant 1 : index
+// CHECK:          %[[v0:.*]] = memref.alloc() : memref<
+// CHECK:          %[[v1:.*]] = memref.alloc() : memref<
+// CHECK:          %[[v2:.*]] = memref.alloc() : memref<
+// CHECK:          affine.for %[[a1:.*]] = 0 to %arg0 {
+// CHECK:             affine.for %[[a2:.*]] = 0 to %arg0 {
+// CHECK:                %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
+// CHECK:                affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
+// CHECK:                  %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] :
+// CHECK:                  %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] :
+// CHECK:                  %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
+// CHECK:                  %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
+// CHECK:                  %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
+// CHECK:                  affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] :
+// CHECK:                  memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] :
+// CHECK:                  %[[lhs7:.*]] = "ab.v"
+// CHECK:                  memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] :
+// CHECK:                  affine.yield %[[lhs6]]
+
+
+// CHECK-LABEL:    func @non_affine(
+func.func @non_affine(%N : index) {
+  %2 = memref.alloc() : memref<1024x1024xf32>
+  affine.for %i = 0 to %N {
+    affine.for %j = 0 to %N {
+      %ij = arith.muli %i, %j : index
+      %7 = memref.load %2[%i, %ij] : memref<1024x1024xf32>
+      memref.store %7, %2[%ij, %ij] : memref<1024x1024xf32>
+    }
+  }
+  return
+}
+
+// CHECK:          affine.for %[[i:.*]] =
+// CHECK:             affine.for %[[j:.*]] =
+// CHECK:                  %[[ij:.*]] = arith.muli %[[i]], %[[j]]
+// CHECK:                  %[[v:.*]] = memref.load %{{.*}}[%[[i]], %[[ij]]]
+// CHECK:                  memref.store %[[v]], %{{.*}}[%[[ij]], %[[ij]]]


        


More information about the Mlir-commits mailing list