[flang-commits] [flang] 80c27ab - [fir] Add affine demotion pass

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Fri Oct 8 05:03:34 PDT 2021


Author: Rajan Walia
Date: 2021-10-08T14:03:27+02:00
New Revision: 80c27abb2f74aa439ba381148186407ebae1793e

URL: https://github.com/llvm/llvm-project/commit/80c27abb2f74aa439ba381148186407ebae1793e
DIFF: https://github.com/llvm/llvm-project/commit/80c27abb2f74aa439ba381148186407ebae1793e.diff

LOG: [fir] Add affine demotion pass

Add affine demotion pass.
Affine dialect's default lowering for loads and stores is different from
fir as it uses the `memref` type. The `memref` type is not compatible with
the Fortran runtime. Therefore, conversion of memory operations back to
`fir.load` and `fir.store` with `!fir.ref<?>` types is required.

This patch is part of the upstreaming effort from fir-dev branch.

Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: V Donaldson <vdonaldson at nvidia.com>
Co-authored-by: Rajan Walia <walrajan at gmail.com>
Co-authored-by: Sourabh Singh Tomar <SourabhSingh.Tomar at amd.com>
Co-authored-by: Valentin Clement <clementval at gmail.com>

Reviewed By: schweitz

Differential Revision: https://reviews.llvm.org/D111257

Added: 
    flang/lib/Optimizer/Transforms/AffineDemotion.cpp
    flang/test/Fir/affine-demotion.fir

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Optimizer/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 2b17a7e4c11f..f89e80c889e9 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -22,8 +22,13 @@ class Region;
 
 namespace fir {
 
-std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
+//===----------------------------------------------------------------------===//
+// Passes defined in Passes.td
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<mlir::Pass> createAffineDemotionPass();
 std::unique_ptr<mlir::Pass> createExternalNameConversionPass();
+std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
 
 /// Support for inlining on FIR.
 bool canLegallyInline(mlir::Operation *op, mlir::Region *reg,

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 64ee44a9c36e..1929480dc5ec 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -41,6 +41,20 @@ def AffineDialectPromotion : FunctionPass<"promote-to-affine"> {
   ];
 }
 
+def AffineDialectDemotion : FunctionPass<"demote-affine"> {
+  let summary = "Converts `affine.{load,store}` back to fir operations";
+  let description = [{
+    Affine dialect's default lowering for loads and stores is 
diff erent from
+    fir as it uses the `memref` type. The `memref` type is not compatible with
+    the Fortran runtime. Therefore, conversion of memory operations back to
+    `fir.load` and `fir.store` with `!fir.ref<?>` types is required.
+  }];
+  let constructor = "::fir::createAffineDemotionPass()";
+  let dependentDialects = [
+    "fir::FIROpsDialect", "mlir::StandardOpsDialect", "mlir::AffineDialect"
+  ];
+}
+
 def ExternalNameConversion : Pass<"external-name-interop", "mlir::ModuleOp"> {
   let summary = "Convert name for external interoperability";
   let description = [{

diff  --git a/flang/lib/Optimizer/Transforms/AffineDemotion.cpp b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
new file mode 100644
index 000000000000..29bff0f60901
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/AffineDemotion.cpp
@@ -0,0 +1,162 @@
+//===-- AffineDemotion.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "flang-affine-demotion"
+
+using namespace fir;
+
+namespace {
+
+class AffineLoadConversion : public OpRewritePattern<mlir::AffineLoadOp> {
+public:
+  using OpRewritePattern<mlir::AffineLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::AffineLoadOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> indices(op.getMapOperands());
+    auto maybeExpandedMap =
+        expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+    if (!maybeExpandedMap)
+      return failure();
+
+    auto coorOp = rewriter.create<fir::CoordinateOp>(
+        op.getLoc(), fir::ReferenceType::get(op.getResult().getType()),
+        op.getMemRef(), *maybeExpandedMap);
+
+    rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult());
+    return success();
+  }
+};
+
+class AffineStoreConversion : public OpRewritePattern<mlir::AffineStoreOp> {
+public:
+  using OpRewritePattern<mlir::AffineStoreOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::AffineStoreOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> indices(op.getMapOperands());
+    auto maybeExpandedMap =
+        expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
+    if (!maybeExpandedMap)
+      return failure();
+
+    auto coorOp = rewriter.create<fir::CoordinateOp>(
+        op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()),
+        op.getMemRef(), *maybeExpandedMap);
+    rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValueToStore(),
+                                              coorOp.getResult());
+    return success();
+  }
+};
+
+class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  mlir::LogicalResult
+  matchAndRewrite(fir::ConvertOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    if (op.res().getType().isa<mlir::MemRefType>()) {
+      // due to index calculation moving to affine maps we still need to
+      // add converts for sequence types this has a side effect of losing
+      // some information about arrays with known dimensions by creating:
+      // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) ->
+      // !fir.ref<!fir.array<?xi32>>
+      if (auto refTy = op.value().getType().dyn_cast<fir::ReferenceType>())
+        if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) {
+          fir::SequenceType::Shape flatShape = {
+              fir::SequenceType::getUnknownExtent()};
+          auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy());
+          auto flatTy = fir::ReferenceType::get(flatArrTy);
+          rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, op.value());
+          return success();
+        }
+      rewriter.startRootUpdate(op->getParentOp());
+      op.getResult().replaceAllUsesWith(op.value());
+      rewriter.finalizeRootUpdate(op->getParentOp());
+      rewriter.eraseOp(op);
+    }
+    return success();
+  }
+};
+
+mlir::Type convertMemRef(mlir::MemRefType type) {
+  return fir::SequenceType::get(
+      SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()),
+      type.getElementType());
+}
+
+class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  mlir::LogicalResult
+  matchAndRewrite(memref::AllocOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()),
+                                               op.memref());
+    return success();
+  }
+};
+
+class AffineDialectDemotion
+    : public AffineDialectDemotionBase<AffineDialectDemotion> {
+public:
+  void runOnFunction() override {
+    auto *context = &getContext();
+    auto function = getFunction();
+    LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n";
+               function.print(llvm::dbgs()););
+
+    mlir::OwningRewritePatternList patterns(context);
+    patterns.insert<ConvertConversion>(context);
+    patterns.insert<AffineLoadConversion>(context);
+    patterns.insert<AffineStoreConversion>(context);
+    patterns.insert<StdAllocConversion>(context);
+    mlir::ConversionTarget target(*context);
+    target.addIllegalOp<memref::AllocOp>();
+    target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) {
+      if (op.res().getType().isa<mlir::MemRefType>())
+        return false;
+      return true;
+    });
+    target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
+                           mlir::StandardOpsDialect>();
+
+    if (mlir::failed(mlir::applyPartialConversion(function, target,
+                                                  std::move(patterns)))) {
+      mlir::emitError(mlir::UnknownLoc::get(context),
+                      "error in converting affine dialect\n");
+      signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() {
+  return std::make_unique<AffineDialectDemotion>();
+}

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 2ae6fbc95fcf..dbb0d46aa95d 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_flang_library(FIRTransforms
   AffinePromotion.cpp
+  AffineDemotion.cpp
   Inliner.cpp
   ExternalNameConversion.cpp
 

diff  --git a/flang/test/Fir/affine-demotion.fir b/flang/test/Fir/affine-demotion.fir
new file mode 100644
index 000000000000..a57e30d12c2e
--- /dev/null
+++ b/flang/test/Fir/affine-demotion.fir
@@ -0,0 +1,68 @@
+// Test affine demotion pass
+
+// RUN: fir-opt --split-input-file --demote-affine %s | FileCheck %s
+
+#map0 = affine_map<()[s0, s1] -> (s1 - s0 + 1)>
+#map1 = affine_map<()[s0] -> (s0 + 1)>
+#map2 = affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)>
+module  {
+  func @calc(%arg0: !fir.ref<!fir.array<?xf32>>, %arg1: !fir.ref<!fir.array<?xf32>>, %arg2: !fir.ref<!fir.array<?xf32>>) {
+    %c1 = constant 1 : index
+    %c100 = constant 100 : index
+    %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+    %1 = affine.apply #map0()[%c1, %c100]
+    %2 = fir.alloca !fir.array<?xf32>, %1
+    %3 = fir.convert %arg0 : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
+    %4 = fir.convert %arg1 : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
+    %5 = fir.convert %2 : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
+    affine.for %arg3 = %c1 to #map1()[%c100] {
+      %7 = affine.apply #map2(%arg3)[%c1, %c100, %c1]
+      %8 = affine.load %3[%7] : memref<?xf32>
+      %9 = affine.load %4[%7] : memref<?xf32>
+      %10 = addf %8, %9 : f32
+      affine.store %10, %5[%7] : memref<?xf32>
+    }
+    %6 = fir.convert %arg2 : (!fir.ref<!fir.array<?xf32>>) -> memref<?xf32>
+    affine.for %arg3 = %c1 to #map1()[%c100] {
+      %7 = affine.apply #map2(%arg3)[%c1, %c100, %c1]
+      %8 = affine.load %5[%7] : memref<?xf32>
+      %9 = affine.load %4[%7] : memref<?xf32>
+      %10 = mulf %8, %9 : f32
+      affine.store %10, %6[%7] : memref<?xf32>
+    }
+    return
+  }
+}
+
+// CHECK:  func @calc(%[[VAL_0:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_1:.*]]: !fir.ref<!fir.array<?xf32>>, %[[VAL_2:.*]]: !fir.ref<!fir.array<?xf32>>) {
+// CHECK:    %[[VAL_3:.*]] = constant 1 : index
+// CHECK:    %[[VAL_4:.*]] = constant 100 : index
+// CHECK:    %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK:    %[[VAL_6:.*]] = constant 100 : index
+// CHECK:    %[[VAL_7:.*]] = fir.alloca !fir.array<?xf32>, %[[VAL_6]]
+// CHECK:    %[[VAL_8:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK:    %[[VAL_9:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK:    %[[VAL_10:.*]] = fir.convert %[[VAL_7]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK:    affine.for %[[VAL_11:.*]] = 1 to 101 {
+// CHECK:      %[[VAL_12:.*]] = affine.apply #map(%[[VAL_11]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK:      %[[VAL_13:.*]] = fir.coordinate_of %[[VAL_8]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK:      %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<f32>
+// CHECK:      %[[VAL_15:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK:      %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<f32>
+// CHECK:      %[[VAL_17:.*]] = addf %[[VAL_14]], %[[VAL_16]] : f32
+// CHECK:      %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_10]], %[[VAL_12]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK:      fir.store %[[VAL_17]] to %[[VAL_18]] : !fir.ref<f32>
+// CHECK:    }
+// CHECK:    %[[VAL_19:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>>
+// CHECK:    affine.for %[[VAL_20:.*]] = 1 to 101 {
+// CHECK:      %[[VAL_21:.*]] = affine.apply #map(%[[VAL_20]]){{\[}}%[[VAL_3]], %[[VAL_4]], %[[VAL_3]]]
+// CHECK:      %[[VAL_22:.*]] = fir.coordinate_of %[[VAL_10]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK:      %[[VAL_23:.*]] = fir.load %[[VAL_22]] : !fir.ref<f32>
+// CHECK:      %[[VAL_24:.*]] = fir.coordinate_of %[[VAL_9]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK:      %[[VAL_25:.*]] = fir.load %[[VAL_24]] : !fir.ref<f32>
+// CHECK:      %[[VAL_26:.*]] = mulf %[[VAL_23]], %[[VAL_25]] : f32
+// CHECK:      %[[VAL_27:.*]] = fir.coordinate_of %[[VAL_19]], %[[VAL_21]] : (!fir.ref<!fir.array<?xf32>>, index) -> !fir.ref<f32>
+// CHECK:      fir.store %[[VAL_26]] to %[[VAL_27]] : !fir.ref<f32>
+// CHECK:    }
+// CHECK:    return
+// CHECK:  }


        


More information about the flang-commits mailing list