[Mlir-commits] [mlir] [mlir][Conversion] Raise SCF to Affine. (Ported from Polygeist) (PR #158267)
Terapines MLIR
llvmlistbot at llvm.org
Sun Sep 14 19:23:24 PDT 2025
https://github.com/terapines-osc-mlir updated https://github.com/llvm/llvm-project/pull/158267
>From 1ffcb0f0be811dfb06176a47d9c1b0ed51d1be83 Mon Sep 17 00:00:00 2001
From: Terapines MLIR <osc-mlir at terapines.com>
Date: Fri, 12 Sep 2025 18:52:55 +0800
Subject: [PATCH] [mlir][Conversion] Raise SCF to Affine. (Ported from
Polygeist)
Based on Polygeist commit 77c04bb2a7a2406ca9480bcc9e729b07d2c8d077
---
mlir/include/mlir/Conversion/Passes.h | 4 +
mlir/include/mlir/Conversion/Passes.td | 18 +
.../mlir/Conversion/SCFToAffine/SCFToAffine.h | 14 +
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp | 1556 +++++++++++++++++
.../lib/Conversion/SCFToAffine/CMakeLists.txt | 19 +
mlir/lib/Conversion/SCFToAffine/Ops.cpp | 359 ++++
mlir/lib/Conversion/SCFToAffine/Ops.h | 114 ++
.../Conversion/SCFToAffine/RaiseToAffine.cpp | 296 ++++
.../Conversion/SCFToAffine/affinecfg.mlir | 166 ++
.../test/Conversion/SCFToAffine/affraise.mlir | 48 +
.../Conversion/SCFToAffine/affraise2.mlir | 31 +
.../Conversion/SCFToAffine/affraise3.mlir | 95 +
.../Conversion/SCFToAffine/raisescffor.mlir | 25 +
14 files changed, 2746 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h
create mode 100644 mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp
create mode 100644 mlir/lib/Conversion/SCFToAffine/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/SCFToAffine/Ops.cpp
create mode 100644 mlir/lib/Conversion/SCFToAffine/Ops.h
create mode 100644 mlir/lib/Conversion/SCFToAffine/RaiseToAffine.cpp
create mode 100644 mlir/test/Conversion/SCFToAffine/affinecfg.mlir
create mode 100644 mlir/test/Conversion/SCFToAffine/affraise.mlir
create mode 100644 mlir/test/Conversion/SCFToAffine/affraise2.mlir
create mode 100644 mlir/test/Conversion/SCFToAffine/affraise3.mlir
create mode 100644 mlir/test/Conversion/SCFToAffine/raisescffor.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index da061b269daf7..fa4bcb5bce5db 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -58,6 +58,7 @@
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
+#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
@@ -87,6 +88,9 @@
namespace mlir {
+std::unique_ptr<Pass> replaceAffineCFGPass();
+std::unique_ptr<Pass> createRaiseSCFToAffinePass();
+
/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Conversion/Passes.h.inc"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..85f49448e38da 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1029,6 +1029,24 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
}];
}
+//===----------------------------------------------------------------------===//
+// SCFToAffine
+//===----------------------------------------------------------------------===//
+def AffineCFG : Pass<"affine-cfg"> {
+ let summary = "Replace scf.if and similar with affine.if";
+ let constructor = "mlir::replaceAffineCFGPass()";
+}
+
+def RaiseSCFToAffine : Pass<"raise-scf-to-affine"> {
+ let summary = "Raise SCF to affine";
+ let constructor = "mlir::createRaiseSCFToAffinePass()";
+ let dependentDialects = [
+ "affine::AffineDialect",
+ "scf::SCFDialect",
+ ];
+}
+
+
//===----------------------------------------------------------------------===//
// SCFToControlFlow
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h b/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h
new file mode 100644
index 0000000000000..372d19d60fdb3
--- /dev/null
+++ b/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h
@@ -0,0 +1,14 @@
+#ifndef __MLIR_CONVERSION_SCFTOAFFINE_H
+#define __MLIR_CONVERSION_SCFTOAFFINE_H
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+namespace mlir {
+
+#define GEN_PASS_DECL_RAISESCFTOAFFINEPASS
+#define GEN_PASS_DECL_AFFINECFGPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+} // namespace mlir
+
+#endif // __MLIR_CONVERSION_SCFTOAFFINE_H
\ No newline at end of file
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 71986f83c4870..d9da085378834 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -52,6 +52,7 @@ add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(PtrToLLVM)
add_subdirectory(ReconcileUnrealizedCasts)
+add_subdirectory(SCFToAffine)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToEmitC)
add_subdirectory(SCFToGPU)
diff --git a/mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp b/mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp
new file mode 100644
index 0000000000000..593eb02f2e1d3
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp
@@ -0,0 +1,1556 @@
+#include "./Ops.h"
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallSet.h"
+#include <deque>
+
+#define DEBUG_TYPE "affine-cfg"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace mlir::affine;
+
+namespace mlir {
+
+#define GEN_PASS_DEF_AFFINECFG
+#include "mlir/Conversion/Passes.h.inc"
+
+} // namespace mlir
+
+bool isValidIndex(Value val);
+
+bool isReadOnly(Operation *op);
+
+bool isValidSymbolInt(Value value, bool recur = true);
+bool isValidSymbolInt(Operation *defOp, bool recur) {
+ Attribute operandCst;
+ if (matchPattern(defOp, m_Constant(&operandCst)))
+ return true;
+
+ if (recur) {
+ if (isa<SelectOp, IndexCastOp, AddIOp, MulIOp, DivSIOp, DivUIOp, RemSIOp,
+ RemUIOp, SubIOp, CmpIOp, TruncIOp, ExtUIOp, ExtSIOp>(defOp))
+ if (llvm::all_of(defOp->getOperands(),
+ [&](Value v) { return isValidSymbolInt(v, recur); }))
+ return true;
+ if (auto ifOp = mlir::dyn_cast<scf::IfOp>(defOp)) {
+ if (isValidSymbolInt(ifOp.getCondition(), recur)) {
+ if (llvm::all_of(
+ ifOp.thenBlock()->without_terminator(),
+ [&](Operation &o) { return isValidSymbolInt(&o, recur); }) &&
+ llvm::all_of(
+ ifOp.elseBlock()->without_terminator(),
+ [&](Operation &o) { return isValidSymbolInt(&o, recur); }))
+ return true;
+ }
+ }
+ if (auto ifOp = dyn_cast<affine::AffineIfOp>(defOp)) {
+ if (llvm::all_of(ifOp.getOperands(),
+ [&](Value o) { return isValidSymbolInt(o, recur); }))
+ if (llvm::all_of(
+ ifOp.getThenBlock()->without_terminator(),
+ [&](Operation &o) { return isValidSymbolInt(&o, recur); }) &&
+ llvm::all_of(
+ ifOp.getElseBlock()->without_terminator(),
+ [&](Operation &o) { return isValidSymbolInt(&o, recur); }))
+ return true;
+ }
+ }
+ return false;
+}
+
+// isValidSymbol, even if not index
+bool isValidSymbolInt(Value value, bool recur) {
+ // Check that the value is a top level value.
+ if (affine::isTopLevelValue(value))
+ return true;
+
+ if (auto *defOp = value.getDefiningOp()) {
+ if (isValidSymbolInt(defOp, recur))
+ return true;
+ return affine::isValidSymbol(value, affine::getAffineScope(defOp));
+ }
+
+ return false;
+}
+
+struct AffineApplyNormalizer {
+ AffineApplyNormalizer(AffineMap map, ArrayRef<Value> operands,
+ PatternRewriter &rewriter, DominanceInfo &DI);
+
+ /// Returns the AffineMap resulting from normalization.
+ AffineMap getAffineMap() { return affineMap; }
+
+ SmallVector<Value, 8> getOperands() {
+ SmallVector<Value, 8> res(reorderedDims);
+ res.append(concatenatedSymbols.begin(), concatenatedSymbols.end());
+ return res;
+ }
+
+private:
+ /// Helper function to insert `v` into the coordinate system of the current
+ /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding
+ /// renumbered position.
+ AffineDimExpr renumberOneDim(Value v);
+
+ /// Maps of Value to position in `affineMap`.
+ DenseMap<Value, unsigned> dimValueToPosition;
+
+ /// Ordered dims and symbols matching positional dims and symbols in
+ /// `affineMap`.
+ SmallVector<Value, 8> reorderedDims;
+ SmallVector<Value, 8> concatenatedSymbols;
+
+ AffineMap affineMap;
+};
+
+static bool isAffineForArg(Value val) {
+ if (!mlir::isa<BlockArgument>(val))
+ return false;
+ Operation *parentOp =
+ mlir::cast<BlockArgument>(val).getOwner()->getParentOp();
+ return (
+ isa_and_nonnull<affine::AffineForOp, affine::AffineParallelOp>(parentOp));
+}
+
+static bool legalCondition(Value en, bool dim = false) {
+ if (en.getDefiningOp<affine::AffineApplyOp>())
+ return true;
+
+ if (!dim && !isValidSymbolInt(en, /*recur*/ false)) {
+ if (isValidIndex(en) || isValidSymbolInt(en, /*recur*/ true)) {
+ return true;
+ }
+ }
+
+ while (auto ic = en.getDefiningOp<IndexCastOp>())
+ en = ic.getIn();
+
+ if ((en.getDefiningOp<AddIOp>() || en.getDefiningOp<SubIOp>() ||
+ en.getDefiningOp<MulIOp>() || en.getDefiningOp<RemUIOp>() ||
+ en.getDefiningOp<RemSIOp>()) &&
+ (en.getDefiningOp()->getOperand(1).getDefiningOp<ConstantIntOp>() ||
+ en.getDefiningOp()->getOperand(1).getDefiningOp<ConstantIndexOp>()))
+ return true;
+ // if (auto IC = dyn_cast_or_null<IndexCastOp>(en.getDefiningOp())) {
+ // if (!outer || legalCondition(IC.getOperand(), false)) return true;
+ //}
+ if (!dim)
+ if (auto BA = dyn_cast<BlockArgument>(en)) {
+ if (isa<affine::AffineForOp, affine::AffineParallelOp>(
+ BA.getOwner()->getParentOp()))
+ return true;
+ }
+ return false;
+}
+
+/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
+/// keep a correspondence between the mathematical `map` and the `operands` of
+/// a given affine::AffineApplyOp. This correspondence is maintained by
+/// iterating over the operands and forming an `auxiliaryMap` that can be
+/// composed mathematically with `map`. To keep this correspondence in cases
+/// where symbols are produced by affine.apply operations, we perform a local
+/// rewrite of symbols as dims.
+///
+/// Rationale for locally rewriting symbols as dims:
+/// ================================================
+/// The mathematical composition of AffineMap must always concatenate symbols
+/// because it does not have enough information to do otherwise. For example,
+/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
+/// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
+///
+/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
+/// applied to the same mlir::Value for both s0 and s1.
+/// As a consequence mathematical composition of AffineMap always concatenates
+/// symbols.
+///
+/// When AffineMaps are used in affine::AffineApplyOp however, they may specify
+/// composition via symbols, which is ambiguous mathematically. This corner case
+/// is handled by locally rewriting such symbols that come from
+/// affine::AffineApplyOp into dims and composing through dims.
+/// TODO: Composition via symbols comes at a significant code
+/// complexity. Alternatively we should investigate whether we want to
+/// explicitly disallow symbols coming from affine.apply and instead force the
+/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
+/// extra API calls for such uses, which haven't popped up until now) and the
+/// benefit potentially big: simpler and more maintainable code for a
+/// non-trivial, recursive, procedure.
+AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
+ ArrayRef<Value> operands,
+ PatternRewriter &rewriter,
+ DominanceInfo &DI) {
+ assert(map.getNumInputs() == operands.size() &&
+ "number of operands does not match the number of map inputs");
+
+ LLVM_DEBUG(map.print(llvm::dbgs() << "\nInput map: "));
+
+ SmallVector<Value, 8> addedValues;
+
+ llvm::SmallSet<unsigned, 1> symbolsToPromote;
+
+ unsigned numDims = map.getNumDims();
+ unsigned numSymbols = map.getNumSymbols();
+
+ SmallVector<AffineExpr, 8> dimReplacements;
+ SmallVector<AffineExpr, 8> symReplacements;
+
+ SmallVector<SmallVectorImpl<Value> *> opsTodos;
+ auto replaceOp = [&](Operation *oldOp, Operation *newOp) {
+ for (auto [oldV, newV] :
+ llvm::zip(oldOp->getResults(), newOp->getResults()))
+ for (auto *ops : opsTodos)
+ for (auto &op : *ops)
+ if (op == oldV)
+ op = newV;
+ };
+
+ std::function<Value(Value, bool)> fix = [&](Value v,
+ bool index) -> Value /*legal*/ {
+ if (isValidSymbolInt(v, /*recur*/ false))
+ return v;
+ if (index && isAffineForArg(v))
+ return v;
+ auto *op = v.getDefiningOp();
+ if (!op)
+ return nullptr;
+ if (!op)
+ llvm::errs() << v << "\n";
+ assert(op);
+ if (isa<ConstantOp>(op) || isa<ConstantIndexOp>(op))
+ return v;
+ if (!isReadOnly(op)) {
+ return nullptr;
+ }
+ Operation *front = nullptr;
+ SmallVector<Value> ops;
+ opsTodos.push_back(&ops);
+ std::function<void(Operation *)> getAllOps = [&](Operation *todo) {
+ for (auto v : todo->getOperands()) {
+ if (llvm::all_of(op->getRegions(), [&](Region &r) {
+ return !r.isAncestor(v.getParentRegion());
+ }))
+ ops.push_back(v);
+ }
+ for (auto &r : todo->getRegions()) {
+ for (auto &b : r.getBlocks())
+ for (auto &o2 : b.without_terminator())
+ getAllOps(&o2);
+ }
+ };
+ getAllOps(op);
+ for (auto o : ops) {
+ Operation *next;
+ if (auto *op = o.getDefiningOp()) {
+ if (Value nv = fix(o, index)) {
+ op = nv.getDefiningOp();
+ } else {
+ return nullptr;
+ }
+ next = op->getNextNode();
+ } else {
+ auto ba = mlir::cast<BlockArgument>(o);
+ if (index && isAffineForArg(ba)) {
+ } else if (!isValidSymbolInt(o, /*recur*/ false)) {
+ return nullptr;
+ }
+ next = &ba.getOwner()->front();
+ }
+ if (front == nullptr)
+ front = next;
+ else if (DI.dominates(front, next))
+ front = next;
+ }
+ opsTodos.pop_back();
+ if (!front)
+ op->dump();
+ assert(front);
+ PatternRewriter::InsertionGuard B(rewriter);
+ rewriter.setInsertionPoint(front);
+ auto *cloned = rewriter.clone(*op);
+ replaceOp(op, cloned);
+ rewriter.replaceOp(op, cloned->getResults());
+ return cloned->getResult(0);
+ };
+ auto renumberOneSymbol = [&](Value v) {
+ for (auto i : llvm::enumerate(addedValues)) {
+ if (i.value() == v)
+ return getAffineSymbolExpr(i.index(), map.getContext());
+ }
+ auto expr = getAffineSymbolExpr(addedValues.size(), map.getContext());
+ addedValues.push_back(v);
+ return expr;
+ };
+
+ // 2. Compose affine::AffineApplyOps and dispatch dims or symbols.
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ auto t = operands[i];
+ auto decast = t;
+ while (true) {
+ if (auto idx = decast.getDefiningOp<IndexCastOp>()) {
+ decast = idx.getIn();
+ continue;
+ }
+ if (auto idx = decast.getDefiningOp<ExtUIOp>()) {
+ decast = idx.getIn();
+ continue;
+ }
+ if (auto idx = decast.getDefiningOp<ExtSIOp>()) {
+ decast = idx.getIn();
+ continue;
+ }
+ break;
+ }
+
+ if (!isValidSymbolInt(t, /*recur*/ false)) {
+ t = decast;
+ }
+
+ // Only promote one at a time, lest we end up with two dimensions
+ // multiplying each other.
+
+ if (((!isValidSymbolInt(t, /*recur*/ false) &&
+ (t.getDefiningOp<AddIOp>() || t.getDefiningOp<SubIOp>() ||
+ (t.getDefiningOp<MulIOp>() &&
+ ((isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+ isValidSymbolInt(t.getDefiningOp()->getOperand(1))) ||
+ (isValidIndex(t.getDefiningOp()->getOperand(1)) &&
+ isValidSymbolInt(t.getDefiningOp()->getOperand(0)))) &&
+ !(fix(t.getDefiningOp()->getOperand(0), false) &&
+ fix(t.getDefiningOp()->getOperand(1), false))
+
+ ) ||
+ ((t.getDefiningOp<DivUIOp>() || t.getDefiningOp<DivSIOp>()) &&
+ (isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+ isValidSymbolInt(t.getDefiningOp()->getOperand(1))) &&
+ (!(fix(t.getDefiningOp()->getOperand(0), false) &&
+ fix(t.getDefiningOp()->getOperand(1), false)))) ||
+ (t.getDefiningOp<DivSIOp>() &&
+ (isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+ isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) ||
+ (t.getDefiningOp<RemUIOp>() &&
+ (isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+ isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) ||
+ (t.getDefiningOp<RemSIOp>() &&
+ (isValidIndex(t.getDefiningOp()->getOperand(0)) &&
+ isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) ||
+ t.getDefiningOp<ConstantIntOp>() ||
+ t.getDefiningOp<ConstantIndexOp>())) ||
+ ((decast.getDefiningOp<AddIOp>() || decast.getDefiningOp<SubIOp>() ||
+ decast.getDefiningOp<MulIOp>() || decast.getDefiningOp<RemUIOp>() ||
+ decast.getDefiningOp<RemSIOp>()) &&
+ (decast.getDefiningOp()
+ ->getOperand(1)
+ .getDefiningOp<ConstantIntOp>() ||
+ decast.getDefiningOp()
+ ->getOperand(1)
+ .getDefiningOp<ConstantIndexOp>())))) {
+ t = decast;
+ LLVM_DEBUG(llvm::dbgs() << " Replacing: " << t << "\n");
+
+ AffineMap affineApplyMap;
+ SmallVector<Value, 8> affineApplyOperands;
+
+ // llvm::dbgs() << "\nop to start: " << t << "\n";
+
+ if (auto op = t.getDefiningOp<AddIOp>()) {
+ affineApplyMap =
+ AffineMap::get(0, 2,
+ getAffineSymbolExpr(0, op.getContext()) +
+ getAffineSymbolExpr(1, op.getContext()));
+ affineApplyOperands.push_back(op.getLhs());
+ affineApplyOperands.push_back(op.getRhs());
+ } else if (auto op = t.getDefiningOp<SubIOp>()) {
+ affineApplyMap =
+ AffineMap::get(0, 2,
+ getAffineSymbolExpr(0, op.getContext()) -
+ getAffineSymbolExpr(1, op.getContext()));
+ affineApplyOperands.push_back(op.getLhs());
+ affineApplyOperands.push_back(op.getRhs());
+ } else if (auto op = t.getDefiningOp<MulIOp>()) {
+ if (auto ci = op.getRhs().getDefiningOp<ConstantIntOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1, getAffineSymbolExpr(0, op.getContext()) * ci.value());
+ affineApplyOperands.push_back(op.getLhs());
+ } else if (auto ci = op.getRhs().getDefiningOp<ConstantIndexOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1, getAffineSymbolExpr(0, op.getContext()) * ci.value());
+ affineApplyOperands.push_back(op.getLhs());
+ } else {
+ affineApplyMap =
+ AffineMap::get(0, 2,
+ getAffineSymbolExpr(0, op.getContext()) *
+ getAffineSymbolExpr(1, op.getContext()));
+ affineApplyOperands.push_back(op.getLhs());
+ affineApplyOperands.push_back(op.getRhs());
+ }
+ } else if (auto op = t.getDefiningOp<DivSIOp>()) {
+ if (auto ci = op.getRhs().getDefiningOp<ConstantIntOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1,
+ getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value()));
+ affineApplyOperands.push_back(op.getLhs());
+ } else if (auto ci = op.getRhs().getDefiningOp<ConstantIndexOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1,
+ getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value()));
+ affineApplyOperands.push_back(op.getLhs());
+ } else {
+ affineApplyMap = AffineMap::get(
+ 0, 2,
+ getAffineSymbolExpr(0, op.getContext())
+ .floorDiv(getAffineSymbolExpr(1, op.getContext())));
+ affineApplyOperands.push_back(op.getLhs());
+ affineApplyOperands.push_back(op.getRhs());
+ }
+ } else if (auto op = t.getDefiningOp<DivUIOp>()) {
+ if (auto ci = op.getRhs().getDefiningOp<ConstantIntOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1,
+ getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value()));
+ affineApplyOperands.push_back(op.getLhs());
+ } else if (auto ci = op.getRhs().getDefiningOp<ConstantIndexOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1,
+ getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value()));
+ affineApplyOperands.push_back(op.getLhs());
+ } else {
+ affineApplyMap = AffineMap::get(
+ 0, 2,
+ getAffineSymbolExpr(0, op.getContext())
+ .floorDiv(getAffineSymbolExpr(1, op.getContext())));
+ affineApplyOperands.push_back(op.getLhs());
+ affineApplyOperands.push_back(op.getRhs());
+ }
+ } else if (auto op = t.getDefiningOp<RemSIOp>()) {
+ if (auto ci = op.getRhs().getDefiningOp<ConstantIntOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value());
+ affineApplyOperands.push_back(op.getLhs());
+ } else if (auto ci = op.getRhs().getDefiningOp<ConstantIndexOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value());
+ affineApplyOperands.push_back(op.getLhs());
+ } else {
+ affineApplyMap =
+ AffineMap::get(0, 2,
+ getAffineSymbolExpr(0, op.getContext()) %
+ getAffineSymbolExpr(1, op.getContext()));
+ affineApplyOperands.push_back(op.getLhs());
+ affineApplyOperands.push_back(op.getRhs());
+ }
+ } else if (auto op = t.getDefiningOp<RemUIOp>()) {
+ if (auto ci = op.getRhs().getDefiningOp<ConstantIntOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value());
+ affineApplyOperands.push_back(op.getLhs());
+ } else if (auto ci = op.getRhs().getDefiningOp<ConstantIndexOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value());
+ affineApplyOperands.push_back(op.getLhs());
+ } else {
+ affineApplyMap =
+ AffineMap::get(0, 2,
+ getAffineSymbolExpr(0, op.getContext()) %
+ getAffineSymbolExpr(1, op.getContext()));
+ affineApplyOperands.push_back(op.getLhs());
+ affineApplyOperands.push_back(op.getRhs());
+ }
+ } else if (auto op = t.getDefiningOp<ConstantIntOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 0, getAffineConstantExpr(op.value(), op.getContext()));
+ } else if (auto op = t.getDefiningOp<ConstantIndexOp>()) {
+ affineApplyMap = AffineMap::get(
+ 0, 0, getAffineConstantExpr(op.value(), op.getContext()));
+ } else {
+ llvm_unreachable("");
+ }
+
+ SmallVector<AffineExpr, 0> dimRemapping;
+ unsigned numOtherSymbols = affineApplyOperands.size();
+ SmallVector<AffineExpr, 2> symRemapping(numOtherSymbols);
+ for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
+ symRemapping[idx] = renumberOneSymbol(affineApplyOperands[idx]);
+ }
+ affineApplyMap = affineApplyMap.replaceDimsAndSymbols(
+ dimRemapping, symRemapping, reorderedDims.size(), addedValues.size());
+
+ LLVM_DEBUG(affineApplyMap.print(
+ llvm::dbgs() << "\nRenumber into current normalizer: "));
+
+ if (i >= numDims)
+ symReplacements.push_back(affineApplyMap.getResult(0));
+ else
+ dimReplacements.push_back(affineApplyMap.getResult(0));
+
+ } else if (isAffineForArg(t)) {
+ if (i >= numDims)
+ symReplacements.push_back(renumberOneDim(t));
+ else
+ dimReplacements.push_back(renumberOneDim(t));
+ } else if (t.getDefiningOp<affine::AffineApplyOp>()) {
+ auto affineApply = t.getDefiningOp<affine::AffineApplyOp>();
+ // a. Compose affine.apply operations.
+ LLVM_DEBUG(affineApply->print(
+ llvm::dbgs() << "\nCompose affine::AffineApplyOp recursively: "));
+ AffineMap affineApplyMap = affineApply.getAffineMap();
+ SmallVector<Value, 8> affineApplyOperands(
+ affineApply.getOperands().begin(), affineApply.getOperands().end());
+
+ SmallVector<AffineExpr, 0> dimRemapping(affineApplyMap.getNumDims());
+
+ for (size_t i = 0; i < affineApplyMap.getNumDims(); ++i) {
+ assert(i < affineApplyOperands.size());
+ dimRemapping[i] = renumberOneDim(affineApplyOperands[i]);
+ }
+ unsigned numOtherSymbols = affineApplyOperands.size();
+ SmallVector<AffineExpr, 2> symRemapping(numOtherSymbols -
+ affineApplyMap.getNumDims());
+ for (unsigned idx = 0; idx < symRemapping.size(); ++idx) {
+ symRemapping[idx] = renumberOneSymbol(
+ affineApplyOperands[idx + affineApplyMap.getNumDims()]);
+ }
+ affineApplyMap = affineApplyMap.replaceDimsAndSymbols(
+ dimRemapping, symRemapping, reorderedDims.size(), addedValues.size());
+
+ LLVM_DEBUG(
+ affineApplyMap.print(llvm::dbgs() << "\nAffine apply fixup map: "));
+
+ if (i >= numDims)
+ symReplacements.push_back(affineApplyMap.getResult(0));
+ else
+ dimReplacements.push_back(affineApplyMap.getResult(0));
+ } else {
+ if (!isValidSymbolInt(t, /*recur*/ false)) {
+ if (t.getDefiningOp()) {
+ if ((t = fix(t, false))) {
+ assert(isValidSymbolInt(t, /*recur*/ false));
+ } else
+ llvm_unreachable("cannot move");
+ } else
+ llvm_unreachable("cannot move2");
+ }
+ if (i < numDims) {
+ // b. The mathematical composition of AffineMap composes dims.
+ dimReplacements.push_back(renumberOneDim(t));
+ } else {
+ // c. The mathematical composition of AffineMap concatenates symbols.
+ // Note that the map composition will put symbols already present
+ // in the map before any symbols coming from the auxiliary map, so
+ // we insert them before any symbols that are due to renumbering,
+ // and after the proper symbols we have seen already.
+ symReplacements.push_back(renumberOneSymbol(t));
+ }
+ }
+ }
+ for (auto v : addedValues)
+ concatenatedSymbols.push_back(v);
+
+ // Create the new map by replacing each symbol at pos by the next new dim.
+ unsigned numNewDims = reorderedDims.size();
+ unsigned numNewSymbols = addedValues.size();
+ assert(dimReplacements.size() == map.getNumDims());
+ assert(symReplacements.size() == map.getNumSymbols());
+ auto auxillaryMap = map.replaceDimsAndSymbols(
+ dimReplacements, symReplacements, numNewDims, numNewSymbols);
+ LLVM_DEBUG(auxillaryMap.print(llvm::dbgs() << "\nRewritten map: "));
+
+ affineMap = auxillaryMap; // simplifyAffineMap(auxillaryMap);
+
+ LLVM_DEBUG(affineMap.print(llvm::dbgs() << "\nSimplified result: "));
+ LLVM_DEBUG(llvm::dbgs() << "\n");
+}
+
+AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
+ DenseMap<Value, unsigned>::iterator iterPos;
+ bool inserted = false;
+ std::tie(iterPos, inserted) =
+ dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
+ if (inserted) {
+ reorderedDims.push_back(v);
+ }
+ return mlir::cast<AffineDimExpr>(
+ getAffineDimExpr(iterPos->second, v.getContext()));
+}
+
+static void composeAffineMapAndOperands(AffineMap *map,
+ SmallVectorImpl<Value> *operands,
+ PatternRewriter &rewriter,
+ DominanceInfo &di) {
+ AffineApplyNormalizer normalizer(*map, *operands, rewriter, di);
+ auto normalizedMap = normalizer.getAffineMap();
+ auto normalizedOperands = normalizer.getOperands();
+ affine::canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
+ *map = normalizedMap;
+ *operands = normalizedOperands;
+ assert(*map);
+}
+
+bool need(AffineMap *map, SmallVectorImpl<Value> *operands) {
+ assert(map->getNumInputs() == operands->size());
+ for (size_t i = 0; i < map->getNumInputs(); ++i) {
+ auto v = (*operands)[i];
+ if (legalCondition(v, i < map->getNumDims()))
+ return true;
+ }
+ return false;
+}
+bool need(IntegerSet *map, SmallVectorImpl<Value> *operands) {
+ for (size_t i = 0; i < map->getNumInputs(); ++i) {
+ auto v = (*operands)[i];
+ if (legalCondition(v, i < map->getNumDims()))
+ return true;
+ }
+ return false;
+}
+
+void fully2ComposeAffineMapAndOperands(PatternRewriter &builder, AffineMap *map,
+ SmallVectorImpl<Value> *operands,
+ DominanceInfo &di) {
+ IRMapping indexMap;
+ for (auto op : *operands) {
+ SmallVector<IndexCastOp> attempt;
+ auto idx0 = op.getDefiningOp<IndexCastOp>();
+ attempt.push_back(idx0);
+ if (!idx0)
+ continue;
+
+ for (auto &u : idx0.getIn().getUses()) {
+ if (auto idx = dyn_cast<IndexCastOp>(u.getOwner()))
+ if (di.dominates((Operation *)idx, &*builder.getInsertionPoint()))
+ attempt.push_back(idx);
+ }
+
+ for (auto idx : attempt) {
+ if (affine::isValidSymbol(idx)) {
+ indexMap.map(idx.getIn(), idx);
+ break;
+ }
+ }
+ }
+ assert(map->getNumInputs() == operands->size());
+ while (need(map, operands)) {
+ composeAffineMapAndOperands(map, operands, builder, di);
+ assert(map->getNumInputs() == operands->size());
+ }
+ *map = simplifyAffineMap(*map);
+ for (auto &op : *operands) {
+ if (!op.getType().isIndex()) {
+ Operation *toInsert;
+ if (auto *o = op.getDefiningOp())
+ toInsert = o->getNextNode();
+ else {
+ auto ba = mlir::cast<BlockArgument>(op);
+ toInsert = &ba.getOwner()->front();
+ }
+
+ if (auto v = indexMap.lookupOrNull(op))
+ op = v;
+ else {
+ PatternRewriter::InsertionGuard b(builder);
+ builder.setInsertionPoint(toInsert);
+ op = builder.create<IndexCastOp>(op.getLoc(), builder.getIndexType(),
+ op);
+ }
+ }
+ }
+}
+
+void fully2ComposeIntegerSetAndOperands(PatternRewriter &builder,
+ IntegerSet *set,
+ SmallVectorImpl<Value> *operands,
+ DominanceInfo &DI) {
+ IRMapping indexMap;
+ for (auto op : *operands) {
+ SmallVector<IndexCastOp> attempt;
+ auto idx0 = op.getDefiningOp<IndexCastOp>();
+ attempt.push_back(idx0);
+ if (!idx0)
+ continue;
+
+ for (auto &u : idx0.getIn().getUses()) {
+ if (auto idx = dyn_cast<IndexCastOp>(u.getOwner()))
+ if (DI.dominates((Operation *)idx, &*builder.getInsertionPoint()))
+ attempt.push_back(idx);
+ }
+
+ for (auto idx : attempt) {
+ if (affine::isValidSymbol(idx)) {
+ indexMap.map(idx.getIn(), idx);
+ break;
+ }
+ }
+ }
+ auto map = AffineMap::get(set->getNumDims(), set->getNumSymbols(),
+ set->getConstraints(), set->getContext());
+ while (need(&map, operands)) {
+ composeAffineMapAndOperands(&map, operands, builder, DI);
+ }
+ map = simplifyAffineMap(map);
+ *set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(),
+ map.getResults(), set->getEqFlags());
+ for (auto &op : *operands) {
+ if (!op.getType().isIndex()) {
+ Operation *toInsert;
+ if (auto *o = op.getDefiningOp())
+ toInsert = o->getNextNode();
+ else {
+ auto ba = mlir::cast<BlockArgument>(op);
+ toInsert = &ba.getOwner()->front();
+ }
+
+ if (auto v = indexMap.lookupOrNull(op))
+ op = v;
+ else {
+ PatternRewriter::InsertionGuard b(builder);
+ builder.setInsertionPoint(toInsert);
+ op = builder.create<IndexCastOp>(op.getLoc(), builder.getIndexType(),
+ op);
+ }
+ }
+ }
+}
+
+namespace {
+struct AffineCFG : public impl::AffineCFGBase<AffineCFG> {
+ void runOnOperation() override;
+};
+} // namespace
+
+static void setLocationAfter(PatternRewriter &b, mlir::Value val) {
+ if (val.getDefiningOp()) {
+ auto it = val.getDefiningOp()->getIterator();
+ it++;
+ b.setInsertionPoint(val.getDefiningOp()->getBlock(), it);
+ }
+ if (auto bop = dyn_cast<mlir::BlockArgument>(val))
+ b.setInsertionPoint(bop.getOwner(), bop.getOwner()->begin());
+}
+
+struct IndexCastMovement : public OpRewritePattern<IndexCastOp> {
+ using OpRewritePattern<IndexCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexCastOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.use_empty()) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ mlir::Value val = op.getOperand();
+ if (auto bop = dyn_cast<mlir::BlockArgument>(val)) {
+ if (op.getOperation()->getBlock() != bop.getOwner()) {
+ op.getOperation()->moveBefore(bop.getOwner(), bop.getOwner()->begin());
+ return success();
+ }
+ return failure();
+ }
+
+ if (val.getDefiningOp()) {
+ if (op.getOperation()->getBlock() != val.getDefiningOp()->getBlock()) {
+ auto it = val.getDefiningOp()->getIterator();
+ op.getOperation()->moveAfter(val.getDefiningOp()->getBlock(), it);
+ }
+ return failure();
+ }
+ return failure();
+ }
+};
+
+/*
+struct SimplfyIntegerCastMath : public OpRewritePattern<IndexCastOp> {
+ using OpRewritePattern<IndexCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexCastOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.use_empty()) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+ if (auto iadd = op.getOperand().getDefiningOp<AddIOp>()) {
+ PatternRewriter b(rewriter);
+ setLocationAfter(b, iadd.getOperand(0));
+ PatternRewriter b2(rewriter);
+ setLocationAfter(b2, iadd.getOperand(1));
+ rewriter.replaceOpWithNewOp<AddIOp>(
+ op,
+ b.create<IndexCastOp>(op.getLoc(), op.getType(), iadd.getOperand(0)),
+ b2.create<IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(1)));
+ return success();
+ }
+ if (auto iadd = op.getOperand().getDefiningOp<SubIOp>()) {
+ PatternRewriter b(rewriter);
+ setLocationAfter(b, iadd.getOperand(0));
+ PatternRewriter b2(rewriter);
+ setLocationAfter(b2, iadd.getOperand(1));
+ rewriter.replaceOpWithNewOp<SubIOp>(
+ op,
+ b.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(0)),
+ b2.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(1)));
+ return success();
+ }
+ if (auto iadd = op.getOperand().getDefiningOp<MulIOp>()) {
+ PatternRewriter b(rewriter);
+ setLocationAfter(b, iadd.getOperand(0));
+ PatternRewriter b2(rewriter);
+ setLocationAfter(b2, iadd.getOperand(1));
+ rewriter.replaceOpWithNewOp<MulIOp>(
+ op,
+ b.create<IndexCastOp>(op.getLoc(), op.getType(), iadd.getOperand(0)),
+ b2.create<IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(1)));
+ return success();
+ }
+ if (auto iadd = op.getOperand().getDefiningOp<DivUIOp>()) {
+ PatternRewriter b(rewriter);
+ setLocationAfter(b, iadd.getOperand(0));
+ PatternRewriter b2(rewriter);
+ setLocationAfter(b2, iadd.getOperand(1));
+ rewriter.replaceOpWithNewOp<DivUIOp>(
+ op,
+ b.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(0)),
+ b2.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(1)));
+ return success();
+ }
+ if (auto iadd = op.getOperand().getDefiningOp<DivSIOp>()) {
+ PatternRewriter b(rewriter);
+ setLocationAfter(b, iadd.getOperand(0));
+ PatternRewriter b2(rewriter);
+ setLocationAfter(b2, iadd.getOperand(1));
+ rewriter.replaceOpWithNewOp<DivSIOp>(
+ op,
+ b.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(0)),
+ b2.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(1)));
+ return success();
+ }
+ if (auto iadd = op.getOperand().getDefiningOp<RemUIOp>()) {
+ PatternRewriter b(rewriter);
+ setLocationAfter(b, iadd.getOperand(0));
+ PatternRewriter b2(rewriter);
+ setLocationAfter(b2, iadd.getOperand(1));
+ rewriter.replaceOpWithNewOp<RemUIOp>(
+ op,
+ b.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(0)),
+ b2.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(1)));
+ return success();
+ }
+ if (auto iadd = op.getOperand().getDefiningOp<RemSIOp>()) {
+ PatternRewriter b(rewriter);
+ setLocationAfter(b, iadd.getOperand(0));
+ PatternRewriter b2(rewriter);
+ setLocationAfter(b2, iadd.getOperand(1));
+ rewriter.replaceOpWithNewOp<RemSIOp>(
+ op,
+ b.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(0)),
+ b2.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getOperand(1)));
+ return success();
+ }
+ if (auto iadd = op.getOperand().getDefiningOp<SelectOp>()) {
+ PatternRewriter b(rewriter);
+ setLocationAfter(b, iadd.getTrueValue());
+ PatternRewriter b2(rewriter);
+ setLocationAfter(b2, iadd.getFalseValue());
+ auto cond = iadd.getCondition();
+ PatternRewriter b3(rewriter);
+ setLocationAfter(b3, cond);
+ if (auto cmp = iadd.getCondition().getDefiningOp<CmpIOp>()) {
+ if (cmp.getLhs() == iadd.getTrueValue() &&
+ cmp.getRhs() == iadd.getFalseValue()) {
+
+ auto truev = b.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getTrueValue());
+ auto falsev = b2.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
+ iadd.getFalseValue());
+ cond = b3.create<CmpIOp>(cmp.getLoc(), cmp.getPredicate(), truev,
+ falsev);
+ rewriter.replaceOpWithNewOp<SelectOp>(op, cond, truev, falsev);
+ return success();
+ }
+ }
+ }
+ return failure();
+ }
+};
+*/
+
+struct CanonicalizeAffineApply
+ : public OpRewritePattern<affine::AffineApplyOp> {
+ using OpRewritePattern<affine::AffineApplyOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineApplyOp affineOp,
+ PatternRewriter &rewriter) const override {
+
+ SmallVector<Value, 4> mapOperands(affineOp.getMapOperands());
+ auto map = affineOp.getMap();
+ auto prevMap = map;
+
+ auto *scope = affine::getAffineScope(affineOp)->getParentOp();
+ DominanceInfo di(scope);
+
+ fully2ComposeAffineMapAndOperands(rewriter, &map, &mapOperands, di);
+ affine::canonicalizeMapAndOperands(&map, &mapOperands);
+ map = removeDuplicateExprs(map);
+
+ if (map == prevMap)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(affineOp, map,
+ mapOperands);
+ return success();
+ }
+};
+
+struct CanonicalizeIndexCast : public OpRewritePattern<IndexCastOp> {
+ using OpRewritePattern<IndexCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexCastOp indexcastOp,
+ PatternRewriter &rewriter) const override {
+
+ // Fold IndexCast(IndexCast(x)) -> x
+ auto cast = indexcastOp.getOperand().getDefiningOp<IndexCastOp>();
+ if (cast && cast.getOperand().getType() == indexcastOp.getType()) {
+ mlir::Value vals[] = {cast.getOperand()};
+ rewriter.replaceOp(indexcastOp, vals);
+ return success();
+ }
+
+ // Fold IndexCast(constant) -> constant
+ // A little hack because we go through int. Otherwise, the size
+ // of the constant might need to change.
+ if (auto cst = indexcastOp.getOperand().getDefiningOp<ConstantIntOp>()) {
+ rewriter.replaceOpWithNewOp<ConstantIndexOp>(indexcastOp, cst.value());
+ return success();
+ }
+ return failure();
+ }
+};
+
+bool isValidIndex(Value val) {
+ if (isValidSymbolInt(val))
+ return true;
+
+ if (auto cast = val.getDefiningOp<IndexCastOp>())
+ return isValidIndex(cast.getOperand());
+
+ if (auto cast = val.getDefiningOp<ExtSIOp>())
+ return isValidIndex(cast.getOperand());
+
+ if (auto cast = val.getDefiningOp<ExtUIOp>())
+ return isValidIndex(cast.getOperand());
+
+ if (auto bop = val.getDefiningOp<AddIOp>())
+ return isValidIndex(bop.getOperand(0)) && isValidIndex(bop.getOperand(1));
+
+ if (auto bop = val.getDefiningOp<MulIOp>())
+ return (isValidIndex(bop.getOperand(0)) &&
+ isValidSymbolInt(bop.getOperand(1))) ||
+ (isValidIndex(bop.getOperand(1)) &&
+ isValidSymbolInt(bop.getOperand(0)));
+
+ if (auto bop = val.getDefiningOp<DivSIOp>())
+ return (isValidIndex(bop.getOperand(0)) &&
+ isValidSymbolInt(bop.getOperand(1)));
+
+ if (auto bop = val.getDefiningOp<DivUIOp>())
+ return (isValidIndex(bop.getOperand(0)) &&
+ isValidSymbolInt(bop.getOperand(1)));
+
+ if (auto bop = val.getDefiningOp<RemSIOp>()) {
+ return (isValidIndex(bop.getOperand(0)) &&
+ bop.getOperand(1).getDefiningOp<arith::ConstantOp>());
+ }
+
+ if (auto bop = val.getDefiningOp<RemUIOp>())
+ return (isValidIndex(bop.getOperand(0)) &&
+ bop.getOperand(1).getDefiningOp<arith::ConstantOp>());
+
+ if (auto bop = val.getDefiningOp<SubIOp>())
+ return isValidIndex(bop.getOperand(0)) && isValidIndex(bop.getOperand(1));
+
+ if (val.getDefiningOp<ConstantIndexOp>())
+ return true;
+
+ if (val.getDefiningOp<ConstantIntOp>())
+ return true;
+
+ if (auto ba = dyn_cast<BlockArgument>(val)) {
+ auto *owner = ba.getOwner();
+ assert(owner);
+
+ auto *parentOp = owner->getParentOp();
+ if (!parentOp) {
+ owner->dump();
+ llvm::errs() << " ba: " << ba << "\n";
+ }
+ assert(parentOp);
+ if (isa<FunctionOpInterface>(parentOp))
+ return true;
+ if (auto af = dyn_cast<affine::AffineForOp>(parentOp))
+ return af.getInductionVar() == ba;
+
+ // TODO ensure not a reduced var
+ if (isa<affine::AffineParallelOp>(parentOp))
+ return true;
+
+ if (isa<FunctionOpInterface>(parentOp))
+ return true;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "illegal isValidIndex: " << val << "\n");
+ return false;
+}
+
+// returns legality
+bool handleMinMax(Value start, SmallVectorImpl<Value> &out, bool &min,
+ bool &max) {
+
+ SmallVector<Value> todo = {start};
+ while (todo.size()) {
+ auto cur = todo.back();
+ todo.pop_back();
+ if (isValidIndex(cur)) {
+ out.push_back(cur);
+ continue;
+ }
+ if (auto selOp = cur.getDefiningOp<SelectOp>()) {
+ // UB only has min of operands
+ if (auto cmp = selOp.getCondition().getDefiningOp<CmpIOp>()) {
+ if (cmp.getLhs() == selOp.getTrueValue() &&
+ cmp.getRhs() == selOp.getFalseValue()) {
+ todo.push_back(cmp.getLhs());
+ todo.push_back(cmp.getRhs());
+ if (cmp.getPredicate() == CmpIPredicate::sle ||
+ cmp.getPredicate() == CmpIPredicate::slt) {
+ min = true;
+ continue;
+ }
+ if (cmp.getPredicate() == CmpIPredicate::sge ||
+ cmp.getPredicate() == CmpIPredicate::sgt) {
+ max = true;
+ continue;
+ }
+ }
+ }
+ }
+ return false;
+ }
+ return !(min && max);
+}
+
+bool handle(PatternRewriter &b, CmpIOp cmpi, SmallVectorImpl<AffineExpr> &exprs,
+ SmallVectorImpl<bool> &eqflags, SmallVectorImpl<Value> &applies) {
+ SmallVector<Value> lhs;
+ bool lhsMin = false;
+ bool lhsMax = false;
+ if (!handleMinMax(cmpi.getLhs(), lhs, lhsMin, lhsMax)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "illegal lhs: " << cmpi.getLhs() << " - " << cmpi << "\n");
+ return false;
+ }
+ assert(lhs.size());
+ SmallVector<Value> rhs;
+ bool rhsMin = false;
+ bool rhsMax = false;
+ if (!handleMinMax(cmpi.getRhs(), rhs, rhsMin, rhsMax)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "illegal rhs: " << cmpi.getRhs() << " - " << cmpi << "\n");
+ return false;
+ }
+ assert(rhs.size());
+ for (auto &lhspack : lhs)
+ if (!mlir::isa<IndexType>(lhspack.getType())) {
+ lhspack = b.create<arith::IndexCastOp>(
+ cmpi.getLoc(), IndexType::get(cmpi.getContext()), lhspack);
+ }
+
+ for (auto &rhspack : rhs)
+ if (!mlir::isa<IndexType>(rhspack.getType())) {
+ rhspack = b.create<arith::IndexCastOp>(
+ cmpi.getLoc(), IndexType::get(cmpi.getContext()), rhspack);
+ }
+
+ switch (cmpi.getPredicate()) {
+ case CmpIPredicate::eq: {
+ if (lhsMin || lhsMax || rhsMin || rhsMax)
+ return false;
+ eqflags.push_back(true);
+
+ applies.push_back(lhs[0]);
+ applies.push_back(rhs[0]);
+ AffineExpr dims[2] = {b.getAffineSymbolExpr(2 * exprs.size() + 0),
+ b.getAffineSymbolExpr(2 * exprs.size() + 1)};
+ exprs.push_back(dims[0] - dims[1]);
+ } break;
+
+ case CmpIPredicate::ugt:
+ case CmpIPredicate::uge:
+ for (auto lhspack : lhs)
+ if (!valueCmp(Cmp::GE, lhspack, 0)) {
+ LLVM_DEBUG(llvm::dbgs() << "illegal greater lhs icmp: " << cmpi << " - "
+ << lhspack << "\n");
+ return false;
+ }
+ for (auto rhspack : rhs)
+ if (!valueCmp(Cmp::GE, rhspack, 0)) {
+ LLVM_DEBUG(llvm::dbgs() << "illegal greater rhs icmp: " << cmpi << " - "
+ << rhspack << "\n");
+ return false;
+ }
+
+ case CmpIPredicate::sge:
+ case CmpIPredicate::sgt: {
+ // if lhs >=? rhs
+ // if lhs is a min(a, b) both must be true and this is fine
+ // if lhs is a max(a, b) either may be true, and sets require and
+ // similarly if rhs is a max(), both must be true;
+ if (lhsMax || rhsMin)
+ return false;
+ for (auto lhspack : lhs)
+ for (auto rhspack : rhs) {
+ eqflags.push_back(false);
+ applies.push_back(lhspack);
+ applies.push_back(rhspack);
+ AffineExpr dims[2] = {b.getAffineSymbolExpr(2 * exprs.size() + 0),
+ b.getAffineSymbolExpr(2 * exprs.size() + 1)};
+ auto expr = dims[0] - dims[1];
+ if (cmpi.getPredicate() == CmpIPredicate::sgt ||
+ cmpi.getPredicate() == CmpIPredicate::ugt)
+ expr = expr - 1;
+ exprs.push_back(expr);
+ }
+ } break;
+
+ case CmpIPredicate::ult:
+ case CmpIPredicate::ule:
+ for (auto lhspack : lhs)
+ if (!valueCmp(Cmp::GE, lhspack, 0)) {
+ LLVM_DEBUG(llvm::dbgs() << "illegal less lhs icmp: " << cmpi << " - "
+ << lhspack << "\n");
+ return false;
+ }
+ for (auto rhspack : rhs)
+ if (!valueCmp(Cmp::GE, rhspack, 0)) {
+ LLVM_DEBUG(llvm::dbgs() << "illegal less rhs icmp: " << cmpi << " - "
+ << rhspack << "\n");
+ return false;
+ }
+
+ case CmpIPredicate::slt:
+ case CmpIPredicate::sle: {
+ if (lhsMin || rhsMax)
+ return false;
+ for (auto lhspack : lhs)
+ for (auto rhspack : rhs) {
+ eqflags.push_back(false);
+ applies.push_back(lhspack);
+ applies.push_back(rhspack);
+ AffineExpr dims[2] = {b.getAffineSymbolExpr(2 * exprs.size() + 0),
+ b.getAffineSymbolExpr(2 * exprs.size() + 1)};
+ auto expr = dims[1] - dims[0];
+ if (cmpi.getPredicate() == CmpIPredicate::slt ||
+ cmpi.getPredicate() == CmpIPredicate::ult)
+ expr = expr - 1;
+ exprs.push_back(expr);
+ }
+ } break;
+
+ case CmpIPredicate::ne:
+ LLVM_DEBUG(llvm::dbgs() << "illegal icmp: " << cmpi << "\n");
+ return false;
+ }
+ return true;
+}
+/*
+static void replaceStore(memref::StoreOp store,
+ const SmallVector<Value, 2> &newIndexes) {
+ auto memrefType = mlir::cast<MemRefType>(store.getMemRef().getType());
+ size_t rank = memrefType.getRank();
+ if (rank != newIndexes.size()) {
+ llvm::errs() << store << "\n";
+ }
+ assert(rank == newIndexes.size() && "Expect rank to match new indexes");
+
+ PatternRewriter builder(store);
+ Location loc = store.getLoc();
+ builder.create<affine::AffineStoreOp>(loc, store.getValueToStore(),
+store.getMemRef(), newIndexes); store.erase();
+}
+
+static void replaceLoad(memref::LoadOp load,
+ const SmallVector<Value, 2> &newIndexes) {
+ PatternRewriter builder(load);
+ Location loc = load.getLoc();
+
+ auto memrefType = mlir::cast<MemRefType>(load.getMemRef().getType());
+ size_t rank = memrefType.getRank();
+ if (rank != newIndexes.size()) {
+ llvm::errs() << load << "\n";
+ }
+ assert(rank == newIndexes.size() && "rank must equal new indexes size");
+
+ affine::AffineLoadOp affineLoad =
+ builder.create<affine::AffineLoadOp>(loc, load.getMemRef(), newIndexes);
+ load.getResult().replaceAllUsesWith(affineLoad.getResult());
+ load.erase();
+}
+*/
+struct MoveLoadToAffine : public OpRewritePattern<memref::LoadOp> {
+ using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::LoadOp load,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::all_of(load.getIndices(), isValidIndex))
+ return failure();
+
+ auto memrefType = mlir::cast<MemRefType>(load.getMemRef().getType());
+ int64_t rank = memrefType.getRank();
+
+ // Create identity map for memrefs with at least one dimension or () -> ()
+ // for zero-dimensional memrefs.
+ SmallVector<AffineExpr, 4> dimExprs;
+ dimExprs.reserve(rank);
+ for (unsigned i = 0; i < rank; ++i)
+ dimExprs.push_back(rewriter.getAffineSymbolExpr(i));
+ auto map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs,
+ rewriter.getContext());
+
+ SmallVector<Value, 4> operands = load.getIndices();
+
+ if (map.getNumInputs() != operands.size()) {
+ // load->getParentOfType<FuncOp>().dump();
+ llvm::errs() << " load: " << load << "\n";
+ }
+ auto *scope = affine::getAffineScope(load)->getParentOp();
+ DominanceInfo di(scope);
+ assert(map.getNumInputs() == operands.size());
+ fully2ComposeAffineMapAndOperands(rewriter, &map, &operands, di);
+ assert(map.getNumInputs() == operands.size());
+ affine::canonicalizeMapAndOperands(&map, &operands);
+ assert(map.getNumInputs() == operands.size());
+
+ affine::AffineLoadOp affineLoad = affine::AffineLoadOp::create(
+ rewriter, load.getLoc(), load.getMemRef(), map, operands);
+ load.getResult().replaceAllUsesWith(affineLoad.getResult());
+ rewriter.eraseOp(load);
+ return success();
+ }
+};
+
+struct MoveStoreToAffine : public OpRewritePattern<memref::StoreOp> {
+ using OpRewritePattern<memref::StoreOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::StoreOp store,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::all_of(store.getIndices(), isValidIndex))
+ return failure();
+
+ auto memrefType = mlir::cast<MemRefType>(store.getMemRef().getType());
+ int64_t rank = memrefType.getRank();
+
+ // Create identity map for memrefs with at least one dimension or () -> ()
+ // for zero-dimensional memrefs.
+ SmallVector<AffineExpr, 4> dimExprs;
+ dimExprs.reserve(rank);
+ for (unsigned i = 0; i < rank; ++i)
+ dimExprs.push_back(rewriter.getAffineSymbolExpr(i));
+ auto map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs,
+ rewriter.getContext());
+ SmallVector<Value, 4> operands = store.getIndices();
+
+ auto *scope = affine::getAffineScope(store)->getParentOp();
+ DominanceInfo di(scope);
+
+ fully2ComposeAffineMapAndOperands(rewriter, &map, &operands, di);
+ affine::canonicalizeMapAndOperands(&map, &operands);
+
+ affine::AffineStoreOp::create(rewriter, store.getLoc(),
+ store.getValueToStore(), store.getMemRef(),
+ map, operands);
+ rewriter.eraseOp(store);
+ return success();
+ }
+};
+
+static bool areChanged(SmallVectorImpl<Value> &afterOperands,
+ SmallVectorImpl<Value> &beforeOperands) {
+ if (afterOperands.size() != beforeOperands.size())
+ return true;
+ if (!std::equal(afterOperands.begin(), afterOperands.end(),
+ beforeOperands.begin()))
+ return true;
+ return false;
+}
+
+template <typename T>
+struct AffineFixup : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+
+ /// Replace the affine op with another instance of it with the supplied
+ /// map and mapOperands.
+ void replaceAffineOp(PatternRewriter &rewriter, T affineOp, AffineMap map,
+ ArrayRef<Value> mapOperands) const;
+
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ auto map = op.getAffineMap();
+ SmallVector<Value, 4> operands = op.getMapOperands();
+
+ auto prevMap = map;
+ auto prevOperands = operands;
+
+ auto *scope = affine::getAffineScope(op)->getParentOp();
+ DominanceInfo di(scope);
+
+ assert(map.getNumInputs() == operands.size());
+ fully2ComposeAffineMapAndOperands(rewriter, &map, &operands, di);
+ assert(map.getNumInputs() == operands.size());
+ affine::canonicalizeMapAndOperands(&map, &operands);
+ assert(map.getNumInputs() == operands.size());
+
+ if (map == prevMap && !areChanged(operands, prevOperands))
+ return failure();
+
+ replaceAffineOp(rewriter, op, map, operands);
+ return success();
+ }
+};
+
+// Specialize the template to account for the different build signatures for
+// affine load, store, and apply ops.
+template <>
+void AffineFixup<affine::AffineLoadOp>::replaceAffineOp(
+ PatternRewriter &rewriter, affine::AffineLoadOp load, AffineMap map,
+ ArrayRef<Value> mapOperands) const {
+ rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(load, load.getMemRef(), map,
+ mapOperands);
+}
+template <>
+void AffineFixup<affine::AffinePrefetchOp>::replaceAffineOp(
+ PatternRewriter &rewriter, affine::AffinePrefetchOp prefetch, AffineMap map,
+ ArrayRef<Value> mapOperands) const {
+ rewriter.replaceOpWithNewOp<affine::AffinePrefetchOp>(
+ prefetch, prefetch.getMemref(), map, mapOperands,
+ prefetch.getLocalityHint(), prefetch.getIsWrite(),
+ prefetch.getIsDataCache());
+}
+template <>
+void AffineFixup<affine::AffineStoreOp>::replaceAffineOp(
+ PatternRewriter &rewriter, affine::AffineStoreOp store, AffineMap map,
+ ArrayRef<Value> mapOperands) const {
+ rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
+ store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
+}
+template <>
+void AffineFixup<affine::AffineVectorLoadOp>::replaceAffineOp(
+ PatternRewriter &rewriter, affine::AffineVectorLoadOp vectorload,
+ AffineMap map, ArrayRef<Value> mapOperands) const {
+ rewriter.replaceOpWithNewOp<affine::AffineVectorLoadOp>(
+ vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
+ mapOperands);
+}
+template <>
+void AffineFixup<affine::AffineVectorStoreOp>::replaceAffineOp(
+ PatternRewriter &rewriter, affine::AffineVectorStoreOp vectorstore,
+ AffineMap map, ArrayRef<Value> mapOperands) const {
+ rewriter.replaceOpWithNewOp<affine::AffineVectorStoreOp>(
+ vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
+ mapOperands);
+}
+
+// Generic version for ops that don't have extra operands.
+template <typename AffineOpTy>
+void AffineFixup<AffineOpTy>::replaceAffineOp(
+ PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
+ ArrayRef<Value> mapOperands) const {
+ rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
+}
+
+struct CanonicalieForBounds : public OpRewritePattern<affine::AffineForOp> {
+ using OpRewritePattern<affine::AffineForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineForOp forOp,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
+ SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
+ SmallVector<Value, 4> origLbOperands(forOp.getLowerBoundOperands());
+ SmallVector<Value, 4> origUbOperands(forOp.getUpperBoundOperands());
+
+ auto lbMap = forOp.getLowerBoundMap();
+ auto ubMap = forOp.getUpperBoundMap();
+ auto prevLbMap = lbMap;
+ auto prevUbMap = ubMap;
+
+ // llvm::errs() << "*********\n";
+ // ubMap.dump();
+
+ auto *scope = affine::getAffineScope(forOp)->getParentOp();
+ DominanceInfo di(scope);
+
+ fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbOperands, di);
+ affine::canonicalizeMapAndOperands(&lbMap, &lbOperands);
+ lbMap = removeDuplicateExprs(lbMap);
+
+ fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubOperands, di);
+ affine::canonicalizeMapAndOperands(&ubMap, &ubOperands);
+ ubMap = removeDuplicateExprs(ubMap);
+
+ // ubMap.dump();
+ // forOp.dump();
+
+ // Any canonicalization change in map or operands always leads to updated
+ // map(s).
+ if ((lbMap == prevLbMap && ubMap == prevUbMap) &&
+ (!areChanged(lbOperands, origLbOperands)) &&
+ (!areChanged(ubOperands, origUbOperands)))
+ return failure();
+
+ // llvm::errs() << "oldParent:" << *forOp.getParentOp() << "\n";
+ // llvm::errs() << "oldfor:" << forOp << "\n";
+
+ if ((lbMap != prevLbMap) || areChanged(lbOperands, origLbOperands))
+ forOp.setLowerBound(lbOperands, lbMap);
+ if ((ubMap != prevUbMap) || areChanged(ubOperands, origUbOperands))
+ forOp.setUpperBound(ubOperands, ubMap);
+
+ // llvm::errs() << "newfor:" << forOp << "\n";
+ return success();
+ }
+};
+
+struct CanonicalizIfBounds : public OpRewritePattern<affine::AffineIfOp> {
+ using OpRewritePattern<affine::AffineIfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineIfOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value, 4> operands(op.getOperands());
+ SmallVector<Value, 4> origOperands(operands);
+
+ auto map = op.getIntegerSet();
+ auto prevMap = map;
+
+ // llvm::errs() << "*********\n";
+ // ubMap.dump();
+
+ auto *scope = affine::getAffineScope(op)->getParentOp();
+ DominanceInfo DI(scope);
+
+ fully2ComposeIntegerSetAndOperands(rewriter, &map, &operands, DI);
+ affine::canonicalizeSetAndOperands(&map, &operands);
+
+ // map(s).
+ if (map == prevMap && !areChanged(operands, origOperands))
+ return failure();
+
+ op.setConditional(map, operands);
+
+ return success();
+ }
+};
+
+struct MoveIfToAffine : public OpRewritePattern<scf::IfOp> {
+ using OpRewritePattern<scf::IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::IfOp ifOp,
+ PatternRewriter &rewriter) const override {
+ if (!ifOp->getParentOfType<affine::AffineForOp>() &&
+ !ifOp->getParentOfType<affine::AffineParallelOp>())
+ return failure();
+
+ std::vector<mlir::Type> types;
+ for (auto v : ifOp.getResults()) {
+ types.push_back(v.getType());
+ }
+
+ SmallVector<AffineExpr, 2> exprs;
+ SmallVector<bool, 2> eqflags;
+ SmallVector<Value, 4> applies;
+
+ std::deque<Value> todo = {ifOp.getCondition()};
+ while (todo.size()) {
+ auto cur = todo.front();
+ todo.pop_front();
+ if (auto cmpi = cur.getDefiningOp<CmpIOp>()) {
+ if (!handle(rewriter, cmpi, exprs, eqflags, applies)) {
+ return failure();
+ }
+ continue;
+ }
+ if (auto andi = cur.getDefiningOp<AndIOp>()) {
+ todo.push_back(andi.getOperand(0));
+ todo.push_back(andi.getOperand(1));
+ continue;
+ }
+ return failure();
+ }
+
+ auto *scope = affine::getAffineScope(ifOp)->getParentOp();
+ DominanceInfo di(scope);
+
+ auto iset =
+ IntegerSet::get(/*dim*/ 0, /*symbol*/ 2 * exprs.size(), exprs, eqflags);
+ fully2ComposeIntegerSetAndOperands(rewriter, &iset, &applies, di);
+ affine::canonicalizeSetAndOperands(&iset, &applies);
+ affine::AffineIfOp affineIfOp = affine::AffineIfOp::create(
+ rewriter, ifOp.getLoc(), types, iset, applies,
+ /*elseBlock=*/true);
+
+ rewriter.setInsertionPoint(ifOp.thenYield());
+ rewriter.replaceOpWithNewOp<affine::AffineYieldOp>(
+ ifOp.thenYield(), ifOp.thenYield().getOperands());
+
+ rewriter.eraseBlock(affineIfOp.getThenBlock());
+ rewriter.eraseBlock(affineIfOp.getElseBlock());
+ if (ifOp.getElseRegion().getBlocks().size()) {
+ rewriter.setInsertionPoint(ifOp.elseYield());
+ rewriter.replaceOpWithNewOp<affine::AffineYieldOp>(
+ ifOp.elseYield(), ifOp.elseYield().getOperands());
+ }
+
+ rewriter.inlineRegionBefore(ifOp.getThenRegion(),
+ affineIfOp.getThenRegion(),
+ affineIfOp.getThenRegion().begin());
+ rewriter.inlineRegionBefore(ifOp.getElseRegion(),
+ affineIfOp.getElseRegion(),
+ affineIfOp.getElseRegion().begin());
+
+ rewriter.replaceOp(ifOp, affineIfOp.getResults());
+ return success();
+ }
+};
+
+void AffineCFG::runOnOperation() {
+ mlir::RewritePatternSet rpl(getOperation()->getContext());
+ rpl.add</*SimplfyIntegerCastMath, */ CanonicalizeAffineApply,
+ CanonicalizeIndexCast,
+ /* IndexCastMovement,*/ AffineFixup<affine::AffineLoadOp>,
+ AffineFixup<affine::AffineStoreOp>, CanonicalizIfBounds,
+ MoveStoreToAffine, MoveIfToAffine, MoveLoadToAffine,
+ CanonicalieForBounds>(getOperation()->getContext());
+ GreedyRewriteConfig config;
+ (void)applyPatternsGreedily(getOperation(), std::move(rpl), config);
+}
+
+std::unique_ptr<Pass> mlir::replaceAffineCFGPass() {
+ return std::make_unique<AffineCFG>();
+}
diff --git a/mlir/lib/Conversion/SCFToAffine/CMakeLists.txt b/mlir/lib/Conversion/SCFToAffine/CMakeLists.txt
new file mode 100644
index 0000000000000..8bc6d43ff199c
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToAffine/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRSCFToAffine
+ RaiseToAffine.cpp
+ AffineCFG.cpp
+ Ops.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToAffine
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRAffineDialect
+ MLIRLLVMDialect
+ MLIRSCFDialect
+ MLIRSCFTransforms
+ MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/SCFToAffine/Ops.cpp b/mlir/lib/Conversion/SCFToAffine/Ops.cpp
new file mode 100644
index 0000000000000..fdd13bbb5f384
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToAffine/Ops.cpp
@@ -0,0 +1,359 @@
+
+#include "./Ops.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Value.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+
+bool valueCmp(Cmp cmp, Value bval, ValueOrInt val) {
+ if (auto icast = bval.getDefiningOp<IndexCastOp>()) {
+ return valueCmp(cmp, icast.getIn(), val);
+ }
+
+ IntegerAttr iattr;
+ if (matchPattern(bval, m_Constant(&iattr))) {
+ switch (cmp) {
+ case Cmp::EQ:
+ return val == iattr.getValue();
+ case Cmp::LT:
+ return val > iattr.getValue();
+ case Cmp::LE:
+ return val >= iattr.getValue();
+ case Cmp::GT:
+ return val < iattr.getValue();
+ case Cmp::GE:
+ return val <= iattr.getValue();
+ }
+ }
+
+ if (auto baval = dyn_cast<BlockArgument>(bval)) {
+ if (affine::AffineForOp afFor =
+ dyn_cast<affine::AffineForOp>(baval.getOwner()->getParentOp())) {
+ auto forLb = afFor.getLowerBoundMap().getResults()[baval.getArgNumber()];
+ auto forUb = afFor.getUpperBoundMap().getResults()[baval.getArgNumber()];
+ switch (cmp) {
+ // \forall i \in [LB, UB) == k => LB == k and UB == k+1
+ case Cmp::EQ: {
+ if (!valueCmp(Cmp::EQ, forLb, afFor.getLowerBoundMap().getNumDims(),
+ afFor.getLowerBoundOperands(), val))
+ return false;
+ if (!val.isValue) {
+ return valueCmp(Cmp::EQ, forUb, afFor.getUpperBoundMap().getNumDims(),
+ afFor.getUpperBoundOperands(), val.iVal + 1);
+ }
+ return false;
+ }
+ // \forall i \in [LB, UB) < k => UB <= k
+ case Cmp::LT: {
+ return valueCmp(Cmp::LE, forUb, afFor.getUpperBoundMap().getNumDims(),
+ afFor.getUpperBoundOperands(), val);
+ }
+ // \forall i \in [LB, UB) <= k => UB-1 <= k => UB <= k+1
+ case Cmp::LE: {
+ if (!val.isValue) {
+ return valueCmp(Cmp::LE, forUb, afFor.getUpperBoundMap().getNumDims(),
+ afFor.getUpperBoundOperands(), val.iVal + 1);
+ }
+ return valueCmp(Cmp::LE, forUb, afFor.getUpperBoundMap().getNumDims(),
+ afFor.getUpperBoundOperands(), val);
+ }
+ // \forall i \in [LB, UB) > k => LB > k
+ case Cmp::GT: {
+ return valueCmp(Cmp::GT, forLb, afFor.getLowerBoundMap().getNumDims(),
+ afFor.getLowerBoundOperands(), val);
+ }
+ // \forall i \in [LB, UB) >= k => LB >= k
+ case Cmp::GE: {
+ return valueCmp(Cmp::GE, forLb, afFor.getLowerBoundMap().getNumDims(),
+ afFor.getLowerBoundOperands(), val);
+ }
+ }
+ }
+ if (affine::AffineParallelOp afFor = dyn_cast<affine::AffineParallelOp>(
+ baval.getOwner()->getParentOp())) {
+ switch (cmp) {
+ // \forall i \in [max(LB...), min(UB...)) == k => all(LB == k) and
+ // all(UB == k+1)
+ case Cmp::EQ: {
+ for (auto forLb :
+ afFor.getLowerBoundMap(baval.getArgNumber()).getResults())
+ if (!valueCmp(Cmp::EQ, forLb, afFor.getLowerBoundsMap().getNumDims(),
+ afFor.getLowerBoundsOperands(), val))
+ return false;
+ if (!val.isValue) {
+ for (auto forUb :
+ afFor.getUpperBoundMap(baval.getArgNumber()).getResults())
+ if (!valueCmp(Cmp::EQ, forUb,
+ afFor.getUpperBoundsMap().getNumDims(),
+ afFor.getUpperBoundsOperands(), val.iVal + 1))
+ return false;
+ return true;
+ }
+ return false;
+ }
+ // \forall i \in [max(LB...), min(UB...)) < k => any(UB <= k)
+ case Cmp::LT: {
+ for (auto forUb :
+ afFor.getUpperBoundMap(baval.getArgNumber()).getResults())
+ if (valueCmp(Cmp::LE, forUb, afFor.getUpperBoundsMap().getNumDims(),
+ afFor.getUpperBoundsOperands(), val))
+ return true;
+ return false;
+ }
+ // \forall i \in [max(LB...), min(UB...)) <= k => any(UB-1 <= k) =>
+ // any(UB <= k+1)
+ case Cmp::LE: {
+ if (!val.isValue) {
+ for (auto forUb :
+ afFor.getUpperBoundMap(baval.getArgNumber()).getResults())
+ if (valueCmp(Cmp::LE, forUb, afFor.getUpperBoundsMap().getNumDims(),
+ afFor.getUpperBoundsOperands(), val.iVal + 1))
+ return true;
+ return false;
+ }
+
+ for (auto forUb :
+ afFor.getUpperBoundMap(baval.getArgNumber()).getResults())
+ if (valueCmp(Cmp::LE, forUb, afFor.getUpperBoundsMap().getNumDims(),
+ afFor.getUpperBoundsOperands(), val))
+ return true;
+ return false;
+ }
+ // \forall i \in [max(LB...), min(UB...)) > k => any(LB > k)
+ case Cmp::GT: {
+ for (auto forLb :
+ afFor.getLowerBoundMap(baval.getArgNumber()).getResults())
+ if (valueCmp(Cmp::GT, forLb, afFor.getLowerBoundsMap().getNumDims(),
+ afFor.getLowerBoundsOperands(), val))
+ return true;
+ return false;
+ }
+ // \forall i \in [max(LB...), min(UB...)) >= k => any(LB >= k)
+ case Cmp::GE: {
+ for (auto forLb :
+ afFor.getLowerBoundMap(baval.getArgNumber()).getResults())
+ if (valueCmp(Cmp::GE, forLb, afFor.getLowerBoundsMap().getNumDims(),
+ afFor.getLowerBoundsOperands(), val))
+ return true;
+ return false;
+ }
+ }
+ }
+
+ if (scf::ForOp afFor =
+ dyn_cast<scf::ForOp>(baval.getOwner()->getParentOp())) {
+ if (baval.getArgNumber() == 0) {
+ auto forLb = afFor.getLowerBound();
+ auto forUb = afFor.getUpperBound();
+ switch (cmp) {
+ // \forall i \in [LB, UB) == k => LB == k and UB == k+1
+ case Cmp::EQ: {
+ if (!valueCmp(Cmp::EQ, forLb, val))
+ return false;
+ if (!val.isValue) {
+ return valueCmp(Cmp::EQ, forUb, val.iVal + 1);
+ }
+ return false;
+ }
+ // \forall i \in [LB, UB) < k => UB <= k
+ case Cmp::LT: {
+ return valueCmp(Cmp::LE, forUb, val);
+ }
+ // \forall i \in [LB, UB) <= k => UB-1 <= k => UB <= k+1
+ case Cmp::LE: {
+ if (!val.isValue) {
+ return valueCmp(Cmp::LE, forUb, val.iVal + 1);
+ }
+ return valueCmp(Cmp::LE, forUb, val);
+ }
+ // \forall i \in [LB, UB) > k => LB > k
+ case Cmp::GT: {
+ return valueCmp(Cmp::GT, forLb, val);
+ }
+ // \forall i \in [LB, UB) >= k => LB >= k
+ case Cmp::GE: {
+ return valueCmp(Cmp::GE, forLb, val);
+ }
+ }
+ }
+ }
+
+ if (scf::ParallelOp afFor =
+ dyn_cast<scf::ParallelOp>(baval.getOwner()->getParentOp())) {
+ auto forLb = afFor.getLowerBound()[baval.getArgNumber()];
+ auto forUb = afFor.getUpperBound()[baval.getArgNumber()];
+ switch (cmp) {
+ // \forall i \in [LB, UB) == k => LB == k and UB == k+1
+ case Cmp::EQ: {
+ if (!valueCmp(Cmp::EQ, forLb, val))
+ return false;
+ if (!val.isValue) {
+ return valueCmp(Cmp::EQ, forUb, val.iVal + 1);
+ }
+ return false;
+ }
+ // \forall i \in [LB, UB) < k => UB <= k
+ case Cmp::LT: {
+ return valueCmp(Cmp::LE, forUb, val);
+ }
+ // \forall i \in [LB, UB) <= k => UB-1 <= k => UB <= k+1
+ case Cmp::LE: {
+ if (!val.isValue) {
+ return valueCmp(Cmp::LE, forUb, val.iVal + 1);
+ }
+ return valueCmp(Cmp::LE, forUb, val);
+ }
+ // \forall i \in [LB, UB) > k => LB > k
+ case Cmp::GT: {
+ return valueCmp(Cmp::GT, forLb, val);
+ }
+ // \forall i \in [LB, UB) >= k => LB >= k
+ case Cmp::GE: {
+ return valueCmp(Cmp::GE, forLb, val);
+ }
+ }
+ }
+ }
+ if (val.isValue && val.vVal == bval) {
+ switch (cmp) {
+ case Cmp::EQ:
+ return true;
+ case Cmp::LT:
+ return false;
+ case Cmp::LE:
+ return true;
+ case Cmp::GT:
+ return false;
+ case Cmp::GE:
+ return true;
+ }
+ }
+ return false;
+}
+
+bool valueCmp(Cmp cmp, AffineExpr expr, size_t numDim, ValueRange operands,
+ ValueOrInt val) {
+
+ if (auto opd = mlir::dyn_cast<AffineConstantExpr>(expr)) {
+ switch (cmp) {
+ case Cmp::EQ:
+ return val == opd.getValue();
+ case Cmp::LT:
+ return val > opd.getValue();
+ case Cmp::LE:
+ return val >= opd.getValue();
+ case Cmp::GT:
+ return val < opd.getValue();
+ case Cmp::GE:
+ return val <= opd.getValue();
+ }
+ }
+ if (auto opd = mlir::dyn_cast<AffineDimExpr>(expr)) {
+ return valueCmp(cmp, operands[opd.getPosition()], val);
+ }
+ if (auto opd = mlir::dyn_cast<AffineSymbolExpr>(expr)) {
+ return valueCmp(cmp, operands[opd.getPosition() + numDim], val);
+ }
+
+ if (auto bop = mlir::dyn_cast<AffineBinaryOpExpr>(expr)) {
+ if (bop.getKind() == AffineExprKind::Add) {
+ switch (cmp) {
+ case Cmp::EQ:
+ return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) ||
+ (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, val));
+ case Cmp::LT:
+ return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) &&
+ valueCmp(Cmp::LE, bop.getRHS(), numDim, operands, 0)) ||
+ (valueCmp(Cmp::LE, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, val)) ||
+ (valueCmp(Cmp::LE, bop.getLHS(), numDim, operands, val) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) ||
+ (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(Cmp::LE, bop.getRHS(), numDim, operands, val));
+ case Cmp::LE:
+ return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) ||
+ (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, val));
+ case Cmp::GT:
+ return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) &&
+ valueCmp(Cmp::GE, bop.getRHS(), numDim, operands, 0)) ||
+ (valueCmp(Cmp::GE, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, val)) ||
+ (valueCmp(Cmp::GE, bop.getLHS(), numDim, operands, val) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) ||
+ (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(Cmp::GE, bop.getRHS(), numDim, operands, val));
+ case Cmp::GE:
+ return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) ||
+ (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(cmp, bop.getRHS(), numDim, operands, val));
+ }
+ }
+ if (bop.getKind() == AffineExprKind::Mul && val == 0) {
+ switch (cmp) {
+ case Cmp::EQ:
+ return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) ||
+ valueCmp(cmp, bop.getRHS(), numDim, operands, val));
+ case Cmp::LT:
+ return (valueCmp(Cmp::LT, bop.getLHS(), numDim, operands, val) &&
+ valueCmp(Cmp::GT, bop.getRHS(), numDim, operands, 0)) ||
+ (valueCmp(Cmp::GT, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(Cmp::LT, bop.getRHS(), numDim, operands, val));
+ case Cmp::LE:
+ return valueCmp(Cmp::EQ, bop.getLHS(), numDim, operands, val) ||
+ valueCmp(Cmp::EQ, bop.getRHS(), numDim, operands, val) ||
+ ((valueCmp(Cmp::GE, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(Cmp::LE, bop.getRHS(), numDim, operands, val)) ||
+ (valueCmp(Cmp::LE, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(Cmp::GE, bop.getRHS(), numDim, operands, val)));
+ case Cmp::GT:
+ return (valueCmp(Cmp::LT, bop.getLHS(), numDim, operands, val) &&
+ valueCmp(Cmp::LT, bop.getRHS(), numDim, operands, 0)) ||
+ (valueCmp(Cmp::GT, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(Cmp::GT, bop.getRHS(), numDim, operands, val));
+ case Cmp::GE:
+ return valueCmp(Cmp::EQ, bop.getLHS(), numDim, operands, val) ||
+ valueCmp(Cmp::EQ, bop.getRHS(), numDim, operands, val) ||
+ ((valueCmp(Cmp::GE, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(Cmp::GE, bop.getRHS(), numDim, operands, val)) ||
+ (valueCmp(Cmp::LE, bop.getLHS(), numDim, operands, 0) &&
+ valueCmp(Cmp::LE, bop.getRHS(), numDim, operands, val)));
+ }
+ }
+ }
+ return false;
+}
+
+bool isReadOnly(Operation *op) {
+ bool hasRecursiveEffects = op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
+ if (hasRecursiveEffects) {
+ for (Region ®ion : op->getRegions()) {
+ for (auto &block : region) {
+ for (auto &nestedOp : block)
+ if (!isReadOnly(&nestedOp))
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // If the op has memory effects, try to characterize them to see if the op
+ // is trivially dead here.
+ if (auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
+ // Check to see if this op either has no effects, or only allocates/reads
+ // memory.
+ SmallVector<MemoryEffects::EffectInstance, 1> effects;
+ effectInterface.getEffects(effects);
+ return llvm::all_of(effects, [](const MemoryEffects::EffectInstance &it) {
+ return isa<MemoryEffects::Read>(it.getEffect());
+ });
+ }
+ return false;
+}
\ No newline at end of file
diff --git a/mlir/lib/Conversion/SCFToAffine/Ops.h b/mlir/lib/Conversion/SCFToAffine/Ops.h
new file mode 100644
index 0000000000000..d8ddae9c42aca
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToAffine/Ops.h
@@ -0,0 +1,114 @@
+#ifndef POLYGEISTOPS_H
+#define POLYGEISTOPS_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+bool collectEffects(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects,
+ bool ignoreBarriers);
+
+bool getEffectsBefore(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects,
+ bool stopAtBarrier);
+
+bool getEffectsAfter(
+ mlir::Operation *op,
+ llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects,
+ bool stopAtBarrier);
+
+bool isReadOnly(mlir::Operation *);
+bool isReadNone(mlir::Operation *);
+
+bool mayReadFrom(mlir::Operation *, mlir::Value);
+bool mayWriteTo(mlir::Operation *, mlir::Value, bool ignoreBarrier = false);
+
+bool mayAlias(mlir::MemoryEffects::EffectInstance a,
+ mlir::MemoryEffects::EffectInstance b);
+
+bool mayAlias(mlir::MemoryEffects::EffectInstance a, mlir::Value b);
+
+struct ValueOrInt {
+ bool isValue;
+ mlir::Value vVal;
+ int64_t iVal;
+ ValueOrInt(mlir::Value v) { initValue(v); }
+ void initValue(mlir::Value v) {
+ using namespace mlir;
+ if (v) {
+ IntegerAttr iattr;
+ if (matchPattern(v, m_Constant(&iattr))) {
+ iVal = iattr.getValue().getSExtValue();
+ vVal = nullptr;
+ isValue = false;
+ return;
+ }
+ }
+ isValue = true;
+ vVal = v;
+ }
+
+ ValueOrInt(size_t i) : isValue(false), vVal(), iVal(i) {}
+
+ bool operator>=(int64_t v) {
+ if (isValue)
+ return false;
+ return iVal >= v;
+ }
+ bool operator>(int64_t v) {
+ if (isValue)
+ return false;
+ return iVal > v;
+ }
+ bool operator==(int64_t v) {
+ if (isValue)
+ return false;
+ return iVal == v;
+ }
+ bool operator<(int64_t v) {
+ if (isValue)
+ return false;
+ return iVal < v;
+ }
+ bool operator<=(int64_t v) {
+ if (isValue)
+ return false;
+ return iVal <= v;
+ }
+ bool operator>=(const llvm::APInt &v) {
+ if (isValue)
+ return false;
+ return iVal >= v.getSExtValue();
+ }
+ bool operator>(const llvm::APInt &v) {
+ if (isValue)
+ return false;
+ return iVal > v.getSExtValue();
+ }
+ bool operator==(const llvm::APInt &v) {
+ if (isValue)
+ return false;
+ return iVal == v.getSExtValue();
+ }
+ bool operator<(const llvm::APInt &v) {
+ if (isValue)
+ return false;
+ return iVal < v.getSExtValue();
+ }
+ bool operator<=(const llvm::APInt &v) {
+ if (isValue)
+ return false;
+ return iVal <= v.getSExtValue();
+ }
+};
+
+enum class Cmp { EQ, LT, LE, GT, GE };
+
+bool valueCmp(Cmp cmp, mlir::AffineExpr expr, size_t numDim,
+ mlir::ValueRange operands, ValueOrInt val);
+
+bool valueCmp(Cmp cmp, mlir::Value bval, ValueOrInt val);
+#endif
diff --git a/mlir/lib/Conversion/SCFToAffine/RaiseToAffine.cpp b/mlir/lib/Conversion/SCFToAffine/RaiseToAffine.cpp
new file mode 100644
index 0000000000000..fc8fe44b7216d
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToAffine/RaiseToAffine.cpp
@@ -0,0 +1,296 @@
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "raise-to-affine"
+
+using namespace mlir;
+using namespace mlir::arith;
+using namespace affine;
+
+namespace mlir {
+
+#define GEN_PASS_DEF_RAISESCFTOAFFINE
+#include "mlir/Conversion/Passes.h.inc"
+
+} // namespace mlir
+
+bool isValidIndex(Value val);
+void fully2ComposeAffineMapAndOperands(PatternRewriter &builder, AffineMap *map,
+ SmallVectorImpl<Value> *operands,
+ DominanceInfo &di);
+
+namespace {
+struct RaiseSCFToAffine : public impl::RaiseSCFToAffineBase<RaiseSCFToAffine> {
+ void runOnOperation() override;
+};
+} // namespace
+
+struct ForOpRaising : public OpRewritePattern<scf::ForOp> {
+ using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+
+ // TODO: remove me or rename me.
+ bool isAffine(scf::ForOp loop) const {
+ // return true;
+ // enforce step to be a ConstantIndexOp (maybe too restrictive).
+ return affine::isValidSymbol(loop.getStep());
+ }
+
+ int64_t getStep(mlir::Value value) const {
+ ConstantIndexOp cstOp = value.getDefiningOp<ConstantIndexOp>();
+ if (cstOp)
+ return cstOp.value();
+ return 1;
+ }
+
+ AffineMap getMultiSymbolIdentity(Builder &b, unsigned rank) const {
+ SmallVector<AffineExpr, 4> dimExprs;
+ dimExprs.reserve(rank);
+ for (unsigned i = 0; i < rank; ++i)
+ dimExprs.push_back(b.getAffineSymbolExpr(i));
+ return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs,
+ b.getContext());
+ }
+ LogicalResult matchAndRewrite(scf::ForOp loop,
+ PatternRewriter &rewriter) const final {
+ if (isAffine(loop)) {
+ OpBuilder builder(loop);
+
+ SmallVector<Value> lbs;
+ {
+ SmallVector<Value> todo = {loop.getLowerBound()};
+ while (todo.size()) {
+ auto cur = todo.back();
+ todo.pop_back();
+ if (isValidIndex(cur)) {
+ lbs.push_back(cur);
+ continue;
+ }
+ if (auto selOp = cur.getDefiningOp<SelectOp>()) {
+ // LB only has max of operands
+ if (auto cmp = selOp.getCondition().getDefiningOp<CmpIOp>()) {
+ if (cmp.getLhs() == selOp.getTrueValue() &&
+ cmp.getRhs() == selOp.getFalseValue() &&
+ cmp.getPredicate() == CmpIPredicate::sge) {
+ todo.push_back(cmp.getLhs());
+ todo.push_back(cmp.getRhs());
+ continue;
+ }
+ }
+ }
+ return failure();
+ }
+ }
+
+ SmallVector<Value> ubs;
+ {
+ SmallVector<Value> todo = {loop.getUpperBound()};
+ while (todo.size()) {
+ auto cur = todo.back();
+ todo.pop_back();
+ if (isValidIndex(cur)) {
+ ubs.push_back(cur);
+ continue;
+ }
+ if (auto selOp = cur.getDefiningOp<SelectOp>()) {
+ // UB only has min of operands
+ if (auto cmp = selOp.getCondition().getDefiningOp<CmpIOp>()) {
+ if (cmp.getLhs() == selOp.getTrueValue() &&
+ cmp.getRhs() == selOp.getFalseValue() &&
+ cmp.getPredicate() == CmpIPredicate::sle) {
+ todo.push_back(cmp.getLhs());
+ todo.push_back(cmp.getRhs());
+ continue;
+ }
+ }
+ }
+ return failure();
+ }
+ }
+
+ bool rewrittenStep = false;
+ if (!loop.getStep().getDefiningOp<ConstantIndexOp>()) {
+ if (ubs.size() != 1 || lbs.size() != 1)
+ return failure();
+ ubs[0] = DivUIOp::create(
+ rewriter, loop.getLoc(),
+ AddIOp::create(
+ rewriter, loop.getLoc(),
+ SubIOp::create(
+ rewriter, loop.getLoc(), loop.getStep(),
+ ConstantIndexOp::create(rewriter, loop.getLoc(), 1)),
+ SubIOp::create(rewriter, loop.getLoc(), loop.getUpperBound(),
+ loop.getLowerBound())),
+ loop.getStep());
+ lbs[0] = ConstantIndexOp::create(rewriter, loop.getLoc(), 0);
+ rewrittenStep = true;
+ }
+
+ auto *scope = affine::getAffineScope(loop)->getParentOp();
+ DominanceInfo di(scope);
+
+ AffineMap lbMap = getMultiSymbolIdentity(builder, lbs.size());
+ {
+ fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbs, di);
+ affine::canonicalizeMapAndOperands(&lbMap, &lbs);
+ lbMap = removeDuplicateExprs(lbMap);
+ }
+ AffineMap ubMap = getMultiSymbolIdentity(builder, ubs.size());
+ {
+ fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubs, di);
+ affine::canonicalizeMapAndOperands(&ubMap, &ubs);
+ ubMap = removeDuplicateExprs(ubMap);
+ }
+
+ affine::AffineForOp affineLoop = affine::AffineForOp::create(
+ rewriter, loop.getLoc(), lbs, lbMap, ubs, ubMap,
+ getStep(loop.getStep()), loop.getInits());
+
+ auto mergedYieldOp =
+ cast<scf::YieldOp>(loop.getRegion().front().getTerminator());
+
+ Block &newBlock = affineLoop.getRegion().front();
+
+ // The terminator is added if the iterator args are not provided.
+ // see the ::build method.
+ if (affineLoop.getNumIterOperands() == 0) {
+ auto *affineYieldOp = newBlock.getTerminator();
+ rewriter.eraseOp(affineYieldOp);
+ }
+
+ SmallVector<Value> vals;
+ rewriter.setInsertionPointToStart(&affineLoop.getRegion().front());
+ for (Value arg : affineLoop.getRegion().front().getArguments()) {
+ if (rewrittenStep && arg == affineLoop.getInductionVar()) {
+ arg = AddIOp::create(
+ rewriter, loop.getLoc(), loop.getLowerBound(),
+ MulIOp::create(rewriter, loop.getLoc(), arg, loop.getStep()));
+ }
+ vals.push_back(arg);
+ }
+ assert(vals.size() == loop.getRegion().front().getNumArguments());
+ rewriter.mergeBlocks(&loop.getRegion().front(),
+ &affineLoop.getRegion().front(), vals);
+
+ rewriter.setInsertionPoint(mergedYieldOp);
+ affine::AffineYieldOp::create(rewriter, mergedYieldOp.getLoc(),
+ mergedYieldOp.getOperands());
+ rewriter.eraseOp(mergedYieldOp);
+
+ rewriter.replaceOp(loop, affineLoop.getResults());
+
+ return success();
+ }
+ return failure();
+ }
+};
+
+struct ParallelOpRaising : public OpRewritePattern<scf::ParallelOp> {
+ using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
+
+ void canonicalizeLoopBounds(PatternRewriter &rewriter,
+ affine::AffineParallelOp forOp) const {
+ SmallVector<Value, 4> lbOperands(forOp.getLowerBoundsOperands());
+ SmallVector<Value, 4> ubOperands(forOp.getUpperBoundsOperands());
+
+ auto lbMap = forOp.getLowerBoundsMap();
+ auto ubMap = forOp.getUpperBoundsMap();
+
+ auto *scope = affine::getAffineScope(forOp)->getParentOp();
+ DominanceInfo di(scope);
+
+ fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbOperands, di);
+ affine::canonicalizeMapAndOperands(&lbMap, &lbOperands);
+
+ fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubOperands, di);
+ affine::canonicalizeMapAndOperands(&ubMap, &ubOperands);
+
+ forOp.setLowerBounds(lbOperands, lbMap);
+ forOp.setUpperBounds(ubOperands, ubMap);
+ }
+
+ LogicalResult matchAndRewrite(scf::ParallelOp loop,
+ PatternRewriter &rewriter) const final {
+ OpBuilder builder(loop);
+
+ if (loop.getResults().size())
+ return failure();
+
+ if (!llvm::all_of(loop.getLowerBound(), isValidIndex)) {
+ return failure();
+ }
+
+ if (!llvm::all_of(loop.getUpperBound(), isValidIndex)) {
+ return failure();
+ }
+
+ SmallVector<int64_t> steps;
+ for (auto step : loop.getStep())
+ if (auto cst = step.getDefiningOp<ConstantIndexOp>())
+ steps.push_back(cst.value());
+ else
+ return failure();
+
+ ArrayRef<AtomicRMWKind> reductions;
+ SmallVector<AffineMap> bounds;
+ for (size_t i = 0; i < loop.getLowerBound().size(); i++)
+ bounds.push_back(AffineMap::get(
+ /*dimCount=*/0, /*symbolCount=*/loop.getLowerBound().size(),
+ builder.getAffineSymbolExpr(i)));
+ affine::AffineParallelOp affineLoop = affine::AffineParallelOp::create(
+ rewriter, loop.getLoc(), loop.getResultTypes(), reductions, bounds,
+ loop.getLowerBound(), bounds, loop.getUpperBound(),
+ steps); //, loop.getInitVals());
+
+ canonicalizeLoopBounds(rewriter, affineLoop);
+
+ auto mergedReduceOp =
+ cast<scf::ReduceOp>(loop.getRegion().front().getTerminator());
+
+ Block &newBlock = affineLoop.getRegion().front();
+
+ // The terminator is added if the iterator args are not provided.
+ // see the ::build method.
+ if (affineLoop.getResults().size() == 0) {
+ auto *affineYieldOp = newBlock.getTerminator();
+ rewriter.eraseOp(affineYieldOp);
+ }
+
+ SmallVector<Value> vals;
+ for (Value arg : affineLoop.getRegion().front().getArguments()) {
+ vals.push_back(arg);
+ }
+ rewriter.mergeBlocks(&loop.getRegion().front(),
+ &affineLoop.getRegion().front(), vals);
+
+ rewriter.setInsertionPoint(mergedReduceOp);
+ affine::AffineYieldOp::create(rewriter, mergedReduceOp.getLoc(),
+ mergedReduceOp.getOperands());
+ rewriter.eraseOp(mergedReduceOp);
+
+ rewriter.replaceOp(loop, affineLoop.getResults());
+
+ return success();
+ }
+};
+
+void RaiseSCFToAffine::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<ForOpRaising, ParallelOpRaising>(&getContext());
+
+ GreedyRewriteConfig config;
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
+}
+
+std::unique_ptr<Pass> mlir::createRaiseSCFToAffinePass() {
+ return std::make_unique<RaiseSCFToAffine>();
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/SCFToAffine/affinecfg.mlir b/mlir/test/Conversion/SCFToAffine/affinecfg.mlir
new file mode 100644
index 0000000000000..2b3dc6f3a34a0
--- /dev/null
+++ b/mlir/test/Conversion/SCFToAffine/affinecfg.mlir
@@ -0,0 +1,166 @@
+// RUN: mlir-opt --affine-cfg --split-input-file %s | FileCheck %s
+
+module {
+ func.func @_Z7runTestiPPc(%arg0: index, %arg2: memref<?xi32>) {
+ %c0_i32 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : index
+ %1 = arith.addi %arg0, %c1 : index
+ affine.for %arg3 = 0 to 2 {
+ %2 = arith.muli %arg3, %1 : index
+ affine.for %arg4 = 0 to 2 {
+ %3 = arith.addi %2, %arg4 : index
+ memref.store %c0_i32, %arg2[%3] : memref<?xi32>
+ }
+ }
+ return
+ }
+
+}
+
+
+// CHECK: func.func @_Z7runTestiPPc(%[[arg0:.+]]: index, %[[arg1:.+]]: memref<?xi32>) {
+// CHECK-NEXT: %[[c0_i32:.+]] = arith.constant 0 : i32
+// CHECK-NEXT: affine.for %[[arg2:.+]] = 0 to 2 {
+// CHECK-NEXT: affine.for %[[arg3:.+]] = 0 to 2 {
+// CHECK-NEXT: affine.store %c0_i32, %arg1[%[[arg3]] + %[[arg2]] * (symbol(%[[arg0]]) + 1)] : memref<?xi32>
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// -----
+module {
+func.func @kernel_nussinov(%arg0: i32, %arg2: memref<i32>) {
+ %c0 = arith.constant 0 : index
+ %true = arith.constant true
+ %c1_i32 = arith.constant 1 : i32
+ %c59 = arith.constant 59 : index
+ %c100_i32 = arith.constant 100 : i32
+ affine.for %arg3 = 0 to 60 {
+ %0 = arith.subi %c59, %arg3 : index
+ %1 = arith.index_cast %0 : index to i32
+ %2 = arith.cmpi slt, %1, %c100_i32 : i32
+ scf.if %2 {
+ affine.store %arg0, %arg2[] : memref<i32>
+ }
+ }
+ return
+}
+}
+
+// CHECK: #set = affine_set<(d0) : (d0 + 40 >= 0)>
+// CHECK: func.func @kernel_nussinov(%[[arg0:.+]]: i32, %[[arg1:.+]]: memref<i32>) {
+// CHECK-NEXT: affine.for %[[arg2:.+]] = 0 to 60 {
+// CHECK-NEXT: affine.if #set(%[[arg2]]) {
+// CHECK-NEXT: affine.store %[[arg0]], %[[arg1]][] : memref<i32>
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+
+// -----
+
+module {
+ func.func private @run()
+
+ func.func @minif(%arg4: i32, %arg5 : i32, %arg10 : index) {
+ %c0_i32 = arith.constant 0 : i32
+
+ affine.for %i = 0 to 10 {
+ %70 = arith.index_cast %arg10 : index to i32
+ %71 = arith.muli %70, %arg5 : i32
+ %73 = arith.divui %71, %arg5 : i32
+ %75 = arith.muli %73, %arg5 : i32
+ %79 = arith.subi %arg4, %75 : i32
+ %81 = arith.cmpi sle, %arg5, %79 : i32
+ %83 = arith.select %81, %arg5, %79 : i32
+ %92 = arith.cmpi slt, %c0_i32, %83 : i32
+ scf.if %92 {
+ func.call @run() : () -> ()
+ scf.yield
+ }
+ }
+ return
+ }
+}
+
+// CHECK: #set = affine_set<()[s0] : (s0 - 1 >= 0)>
+// CHECK: func.func @minif(%[[arg0:.+]]: i32, %[[arg1:.+]]: i32, %[[arg2:.+]]: index) {
+// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg2]] : index to i32
+// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[arg1]] : i32
+// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[arg1]] : i32
+// CHECK-NEXT: %[[V3:.+]] = arith.muli %[[V2]], %[[arg1]] : i32
+// CHECK-NEXT: %[[V4:.+]] = arith.subi %[[arg0]], %[[V3]] : i32
+// CHECK-NEXT: %[[V5:.+]] = arith.cmpi sle, %[[arg1]], %[[V4]] : i32
+// CHECK-NEXT: %[[V6:.+]] = arith.select %5, %[[arg1]], %[[V4]] : i32
+// CHECK-NEXT: %[[V7:.+]] = arith.index_cast %[[V6]] : i32 to index
+// CHECK-NEXT: affine.for %[[arg3:.+]] = 0 to 10 {
+// CHECK-NEXT: affine.if #set()[%[[V7]]] {
+// CHECK-NEXT: func.call @run() : () -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// -----
+
+module {
+ llvm.func @atoi(!llvm.ptr) -> i32
+func.func @_Z7runTestiPPc(%arg0: i32, %39: memref<?xi32>, %arg1: !llvm.ptr) attributes {llvm.linkage = #llvm.linkage<external>} {
+ %c2_i32 = arith.constant 2 : i32
+ %c16_i32 = arith.constant 16 : i32
+ %58 = llvm.call @atoi(%arg1) : (!llvm.ptr) -> i32
+ %40 = arith.divsi %58, %c16_i32 : i32
+ affine.for %arg2 = 1 to 10 {
+ %62 = arith.index_cast %arg2 : index to i32
+ %67 = arith.muli %58, %62 : i32
+ %69 = arith.addi %67, %40 : i32
+ %75 = arith.addi %69, %58 : i32
+ %76 = arith.index_cast %75 : i32 to index
+ memref.store %c2_i32, %39[%76] : memref<?xi32>
+ }
+ return
+}
+}
+
+// CHECK: func.func @_Z7runTestiPPc(%[[arg0:.+]]: i32, %[[arg1:.+]]: memref<?xi32>, %[[arg2:.+]]: !llvm.ptr) attributes {llvm.linkage = #llvm.linkage<external>} {
+// CHECK-NEXT: %[[c2_i32:.+]] = arith.constant 2 : i32
+// CHECK-NEXT: %[[c16_i32:.+]] = arith.constant 16 : i32
+// CHECK-NEXT: %[[V0:.+]] = llvm.call @atoi(%[[arg2]]) : (!llvm.ptr) -> i32
+// CHECK-NEXT: %[[V1:.+]] = arith.index_cast %[[V0]] : i32 to index
+// CHECK-NEXT: %[[V2:.+]] = arith.divsi %[[V0]], %[[c16_i32]] : i32
+// CHECK-NEXT: %[[V3:.+]] = arith.index_cast %[[V2]] : i32 to index
+// CHECK-NEXT: affine.for %[[arg3:.+]] = 1 to 10 {
+// CHECK-NEXT: affine.store %[[c2_i32]], %[[arg1]][%[[arg3]] * symbol(%1) + symbol(%1) + symbol(%[[V3]])] : memref<?xi32>
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// -----
+
+module {
+ func.func @c(%71: memref<?xf32>, %39: i64) {
+ affine.parallel (%arg2, %arg3) = (0, 0) to (42, 512) {
+ %262 = arith.index_cast %arg2 : index to i32
+ %a264 = arith.extsi %262 : i32 to i64
+ %268 = arith.cmpi slt, %a264, %39 : i64
+ scf.if %268 {
+ "test.something"() : () -> ()
+ }
+ }
+ return
+ }
+}
+
+// CHECK: #set = affine_set<(d0)[s0] : (-d0 + s0 - 1 >= 0)>
+// CHECK: func.func @c(%[[arg0:.+]]: memref<?xf32>, %[[arg1]]: i64) {
+// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i64 to index
+// CHECK-NEXT: affine.parallel (%[[arg2:.+]], %[[arg3:.+]]) = (0, 0) to (42, 512) {
+// CHECK-NEXT: affine.if #set(%[[arg2]])[%[[V0]]] {
+// CHECK-NEXT: "test.something"() : () -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
diff --git a/mlir/test/Conversion/SCFToAffine/affraise.mlir b/mlir/test/Conversion/SCFToAffine/affraise.mlir
new file mode 100644
index 0000000000000..8c16682a2933b
--- /dev/null
+++ b/mlir/test/Conversion/SCFToAffine/affraise.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt --affine-cfg --raise-scf-to-affine %s | FileCheck %s
+
+module {
+ func.func @withinif(%arg0: memref<?xf64>, %arg1: i32, %arg2: memref<?xf64>, %arg3: i1) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.if %arg3 {
+ %3 = arith.index_cast %arg1 : i32 to index
+ scf.for %arg6 = %c1 to %3 step %c1 {
+ %4 = memref.load %arg0[%arg6] : memref<?xf64>
+ memref.store %4, %arg2[%arg6] : memref<?xf64>
+ }
+ }
+ return
+ }
+ func.func @aff(%c : i1, %arg0: i32) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.if %c {
+ %75 = arith.index_cast %arg0 : i32 to index
+ scf.parallel (%arg5) = (%c0) to (%75) step (%c1) -> () {
+ "test.op"() : () -> ()
+ }
+ }
+ return
+ }
+}
+
+// CHECK: func.func @withinif(%[[arg0:.+]]: memref<?xf64>, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref<?xf64>, %[[arg3:.+]]: i1) {
+// CHECK-DAG: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index
+// CHECK-NEXT: scf.if %[[arg3]] {
+// CHECK-NEXT: affine.for %[[arg4:.+]] = 1 to %[[V0]] {
+// CHECK-NEXT: %[[V1:.+]] = memref.load %[[arg0]][%[[arg4]]] : memref<?xf64>
+// CHECK-NEXT: memref.store %[[V1]], %[[arg2]][%[[arg4]]] : memref<?xf64>
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+// CHECK-NEXT: func.func @aff(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32) {
+// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index
+// CHECK-NEXT: scf.if %[[arg0]] {
+// CHECK-NEXT: affine.parallel (%[[arg2:.+]]) = (0) to (symbol(%[[V0]])) {
+// CHECK-NEXT: "test.op"() : () -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
diff --git a/mlir/test/Conversion/SCFToAffine/affraise2.mlir b/mlir/test/Conversion/SCFToAffine/affraise2.mlir
new file mode 100644
index 0000000000000..2da1e3713dd96
--- /dev/null
+++ b/mlir/test/Conversion/SCFToAffine/affraise2.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt --affine-cfg --raise-scf-to-affine %s | FileCheck %s
+
+module {
+ func.func @main(%12 : i1, %14 : i32, %18 : memref<?xf32>, %19 : memref<?xf32> ) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ scf.if %12 {
+ %15 = arith.index_cast %14 : i32 to index
+ %16 = arith.muli %15, %c4 : index
+ %17 = arith.divui %16, %c4 : index
+ scf.for %arg2 = %c0 to %17 step %c1 {
+ %20 = memref.load %19[%arg2] : memref<?xf32>
+ memref.store %20, %18[%arg2] : memref<?xf32>
+ }
+ }
+ return
+ }
+}
+
+// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref<?xf32>, %[[arg3:.+]]: memref<?xf32>) {
+// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index
+// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index
+// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index
+// CHECK-NEXT: scf.if %[[arg0]] {
+// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index
+// CHECK-NEXT: affine.for %[[arg4:.+]] = 0 to %[[V2]] {
+// CHECK-NEXT: %[[a:.+]] = memref.load %[[arg3]][%[[arg4]]] : memref<?xf32>
+// CHECK-NEXT: memref.store %[[a]], %[[arg2]][%[[arg4]]] : memref<?xf32>
+// CHECK-NEXT: }
+// CHECK-NEXT: }
diff --git a/mlir/test/Conversion/SCFToAffine/affraise3.mlir b/mlir/test/Conversion/SCFToAffine/affraise3.mlir
new file mode 100644
index 0000000000000..f253c384622c3
--- /dev/null
+++ b/mlir/test/Conversion/SCFToAffine/affraise3.mlir
@@ -0,0 +1,95 @@
+// RUN: mlir-opt --affine-cfg --raise-scf-to-affine %s | FileCheck %s
+
+module {
+ func.func @slt(%arg0: index) {
+ affine.for %arg1 = 0 to 10 {
+ %c = arith.cmpi slt, %arg1, %arg0 : index
+ scf.if %c {
+ "test.run"(%arg1) : (index) -> ()
+ }
+ }
+ return
+ }
+ func.func @sle(%arg0: index) {
+ affine.for %arg1 = 0 to 10 {
+ %c = arith.cmpi sle, %arg1, %arg0 : index
+ scf.if %c {
+ "test.run"(%arg1) : (index) -> ()
+ }
+ }
+ return
+ }
+ func.func @sgt(%arg0: index) {
+ affine.for %arg1 = 0 to 10 {
+ %c = arith.cmpi sgt, %arg1, %arg0 : index
+ scf.if %c {
+ "test.run"(%arg1) : (index) -> ()
+ }
+ }
+ return
+ }
+ func.func @sge(%arg0: index) {
+ affine.for %arg1 = 0 to 10 {
+ %c = arith.cmpi sge, %arg1, %arg0 : index
+ scf.if %c {
+ "test.run"(%arg1) : (index) -> ()
+ }
+ }
+ return
+ }
+}
+
+// -d0 + s0 - 1 >= 0 =>
+// -d0 >= 1 - s0
+// d0 <= s0 - 1
+// d0 < s0
+// CHECK: #set = affine_set<(d0)[s0] : (-d0 + s0 - 1 >= 0)>
+
+
+// -d0 + s0 >= 0 =>
+// -d0 >= - s0
+// d0 <= s0
+// CHECK: #set1 = affine_set<(d0)[s0] : (-d0 + s0 >= 0)>
+
+// d0 - s0 - 1 >= 0 =>
+// d0 >= s0 + 1
+// d0 > s0
+// CHECK: #set2 = affine_set<(d0)[s0] : (d0 - s0 - 1 >= 0)>
+
+// d0 - s0 >= 0 =>
+// d0 >= s0
+// CHECK: #set3 = affine_set<(d0)[s0] : (d0 - s0 >= 0)>
+
+// CHECK: func.func @slt(%[[arg0:.+]]: index) {
+// CHECK-NEXT: affine.for %[[arg1:.+]] = 0 to 10 {
+// CHECK-NEXT: affine.if #set(%arg1)[%[[arg0]]] {
+// CHECK-NEXT: "test.run"(%[[arg1]]) : (index) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK: func.func @sle(%[[arg0:.+]]: index) {
+// CHECK-NEXT: affine.for %[[arg1:.+]] = 0 to 10 {
+// CHECK-NEXT: affine.if #set1(%arg1)[%[[arg0]]] {
+// CHECK-NEXT: "test.run"(%[[arg1]]) : (index) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK: func.func @sgt(%[[arg0:.+]]: index) {
+// CHECK-NEXT: affine.for %[[arg1:.+]] = 0 to 10 {
+// CHECK-NEXT: affine.if #set2(%arg1)[%[[arg0]]] {
+// CHECK-NEXT: "test.run"(%[[arg1]]) : (index) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK: func.func @sge(%[[arg0:.+]]: index) {
+// CHECK-NEXT: affine.for %[[arg1:.+]] = 0 to 10 {
+// CHECK-NEXT: affine.if #set3(%arg1)[%[[arg0]]] {
+// CHECK-NEXT: "test.run"(%[[arg1]]) : (index) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
diff --git a/mlir/test/Conversion/SCFToAffine/raisescffor.mlir b/mlir/test/Conversion/SCFToAffine/raisescffor.mlir
new file mode 100644
index 0000000000000..5eb78ff3079c8
--- /dev/null
+++ b/mlir/test/Conversion/SCFToAffine/raisescffor.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --raise-scf-to-affine %s | FileCheck %s
+
+module {
+ func.func private @_Z12kernel5_initPc(%0: index, %arg0: memref<index>) {
+ %c10 = arith.constant 10 : index
+ %c0 = arith.constant 0 : index
+ scf.for %arg1 = %c0 to %c10 step %0 {
+ memref.store %c10, %arg0[] : memref<index>
+ }
+ return
+ }
+}
+
+// CHECK-LABEL: func.func private @_Z12kernel5_initPc(
+// CHECK-SAME: %[[VAL_0:.*]]: index,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<index>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 10 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_0]], %[[VAL_2]] : index
+// CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index
+// CHECK: %[[VAL_6:.*]] = arith.divui %[[VAL_5]], %[[VAL_0]] : index
+// CHECK: affine.for %[[VAL_7:.*]] = 0 to %[[VAL_6]] {
+// CHECK: memref.store %[[VAL_3]], %[[VAL_1]][] : memref<index>
+// CHECK: }
+
More information about the Mlir-commits
mailing list