[Mlir-commits] [mlir] 3bffe60 - [mlir][VectorOps] Lower vector.fma to llvm.fmuladd instead of llvm.fma

Benjamin Kramer llvmlistbot at llvm.org
Mon Jul 13 03:37:07 PDT 2020


Author: Benjamin Kramer
Date: 2020-07-13T12:26:03+02:00
New Revision: 3bffe6022cc96a38ad0f6ada5f5b7b41eca5796e

URL: https://github.com/llvm/llvm-project/commit/3bffe6022cc96a38ad0f6ada5f5b7b41eca5796e
DIFF: https://github.com/llvm/llvm-project/commit/3bffe6022cc96a38ad0f6ada5f5b7b41eca5796e.diff

LOG: [mlir][VectorOps] Lower vector.fma to llvm.fmuladd instead of llvm.fma

Summary:
These are semantically equivalent, but fmuladd allows decaying the op
into fmul+fadd if there is no fma instruction available. llvm.fma lowers
to scalar calls to libm fmaf, which is a lot slower.

Reviewers: nicolasvasilache, aartbik, ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, Kayjukh, jurahul, msifontes

Tags: #mlir

Differential Revision: https://reviews.llvm.org/D83666

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 96a8fa4c6f22..2be2bd9bb7d0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -481,7 +481,7 @@ class VectorExtractOpConversion : public ConvertToLLVMPattern {
 /// ```
 /// is converted to:
 /// ```
-///  llvm.intr.fma %va, %va, %va:
+///  llvm.intr.fmuladd %va, %va, %va:
 ///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
 ///    -> !llvm<"<8 x float>">
 /// ```
@@ -500,8 +500,8 @@ class VectorFMAOp1DConversion : public ConvertToLLVMPattern {
     VectorType vType = fmaOp.getVectorType();
     if (vType.getRank() != 1)
       return failure();
-    rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
-                                             adaptor.acc());
+    rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(op, adaptor.lhs(),
+                                                 adaptor.rhs(), adaptor.acc());
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09162aa0236b..829edf5f66f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -236,7 +236,7 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
 //      CHECK: %[[T5:.*]] = llvm.insertelement %[[T2]], %[[T3]][%[[T4]] : !llvm.i32] : !llvm<"<3 x float>">
 //      CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T3]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
 //      CHECK: %[[T7:.*]] = llvm.extractvalue %[[C]][0] : !llvm<"[2 x <3 x float>]">
-//      CHECK: %[[T8:.*]] = "llvm.intr.fma"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
+//      CHECK: %[[T8:.*]] = "llvm.intr.fmuladd"(%[[T6]], %[[B]], %[[T7]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
 //      CHECK: %[[T9:.*]] = llvm.insertvalue %[[T8]], %[[T0]][0] : !llvm<"[2 x <3 x float>]">
 //      CHECK: %[[T10:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64
 //      CHECK: %[[T11:.*]] = llvm.extractelement %[[A]][%[[T10]] : !llvm.i64] : !llvm<"<2 x float>">
@@ -245,7 +245,7 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
 //      CHECK: %[[T14:.*]] = llvm.insertelement %[[T11]], %[[T12]][%[[T13]] : !llvm.i32] : !llvm<"<3 x float>">
 //      CHECK: %[[T15:.*]] = llvm.shufflevector %[[T14]], %[[T12]] [0 : i32, 0 : i32, 0 : i32] : !llvm<"<3 x float>">, !llvm<"<3 x float>">
 //      CHECK: %[[T16:.*]] = llvm.extractvalue %[[C]][1] : !llvm<"[2 x <3 x float>]">
-//      CHECK: %[[T17:.*]] = "llvm.intr.fma"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
+//      CHECK: %[[T17:.*]] = "llvm.intr.fmuladd"(%[[T15]], %[[B]], %[[T16]]) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">)
 //      CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T9]][1] : !llvm<"[2 x <3 x float>]">
 //      CHECK: llvm.return %[[T18]] : !llvm<"[2 x <3 x float>]">
 
@@ -688,20 +688,20 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
 //  CHECK-SAME: %[[A:.*]]: !llvm<"<8 x float>">, %[[B:.*]]: !llvm<"[2 x <4 x float>]">)
 //  CHECK-SAME: -> !llvm<"{ <8 x float>, [2 x <4 x float>] }"> {
 func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
-  //         CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) :
+  //         CHECK: "llvm.intr.fmuladd"(%[[A]], %[[A]], %[[A]]) :
   //    CHECK-SAME:   (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
   %0 = vector.fma %a, %a, %a : vector<8xf32>
 
   //       CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
   //       CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
   //       CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
-  //       CHECK: %[[B0:.*]] = "llvm.intr.fma"(%[[b00]], %[[b01]], %[[b02]]) :
+  //       CHECK: %[[B0:.*]] = "llvm.intr.fmuladd"(%[[b00]], %[[b01]], %[[b02]]) :
   //  CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
   //       CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm<"[2 x <4 x float>]">
   //       CHECK: %[[b10:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
   //       CHECK: %[[b11:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
   //       CHECK: %[[b12:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
-  //       CHECK: %[[B1:.*]] = "llvm.intr.fma"(%[[b10]], %[[b11]], %[[b12]]) :
+  //       CHECK: %[[B1:.*]] = "llvm.intr.fmuladd"(%[[b10]], %[[b11]], %[[b12]]) :
   //  CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
   //       CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]">
   %1 = vector.fma %b, %b, %b : vector<2x4xf32>


        


More information about the Mlir-commits mailing list