[Mlir-commits] [mlir] [mlir][affine] Add pass --affine-raise-from-memref (PR #114032)
Clément Fournier
llvmlistbot at llvm.org
Mon Dec 2 06:28:37 PST 2024
https://github.com/oowekyala updated https://github.com/llvm/llvm-project/pull/114032
>From 6127ad98805773749b774d67b4bbdc1d09e77d20 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Mon, 28 Oct 2024 16:51:26 +0100
Subject: [PATCH 1/2] Add --affine-raise-from-memref
Restrict isValidDim to induction vars, and not iter_args
---
mlir/include/mlir/Dialect/Affine/Passes.h | 6 +
mlir/include/mlir/Dialect/Affine/Passes.td | 12 ++
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 13 +-
.../Dialect/Affine/Transforms/CMakeLists.txt | 1 +
.../Affine/Transforms/DecomposeAffineOps.cpp | 11 ++
.../Affine/Transforms/RaiseMemrefDialect.cpp | 168 ++++++++++++++++++
mlir/test/Dialect/Affine/raise-memref.mlir | 130 ++++++++++++++
7 files changed, 335 insertions(+), 6 deletions(-)
create mode 100644 mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
create mode 100644 mlir/test/Dialect/Affine/raise-memref.mlir
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index e152101236dc7a..c1b9c30d302dd0 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -22,6 +22,9 @@ namespace mlir {
namespace func {
class FuncOp;
} // namespace func
+namespace memref {
+class MemRefDialect;
+} // namespace memref
namespace affine {
class AffineForOp;
@@ -48,6 +51,9 @@ createAffineLoopInvariantCodeMotionPass();
/// ops.
std::unique_ptr<OperationPass<func::FuncOp>> createAffineParallelizePass();
+/// Creates a pass that converts some memref operators to affine operators.
+std::unique_ptr<OperationPass<func::FuncOp>> createRaiseMemrefToAffine();
+
/// Apply normalization transformations to affine loop-like ops. If
/// `promoteSingleIter` is true, single iteration loops are promoted (i.e., the
/// loop is replaced by its loop body).
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 77073aa29da73e..43ce9dae93441a 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -397,6 +397,18 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
}
+def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
+ let summary = "Turn some memref operators to affine operators where supported";
+ let description = [{
+ Raise memref.load and memref.store to affine.store and affine.load, inferring
+ the affine map of those operators if needed. This allows passes like --affine-scalrep
+ to optimize those loads and stores (forwarding them or eliminating them).
+ They can be turned back to memref dialect ops with --lower-affine.
+ }];
+ let constructor = "mlir::affine::createRaiseMemrefToAffine()";
+ let dependentDialects = ["memref::MemRefDialect"];
+}
+
def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
let summary = "Simplify affine expressions in maps/sets and normalize "
"memrefs";
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..42b7625925df81 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -284,10 +284,12 @@ bool mlir::affine::isValidDim(Value value) {
return isValidDim(value, getAffineScope(defOp));
// This value has to be a block argument for an op that has the
- // `AffineScope` trait or for an affine.for or affine.parallel.
+ // `AffineScope` trait or an induction var of an affine.for or
+ // affine.parallel.
+ if (isAffineInductionVar(value))
+ return true;
auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
- return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
- isa<AffineForOp, AffineParallelOp>(parentOp));
+ return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
}
// Value can be used as a dimension id iff it meets one of the following
@@ -306,10 +308,9 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
auto *op = value.getDefiningOp();
if (!op) {
- // This value has to be a block argument for an affine.for or an
+ // This value has to be an induction var for an affine.for or an
// affine.parallel.
- auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
- return isa<AffineForOp, AffineParallelOp>(parentOp);
+ return isAffineInductionVar(value);
}
// Affine apply operation is ok if all of its operands are ok.
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index c42789b01bc9fa..1c82822b2bd7f9 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
LoopUnroll.cpp
LoopUnrollAndJam.cpp
PipelineDataTransfer.cpp
+ RaiseMemrefDialect.cpp
ReifyValueBounds.cpp
SuperVectorize.cpp
SimplifyAffineStructures.cpp
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index f28fb3acb7db7f..4d5ff5765ccc96 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -13,9 +13,20 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
+#include <algorithm>
+#include <cstddef>
+#include <functional>
+#include <iterator>
using namespace mlir;
using namespace mlir::affine;
diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
new file mode 100644
index 00000000000000..2fd47549000001
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
@@ -0,0 +1,168 @@
+
+
+#include "mlir/Dialect/Affine/Analysis/Utils.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
+#include <algorithm>
+#include <cstddef>
+#include <functional>
+#include <iterator>
+#include <memory>
+#include <optional>
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_RAISEMEMREFDIALECT
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+#define DEBUG_TYPE "raise-memref-to-affine"
+
+using namespace mlir;
+using namespace mlir::affine;
+
+namespace {
+
+static std::optional<size_t>
+findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
+ const std::function<bool(Value)> &isValidElement) {
+
+ Value *loopIV = std::find(dims.begin(), dims.end(), value);
+ if (loopIV != dims.end()) {
+ // found an IV that already has an index
+ return {std::distance(dims.begin(), loopIV)};
+ }
+ if (isValidElement(value)) {
+ // push this IV in the parameters
+ size_t idx = dims.size();
+ dims.push_back(value);
+ return idx;
+ }
+ return std::nullopt;
+}
+
+static LogicalResult toAffineExpr(Value value, AffineExpr &result,
+ llvm::SmallVectorImpl<Value> &affineDims,
+ llvm::SmallVectorImpl<Value> &affineSymbols) {
+ using namespace matchers;
+ IntegerAttr::ValueType cst;
+ if (matchPattern(value, m_ConstantInt(&cst))) {
+ result = getAffineConstantExpr(cst.getSExtValue(), value.getContext());
+ return success();
+ }
+ Value lhs;
+ Value rhs;
+ if (matchPattern(value, m_Op<arith::AddIOp>(m_Any(&lhs), m_Any(&rhs))) ||
+ matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
+ AffineExpr lhsE;
+ AffineExpr rhsE;
+ if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) &&
+ succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) {
+ AffineExprKind kind;
+ if (isa<arith::AddIOp>(value.getDefiningOp())) {
+ kind = mlir::AffineExprKind::Add;
+ } else {
+ kind = mlir::AffineExprKind::Mul;
+ }
+ result = getAffineBinaryOpExpr(kind, lhsE, rhsE);
+ return success();
+ }
+ }
+
+ if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
+ return affine::isValidSymbol(v);
+ })) {
+ result = getAffineSymbolExpr(*dimIx, value.getContext());
+ return success();
+ }
+
+ if (auto dimIx = findInListOrAdd(
+ value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
+
+ result = getAffineDimExpr(*dimIx, value.getContext());
+ return success();
+ }
+
+ return failure();
+}
+
+static LogicalResult
+computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
+ llvm::SmallVectorImpl<Value> &mapArgs) {
+ llvm::SmallVector<AffineExpr> results;
+ llvm::SmallVector<Value, 2> symbols;
+ llvm::SmallVector<Value, 8> dims;
+
+ for (auto indexExpr : indices) {
+ if (failed(
+ toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) {
+ return failure();
+ }
+ }
+
+ map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
+
+ dims.append(symbols);
+ mapArgs.swap(dims);
+ return success();
+}
+
+struct RaiseMemrefDialect
+ : public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
+
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ Operation *op = getOperation();
+ IRRewriter rewriter(ctx);
+ AffineMap map;
+ SmallVector<Value> mapArgs;
+ op->walk([&](Operation *op) {
+ rewriter.setInsertionPoint(op);
+ if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
+
+ if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
+ mapArgs))) {
+ rewriter.replaceOpWithNewOp<AffineStoreOp>(
+ op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
+ } else {
+ LLVM_DEBUG(llvm::dbgs()
+ << "[affine] Cannot raise memref op: " << op << "\n");
+ }
+
+ } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
+
+ if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
+ mapArgs))) {
+ rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
+ mapArgs);
+ } else {
+ LLVM_DEBUG(llvm::dbgs()
+ << "[affine] Cannot raise memref op: " << op << "\n");
+ }
+ }
+ });
+ }
+};
+
+} // namespace
+
+std::unique_ptr<OperationPass<func::FuncOp>>
+mlir::affine::createRaiseMemrefToAffine() {
+ return std::make_unique<RaiseMemrefDialect>();
+}
diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir
new file mode 100644
index 00000000000000..d529e2c0c907a6
--- /dev/null
+++ b/mlir/test/Dialect/Affine/raise-memref.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @reduce_window_max() {
+func.func @reduce_window_max() {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = memref.alloc() : memref<1x8x8x64xf32>
+ %1 = memref.alloc() : memref<1x18x18x64xf32>
+ affine.for %arg0 = 0 to 1 {
+ affine.for %arg1 = 0 to 8 {
+ affine.for %arg2 = 0 to 8 {
+ affine.for %arg3 = 0 to 64 {
+ memref.store %cst, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+ }
+ }
+ }
+ }
+ affine.for %arg0 = 0 to 1 {
+ affine.for %arg1 = 0 to 8 {
+ affine.for %arg2 = 0 to 8 {
+ affine.for %arg3 = 0 to 64 {
+ affine.for %arg4 = 0 to 1 {
+ affine.for %arg5 = 0 to 3 {
+ affine.for %arg6 = 0 to 3 {
+ affine.for %arg7 = 0 to 1 {
+ %2 = memref.load %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+ %21 = arith.addi %arg0, %arg4 : index
+ %22 = arith.constant 2 : index
+ %23 = arith.muli %arg1, %22 : index
+ %24 = arith.addi %23, %arg5 : index
+ %25 = arith.muli %arg2, %22 : index
+ %26 = arith.addi %25, %arg6 : index
+ %27 = arith.addi %arg3, %arg7 : index
+ %3 = memref.load %1[%21, %24, %26, %27] : memref<1x18x18x64xf32>
+ %4 = arith.cmpf ogt, %2, %3 : f32
+ %5 = arith.select %4, %2, %3 : f32
+ memref.store %5, %0[%arg0, %arg1, %arg2, %arg3] : memref<1x8x8x64xf32>
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ return
+}
+
+// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
+// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
+// CHECK: affine.for %[[arg0:.*]] = 0 to 1 {
+// CHECK: affine.for %[[arg1:.*]] = 0 to 8 {
+// CHECK: affine.for %[[arg2:.*]] = 0 to 8 {
+// CHECK: affine.for %[[arg3:.*]] = 0 to 64 {
+// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: affine.for %[[a0:.*]] = 0 to 1 {
+// CHECK: affine.for %[[a1:.*]] = 0 to 8 {
+// CHECK: affine.for %[[a2:.*]] = 0 to 8 {
+// CHECK: affine.for %[[a3:.*]] = 0 to 64 {
+// CHECK: affine.for %[[a4:.*]] = 0 to 1 {
+// CHECK: affine.for %[[a5:.*]] = 0 to 3 {
+// CHECK: affine.for %[[a6:.*]] = 0 to 3 {
+// CHECK: affine.for %[[a7:.*]] = 0 to 1 {
+// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
+// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32>
+// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
+// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
+// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+func.func @symbols(%N : index) {
+ %0 = memref.alloc() : memref<1024x1024xf32>
+ %1 = memref.alloc() : memref<1024x1024xf32>
+ %2 = memref.alloc() : memref<1024x1024xf32>
+ %cst1 = arith.constant 1 : index
+ %cst2 = arith.constant 2 : index
+ affine.for %i = 0 to %N {
+ affine.for %j = 0 to %N {
+ %7 = memref.load %2[%i, %j] : memref<1024x1024xf32>
+ %10 = affine.for %k = 0 to %N iter_args(%ax = %cst1) -> index {
+ %12 = arith.muli %N, %cst2 : index
+ %13 = arith.addi %12, %cst1 : index
+ %14 = arith.addi %13, %j : index
+ %5 = memref.load %0[%i, %12] : memref<1024x1024xf32>
+ %6 = memref.load %1[%14, %j] : memref<1024x1024xf32>
+ %8 = arith.mulf %5, %6 : f32
+ %9 = arith.addf %7, %8 : f32
+ %4 = arith.addi %N, %cst1 : index
+ %11 = arith.addi %ax, %cst1 : index
+ memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
+ memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered
+ %something = "ab.v"() : () -> index
+ memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered
+ affine.yield %11 : index
+ }
+ }
+ }
+ return
+}
+
+// CHECK: %[[cst1:.*]] = arith.constant 1 : index
+// CHECK: %[[v0:.*]] = memref.alloc() : memref<
+// CHECK: %[[v1:.*]] = memref.alloc() : memref<
+// CHECK: %[[v2:.*]] = memref.alloc() : memref<
+// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 {
+// CHECK-NEXT: affine.for %[[a2:.*]] = 0 to %arg0 {
+// CHECK-NEXT: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
+// CHECK-NEXT: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
+// CHECK-NEXT: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32>
+// CHECK-NEXT: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32>
+// CHECK-NEXT: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
+// CHECK-NEXT: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
+// CHECK-NEXT: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
+// CHECK-NEXT: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32>
+// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32>
+// CHECK-NEXT: %[[lhs7:.*]] = "ab.v"
+// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32>
+// CHECK-NEXT: affine.yield %[[lhs6]]
>From 341cd98d5fd65062942a2d67791d395947f5a9e6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment=20Fournier?= <clement.fournier at tu-dresden.de>
Date: Fri, 29 Nov 2024 13:50:04 +0100
Subject: [PATCH 2/2] Address review comments
---
mlir/include/mlir/Dialect/Affine/Passes.td | 2 +-
.../Affine/Transforms/DecomposeAffineOps.cpp | 11 ---
.../Affine/Transforms/RaiseMemrefDialect.cpp | 92 ++++++++++---------
mlir/test/Dialect/Affine/raise-memref.mlir | 78 +++++++---------
4 files changed, 82 insertions(+), 101 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 43ce9dae93441a..a77bcac5ed407f 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -406,7 +406,7 @@ def RaiseMemrefDialect : Pass<"affine-raise-from-memref", "func::FuncOp"> {
They can be turned back to memref dialect ops with --lower-affine.
}];
let constructor = "mlir::affine::createRaiseMemrefToAffine()";
- let dependentDialects = ["memref::MemRefDialect"];
+ let dependentDialects = ["affine::AffineDialect"];
}
def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index 4d5ff5765ccc96..f28fb3acb7db7f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -13,20 +13,9 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Support/LogicalResult.h"
-#include <algorithm>
-#include <cstddef>
-#include <functional>
-#include <iterator>
using namespace mlir;
using namespace mlir::affine;
diff --git a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
index 2fd47549000001..a6e961a6d64390 100644
--- a/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp
@@ -1,29 +1,27 @@
-
+//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements functionality to convert memref load and store ops to
+// the corresponding affine ops, inferring the affine map as needed.
+//
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
-#include "llvm/Support/LogicalResult.h"
-#include <algorithm>
-#include <cstddef>
-#include <functional>
-#include <iterator>
-#include <memory>
-#include <optional>
namespace mlir {
namespace affine {
@@ -39,17 +37,24 @@ using namespace mlir::affine;
namespace {
+/// Find the index of the given value in the `dims` list,
+/// and append it if it was not already in the list. The
+/// dims list is a list of symbols or dimensions of the
+/// affine map. Within the results of an affine map, they
+/// are identified by their index, which is why we need
+/// this function.
static std::optional<size_t>
findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
- const std::function<bool(Value)> &isValidElement) {
+ function_ref<bool(Value)> isValidElement) {
Value *loopIV = std::find(dims.begin(), dims.end(), value);
if (loopIV != dims.end()) {
- // found an IV that already has an index
+ // We found an IV that already has an index, return that index.
return {std::distance(dims.begin(), loopIV)};
}
if (isValidElement(value)) {
- // push this IV in the parameters
+ // This is a valid element for the dim/symbol list, push this as a
+ // parameter.
size_t idx = dims.size();
dims.push_back(value);
return idx;
@@ -57,14 +62,15 @@ findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
return std::nullopt;
}
-static LogicalResult toAffineExpr(Value value, AffineExpr &result,
- llvm::SmallVectorImpl<Value> &affineDims,
- llvm::SmallVectorImpl<Value> &affineSymbols) {
+/// Convert a value to an affine expr if possible. Adds dims and symbols
+/// if needed.
+static AffineExpr toAffineExpr(Value value,
+ llvm::SmallVectorImpl<Value> &affineDims,
+ llvm::SmallVectorImpl<Value> &affineSymbols) {
using namespace matchers;
IntegerAttr::ValueType cst;
if (matchPattern(value, m_ConstantInt(&cst))) {
- result = getAffineConstantExpr(cst.getSExtValue(), value.getContext());
- return success();
+ return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
}
Value lhs;
Value rhs;
@@ -72,48 +78,46 @@ static LogicalResult toAffineExpr(Value value, AffineExpr &result,
matchPattern(value, m_Op<arith::MulIOp>(m_Any(&lhs), m_Any(&rhs)))) {
AffineExpr lhsE;
AffineExpr rhsE;
- if (succeeded(toAffineExpr(lhs, lhsE, affineDims, affineSymbols)) &&
- succeeded(toAffineExpr(rhs, rhsE, affineDims, affineSymbols))) {
+ if ((lhsE = toAffineExpr(lhs, affineDims, affineSymbols)) &&
+ (rhsE = toAffineExpr(rhs, affineDims, affineSymbols))) {
AffineExprKind kind;
if (isa<arith::AddIOp>(value.getDefiningOp())) {
kind = mlir::AffineExprKind::Add;
} else {
kind = mlir::AffineExprKind::Mul;
}
- result = getAffineBinaryOpExpr(kind, lhsE, rhsE);
- return success();
+ return getAffineBinaryOpExpr(kind, lhsE, rhsE);
}
}
if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
return affine::isValidSymbol(v);
})) {
- result = getAffineSymbolExpr(*dimIx, value.getContext());
- return success();
+ return getAffineSymbolExpr(*dimIx, value.getContext());
}
if (auto dimIx = findInListOrAdd(
value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
- result = getAffineDimExpr(*dimIx, value.getContext());
- return success();
+ return getAffineDimExpr(*dimIx, value.getContext());
}
- return failure();
+ return {};
}
static LogicalResult
computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
llvm::SmallVectorImpl<Value> &mapArgs) {
- llvm::SmallVector<AffineExpr> results;
- llvm::SmallVector<Value, 2> symbols;
- llvm::SmallVector<Value, 8> dims;
+ SmallVector<AffineExpr> results;
+ SmallVector<Value> symbols;
+ SmallVector<Value> dims;
- for (auto indexExpr : indices) {
- if (failed(
- toAffineExpr(indexExpr, results.emplace_back(), dims, symbols))) {
+ for (Value indexExpr : indices) {
+ AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
+ if (!res) {
return failure();
}
+ results.push_back(res);
}
map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
@@ -140,21 +144,21 @@ struct RaiseMemrefDialect
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineStoreOp>(
op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
- } else {
- LLVM_DEBUG(llvm::dbgs()
- << "[affine] Cannot raise memref op: " << op << "\n");
+ return;
}
- } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "[affine] Cannot raise memref op: " << op << "\n");
+ } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
mapArgs))) {
rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
mapArgs);
- } else {
- LLVM_DEBUG(llvm::dbgs()
- << "[affine] Cannot raise memref op: " << op << "\n");
+ return;
}
+ LLVM_DEBUG(llvm::dbgs()
+ << "[affine] Cannot raise memref op: " << op << "\n");
}
});
}
diff --git a/mlir/test/Dialect/Affine/raise-memref.mlir b/mlir/test/Dialect/Affine/raise-memref.mlir
index d529e2c0c907a6..d8f2aaab4839e0 100644
--- a/mlir/test/Dialect/Affine/raise-memref.mlir
+++ b/mlir/test/Dialect/Affine/raise-memref.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -allow-unregistered-dialect -affine-raise-from-memref --canonicalize | FileCheck %s
-// CHECK-LABEL: func @reduce_window_max() {
+// CHECK-LABEL: func @reduce_window_max(
func.func @reduce_window_max() {
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.alloc() : memref<1x8x8x64xf32>
@@ -45,41 +45,29 @@ func.func @reduce_window_max() {
return
}
-// CHECK: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[cst:.*]] = arith.constant 0
// CHECK: %[[v0:.*]] = memref.alloc() : memref<1x8x8x64xf32>
// CHECK: %[[v1:.*]] = memref.alloc() : memref<1x18x18x64xf32>
-// CHECK: affine.for %[[arg0:.*]] = 0 to 1 {
-// CHECK: affine.for %[[arg1:.*]] = 0 to 8 {
-// CHECK: affine.for %[[arg2:.*]] = 0 to 8 {
-// CHECK: affine.for %[[arg3:.*]] = 0 to 64 {
-// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] : memref<1x8x8x64xf32>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: affine.for %[[a0:.*]] = 0 to 1 {
-// CHECK: affine.for %[[a1:.*]] = 0 to 8 {
-// CHECK: affine.for %[[a2:.*]] = 0 to 8 {
-// CHECK: affine.for %[[a3:.*]] = 0 to 64 {
-// CHECK: affine.for %[[a4:.*]] = 0 to 1 {
-// CHECK: affine.for %[[a5:.*]] = 0 to 3 {
-// CHECK: affine.for %[[a6:.*]] = 0 to 3 {
-// CHECK: affine.for %[[a7:.*]] = 0 to 1 {
-// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
-// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] : memref<1x18x18x64xf32>
+// CHECK: affine.for %[[arg0:.*]] =
+// CHECK: affine.for %[[arg1:.*]] =
+// CHECK: affine.for %[[arg2:.*]] =
+// CHECK: affine.for %[[arg3:.*]] =
+// CHECK: affine.store %[[cst]], %[[v0]][%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]] :
+// CHECK: affine.for %[[a0:.*]] =
+// CHECK: affine.for %[[a1:.*]] =
+// CHECK: affine.for %[[a2:.*]] =
+// CHECK: affine.for %[[a3:.*]] =
+// CHECK: affine.for %[[a4:.*]] =
+// CHECK: affine.for %[[a5:.*]] =
+// CHECK: affine.for %[[a6:.*]] =
+// CHECK: affine.for %[[a7:.*]] =
+// CHECK: %[[lhs:.*]] = affine.load %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
+// CHECK: %[[rhs:.*]] = affine.load %[[v1]][%[[a0]] + %[[a4]], %[[a1]] * 2 + %[[a5]], %[[a2]] * 2 + %[[a6]], %[[a3]] + %[[a7]]] :
// CHECK: %[[res:.*]] = arith.cmpf ogt, %[[lhs]], %[[rhs]] : f32
// CHECK: %[[sel:.*]] = arith.select %[[res]], %[[lhs]], %[[rhs]] : f32
-// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] : memref<1x8x8x64xf32>
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
-// CHECK: }
+// CHECK: affine.store %[[sel]], %[[v0]][%[[a0]], %[[a1]], %[[a2]], %[[a3]]] :
+// CHECK-LABEL: func @symbols(
func.func @symbols(%N : index) {
%0 = memref.alloc() : memref<1024x1024xf32>
%1 = memref.alloc() : memref<1024x1024xf32>
@@ -100,7 +88,7 @@ func.func @symbols(%N : index) {
%4 = arith.addi %N, %cst1 : index
%11 = arith.addi %ax, %cst1 : index
memref.store %9, %2[%i, %4] : memref<1024x1024xf32> // this uses an expression of the symbol
- memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be lowered
+ memref.store %9, %2[%i, %11] : memref<1024x1024xf32> // this uses an iter_args and cannot be raised
%something = "ab.v"() : () -> index
memref.store %9, %2[%i, %something] : memref<1024x1024xf32> // this cannot be lowered
affine.yield %11 : index
@@ -115,16 +103,16 @@ func.func @symbols(%N : index) {
// CHECK: %[[v1:.*]] = memref.alloc() : memref<
// CHECK: %[[v2:.*]] = memref.alloc() : memref<
// CHECK: affine.for %[[a1:.*]] = 0 to %arg0 {
-// CHECK-NEXT: affine.for %[[a2:.*]] = 0 to %arg0 {
-// CHECK-NEXT: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
-// CHECK-NEXT: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
-// CHECK-NEXT: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] : memref<1024x1024xf32>
-// CHECK-NEXT: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] : memref<1024x1024xf32>
-// CHECK-NEXT: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
-// CHECK-NEXT: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
-// CHECK-NEXT: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
-// CHECK-NEXT: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] : memref<1024x1024xf32>
-// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] : memref<1024x1024xf32>
-// CHECK-NEXT: %[[lhs7:.*]] = "ab.v"
-// CHECK-NEXT: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] : memref<1024x1024xf32>
-// CHECK-NEXT: affine.yield %[[lhs6]]
+// CHECK: affine.for %[[a2:.*]] = 0 to %arg0 {
+// CHECK: %[[lhs:.*]] = affine.load %{{.*}}[%[[a1]], %[[a2]]] : memref<1024x1024xf32>
+// CHECK: affine.for %[[a3:.*]] = 0 to %arg0 iter_args(%[[a4:.*]] = %[[cst1]]) -> (index) {
+// CHECK: %[[lhs2:.*]] = affine.load %{{.*}}[%[[a1]], symbol(%arg0) * 2] :
+// CHECK: %[[lhs3:.*]] = affine.load %{{.*}}[%[[a2]] + symbol(%arg0) * 2 + 1, %[[a2]]] :
+// CHECK: %[[lhs4:.*]] = arith.mulf %[[lhs2]], %[[lhs3]]
+// CHECK: %[[lhs5:.*]] = arith.addf %[[lhs]], %[[lhs4]]
+// CHECK: %[[lhs6:.*]] = arith.addi %[[a4]], %[[cst1]]
+// CHECK: affine.store %[[lhs5]], %{{.*}}[%[[a1]], symbol(%arg0) + 1] :
+// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs6]]] :
+// CHECK: %[[lhs7:.*]] = "ab.v"
+// CHECK: memref.store %[[lhs5]], %{{.*}}[%[[a1]], %[[lhs7]]] :
+// CHECK: affine.yield %[[lhs6]]
More information about the Mlir-commits
mailing list