[Mlir-commits] [mlir] ceb1b32 - [mlir] [VectorOps] Add the ability to mark FP reductions with "reassociate" attribute

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 26 11:04:48 PDT 2020


Author: aartbik
Date: 2020-06-26T11:03:14-07:00
New Revision: ceb1b327b53c82203deaa0b0ede3fb07ec8f823a

URL: https://github.com/llvm/llvm-project/commit/ceb1b327b53c82203deaa0b0ede3fb07ec8f823a
DIFF: https://github.com/llvm/llvm-project/commit/ceb1b327b53c82203deaa0b0ede3fb07ec8f823a.diff

LOG: [mlir] [VectorOps] Add the ability to mark FP reductions with "reassociate" attribute

Rationale:
In general, passing "fastmath" from MLIR to LLVM backend is not supported, and even just providing such a feature for experimentation is under debate. However, passing fine-grained fastmath related attributes on individual operations is generally accepted. This CL introduces an option to instruct the vector-to-llvm lowering phase to annotate floating-point reductions with the "reassociate" fastmath attribute, which allows the LLVM backend to use SIMD implementations for such constructs. Oher lowering passes can start using this mechanism right away in cases where reassociation is allowed.

Benefit:
For some microbenchmarks on x86-avx2, speedups over 20 were observed for longer vector (due to cleaner, spill-free and SIMD exploiting code).

Usage:
mlir-opt --convert-vector-to-llvm="reassociate-fp-reductions"

Reviewed By: ftynse, mehdi_amini

Differential Revision: https://reviews.llvm.org/D82624

Added: 
    mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Target/llvmir-intrinsics.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 48149ced5403..89b63e82d471 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -308,6 +308,11 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
   let summary = "Lower the operations from the vector dialect into the LLVM "
                 "dialect";
   let constructor = "mlir::createConvertVectorToLLVMPass()";
+  let options = [
+    Option<"reassociateFPReductions", "reassociate-fp-reductions",
+           "bool", /*default=*/"false",
+           "Allows llvm to reassociate floating-point reductions for speed">
+  ];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index e09a74d5be83..cdff18802616 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -23,8 +23,9 @@ void populateVectorToLLVMMatrixConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
 
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
-void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                            OwningRewritePatternList &patterns);
+void populateVectorToLLVMConversionPatterns(
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+    bool reassociateFPReductions = false);
 
 /// Create a pass to convert vector operations to the LLVMIR dialect.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass();

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 48eecb4eed87..d88b372dbf43 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -214,10 +214,30 @@ class LLVM_VectorReduction<string mnem>
     : LLVM_OneResultIntrOp<"experimental.vector.reduce." # mnem, [], [0], []>,
       Arguments<(ins LLVM_Type)>;
 
-// LLVM vector reduction over a single vector, with an initial value.
+// LLVM vector reduction over a single vector, with an initial value,
+// and with permission to reassociate the reduction operations.
 class LLVM_VectorReductionV2<string mnem>
-    : LLVM_OneResultIntrOp<"experimental.vector.reduce.v2." # mnem,
-                           [0], [1], []>,
-      Arguments<(ins LLVM_Type, LLVM_Type)>;
+    : LLVM_OpBase<LLVM_Dialect, "intr.experimental.vector.reduce.v2." # mnem, []>,
+      Results<(outs LLVM_Type:$res)>,
+      Arguments<(ins LLVM_Type, LLVM_Type,
+                 DefaultValuedAttr<BoolAttr, "false">:$reassoc)> {
+  let llvmBuilder = [{
+    llvm::Module *module = builder.GetInsertBlock()->getModule();
+    llvm::Function *fn = llvm::Intrinsic::getDeclaration(
+        module,
+        llvm::Intrinsic::experimental_vector_reduce_v2_}] # mnem # [{,
+        { }] # StrJoin<!listconcat(
+            ListIntSubst<LLVM_IntrPatterns.result, [0]>.lst,
+            ListIntSubst<LLVM_IntrPatterns.operand, [1]>.lst)>.result # [{
+        });
+    auto operands = lookupValues(opInst.getOperands());
+    llvm::FastMathFlags origFM = builder.getFastMathFlags();
+    llvm::FastMathFlags tempFM = origFM;
+    tempFM.setAllowReassoc($reassoc);
+    builder.setFastMathFlags(tempFM);  // set fastmath flag
+    $res = builder.CreateCall(fn, operands);
+    builder.setFastMathFlags(origFM);  // restore fastmath flag
+  }];
+}
 
 #endif  // LLVMIR_OP_BASE

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index bd9ec9363383..6b43a1e1fdc1 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -255,9 +255,11 @@ class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern {
 class VectorReductionOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorReductionOpConversion(MLIRContext *context,
-                                       LLVMTypeConverter &typeConverter)
+                                       LLVMTypeConverter &typeConverter,
+                                       bool reassociateFP)
       : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context,
-                             typeConverter) {}
+                             typeConverter),
+        reassociateFPReductions(reassociateFP) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -302,7 +304,8 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
                                               op->getLoc(), llvmType,
                                               rewriter.getZeroAttr(eltType));
         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
-            op, llvmType, acc, operands[0]);
+            op, llvmType, acc, operands[0],
+            rewriter.getBoolAttr(reassociateFPReductions));
       } else if (kind == "mul") {
         // Optional accumulator (or one).
         Value acc = operands.size() > 1
@@ -311,7 +314,8 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
                               op->getLoc(), llvmType,
                               rewriter.getFloatAttr(eltType, 1.0));
         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
-            op, llvmType, acc, operands[0]);
+            op, llvmType, acc, operands[0],
+            rewriter.getBoolAttr(reassociateFPReductions));
       } else if (kind == "min")
         rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
             op, llvmType, operands[0]);
@@ -324,6 +328,9 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
     }
     return failure();
   }
+
+private:
+  const bool reassociateFPReductions;
 };
 
 class VectorShuffleOpConversion : public ConvertToLLVMPattern {
@@ -1139,16 +1146,18 @@ class VectorStridedSliceOpConversion
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
-    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
+    LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
+    bool reassociateFPReductions) {
   MLIRContext *ctx = converter.getDialect()->getContext();
   // clang-format off
   patterns.insert<VectorFMAOpNDRewritePattern,
                   VectorInsertStridedSliceOpDifferentRankRewritePattern,
                   VectorInsertStridedSliceOpSameRankRewritePattern,
                   VectorStridedSliceOpConversion>(ctx);
+  patterns.insert<VectorReductionOpConversion>(
+      ctx, converter, reassociateFPReductions);
   patterns
-      .insert<VectorReductionOpConversion,
-              VectorShuffleOpConversion,
+      .insert<VectorShuffleOpConversion,
               VectorExtractElementOpConversion,
               VectorExtractOpConversion,
               VectorFMAOp1DConversion,
@@ -1190,7 +1199,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext());
   OwningRewritePatternList patterns;
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
-  populateVectorToLLVMConversionPatterns(converter, patterns);
+  populateVectorToLLVMConversionPatterns(converter, patterns,
+                                         reassociateFPReductions);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateStdToLLVMConversionPatterns(converter, patterns);
 

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
new file mode 100644
index 000000000000..eab211ffaaf4
--- /dev/null
+++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions' | FileCheck %s --check-prefix=REASSOC
+
+//
+// CHECK-LABEL: llvm.func @reduce_add_f32(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">)
+//      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
+// CHECK-SAME: {reassoc = false} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float
+//      CHECK: llvm.return %[[V]] : !llvm.float
+//
+// REASSOC-LABEL: llvm.func @reduce_add_f32(
+// REASSOC-SAME: %[[A:.*]]: !llvm<"<16 x float>">)
+//      REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+//      REASSOC: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
+// REASSOC-SAME: {reassoc = true} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float
+//      REASSOC: llvm.return %[[V]] : !llvm.float
+//
+func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 {
+  %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32
+  return %0 : f32
+}
+
+//
+// CHECK-LABEL: llvm.func @reduce_mul_f32(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">)
+//      CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+//      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fmul"(%[[C]], %[[A]])
+// CHECK-SAME: {reassoc = false} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float
+//      CHECK: llvm.return %[[V]] : !llvm.float
+//
+// REASSOC-LABEL: llvm.func @reduce_mul_f32(
+// REASSOC-SAME: %[[A:.*]]: !llvm<"<16 x float>">)
+//      REASSOC: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float
+//      REASSOC: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fmul"(%[[C]], %[[A]])
+// REASSOC-SAME: {reassoc = true} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float
+//      REASSOC: llvm.return %[[V]] : !llvm.float
+//
+func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 {
+  %0 = vector.reduction "mul", %arg0 : vector<16xf32> into f32
+  return %0 : f32
+}

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 8351fb23c078..2b2adf0cca64 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -721,6 +721,7 @@ func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
 // CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">)
 //      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
 //      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
+// CHECK-SAME: {reassoc = false} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float
 //      CHECK: llvm.return %[[V]] : !llvm.float
 
 func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
@@ -731,6 +732,7 @@ func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
 // CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>">)
 //      CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double
 //      CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
+// CHECK-SAME: {reassoc = false} : (!llvm.double, !llvm<"<16 x double>">) -> !llvm.double
 //      CHECK: llvm.return %[[V]] : !llvm.double
 
 func @reduce_i32(%arg0: vector<16xi32>) -> i32 {

diff  --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir
index 45292124aedc..ffbbf359b380 100644
--- a/mlir/test/Target/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/llvmir-intrinsics.mlir
@@ -161,6 +161,10 @@ llvm.func @vector_reductions(%arg0: !llvm.float, %arg1: !llvm<"<8 x float>">, %a
   "llvm.intr.experimental.vector.reduce.v2.fadd"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float
   // CHECK: call float @llvm.experimental.vector.reduce.v2.fmul.f32.v8f32
   "llvm.intr.experimental.vector.reduce.v2.fmul"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float
+  // CHECK: call reassoc float @llvm.experimental.vector.reduce.v2.fadd.f32.v8f32
+  "llvm.intr.experimental.vector.reduce.v2.fadd"(%arg0, %arg1) {reassoc = true} : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float
+  // CHECK: call reassoc float @llvm.experimental.vector.reduce.v2.fmul.f32.v8f32
+  "llvm.intr.experimental.vector.reduce.v2.fmul"(%arg0, %arg1) {reassoc = true} : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float
   // CHECK: call i32 @llvm.experimental.vector.reduce.xor.v8i32
   "llvm.intr.experimental.vector.reduce.xor"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32
   llvm.return


        


More information about the Mlir-commits mailing list