[Mlir-commits] [mlir] 499ad45 - [mlir][VectorOps] Expose and use llvm.intrin.fma*

Nicolas Vasilache llvmlistbot at llvm.org
Fri Feb 7 12:38:58 PST 2020


Author: Nicolas Vasilache
Date: 2020-02-07T15:38:40-05:00
New Revision: 499ad45877b930325b641d18e7b8b71094116e49

URL: https://github.com/llvm/llvm-project/commit/499ad45877b930325b641d18e7b8b71094116e49
DIFF: https://github.com/llvm/llvm-project/commit/499ad45877b930325b641d18e7b8b71094116e49.diff

LOG: [mlir][VectorOps] Expose and use llvm.intrin.fma*

Summary:
This revision exposes the portable `llvm.fma` intrinsic in LLVMOps and uses it
in lieu of `llvm.fmuladd` when lowering the `vector.outerproduct` op to LLVM.
This guarantees proper `fma` instructions will be emitted if the target ISA
supports it.

`llvm.fmuladd` does not have this guarantee in its semantics, despite evidence
that the proper x86 instructions are emitted.

For more details, see https://llvm.org/docs/LangRef.html#llvm-fmuladd-intrinsic.

Reviewers: ftynse, aartbik, dcaballe, fhahn

Reviewed By: aartbik

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74219

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/VectorOps/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/Linalg/llvm.mlir
    mlir/test/Target/llvmir-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 33679cd17455..72daab5dc194 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -732,6 +732,7 @@ def LLVM_FAbsOp : LLVM_UnaryIntrinsicOp<"fabs">;
 def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">;
 def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">;
 def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">;
+def LLVM_FMAOp : LLVM_TernarySameArgsIntrinsicOp<"fma">;
 def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">;
 def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
 

diff  --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index b2671d980619..6dfb1dbea75f 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -567,21 +567,25 @@ def Vector_OuterProductOp :
     Results<(outs AnyVector)> {
   let summary = "vector outerproduct with optional fused add";
   let description = [{
-    Takes 2 1-D vectors and returns the 2-D vector containing the outer product.
+    Takes 2 1-D vectors and returns the 2-D vector containing the outer-product.
 
     An optional extra 2-D vector argument may be specified in which case the
-    operation returns the sum of the outer product and the extra vector. When
-    lowered to the LLVMIR dialect, this form emits `llvm.intr.fmuladd`, which
-    can lower to actual `fma` instructions in LLVM.
+    operation returns the sum of the outer-product and the extra vector. In this
+    multiply-accumulate scenario, the rounding mode is that obtained by
+    guaranteeing that a fused-multiply add operation is emitted. When lowered to
+    the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to
+    lower to actual `fma` instructions on x86.
 
-    Examples
+    Examples:
 
+    ```
       %2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32>
       return %2: vector<4x8xf32>
 
       %3 = vector.outerproduct %0, %1, %2:
         vector<4xf32>, vector<8xf32>, vector<4x8xf32>
       return %3: vector<4x8xf32>
+    ```
   }];
   let extraClassDeclaration = [{
     VectorType getOperandVectorTypeLHS() {

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 7df164be87de..b08ba1d9587e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -674,9 +674,9 @@ class VectorOuterProductOpConversion : public LLVMOpLowering {
             loc, vRHS, acc, rewriter.getI64ArrayAttr(d));
       // 3. Compute aD outer b (plus accD, if relevant).
       Value aOuterbD =
-          accD ? rewriter.create<LLVM::FMulAddOp>(loc, vRHS, aD, b, accD)
-                     .getResult()
-               : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
+          accD
+              ? rewriter.create<LLVM::FMAOp>(loc, vRHS, aD, b, accD).getResult()
+              : rewriter.create<LLVM::FMulOp>(loc, aD, b).getResult();
       // 4. Insert as value `d` in the descriptor.
       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayOfVectType,
                                                   desc, aOuterbD,

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 18525127a66c..efbb8aa2e54b 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -222,11 +222,11 @@ func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector
 //       CHECK:   llvm.mlir.undef : !llvm<"[2 x <3 x float>]">
 //       CHECK:   llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
 //       CHECK:   llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
-//       CHECK:   "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+//       CHECK:   "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
 //       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]">
 //       CHECK:   llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
 //       CHECK:   llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
-//       CHECK:   "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
+//       CHECK:   "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>">
 //       CHECK:   llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]">
 //       CHECK:   llvm.return {{.*}} : !llvm<"[2 x <3 x float>]">
 

diff  --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 7c1d02b3fc23..4a5f15b319b4 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -171,7 +171,7 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
 //   LLVM-LOOPS: llvm.shufflevector {{.*}} [2 : i32, 2 : i32, 2 : i32, 2 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
 //   LLVM-LOOPS: llvm.shufflevector {{.*}} [3 : i32, 3 : i32, 3 : i32, 3 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
 //   LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]">
-//   LLVM-LOOPS-NEXT: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
+//   LLVM-LOOPS-NEXT: "llvm.intr.fma"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
 //   LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
 
 

diff  --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir
index 343b1f083b59..fcc110215229 100644
--- a/mlir/test/Target/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/llvmir-intrinsics.mlir
@@ -9,6 +9,10 @@ llvm.func @intrinsics(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<8 x
   "llvm.intr.fmuladd"(%arg0, %arg1, %arg0) : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.float
   // CHECK: call <8 x float> @llvm.fmuladd.v8f32
   "llvm.intr.fmuladd"(%arg2, %arg2, %arg2) : (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
+  // CHECK: call float @llvm.fma.f32
+  "llvm.intr.fma"(%arg0, %arg1, %arg0) : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.float
+  // CHECK: call <8 x float> @llvm.fma.v8f32
+  "llvm.intr.fma"(%arg2, %arg2, %arg2) : (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
   // CHECK: call void @llvm.prefetch.p0i8(i8* %3, i32 0, i32 3, i32 1)
   "llvm.intr.prefetch"(%arg3, %c0, %c3, %c1) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm.i32) -> ()
   llvm.return
@@ -96,23 +100,25 @@ llvm.func @copysign_test(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm<"<
 }
 
 // Check that intrinsics are declared with appropriate types.
-// CHECK: declare float @llvm.fmuladd.f32(float, float, float)
-// CHECK: declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
-// CHECK: declare void @llvm.prefetch.p0i8(i8* nocapture readonly, i32 immarg, i32 immarg, i32)
-// CHECK: declare float @llvm.exp.f32(float)
-// CHECK: declare <8 x float> @llvm.exp.v8f32(<8 x float>) #0
-// CHECK: declare float @llvm.log.f32(float)
-// CHECK: declare <8 x float> @llvm.log.v8f32(<8 x float>) #0
-// CHECK: declare float @llvm.log10.f32(float)
-// CHECK: declare <8 x float> @llvm.log10.v8f32(<8 x float>) #0
-// CHECK: declare float @llvm.log2.f32(float)
-// CHECK: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0
-// CHECK: declare float @llvm.fabs.f32(float)
-// CHECK: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0
-// CHECK: declare float @llvm.sqrt.f32(float)
-// CHECK: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0
-// CHECK: declare float @llvm.ceil.f32(float)
-// CHECK: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0
-// CHECK: declare float @llvm.cos.f32(float)
-// CHECK: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0
-// CHECK: declare float @llvm.copysign.f32(float, float)
+// CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
+// CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
+// CHECK-DAG: declare float @llvm.fmuladd.f32(float, float, float)
+// CHECK-DAG: declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
+// CHECK-DAG: declare void @llvm.prefetch.p0i8(i8* nocapture readonly, i32 immarg, i32 immarg, i32)
+// CHECK-DAG: declare float @llvm.exp.f32(float)
+// CHECK-DAG: declare <8 x float> @llvm.exp.v8f32(<8 x float>) #0
+// CHECK-DAG: declare float @llvm.log.f32(float)
+// CHECK-DAG: declare <8 x float> @llvm.log.v8f32(<8 x float>) #0
+// CHECK-DAG: declare float @llvm.log10.f32(float)
+// CHECK-DAG: declare <8 x float> @llvm.log10.v8f32(<8 x float>) #0
+// CHECK-DAG: declare float @llvm.log2.f32(float)
+// CHECK-DAG: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0
+// CHECK-DAG: declare float @llvm.fabs.f32(float)
+// CHECK-DAG: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0
+// CHECK-DAG: declare float @llvm.sqrt.f32(float)
+// CHECK-DAG: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0
+// CHECK-DAG: declare float @llvm.ceil.f32(float)
+// CHECK-DAG: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0
+// CHECK-DAG: declare float @llvm.cos.f32(float)
+// CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0
+// CHECK-DAG: declare float @llvm.copysign.f32(float, float)


        


More information about the Mlir-commits mailing list