[flang-commits] [flang] [flang][hlfir][NFC] Fix mlir misuse in LowerHLFIRIntrinsics (PR #83293)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Wed Feb 28 08:58:19 PST 2024


https://github.com/tblah created https://github.com/llvm/llvm-project/pull/83293

In #83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp should not be using rewrite patterns with the dialect conversion driver.

The intention of this pass is to lower HLFIR intrinsic operations into FIR and so I think this best fits dialect conversion and so I have changed all of these into conversion patterns. Taking this approach also avoids test suite churn because GreedyPatternRewriter also performs canonicalization.

One remaining misuse of the MLIR API is that we replace values of one type with a different (although safe) type e.g.
!hlfir.expr<2xi32> -> !hlfir.expr<?xi32>. There isn't a convenient way to perform this conversion in IR at the moment because fir.convert does not accept !hlfir.expr.

>From 114444c8c976ccff7031563d2ea4478a7262a540 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Wed, 28 Feb 2024 16:34:29 +0000
Subject: [PATCH] [flang][hlfir][NFC] Fix mlir misuse in LowerHLFIRIntrinsics

In #83253 @matthias-springer pointed out that LowerHLFIRIntrinsics.cpp
should not be using rewrite patterns with the dialect conversion driver.

The intention of this pass is to lower HLFIR intrinsic operations into
FIR and so I think this best fits dialect conversion and so I have
changed all of these into conversion patterns. Taking this approach
also avoids test suite churn because GreedyPatternRewriter also performs
canonicalization.

One remaining misuse of the MLIR API is that we replace values of one
type with a different (although safe) type e.g.
!hlfir.expr<2xi32> -> !hlfir.expr<?xi32>. There isn't a convenient way
to perform this conversion in IR at the moment because fir.convert does
not accept !hlfir.expr.
---
 .../HLFIR/Transforms/LowerHLFIRIntrinsics.cpp | 99 ++++++++++---------
 1 file changed, 51 insertions(+), 48 deletions(-)

diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
index 377cc44392028f..b2e02376599636 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
@@ -36,10 +36,10 @@ namespace {
 /// Base class for passes converting transformational intrinsic operations into
 /// runtime calls
 template <class OP>
-class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
+class HlfirIntrinsicConversion : public mlir::OpConversionPattern<OP> {
 public:
   explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
-      : mlir::OpRewritePattern<OP>{ctx} {
+      : mlir::OpConversionPattern<OP>{ctx} {
     // required for cases where intrinsics are chained together e.g.
     // matmul(matmul(a, b), c)
     // because converting the inner operation then invalidates the
@@ -145,7 +145,7 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
   void processReturnValue(mlir::Operation *op,
                           const fir::ExtendedValue &resultExv, bool mustBeFreed,
                           fir::FirOpBuilder &builder,
-                          mlir::PatternRewriter &rewriter) const {
+                          mlir::ConversionPatternRewriter &rewriter) const {
     mlir::Location loc = op->getLoc();
 
     mlir::Value firBase = fir::getBase(resultExv);
@@ -176,13 +176,9 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
           rewriter.eraseOp(use);
       }
     }
-    // TODO: This entire pass should be a greedy pattern rewrite or a manual
-    // IR traversal. A dialect conversion cannot be used here because
-    // `replaceAllUsesWith` is not supported. Similarly, `replaceOp` is not
-    // suitable because "op->getResult(0)" and "base" can have different types.
-    // In such a case, the dialect conversion will attempt to convert the type,
-    // but no type converter is specified in this pass. Also note that all
-    // patterns in this pass are actually rewrite patterns.
+    // the types might not match exactly (but are safe)
+    // e.g. !hlfir.expr<?xi32> vs !hlfir.expr<2xi32>
+    // TODO: is this allowed by MLIR?
     op->getResult(0).replaceAllUsesWith(base);
     rewriter.replaceOp(op, base);
   }
@@ -203,48 +199,53 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
       typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
   using HlfirIntrinsicConversion<OP>::lowerArguments;
   using HlfirIntrinsicConversion<OP>::processReturnValue;
+  using Adaptor = typename OP::Adaptor;
 
 protected:
-  auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+  auto buildNumericalArgs(mlir::Operation *operation, Adaptor adaptor,
+                          mlir::Type i32, mlir::Type logicalType,
                           mlir::PatternRewriter &rewriter,
                           std::string opName) const {
     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
-    inArgs.push_back({operation.getArray(), operation.getArray().getType()});
-    inArgs.push_back({operation.getDim(), i32});
-    inArgs.push_back({operation.getMask(), logicalType});
+    inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()});
+    inArgs.push_back({adaptor.getDim(), i32});
+    inArgs.push_back({adaptor.getMask(), logicalType});
     auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
     return lowerArguments(operation, inArgs, rewriter, argLowering);
   };
 
-  auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+  auto buildMinMaxLocArgs(mlir::Operation *operation, Adaptor adaptor,
+                          mlir::Type i32, mlir::Type logicalType,
                           mlir::PatternRewriter &rewriter, std::string opName,
                           fir::FirOpBuilder builder) const {
     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
-    inArgs.push_back({operation.getArray(), operation.getArray().getType()});
-    inArgs.push_back({operation.getDim(), i32});
-    inArgs.push_back({operation.getMask(), logicalType});
+    inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()});
+    inArgs.push_back({adaptor.getDim(), i32});
+    inArgs.push_back({adaptor.getMask(), logicalType});
     mlir::Value kind = builder.createIntegerConstant(
-        operation->getLoc(), i32, getKindForType(operation.getType()));
+        operation->getLoc(), i32,
+        getKindForType(operation->getResult(0).getType()));
     inArgs.push_back({kind, i32});
-    inArgs.push_back({operation.getBack(), i32});
+    inArgs.push_back({adaptor.getBack(), i32});
     auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
     return lowerArguments(operation, inArgs, rewriter, argLowering);
   };
 
-  auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
+  auto buildLogicalArgs(mlir::Operation *operation, Adaptor adaptor,
+                        mlir::Type i32, mlir::Type logicalType,
                         mlir::PatternRewriter &rewriter,
                         std::string opName) const {
     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
-    inArgs.push_back({operation.getMask(), logicalType});
-    inArgs.push_back({operation.getDim(), i32});
+    inArgs.push_back({adaptor.getMask(), logicalType});
+    inArgs.push_back({adaptor.getDim(), i32});
     auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
     return lowerArguments(operation, inArgs, rewriter, argLowering);
   };
 
 public:
   mlir::LogicalResult
-  matchAndRewrite(OP operation,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(OP operation, Adaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     std::string opName;
     if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
       opName = "sum";
@@ -279,13 +280,15 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
                   std::is_same_v<OP, hlfir::ProductOp> ||
                   std::is_same_v<OP, hlfir::MaxvalOp> ||
                   std::is_same_v<OP, hlfir::MinvalOp>) {
-      args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
+      args = buildNumericalArgs(operation, adaptor, i32, logicalType, rewriter,
+                                opName);
     } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
                          std::is_same_v<OP, hlfir::MaxlocOp>) {
-      args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
-                                builder);
+      args = buildMinMaxLocArgs(operation, adaptor, i32, logicalType, rewriter,
+                                opName, builder);
     } else {
-      args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
+      args = buildLogicalArgs(operation, adaptor, i32, logicalType, rewriter,
+                              opName);
     }
 
     mlir::Type scalarResultType =
@@ -319,8 +322,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
   using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::CountOp count,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::CountOp count, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, count.getOperation()};
     const mlir::Location &loc = count->getLoc();
 
@@ -329,8 +332,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
         builder.getContext(), builder.getKindMap().defaultLogicalKind());
 
     llvm::SmallVector<IntrinsicArgument, 3> inArgs;
-    inArgs.push_back({count.getMask(), logicalType});
-    inArgs.push_back({count.getDim(), i32});
+    inArgs.push_back({adaptor.getMask(), logicalType});
+    inArgs.push_back({adaptor.getDim(), i32});
     mlir::Value kind = builder.createIntegerConstant(
         count->getLoc(), i32, getKindForType(count.getType()));
     inArgs.push_back({kind, i32});
@@ -353,13 +356,13 @@ struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
   using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::MatmulOp matmul,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::MatmulOp matmul, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
     const mlir::Location &loc = matmul->getLoc();
 
-    mlir::Value lhs = matmul.getLhs();
-    mlir::Value rhs = matmul.getRhs();
+    mlir::Value lhs = adaptor.getLhs();
+    mlir::Value rhs = adaptor.getRhs();
     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
     inArgs.push_back({lhs, lhs.getType()});
     inArgs.push_back({rhs, rhs.getType()});
@@ -384,13 +387,13 @@ struct DotProductOpConversion
   using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::DotProductOp dotProduct,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::DotProductOp dotProduct, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
     const mlir::Location &loc = dotProduct->getLoc();
 
-    mlir::Value lhs = dotProduct.getLhs();
-    mlir::Value rhs = dotProduct.getRhs();
+    mlir::Value lhs = adaptor.getLhs();
+    mlir::Value rhs = adaptor.getRhs();
     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
     inArgs.push_back({lhs, lhs.getType()});
     inArgs.push_back({rhs, rhs.getType()});
@@ -415,12 +418,12 @@ class TransposeOpConversion
   using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::TransposeOp transpose,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::TransposeOp transpose, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
     const mlir::Location &loc = transpose->getLoc();
 
-    mlir::Value arg = transpose.getArray();
+    mlir::Value arg = adaptor.getArray();
     llvm::SmallVector<IntrinsicArgument, 1> inArgs;
     inArgs.push_back({arg, arg.getType()});
 
@@ -445,13 +448,13 @@ struct MatmulTransposeOpConversion
       hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;
 
   mlir::LogicalResult
-  matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
-                  mlir::PatternRewriter &rewriter) const override {
+  matchAndRewrite(hlfir::MatmulTransposeOp multranspose, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
     fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
     const mlir::Location &loc = multranspose->getLoc();
 
-    mlir::Value lhs = multranspose.getLhs();
-    mlir::Value rhs = multranspose.getRhs();
+    mlir::Value lhs = adaptor.getLhs();
+    mlir::Value rhs = adaptor.getRhs();
     llvm::SmallVector<IntrinsicArgument, 2> inArgs;
     inArgs.push_back({lhs, lhs.getType()});
     inArgs.push_back({rhs, rhs.getType()});



More information about the flang-commits mailing list