[flang-commits] [flang] d70938b - [fir] Add affine promotion pass

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Fri Oct 8 02:47:31 PDT 2021


Author: Jean Perier
Date: 2021-10-08T11:47:23+02:00
New Revision: d70938bbad0a659b9a142c493be7a71d1a6af0bb

URL: https://github.com/llvm/llvm-project/commit/d70938bbad0a659b9a142c493be7a71d1a6af0bb
DIFF: https://github.com/llvm/llvm-project/commit/d70938bbad0a659b9a142c493be7a71d1a6af0bb.diff

LOG: [fir] Add affine promotion pass

Convert fir operations which satisfy affine constraints to the affine
dialect.

This patch is part of the upstreaming effort from fir-dev branch.

Co-authored-by: V Donaldson <vdonaldson at nvidia.com>
Co-authored-by: Rajan Walia <walrajan at gmail.com>
Co-authored-by: Sourabh Singh Tomar <SourabhSingh.Tomar at amd.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Co-authored-by: Valentin Clement <clementval at gmail.com>

Reviewed By: schweitz, awarzynski

Differential Revision: https://reviews.llvm.org/D111155

Added: 
    flang/lib/Optimizer/Transforms/AffinePromotion.cpp
    flang/test/Fir/affine-promotion.fir

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Optimizer/Transforms/CMakeLists.txt
    flang/lib/Optimizer/Transforms/PassDetail.h

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 881cef29a8b4c..2b17a7e4c11f8 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -22,6 +22,7 @@ class Region;
 
 namespace fir {
 
+std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
 std::unique_ptr<mlir::Pass> createExternalNameConversionPass();
 
 /// Support for inlining on FIR.

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 680f80e06d8cc..64ee44a9c36ea 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -16,6 +16,31 @@
 
 include "mlir/Pass/PassBase.td"
 
+def AffineDialectPromotion : FunctionPass<"promote-to-affine"> {
+  let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`.";
+  let description = [{
+    Convert fir operations which satisfy affine constraints to the affine
+    dialect.
+
+    `fir.do_loop` will be converted to `affine.for` if the loops inside the body
+    can be converted and the indices for memory loads and stores satisfy
+    `affine.apply` criteria for symbols and dimensions.
+
+    `fir.if` will be converted to `affine.if` where possible. `affine.if`'s
+    condition uses an integer set (==, >=) and an analysis is done to determine
+    the fir condition's parent operations to construct the integer set.
+
+    `fir.load` (`fir.store`) will be converted to `affine.load` (`affine.store`)
+    where possible. This conversion includes adding a dummy `fir.convert` cast
+    to adapt values of type `!fir.ref<!fir.array>` to `memref`. This is done
+    because the affine dialect presently only understands the `memref` type.
+  }];
+  let constructor = "::fir::createPromoteToAffinePass()";
+  let dependentDialects = [
+    "fir::FIROpsDialect", "mlir::StandardOpsDialect", "mlir::AffineDialect"
+  ];
+}
+
 def ExternalNameConversion : Pass<"external-name-interop", "mlir::ModuleOp"> {
   let summary = "Convert name for external interoperability";
   let description = [{

diff  --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
new file mode 100644
index 0000000000000..16b66721496b2
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
@@ -0,0 +1,609 @@
+//===-- AffinePromotion.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "flang-affine-promotion"
+
+using namespace fir;
+
+namespace {
+struct AffineLoopAnalysis;
+struct AffineIfAnalysis;
+
+/// Stores analysis objects for all loops and if operations inside a function
+///  these analysis are used twice, first for marking operations for rewrite and
+///  second when doing rewrite.
+struct AffineFunctionAnalysis {
+  explicit AffineFunctionAnalysis(mlir::FuncOp funcOp) {
+    for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
+      loopAnalysisMap.try_emplace(op, op, *this);
+  }
+
+  AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
+
+  AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const;
+
+  llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap;
+  llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap;
+};
+} // namespace
+
+static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) {
+  if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) {
+    if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()))
+      return true;
+    LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a "
+                               "loop induction variable (owner not loopOp)\n";
+               op->dump());
+    return false;
+  }
+  LLVM_DEBUG(
+      llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop "
+                      "induction variable (not a block argument)\n";
+      op->dump(); coordinate.getDefiningOp()->dump());
+  return false;
+}
+
+namespace {
+struct AffineLoopAnalysis {
+  AffineLoopAnalysis() = default;
+
+  explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa)
+      : legality(analyzeLoop(op, afa)) {}
+
+  bool canPromoteToAffine() { return legality; }
+
+private:
+  bool analyzeBody(fir::DoLoopOp loopOperation,
+                   AffineFunctionAnalysis &functionAnalysis) {
+    for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) {
+      auto analysis = functionAnalysis.loopAnalysisMap
+                          .try_emplace(loopOp, loopOp, functionAnalysis)
+                          .first->getSecond();
+      if (!analysis.canPromoteToAffine())
+        return false;
+    }
+    for (auto ifOp : loopOperation.getOps<fir::IfOp>())
+      functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis);
+    return true;
+  }
+
+  bool analyzeLoop(fir::DoLoopOp loopOperation,
+                   AffineFunctionAnalysis &functionAnalysis) {
+    LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
+    return analyzeMemoryAccess(loopOperation) &&
+           analyzeBody(loopOperation, functionAnalysis);
+  }
+
+  bool analyzeReference(mlir::Value memref, mlir::Operation *op) {
+    if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) {
+      if (acoOp.memref().getType().isa<fir::BoxType>()) {
+        // TODO: Look if and how fir.box can be promoted to affine.
+        LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, "
+                                   "array memory operation uses fir.box\n";
+                   op->dump(); acoOp.dump(););
+        return false;
+      }
+      bool canPromote = true;
+      for (auto coordinate : acoOp.indices())
+        canPromote = canPromote && analyzeCoordinate(coordinate, op);
+      return canPromote;
+    }
+    if (auto coOp = memref.getDefiningOp<CoordinateOp>()) {
+      LLVM_DEBUG(llvm::dbgs()
+                     << "AffineLoopAnalysis: cannot promote loop, "
+                        "array memory operation uses non ArrayCoorOp\n";
+                 op->dump(); coOp.dump(););
+
+      return false;
+    }
+    LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory "
+                               "reference for array load\n";
+               op->dump(););
+    return false;
+  }
+
+  bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) {
+    for (auto loadOp : loopOperation.getOps<fir::LoadOp>())
+      if (!analyzeReference(loadOp.memref(), loadOp))
+        return false;
+    for (auto storeOp : loopOperation.getOps<fir::StoreOp>())
+      if (!analyzeReference(storeOp.memref(), storeOp))
+        return false;
+    return true;
+  }
+
+  bool legality{};
+};
+} // namespace
+
+AffineLoopAnalysis
+AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const {
+  auto it = loopAnalysisMap.find_as(op);
+  if (it == loopAnalysisMap.end()) {
+    LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
+               op.dump(););
+    op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n");
+    return {};
+  }
+  return it->getSecond();
+}
+
+namespace {
+/// Calculates arguments for creating an IntegerSet. symCount, dimCount are the
+/// final number of symbols and dimensions of the affine map. Integer set if
+/// possible is in Optional IntegerSet.
+struct AffineIfCondition {
+  using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>;
+
+  explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) {
+    if (auto condDef = firCondition.getDefiningOp<mlir::CmpIOp>())
+      fromCmpIOp(condDef);
+  }
+
+  bool hasIntegerSet() const { return integerSet.hasValue(); }
+
+  mlir::IntegerSet getIntegerSet() const {
+    assert(hasIntegerSet() && "integer set is missing");
+    return integerSet.getValue();
+  }
+
+  mlir::ValueRange getAffineArgs() const { return affineArgs; }
+
+private:
+  MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs,
+                                 mlir::Value rhs) {
+    return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs));
+  }
+
+  MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs,
+                                 MaybeAffineExpr rhs) {
+    if (lhs.hasValue() && rhs.hasValue())
+      return mlir::getAffineBinaryOpExpr(kind, lhs.getValue(), rhs.getValue());
+    return {};
+  }
+
+  MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; }
+
+  MaybeAffineExpr toAffineExpr(int64_t value) {
+    return {mlir::getAffineConstantExpr(value, firCondition.getContext())};
+  }
+
+  /// Returns an AffineExpr if it is a result of operations that can be done
+  /// in an affine expression, this includes -, +, *, rem, constant.
+  /// block arguments of a loopOp or forOp are used as dimensions
+  MaybeAffineExpr toAffineExpr(mlir::Value value) {
+    if (auto op = value.getDefiningOp<mlir::SubIOp>())
+      return affineBinaryOp(mlir::AffineExprKind::Add, toAffineExpr(op.lhs()),
+                            affineBinaryOp(mlir::AffineExprKind::Mul,
+                                           toAffineExpr(op.rhs()),
+                                           toAffineExpr(-1)));
+    if (auto op = value.getDefiningOp<mlir::AddIOp>())
+      return affineBinaryOp(mlir::AffineExprKind::Add, op.lhs(), op.rhs());
+    if (auto op = value.getDefiningOp<mlir::MulIOp>())
+      return affineBinaryOp(mlir::AffineExprKind::Mul, op.lhs(), op.rhs());
+    if (auto op = value.getDefiningOp<mlir::UnsignedRemIOp>())
+      return affineBinaryOp(mlir::AffineExprKind::Mod, op.lhs(), op.rhs());
+    if (auto op = value.getDefiningOp<mlir::ConstantOp>())
+      if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>())
+        return toAffineExpr(intConstant.getInt());
+    if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
+      affineArgs.push_back(value);
+      if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) ||
+          isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp()))
+        return {mlir::getAffineDimExpr(dimCount++, value.getContext())};
+      return {mlir::getAffineSymbolExpr(symCount++, value.getContext())};
+    }
+    return {};
+  }
+
+  void fromCmpIOp(mlir::CmpIOp cmpOp) {
+    auto lhsAffine = toAffineExpr(cmpOp.lhs());
+    auto rhsAffine = toAffineExpr(cmpOp.rhs());
+    if (!lhsAffine.hasValue() || !rhsAffine.hasValue())
+      return;
+    auto constraintPair = constraint(
+        cmpOp.predicate(), rhsAffine.getValue() - lhsAffine.getValue());
+    if (!constraintPair)
+      return;
+    integerSet = mlir::IntegerSet::get(dimCount, symCount,
+                                       {constraintPair.getValue().first},
+                                       {constraintPair.getValue().second});
+    return;
+  }
+
+  llvm::Optional<std::pair<AffineExpr, bool>>
+  constraint(mlir::CmpIPredicate predicate, mlir::AffineExpr basic) {
+    switch (predicate) {
+    case mlir::CmpIPredicate::slt:
+      return {std::make_pair(basic - 1, false)};
+    case mlir::CmpIPredicate::sle:
+      return {std::make_pair(basic, false)};
+    case mlir::CmpIPredicate::sgt:
+      return {std::make_pair(1 - basic, false)};
+    case mlir::CmpIPredicate::sge:
+      return {std::make_pair(0 - basic, false)};
+    case mlir::CmpIPredicate::eq:
+      return {std::make_pair(basic, true)};
+    default:
+      return {};
+    }
+  }
+
+  llvm::SmallVector<mlir::Value> affineArgs;
+  llvm::Optional<mlir::IntegerSet> integerSet;
+  mlir::Value firCondition;
+  unsigned symCount{0u};
+  unsigned dimCount{0u};
+};
+} // namespace
+
+namespace {
+/// Analysis for affine promotion of fir.if
+struct AffineIfAnalysis {
+  AffineIfAnalysis() = default;
+
+  explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa)
+      : legality(analyzeIf(op, afa)) {}
+
+  bool canPromoteToAffine() { return legality; }
+
+private:
+  bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) {
+    if (op.getNumResults() == 0)
+      return true;
+    LLVM_DEBUG(llvm::dbgs()
+                   << "AffineIfAnalysis: not promoting as op has results\n";);
+    return false;
+  }
+
+  bool legality{};
+};
+} // namespace
+
+AffineIfAnalysis
+AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const {
+  auto it = ifAnalysisMap.find_as(op);
+  if (it == ifAnalysisMap.end()) {
+    LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
+               op.dump(););
+    op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n");
+    return {};
+  }
+  return it->getSecond();
+}
+
+/// AffineMap rewriting fir.array_coor operation to affine apply,
+/// %dim = fir.gendim %lowerBound, %upperBound, %stride
+/// %a = fir.array_coor %arr(%dim) %i
+/// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)>
+static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,
+                                                 MLIRContext *context) {
+  auto index = mlir::getAffineConstantExpr(0, context);
+  auto accuExtent = mlir::getAffineConstantExpr(1, context);
+  for (unsigned i = 0; i < dimensions; ++i) {
+    mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context),
+                     lowerBound = mlir::getAffineSymbolExpr(i * 3, context),
+                     currentExtent =
+                         mlir::getAffineSymbolExpr(i * 3 + 1, context),
+                     stride = mlir::getAffineSymbolExpr(i * 3 + 2, context),
+                     currentPart = (idx * stride - lowerBound) * accuExtent;
+    index = currentPart + index;
+    accuExtent = accuExtent * currentExtent;
+  }
+  return mlir::AffineMap::get(dimensions, dimensions * 3, index);
+}
+
+static Optional<int64_t> constantIntegerLike(const mlir::Value value) {
+  if (auto definition = value.getDefiningOp<ConstantOp>())
+    if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>())
+      return stepAttr.getInt();
+  return {};
+}
+
+static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) {
+  if (auto refType = op.memref().getType().dyn_cast_or_null<ReferenceType>()) {
+    if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) {
+      return seqType.getEleTy();
+    }
+  }
+  op.emitError(
+      "AffineLoopConversion: array type in coordinate operation not valid\n");
+  return mlir::Type();
+}
+
+static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape,
+                              SmallVectorImpl<mlir::Value> &indexArgs,
+                              mlir::PatternRewriter &rewriter) {
+  auto one = rewriter.create<mlir::ConstantOp>(
+      acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
+  auto extents = shape.extents();
+  for (auto i = extents.begin(); i < extents.end(); i++) {
+    indexArgs.push_back(one);
+    indexArgs.push_back(*i);
+    indexArgs.push_back(one);
+  }
+}
+
+static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape,
+                              SmallVectorImpl<mlir::Value> &indexArgs,
+                              mlir::PatternRewriter &rewriter) {
+  auto one = rewriter.create<mlir::ConstantOp>(
+      acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
+  auto extents = shape.pairs();
+  for (auto i = extents.begin(); i < extents.end();) {
+    indexArgs.push_back(*i++);
+    indexArgs.push_back(*i++);
+    indexArgs.push_back(one);
+  }
+}
+
+static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice,
+                              SmallVectorImpl<mlir::Value> &indexArgs,
+                              mlir::PatternRewriter &rewriter) {
+  auto extents = slice.triples();
+  for (auto i = extents.begin(); i < extents.end();) {
+    indexArgs.push_back(*i++);
+    indexArgs.push_back(*i++);
+    indexArgs.push_back(*i++);
+  }
+}
+
+static void populateIndexArgs(fir::ArrayCoorOp acoOp,
+                              SmallVectorImpl<mlir::Value> &indexArgs,
+                              mlir::PatternRewriter &rewriter) {
+  if (auto shape = acoOp.shape().getDefiningOp<ShapeOp>())
+    return populateIndexArgs(acoOp, shape, indexArgs, rewriter);
+  if (auto shapeShift = acoOp.shape().getDefiningOp<ShapeShiftOp>())
+    return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter);
+  if (auto slice = acoOp.shape().getDefiningOp<SliceOp>())
+    return populateIndexArgs(acoOp, slice, indexArgs, rewriter);
+  return;
+}
+
+/// Returns affine.apply and fir.convert from array_coor and gendims
+static std::pair<mlir::AffineApplyOp, fir::ConvertOp>
+createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) {
+  auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>();
+  auto affineMap =
+      createArrayIndexAffineMap(acoOp.indices().size(), acoOp.getContext());
+  SmallVector<mlir::Value> indexArgs;
+  indexArgs.append(acoOp.indices().begin(), acoOp.indices().end());
+
+  populateIndexArgs(acoOp, indexArgs, rewriter);
+
+  auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
+                                                          affineMap, indexArgs);
+  auto arrayElementType = coordinateArrayElement(acoOp);
+  auto newType = mlir::MemRefType::get({-1}, arrayElementType);
+  auto arrayConvert =
+      rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, acoOp.memref());
+  return std::make_pair(affineApply, arrayConvert);
+}
+
+static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) {
+  rewriter.setInsertionPoint(loadOp);
+  auto affineOps = createAffineOps(loadOp.memref(), rewriter);
+  rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>(
+      loadOp, affineOps.second.getResult(), affineOps.first.getResult());
+}
+
+static void rewriteStore(fir::StoreOp storeOp,
+                         mlir::PatternRewriter &rewriter) {
+  rewriter.setInsertionPoint(storeOp);
+  auto affineOps = createAffineOps(storeOp.memref(), rewriter);
+  rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.value(),
+                                                   affineOps.second.getResult(),
+                                                   affineOps.first.getResult());
+}
+
+static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
+  for (auto &bodyOp : block->getOperations()) {
+    if (isa<fir::LoadOp>(bodyOp))
+      rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
+    if (isa<fir::StoreOp>(bodyOp))
+      rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
+  }
+}
+
+namespace {
+/// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to
+/// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir
+/// loads and stores to affine.
+class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
+      : OpRewritePattern(context), functionAnalysis(afa) {}
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::DoLoopOp loop,
+                  mlir::PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
+               loop.dump(););
+    LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
+        functionAnalysis.getChildLoopAnalysis(loop);
+    auto &loopOps = loop.getBody()->getOperations();
+    auto loopAndIndex = createAffineFor(loop, rewriter);
+    auto affineFor = loopAndIndex.first;
+    auto inductionVar = loopAndIndex.second;
+
+    rewriter.startRootUpdate(affineFor.getOperation());
+    affineFor.getBody()->getOperations().splice(
+        std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
+        std::prev(loopOps.end()));
+    rewriter.finalizeRootUpdate(affineFor.getOperation());
+
+    rewriter.startRootUpdate(loop.getOperation());
+    loop.getInductionVar().replaceAllUsesWith(inductionVar);
+    rewriter.finalizeRootUpdate(loop.getOperation());
+
+    rewriteMemoryOps(affineFor.getBody(), rewriter);
+
+    LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
+               affineFor.dump(););
+    rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
+    return success();
+  }
+
+private:
+  std::pair<mlir::AffineForOp, mlir::Value>
+  createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
+    if (auto constantStep = constantIntegerLike(op.step()))
+      if (constantStep.getValue() > 0)
+        return positiveConstantStep(op, constantStep.getValue(), rewriter);
+    return genericBounds(op, rewriter);
+  }
+
+  // when step for the loop is positive compile time constant
+  std::pair<mlir::AffineForOp, mlir::Value>
+  positiveConstantStep(fir::DoLoopOp op, int64_t step,
+                       mlir::PatternRewriter &rewriter) const {
+    auto affineFor = rewriter.create<mlir::AffineForOp>(
+        op.getLoc(), ValueRange(op.lowerBound()),
+        mlir::AffineMap::get(0, 1,
+                             mlir::getAffineSymbolExpr(0, op.getContext())),
+        ValueRange(op.upperBound()),
+        mlir::AffineMap::get(0, 1,
+                             1 + mlir::getAffineSymbolExpr(0, op.getContext())),
+        step);
+    return std::make_pair(affineFor, affineFor.getInductionVar());
+  }
+
+  std::pair<mlir::AffineForOp, mlir::Value>
+  genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
+    auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext());
+    auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext());
+    auto step = mlir::getAffineSymbolExpr(2, op.getContext());
+    mlir::AffineMap upperBoundMap = mlir::AffineMap::get(
+        0, 3, (upperBound - lowerBound + step).floorDiv(step));
+    auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>(
+        op.getLoc(), upperBoundMap,
+        ValueRange({op.lowerBound(), op.upperBound(), op.step()}));
+    auto actualIndexMap = mlir::AffineMap::get(
+        1, 2,
+        (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) *
+            mlir::getAffineSymbolExpr(1, op.getContext()));
+
+    auto affineFor = rewriter.create<mlir::AffineForOp>(
+        op.getLoc(), ValueRange(),
+        AffineMap::getConstantMap(0, op.getContext()),
+        genericUpperBound.getResult(),
+        mlir::AffineMap::get(0, 1,
+                             1 + mlir::getAffineSymbolExpr(0, op.getContext())),
+        1);
+    rewriter.setInsertionPointToStart(affineFor.getBody());
+    auto actualIndex = rewriter.create<mlir::AffineApplyOp>(
+        op.getLoc(), actualIndexMap,
+        ValueRange({affineFor.getInductionVar(), op.lowerBound(), op.step()}));
+    return std::make_pair(affineFor, actualIndex.getResult());
+  }
+
+  AffineFunctionAnalysis &functionAnalysis;
+};
+
+/// Convert `fir.if` to `affine.if`.
+class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
+      : OpRewritePattern(context) {}
+  mlir::LogicalResult
+  matchAndRewrite(fir::IfOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
+               op.dump(););
+    auto &ifOps = op.thenRegion().front().getOperations();
+    auto affineCondition = AffineIfCondition(op.condition());
+    if (!affineCondition.hasIntegerSet()) {
+      LLVM_DEBUG(
+          llvm::dbgs()
+              << "AffineIfConversion: couldn't calculate affine condition\n";);
+      return failure();
+    }
+    auto affineIf = rewriter.create<mlir::AffineIfOp>(
+        op.getLoc(), affineCondition.getIntegerSet(),
+        affineCondition.getAffineArgs(), !op.elseRegion().empty());
+    rewriter.startRootUpdate(affineIf);
+    affineIf.getThenBlock()->getOperations().splice(
+        std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
+        std::prev(ifOps.end()));
+    if (!op.elseRegion().empty()) {
+      auto &otherOps = op.elseRegion().front().getOperations();
+      affineIf.getElseBlock()->getOperations().splice(
+          std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
+          std::prev(otherOps.end()));
+    }
+    rewriter.finalizeRootUpdate(affineIf);
+    rewriteMemoryOps(affineIf.getBody(), rewriter);
+
+    LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
+               affineIf.dump(););
+    rewriter.replaceOp(op, affineIf.getOperation()->getResults());
+    return success();
+  }
+};
+
+/// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
+/// where such a promotion is possible.
+class AffineDialectPromotion
+    : public AffineDialectPromotionBase<AffineDialectPromotion> {
+public:
+  void runOnFunction() override {
+
+    auto *context = &getContext();
+    auto function = getFunction();
+    markAllAnalysesPreserved();
+    auto functionAnalysis = AffineFunctionAnalysis(function);
+    mlir::OwningRewritePatternList patterns(context);
+    patterns.insert<AffineIfConversion>(context, functionAnalysis);
+    patterns.insert<AffineLoopConversion>(context, functionAnalysis);
+    mlir::ConversionTarget target = *context;
+    target.addLegalDialect<mlir::AffineDialect, FIROpsDialect,
+                           mlir::scf::SCFDialect, mlir::StandardOpsDialect>();
+    target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
+      return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
+    });
+    target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
+                                               fir::DoLoopOp op) {
+      return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
+    });
+
+    LLVM_DEBUG(llvm::dbgs()
+                   << "AffineDialectPromotion: running promotion on: \n";
+               function.print(llvm::dbgs()););
+    // apply the patterns
+    if (mlir::failed(mlir::applyPartialConversion(function, target,
+                                                  std::move(patterns)))) {
+      mlir::emitError(mlir::UnknownLoc::get(context),
+                      "error in converting to affine dialect\n");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+/// Convert FIR loop constructs to the Affine dialect
+std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() {
+  return std::make_unique<AffineDialectPromotion>();
+}

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 59a1a63c89f79..2ae6fbc95fcf6 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_flang_library(FIRTransforms
+  AffinePromotion.cpp
   Inliner.cpp
   ExternalNameConversion.cpp
 

diff  --git a/flang/lib/Optimizer/Transforms/PassDetail.h b/flang/lib/Optimizer/Transforms/PassDetail.h
index 3155c08b0aadf..02d203f17097d 100644
--- a/flang/lib/Optimizer/Transforms/PassDetail.h
+++ b/flang/lib/Optimizer/Transforms/PassDetail.h
@@ -9,6 +9,7 @@
 #define FORTRAN_OPTMIZER_TRANSFORMS_PASSDETAIL_H
 
 #include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"

diff  --git a/flang/test/Fir/affine-promotion.fir b/flang/test/Fir/affine-promotion.fir
new file mode 100644
index 0000000000000..be6f022cf2857
--- /dev/null
+++ b/flang/test/Fir/affine-promotion.fir
@@ -0,0 +1,133 @@
+// Test affine promotion pass
+
+// RUN: fir-opt --split-input-file --promote-to-affine --affine-loop-invariant-code-motion --cse %s | FileCheck %s
+
+!arr_d1 = type !fir.ref<!fir.array<?xf32>>
+#arr_len = affine_map<()[j1,k1] -> (k1 - j1 + 1)>
+
+func @loop_with_load_and_store(%a1: !arr_d1, %a2: !arr_d1, %a3: !arr_d1) {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %len = constant 100 : index
+  %dims = fir.shape %len : (index) -> !fir.shape<1>
+  %siz = affine.apply #arr_len()[%c1,%len]
+  %t1 = fir.alloca !fir.array<?xf32>, %siz
+
+  fir.do_loop %i = %c1 to %len step %c1 {
+    %a1_idx = fir.array_coor %a1(%dims) %i
+            : (!arr_d1, !fir.shape<1>, index) -> !fir.ref<f32>
+    %a1_v = fir.load %a1_idx : !fir.ref<f32>
+
+    %a2_idx = fir.array_coor %a2(%dims) %i
+            : (!arr_d1, !fir.shape<1>, index) -> !fir.ref<f32>
+    %a2_v = fir.load %a2_idx : !fir.ref<f32>
+
+    %v = addf %a1_v, %a2_v : f32
+    %t1_idx = fir.array_coor %t1(%dims) %i
+            : (!arr_d1, !fir.shape<1>, index) -> !fir.ref<f32>
+
+    fir.store %v to %t1_idx : !fir.ref<f32>
+  }
+  fir.do_loop %i = %c1 to %len step %c1 {
+    %t1_idx = fir.array_coor %t1(%dims) %i
+            : (!arr_d1, !fir.shape<1>, index) -> !fir.ref<f32>
+    %t1_v = fir.load %t1_idx : !fir.ref<f32>
+
+    %a2_idx = fir.array_coor %a2(%dims) %i
+            : (!arr_d1, !fir.shape<1>, index) -> !fir.ref<f32>
+    %a2_v = fir.load %a2_idx : !fir.ref<f32>
+
+    %v = mulf %t1_v, %a2_v : f32
+    %a3_idx = fir.array_coor %a3(%dims) %i
+            : (!arr_d1, !fir.shape<1>, index) -> !fir.ref<f32>
+
+    fir.store %v to %a3_idx : !fir.ref<f32>
+  }
+  return
+}
+
+// CHECK:  func @loop_with_load_and_store(%[[VAL_0:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_1:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_2:.*]]: !fir.ref<!fir.array<?xf32>>) {
+// CHECK:    %[[VAL_3:.*]] = constant 1 : index
+// CHECK:    %[[VAL_4:.*]] = constant 100 : index
+// CHECK:    %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:    %[[VAL_6:.*]] = affine.apply #map0(){{\[}}%[[VAL_3]], %[[VAL_4]]]
+// CHECK:    %[[VAL_7:.*]] = fir.alloca !fir.array<?xf32>, %[[VAL_6]]
+// CHECK:    %[[VAL_8:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
+// CHECK:    %[[VAL_9:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
+// CHECK:    %[[VAL_10:.*]] = fir.convert %[[VAL_7]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
+// CHECK:    affine.for %[[VAL_11:.*]] = %[[VAL_3]] to #map1(){{\[}}%[[VAL_4]]] {
+// CHECK:      %[[VAL_12:.*]] = affine.apply #map2(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK:      %[[VAL_13:.*]] = affine.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<?xf32>
+// CHECK:      %[[VAL_14:.*]] = affine.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<?xf32>
+// CHECK:      %[[VAL_15:.*]] = addf %[[VAL_13]], %[[VAL_14]] : f32
+// CHECK:      affine.store %[[VAL_15]], %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref<?xf32>
+// CHECK:    }
+// CHECK:    %[[VAL_16:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
+// CHECK:    affine.for %[[VAL_17:.*]] = %[[VAL_3]] to #map1(){{\[}}%[[VAL_4]]] {
+// CHECK:      %[[VAL_18:.*]] = affine.apply #map2(%[[VAL_17]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK:      %[[VAL_19:.*]] = affine.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xf32>
+// CHECK:      %[[VAL_20:.*]] = affine.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xf32>
+// CHECK:      %[[VAL_21:.*]] = mulf %[[VAL_19]], %[[VAL_20]] : f32
+// CHECK:      affine.store %[[VAL_21]], %[[VAL_16]]{{\[}}%[[VAL_18]]] : memref<?xf32>
+// CHECK:    }
+// CHECK:    return
+// CHECK:  }
+
+// -----
+
+!arr_d1 = type !fir.ref<!fir.array<?xf32>>
+#arr_len = affine_map<()[j1,k1] -> (k1 - j1 + 1)>
+
+func @loop_with_if(%a: !arr_d1, %v: f32) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %len = constant 100 : index
+  %dims = fir.shape %len : (index) -> !fir.shape<1>
+
+  fir.do_loop %i = %c1 to %len step %c1 {
+    fir.do_loop %j = %c1 to %len step %c1 {
+      fir.do_loop %k = %c1 to %len step %c1 {
+        %im2 = subi %i, %c2 : index
+        %cond = cmpi "sgt", %im2, %c0 : index
+        fir.if %cond {
+          %a_idx = fir.array_coor %a(%dims) %i
+            : (!arr_d1, !fir.shape<1>, index) -> !fir.ref<f32>
+          fir.store %v to %a_idx : !fir.ref<f32>
+        }
+          %aj_idx = fir.array_coor %a(%dims) %j
+            : (!arr_d1, !fir.shape<1>, index) -> !fir.ref<f32>
+          fir.store %v to %aj_idx : !fir.ref<f32>
+          %ak_idx = fir.array_coor %a(%dims) %k
+            : (!arr_d1, !fir.shape<1>, index) -> !fir.ref<f32>
+          fir.store %v to %ak_idx : !fir.ref<f32>
+      }
+    }
+  }
+  return
+}
+
+// CHECK: func @loop_with_if(%[[VAL_0:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_1:.*]]: f32) {
+// CHECK:   %[[VAL_2:.*]] = constant 0 : index
+// CHECK:   %[[VAL_3:.*]] = constant 1 : index
+// CHECK:   %[[VAL_4:.*]] = constant 2 : index
+// CHECK:   %[[VAL_5:.*]] = constant 100 : index
+// CHECK:   %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1>
+// CHECK:   %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
+// CHECK:   affine.for %[[VAL_8:.*]] = %[[VAL_3]] to #map0(){{\[}}%[[VAL_5]]] {
+// CHECK:     %[[VAL_9:.*]] = affine.apply #map1(%[[VAL_8]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]]
+// CHECK:     affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_9]]] : memref<?xf32>
+// CHECK:   }
+// CHECK:   affine.for %[[VAL_10:.*]] = %[[VAL_3]] to #map0(){{\[}}%[[VAL_5]]] {
+// CHECK:     %[[VAL_11:.*]] = affine.apply #map1(%[[VAL_10]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]]
+// CHECK:     affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xf32>
+// CHECK:   }
+// CHECK:   affine.for %[[VAL_12:.*]] = %[[VAL_3]] to #map0(){{\[}}%[[VAL_5]]] {
+// CHECK:     %[[VAL_13:.*]] = subi %[[VAL_12]], %[[VAL_4]] : index
+// CHECK:     affine.if #set(%[[VAL_12]]) {
+// CHECK:       %[[VAL_14:.*]] = affine.apply #map1(%[[VAL_12]]){{\[}}%[[VAL_3]], %[[VAL_5]], %[[VAL_3]]]
+// CHECK:       affine.store %[[VAL_1]], %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref<?xf32>
+// CHECK:     }
+// CHECK:   }
+// CHECK:   return
+// CHECK: }


        


More information about the flang-commits mailing list