[Mlir-commits] [mlir] 85fe8e0 - [mlir] Add mlir::LLVM::FastmathFlags to LLVM instrinsic vector reductions

Stella Stamenova llvmlistbot at llvm.org
Fri Mar 10 09:14:56 PST 2023


Author: Brandon Myers
Date: 2023-03-10T09:14:16-08:00
New Revision: 85fe8e01a0005539b7c137b026afce6a60fdaf01

URL: https://github.com/llvm/llvm-project/commit/85fe8e01a0005539b7c137b026afce6a60fdaf01
DIFF: https://github.com/llvm/llvm-project/commit/85fe8e01a0005539b7c137b026afce6a60fdaf01.diff

LOG: [mlir] Add mlir::LLVM::FastmathFlags to LLVM instrinsic vector reductions

Rationale:
The LLVM dialect supports passing fastmath flags from floating point ops to LLVMIR instructions. However, not all LLVM ops have the required attribute. This change adds support for fastmath flags to `llvm.intr.vector.reduce.{fmin,fmax}`. One scenario where this is useful is in lowering llvm.intr.vector.reduce.{fmax,fmin} to LLVMIR with `nnan` (NoNans) flag so it may be [[ https://github.com/llvm/llvm-project/blob/115c7beda74f3cfaf83b91d14bc97a39bff4cf19/llvm/lib/CodeGen/ExpandReductions.cpp#L159 | lowered to a shuffle reduction ]].

Changes:

  - Make `LLVM_VecReductionF` implement the `FastmathFlagsInterface`; change is modeled on `LLVM_UnaryIntrOpF`
  - Add an assembly format for `LLVM_VecReductionF` ops. The purpose is to keep existing functionality: avoid printing the fastmath flags attribute when it has its default value (`none`). Change is modeled on `LLVM_UnaryIntrOpBase`

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D145692

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/LLVMIR/roundtrip.mlir
    mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir
    mlir/test/Target/LLVMIR/Import/fastmath.ll
    mlir/test/Target/LLVMIR/Import/intrinsic.ll
    mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
    mlir/test/Target/LLVMIR/llvmir-invalid.mlir
    mlir/test/Target/LLVMIR/llvmir.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index d12564b6f245a..1b62ce0ca3e6e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -381,16 +381,27 @@ def LLVM_StackRestoreOp : LLVM_ZeroResultIntrOp<"stackrestore"> {
 //
 
 // LLVM vector reduction over a single vector.
-class LLVM_VecReductionBase<string mnem, Type element>
+class LLVM_VecReductionBase<string mnem, Type element, bit requiresFastmath=0>
     : LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0],
-                           [Pure, SameOperandsAndResultElementType]>,
-      Arguments<(ins LLVM_VectorOf<element>)>;
+                           [Pure, SameOperandsAndResultElementType],
+                           requiresFastmath> {
+      dag commonArgs = (ins LLVM_VectorOf<element>:$in);
+}
 
 class LLVM_VecReductionF<string mnem>
-    : LLVM_VecReductionBase<mnem, AnyFloat>;
+    : LLVM_VecReductionBase<mnem, AnyFloat, /*requiresFastmath=*/1> {
+  dag fmfArg = (
+    ins DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
+  let arguments = !con(commonArgs, fmfArg);
+
+  let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
+      "functional-type(operands, results)";
+}
 
 class LLVM_VecReductionI<string mnem>
-    : LLVM_VecReductionBase<mnem, AnySignlessInteger>;
+    : LLVM_VecReductionBase<mnem, AnySignlessInteger> {
+      let arguments = commonArgs;
+}
 
 // LLVM vector reduction over a single vector, with an initial value,
 // and with permission to reassociate the reduction operations.

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 0b9b86ade3323..62d165c8b3998 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1374,7 +1374,7 @@ func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
 }
 // CHECK-LABEL: @reduce_fmax_f32(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
-//      CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmax"(%[[A]]) : (vector<16xf32>) -> f32
+//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmax(%[[A]]) : (vector<16xf32>) -> f32
 //      CHECK: %[[C0:.*]] = llvm.fcmp "ogt" %[[V]], %[[B]] : f32
 //      CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
 //      CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32
@@ -1390,7 +1390,7 @@ func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
 }
 // CHECK-LABEL: @reduce_fmin_f32(
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32)
-//      CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmin"(%[[A]]) : (vector<16xf32>) -> f32
+//      CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmin(%[[A]]) : (vector<16xf32>) -> f32
 //      CHECK: %[[C0:.*]] = llvm.fcmp "olt" %[[V]], %[[B]] : f32
 //      CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32
 //      CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32

diff  --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 943dc874169fe..c9db19b409e25 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -516,6 +516,11 @@ func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f
   %11 = llvm.intr.sin(%arg0) {fastmathFlags = #llvm.fastmath<fast>} : (f32) -> f32
 // CHECK: {{.*}} = llvm.intr.sin(%arg0) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
   %12 = llvm.intr.sin(%arg0) {fastmathFlags = #llvm.fastmath<afn>} : (f32) -> f32
+
+// CHECK: {{.*}} = llvm.intr.vector.reduce.fmin(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %13 = llvm.intr.vector.reduce.fmin(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+// CHECK: {{.*}} = llvm.intr.vector.reduce.fmax(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %14 = llvm.intr.vector.reduce.fmax(%arg3) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
   return
 }
 

diff  --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir b/mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir
index 1febc6c483ae4..f2e59e1d54824 100644
--- a/mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir
+++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir
@@ -23,13 +23,13 @@ module {
     %12 = llvm.mlir.constant(3 : i64) : i64
     %v = llvm.insertelement %3, %11[%12 : i64] : vector<4xf32>
 
-    %max = "llvm.intr.vector.reduce.fmax"(%v)
+    %max = llvm.intr.vector.reduce.fmax(%v)
         : (vector<4xf32>) -> f32
     llvm.call @printF32(%max) : (f32) -> ()
     llvm.call @printNewline() : () -> ()
     // CHECK: 4
 
-    %min = "llvm.intr.vector.reduce.fmin"(%v)
+    %min = llvm.intr.vector.reduce.fmin(%v)
         : (vector<4xf32>) -> f32
     llvm.call @printF32(%min) : (f32) -> ()
     llvm.call @printNewline() : () -> ()

diff  --git a/mlir/test/Target/LLVMIR/Import/fastmath.ll b/mlir/test/Target/LLVMIR/Import/fastmath.ll
index c54677505bd79..8457c9f1c894d 100644
--- a/mlir/test/Target/LLVMIR/Import/fastmath.ll
+++ b/mlir/test/Target/LLVMIR/Import/fastmath.ll
@@ -41,9 +41,11 @@ declare float @llvm.exp.f32(float)
 declare float @llvm.powi.f32.i32(float, i32)
 declare float @llvm.pow.f32(float, float)
 declare float @llvm.fmuladd.f32(float, float, float)
+declare float @llvm.vector.reduce.fmin.v2f32(<2 x float>)
+declare float @llvm.vector.reduce.fmax.v2f32(<2 x float>)
 
 ; CHECK-LABEL: @fastmath_intr
-define void @fastmath_intr(float %arg1, i32 %arg2) {
+define void @fastmath_intr(float %arg1, i32 %arg2, <2 x float> %arg3) {
   ; CHECK: llvm.intr.exp(%{{.*}}) {fastmathFlags = #llvm.fastmath<nnan, ninf>} : (f32) -> f32
   %1 = call nnan ninf float @llvm.exp.f32(float %arg1)
   ; CHECK: llvm.intr.powi(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, i32) -> f32
@@ -52,5 +54,10 @@ define void @fastmath_intr(float %arg1, i32 %arg2) {
   %3 = call fast float @llvm.pow.f32(float %arg1, float %arg1)
   ; CHECK: llvm.intr.fmuladd(%{{.*}}, %{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32, f32) -> f32
   %4 = call fast float @llvm.fmuladd.f32(float %arg1, float %arg1, float %arg1)
+  ; CHECK: %{{.*}} = llvm.intr.vector.reduce.fmin({{.*}}) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %5 = call nnan float @llvm.vector.reduce.fmin.v2f32(<2 x float> %arg3)
+  ; CHECK: %{{.*}} = llvm.intr.vector.reduce.fmax({{.*}}) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %6 = call nnan float @llvm.vector.reduce.fmax.v2f32(<2 x float> %arg3)
+
   ret void
 }

diff  --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 987fc0c63e006..7caca3cd55aed 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -231,9 +231,9 @@ define void @vector_reductions(float %0, <8 x float> %1, <8 x i32> %2) {
   %4 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2)
   ; CHECK: "llvm.intr.vector.reduce.and"(%{{.*}}) : (vector<8xi32>) -> i32
   %5 = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> %2)
-  ; CHECK: "llvm.intr.vector.reduce.fmax"(%{{.*}}) : (vector<8xf32>) -> f32
+  ; CHECK: llvm.intr.vector.reduce.fmax(%{{.*}}) : (vector<8xf32>) -> f32
   %6 = call float @llvm.vector.reduce.fmax.v8f32(<8 x float> %1)
-  ; CHECK: "llvm.intr.vector.reduce.fmin"(%{{.*}}) : (vector<8xf32>) -> f32
+  ; CHECK: llvm.intr.vector.reduce.fmin(%{{.*}}) : (vector<8xf32>) -> f32
   %7 = call float @llvm.vector.reduce.fmin.v8f32(<8 x float> %1)
   ; CHECK: "llvm.intr.vector.reduce.mul"(%{{.*}}) : (vector<8xi32>) -> i32
   %8 = call i32 @llvm.vector.reduce.mul.v8i32(<8 x i32> %2)

diff  --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 681ba0c2e5485..2536e21fbe0c0 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -251,9 +251,9 @@ llvm.func @vector_reductions(%arg0: f32, %arg1: vector<8xf32>, %arg2: vector<8xi
   // CHECK: call i32 @llvm.vector.reduce.and.v8i32
   "llvm.intr.vector.reduce.and"(%arg2) : (vector<8xi32>) -> i32
   // CHECK: call float @llvm.vector.reduce.fmax.v8f32
-  "llvm.intr.vector.reduce.fmax"(%arg1) : (vector<8xf32>) -> f32
+  llvm.intr.vector.reduce.fmax(%arg1) : (vector<8xf32>) -> f32
   // CHECK: call float @llvm.vector.reduce.fmin.v8f32
-  "llvm.intr.vector.reduce.fmin"(%arg1) : (vector<8xf32>) -> f32
+  llvm.intr.vector.reduce.fmin(%arg1) : (vector<8xf32>) -> f32
   // CHECK: call i32 @llvm.vector.reduce.mul.v8i32
   "llvm.intr.vector.reduce.mul"(%arg2) : (vector<8xi32>) -> i32
   // CHECK: call i32 @llvm.vector.reduce.or.v8i32

diff  --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 19e0b1501e090..a373bd62f2247 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -172,7 +172,7 @@ llvm.func @vec_reduce_add_intr_wrong_type(%arg0 : vector<4xi32>) -> f32 {
 
 llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 {
   // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector of floating-point}}
-  %0 = "llvm.intr.vector.reduce.fmax"(%arg0) : (vector<4xi32>) -> i32
+  %0 = llvm.intr.vector.reduce.fmax(%arg0) : (vector<4xi32>) -> i32
   llvm.return %0 : i32
 }
 

diff  --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index d19f49d4b2938..06fe2f4d524ec 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1871,7 +1871,7 @@ llvm.func @useInlineAsm(%arg0: i32) {
 llvm.func @fastmathFlagsFunc(f32) -> f32
 
 // CHECK-LABEL: @fastmathFlags
-llvm.func @fastmathFlags(%arg0: f32) {
+llvm.func @fastmathFlags(%arg0: f32, %arg1 : vector<2xf32>) {
 // CHECK: {{.*}} = fadd nnan ninf float {{.*}}, {{.*}}
 // CHECK: {{.*}} = fsub nnan ninf float {{.*}}, {{.*}}
 // CHECK: {{.*}} = fmul nnan ninf float {{.*}}, {{.*}}
@@ -1918,6 +1918,12 @@ llvm.func @fastmathFlags(%arg0: f32) {
   %19 = "llvm.intr.powi"(%arg0, %exp) {fastmathFlags = #llvm.fastmath<fast>} : (f32, i32) -> f32
 // CHECK: call afn float @llvm.powi.f32.i32(float {{.*}}, i32 {{.*}})
   %20 = "llvm.intr.powi"(%arg0, %exp) {fastmathFlags = #llvm.fastmath<afn>} : (f32, i32) -> f32
+
+// CHECK: call nnan float @llvm.vector.reduce.fmax.v2f32(<2 x float> {{.*}})
+// CHECK: call nnan float @llvm.vector.reduce.fmin.v2f32(<2 x float> {{.*}})
+  %21 = llvm.intr.vector.reduce.fmax(%arg1) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+  %22 = llvm.intr.vector.reduce.fmin(%arg1) {fastmathFlags = #llvm.fastmath<nnan>} : (vector<2xf32>) -> f32
+
   llvm.return
 }
 


        


More information about the Mlir-commits mailing list