[flang-commits] [flang] a351a60 - [flang][hlfir] add matmul canonicalizer
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Fri Mar 17 02:30:59 PDT 2023
Author: Tom Eccles
Date: 2023-03-17T09:30:04Z
New Revision: a351a60ebae456735ec32808f311a6e9cf5e751e
URL: https://github.com/llvm/llvm-project/commit/a351a60ebae456735ec32808f311a6e9cf5e751e
DIFF: https://github.com/llvm/llvm-project/commit/a351a60ebae456735ec32808f311a6e9cf5e751e.diff
LOG: [flang][hlfir] add matmul canonicalizer
hlfir.matmul_transpose will be lowered to a new runtime call.
A canonicalizer was chosen because
- Alternative: a new pass for rewriting chained intrinsics - this
would add a lot of unnecessary boilerplate.
- Alternative: including this in the HLFIR Intrinsic Lowering pass -
I wanted to separate these two concerns: not adding a second purpose
complicating the intrinsic lowering pass.
With this change, the MLIR built-in canonicalization pass should be run
before the HLFIR Intrinsic Lowering pass.
Depends on D145504, D145957
Reviewed By: clementval, vzakhari
Differential Revision: https://reviews.llvm.org/D145959
Added:
Modified:
flang/include/flang/Optimizer/HLFIR/HLFIROps.h
flang/include/flang/Optimizer/HLFIR/HLFIROps.td
flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
flang/test/HLFIR/mul_transpose.f90
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.h b/flang/include/flang/Optimizer/HLFIR/HLFIROps.h
index 94ea77260e340..88f50fd110db5 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.h
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.h
@@ -15,6 +15,7 @@
#include "flang/Optimizer/Dialect/FortranVariableInterface.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 86318f64bfe9e..15c5735c1e8eb 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -382,6 +382,9 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
$lhs $rhs attr-dict `:` functional-type(operands, results)
}];
+ // MATMUL(TRANSPOSE(...), ...) => hlfir.matmul_transpose
+ let hasCanonicalizeMethod = 1;
+
let hasVerifier = 1;
}
diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
index 0114c12117e29..ef3b4d57e1f7f 100644
--- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
+++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
#include <optional>
#include <tuple>
@@ -638,6 +639,52 @@ mlir::LogicalResult hlfir::MatmulOp::verify() {
return mlir::success();
}
+mlir::LogicalResult
+hlfir::MatmulOp::canonicalize(MatmulOp matmulOp,
+ mlir::PatternRewriter &rewriter) {
+ // the only two uses of the transposed matrix should be for the hlfir.matmul
+ // and hlfir.destory
+ auto isOtherwiseUnused = [&](hlfir::TransposeOp transposeOp) -> bool {
+ std::size_t numUses = 0;
+ for (mlir::Operation *user : transposeOp.getResult().getUsers()) {
+ ++numUses;
+ if (user == matmulOp)
+ continue;
+ if (mlir::dyn_cast_or_null<hlfir::DestroyOp>(user))
+ continue;
+ // some other use!
+ return false;
+ }
+ return numUses <= 2;
+ };
+
+ mlir::Value lhs = matmulOp.getLhs();
+ // Rewrite MATMUL(TRANSPOSE(lhs), rhs) => hlfir.matmul_transpose lhs, rhs
+ if (auto transposeOp = lhs.getDefiningOp<hlfir::TransposeOp>()) {
+ if (isOtherwiseUnused(transposeOp)) {
+ mlir::Location loc = matmulOp.getLoc();
+ mlir::Type resultTy = matmulOp.getResult().getType();
+ auto matmulTransposeOp = rewriter.create<hlfir::MatmulTransposeOp>(
+ loc, resultTy, transposeOp.getArray(), matmulOp.getRhs());
+
+ // we don't need to remove any hlfir.destroy because it will be needed for
+ // the new intrinsic result anyway
+ rewriter.replaceOp(matmulOp, matmulTransposeOp.getResult());
+
+ // but we do need to get rid of the hlfir.destroy for the hlfir.transpose
+ // result (which is entirely removed)
+ for (mlir::Operation *user : transposeOp->getResult(0).getUsers())
+ if (auto destroyOp = mlir::dyn_cast_or_null<hlfir::DestroyOp>(user))
+ rewriter.eraseOp(destroyOp);
+ rewriter.eraseOp(transposeOp);
+
+ return mlir::success();
+ }
+ }
+
+ return mlir::failure();
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
diff --git a/flang/test/HLFIR/mul_transpose.f90 b/flang/test/HLFIR/mul_transpose.f90
index f0fa187ffda4f..ab3742f88043e 100644
--- a/flang/test/HLFIR/mul_transpose.f90
+++ b/flang/test/HLFIR/mul_transpose.f90
@@ -1,4 +1,5 @@
! RUN: bbc -emit-fir -hlfir %s -o - | FileCheck --check-prefix CHECK-BASE --check-prefix CHECK-ALL %s
+! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --canonicalize | FileCheck --check-prefix CHECK-CANONICAL --check-prefix CHECK-ALL %s
! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | FileCheck --check-prefix CHECK-LOWERING --check-prefix CHECK-ALL %s
! RUN: bbc -emit-fir -hlfir %s -o - | fir-opt --lower-hlfir-intrinsics | fir-opt --bufferize-hlfir | FileCheck --check-prefix CHECK-BUFFERING --check-prefix CHECK-ALL %s
@@ -22,6 +23,10 @@ subroutine mul_transpose(a, b, res)
! CHECK-BASE-NEXT: hlfir.destroy %[[MATMUL_RES]]
! CHECK-BASE-NEXT: hlfir.destroy %[[TRANSPOSE_RES]]
+! CHECK-CANONICAL-NEXT: %[[CHAIN_RES:.*]] = hlfir.matmul_transpose %[[A_DECL]]#0 %[[B_DECL]]#0 : (!fir.ref<!fir.array<2x1xf32>>, !fir.ref<!fir.array<2x2xf32>>) -> !hlfir.expr<1x2xf32>
+! CHECK-CANONICAL-NEXT: hlfir.assign %[[CHAIN_RES]] to %[[RES_DECL]]#0 : !hlfir.expr<1x2xf32>, !fir.ref<!fir.array<1x2xf32>>
+! CHECK-CANONICAL-NEXT: hlfir.destroy %[[CHAIN_RES]]
+
! CHECK-LOWERING: %[[A_BOX:.*]] = fir.embox %[[A_DECL]]#1(%{{.*}})
! CHECK-LOWERING: %[[TRANSPOSE_CONV_RES:.*]] = fir.convert %[[TRANSPOSE_RES_BOX:.*]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf32>>>>) -> !fir.ref<!fir.box<none>>
! CHECK-LOWERING: %[[A_BOX_CONV:.*]] = fir.convert %[[A_BOX]] : (!fir.box<!fir.array<2x1xf32>>) -> !fir.box<none>
More information about the flang-commits
mailing list