[Mlir-commits] [mlir] ee80ffb - [mlir][Linalg] Add bounded recursion declaration to FMAOp -> LLVM conversion.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 15 04:46:35 PST 2021
Author: Nicolas Vasilache
Date: 2021-11-15T12:41:52Z
New Revision: ee80ffbf9aa4ceb515233bdd841d3c8eae80c4d0
URL: https://github.com/llvm/llvm-project/commit/ee80ffbf9aa4ceb515233bdd841d3c8eae80c4d0
DIFF: https://github.com/llvm/llvm-project/commit/ee80ffbf9aa4ceb515233bdd841d3c8eae80c4d0.diff
LOG: [mlir][Linalg] Add bounded recursion declaration to FMAOp -> LLVM conversion.
FMAOp -> LLVM conversion is done progressively by peeling off 1 dimension from FMAOp at each pattern iteration. Add the recursively bounded property declaration to the pattern so that the rewriter can apply it multiple times.
Without this, FMAOps with 3+D do not lower to LLVM.
Differential Revision: https://reviews.llvm.org/D113886
Added:
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 966d3f3b8fce..65816a2d0580 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -752,6 +752,12 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
public:
using OpRewritePattern<FMAOp>::OpRewritePattern;
+ void initialize() {
+ // This pattern recursively unpacks one dimension at a time. The recursion
+ // bounded as the rank is strictly decreasing.
+ setHasBoundedRewriteRecursion();
+ }
+
LogicalResult matchAndRewrite(FMAOp op,
PatternRewriter &rewriter) const override {
auto vType = op.getVectorType();
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7f619e7e8db8..7f775161545d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -941,10 +941,11 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
// -----
-func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
+func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>) {
// CHECK-LABEL: @vector_fma
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
// CHECK-SAME: %[[B:.*]]: vector<2x4xf32>
+ // CHECK-SAME: %[[C:.*]]: vector<1x1x1xf32>
// CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
// CHECK: "llvm.intr.fmuladd"
// CHECK-SAME: (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
@@ -964,7 +965,11 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
// CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm.array<2 x vector<4xf32>>
%1 = vector.fma %b, %b, %b : vector<2x4xf32>
- return %0, %1: vector<8xf32>, vector<2x4xf32>
+ // CHECK: %[[C0:.*]] = "llvm.intr.fmuladd"
+ // CHECK-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
+ %2 = vector.fma %c, %c, %c : vector<1x1x1xf32>
+
+ return %0, %1, %2: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>
}
// -----
More information about the Mlir-commits
mailing list