[Mlir-commits] [mlir] ee80ffb - [mlir][Linalg] Add bounded recursion declaration to FMAOp -> LLVM conversion.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Nov 15 04:46:35 PST 2021


Author: Nicolas Vasilache
Date: 2021-11-15T12:41:52Z
New Revision: ee80ffbf9aa4ceb515233bdd841d3c8eae80c4d0

URL: https://github.com/llvm/llvm-project/commit/ee80ffbf9aa4ceb515233bdd841d3c8eae80c4d0
DIFF: https://github.com/llvm/llvm-project/commit/ee80ffbf9aa4ceb515233bdd841d3c8eae80c4d0.diff

LOG: [mlir][Linalg] Add bounded recursion declaration to FMAOp -> LLVM conversion.

FMAOp -> LLVM conversion is done progressively by peeling off 1 dimension from FMAOp at each pattern iteration. Add the recursively bounded property declaration to the pattern so that the rewriter can apply it multiple times.

Without this, FMAOps with 3+D do not lower to LLVM.

Differential Revision: https://reviews.llvm.org/D113886

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 966d3f3b8fce..65816a2d0580 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -752,6 +752,12 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
 public:
   using OpRewritePattern<FMAOp>::OpRewritePattern;
 
+  void initialize() {
+    // This pattern recursively unpacks one dimension at a time. The recursion
+    // bounded as the rank is strictly decreasing.
+    setHasBoundedRewriteRecursion();
+  }
+
   LogicalResult matchAndRewrite(FMAOp op,
                                 PatternRewriter &rewriter) const override {
     auto vType = op.getVectorType();

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 7f619e7e8db8..7f775161545d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -941,10 +941,11 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
 
 // -----
 
-func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
+func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>) {
   // CHECK-LABEL: @vector_fma
   //  CHECK-SAME: %[[A:.*]]: vector<8xf32>
   //  CHECK-SAME: %[[B:.*]]: vector<2x4xf32>
+  //  CHECK-SAME: %[[C:.*]]: vector<1x1x1xf32>
   //       CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>>
   //       CHECK: "llvm.intr.fmuladd"
   //  CHECK-SAME:   (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
@@ -964,7 +965,11 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
   //       CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm.array<2 x vector<4xf32>>
   %1 = vector.fma %b, %b, %b : vector<2x4xf32>
 
-  return %0, %1: vector<8xf32>, vector<2x4xf32>
+  //       CHECK: %[[C0:.*]] = "llvm.intr.fmuladd"
+  //  CHECK-SAME:   (vector<1xf32>, vector<1xf32>, vector<1xf32>) -> vector<1xf32>
+  %2 = vector.fma %c, %c, %c : vector<1x1x1xf32>
+
+  return %0, %1, %2: vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list