[flang-commits] [flang] a351a60 - [flang][hlfir] add matmul canonicalizer

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Fri Mar 17 02:30:59 PDT 2023


Author: Tom Eccles
Date: 2023-03-17T09:30:04Z
New Revision: a351a60ebae456735ec32808f311a6e9cf5e751e

URL: https://github.com/llvm/llvm-project/commit/a351a60ebae456735ec32808f311a6e9cf5e751e
DIFF: https://github.com/llvm/llvm-project/commit/a351a60ebae456735ec32808f311a6e9cf5e751e.diff

LOG: [flang][hlfir] add matmul canonicalizer

hlfir.matmul_transpose will be lowered to a new runtime call.

A canonicalizer was chosen because
  - Alternative: a new pass for rewriting chained intrinsics - this
    would add a lot of unnecessary boilerplate.
  - Alternative: including this in the HLFIR Intrinsic Lowering pass -
    I wanted to separate these two concerns: not adding a second purpose
    complicating the intrinsic lowering pass.

With this change, the MLIR built-in canonicalization pass should be run
before the HLFIR Intrinsic Lowering pass.

Depends on D145504, D145957

Reviewed By: clementval, vzakhari

Differential Revision: https://reviews.llvm.org/D145959

Added: 
    

Modified: 
    flang/include/flang/Optimizer/HLFIR/HLFIROps.h
    flang/include/flang/Optimizer/HLFIR/HLFIROps.td
    flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
    flang/test/HLFIR/mul_transpose.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.h b/flang/include/flang/Optimizer/HLFIR/HLFIROps.h
index 94ea77260e340..88f50fd110db5 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.h
@@ -15,6 +15,7 @@
 #include "flang/Optimizer/Dialect/FortranVariableInterface.h"
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 

diff  --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 86318f64bfe9e..15c5735c1e8eb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -382,6 +382,9 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
     $lhs $rhs attr-dict `:` functional-type(operands, results)
   }];
 
+  // MATMUL(TRANSPOSE(...), ...) => hlfir.matmul_transpose
+  let hasCanonicalizeMethod = 1;
+
   let hasVerifier = 1;
 }
 

diff  --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 0114c12117e29..ef3b4d57e1f7f 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
 #include <optional>
 #include <tuple>
 
@@ -638,6 +639,52 @@ mlir::LogicalResult hlfir::MatmulOp::verify() {
   return mlir::success();
 }
 
+mlir::LogicalResult
+hlfir::MatmulOp::canonicalize(MatmulOp matmulOp,
+                              mlir::PatternRewriter &rewriter) {
+  // the only two uses of the transposed matrix should be for the hlfir.matmul
+  // and hlfir.destory
+  auto isOtherwiseUnused = [&](hlfir::TransposeOp transposeOp) -> bool {
+    std::size_t numUses = 0;
+    for (mlir::Operation *user : transposeOp.getResult().getUsers()) {
+      ++numUses;
+      if (user == matmulOp)
+        continue;
+      if (mlir::dyn_cast_or_null<hlfir::DestroyOp>(user))
+        continue;
+      // some other use!
+      return false;
+    }
+    return numUses <= 2;
+  };
+
+  mlir::Value lhs = matmulOp.getLhs();
+  // Rewrite MATMUL(TRANSPOSE(lhs), rhs) => hlfir.matmul_transpose lhs, rhs
+  if (auto transposeOp = lhs.getDefiningOp<hlfir::TransposeOp>()) {
+    if (isOtherwiseUnused(transposeOp)) {
+      mlir::Location loc = matmulOp.getLoc();
+      mlir::Type resultTy = matmulOp.getResult().getType();
+      auto matmulTransposeOp = rewriter.create<hlfir::MatmulTransposeOp>(
+          loc, resultTy, transposeOp.getArray(), matmulOp.getRhs());
+
+      // we don't need to remove any hlfir.destroy because it will be needed for
+      // the new intrinsic result anyway
+      rewriter.replaceOp(matmulOp, matmulTransposeOp.getResult());
+
+      // but we do need to get rid of the hlfir.destroy for the hlfir.transpose
+      // result (which is entirely removed)
+      for (mlir::Operation *user : transposeOp->getResult(0).getUsers())
+        if (auto destroyOp = mlir::dyn_cast_or_null<hlfir::DestroyOp>(user))
+          rewriter.eraseOp(destroyOp);
+      rewriter.eraseOp(transposeOp);
+
+      return mlir::success();
+    }
+  }
+
+  return mlir::failure();
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/HLFIR/mul_transpose.f90 b/flang/test/HLFIR/mul_transpose.f90
index f0fa187ffda4f..ab3742f88043e 100644
--- a/flang/test/HLFIR/mul_transpose.f90
+++ b/flang/test/HLFIR/mul_transpose.f90
@@ -1,4 +1,5 @@
 ! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck --check-prefix CHECK-BASE --check-prefix CHECK-ALL %s
+! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | FileCheck --check-prefix CHECK-CANONICAL --check-prefix CHECK-ALL %s
 ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING --check-prefix CHECK-ALL %s
 ! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | fir-opt --bufferize-hlfir | FileCheck --check-prefix CHECK-BUFFERING --check-prefix CHECK-ALL %s
 
@@ -22,6 +23,10 @@ subroutine mul_transpose(a, b, res)
 ! CHECK-BASE-NEXT:      hlfir.destroy %[[MATMUL_RES]]
 ! CHECK-BASE-NEXT:      hlfir.destroy %[[TRANSPOSE_RES]]
 
+! CHECK-CANONICAL-NEXT: %[[CHAIN_RES:.*]] = hlfir.matmul_transpose %[[A_DECL]]#0 %[[B_DECL]]#0 : (!fir.ref<!fir.array<2x1xf32>>, !fir.ref<!fir.array<2x2xf32>>) -> !hlfir.expr<1x2xf32>
+! CHECK-CANONICAL-NEXT: hlfir.assign %[[CHAIN_RES]] to %[[RES_DECL]]#0 : !hlfir.expr<1x2xf32>, !fir.ref<!fir.array<1x2xf32>>
+! CHECK-CANONICAL-NEXT: hlfir.destroy %[[CHAIN_RES]]
+
 ! CHECK-LOWERING:       %[[A_BOX:.*]] = fir.embox %[[A_DECL]]#1(%{{.*}})
 ! CHECK-LOWERING:       %[[TRANSPOSE_CONV_RES:.*]] = fir.convert %[[TRANSPOSE_RES_BOX:.*]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
 ! CHECK-LOWERING:       %[[A_BOX_CONV:.*]] = fir.convert %[[A_BOX]] : (!fir.box<!fir.array<2x1xf32>>) -> !fir.box<none>


        


More information about the flang-commits mailing list