[Mlir-commits] [mlir] 681f929 - [mlir][VectorOps] Introduce a `vector.fma` op that works on n-D vectors and lowers to `llvm.intrin.fmuladd`

Nicolas Vasilache llvmlistbot at llvm.org
Fri Feb 7 12:48:33 PST 2020


Author: Nicolas Vasilache
Date: 2020-02-07T15:44:53-05:00
New Revision: 681f929f591616ca048aa470d030d985b6719216

URL: https://github.com/llvm/llvm-project/commit/681f929f591616ca048aa470d030d985b6719216
DIFF: https://github.com/llvm/llvm-project/commit/681f929f591616ca048aa470d030d985b6719216.diff

LOG: [mlir][VectorOps] Introduce a `vector.fma` op that works on n-D vectors and lowers to `llvm.intrin.fmuladd`

Summary:
The `vector.fma` operation is portable enough across targets that we do not want
to keep it wrapped under `vector.outerproduct` and `llvm.intrin.fmuladd`.
This revision lifts the op into the vector dialect and implements the lowering to LLVM by using two patterns:
1. a pattern that lowers from n-D to (n-1)-D by unrolling when n > 2
2. a pattern that converts from 1-D to the proper LLVM representation

Reviewers: ftynse, stellaraccident, aartbik, dcaballe, jsetoain, tetuante

Reviewed By: aartbik

Subscribers: fhahn, dcaballe, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74075

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/VectorOps/VectorOps.td
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
    mlir/test/Dialect/VectorOps/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 6dfb1dbea75f..e7e165aa381f 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -388,6 +388,38 @@ def Vector_ExtractSlicesOp :
   }];
 }
 
+def Vector_FMAOp :
+  Op<Vector_Dialect, "fma", [NoSideEffect,
+                             AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
+    Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
+    Results<(outs AnyVector:$result)> {
+  let summary = "vector fused multiply-add";
+  let description = [{
+    Multiply-add expressions operate on n-D vectors and compute a fused
+    pointwise multiply-and-accumulate: `$result = `$lhs * $rhs + $acc`.
+    All operands and result have the same vector type. The semantics
+    of the operation correspond to those of the `llvm.fma`
+    [intrinsic](https://llvm.org/docs/LangRef.html#int-fma). In the
+    particular case of lowering to LLVM, this is guaranteed to lower
+    to the `llvm.fma.*` intrinsic.
+
+    Example:
+    
+    ```
+      %3 = vector.fma %0, %1, %2: vector<8x16xf32>
+    ```
+  }];
+  // Fully specified by traits.
+  let verifier = ?;
+  let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)";
+  let builders = [OpBuilder<
+    "Builder *b, OperationState &result, Value lhs, Value rhs, Value acc",
+    "build(b, result, lhs.getType(), lhs, rhs, acc);">];
+  let extraClassDeclaration = [{
+    VectorType getVectorType() { return lhs().getType().cast<VectorType>(); }
+  }];
+}
+
 def Vector_InsertElementOp :
   Vector_Op<"insertelement", [NoSideEffect,
      PredOpTrait<"source operand and result have same element type",

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b08ba1d9587e..a3d724bea834 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -410,6 +410,41 @@ class VectorExtractOpConversion : public LLVMOpLowering {
   }
 };
 
+/// Conversion pattern that turns a vector.fma on a 1-D vector
+/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
+/// This does not match vectors of n >= 2 rank.
+///
+/// Example:
+/// ```
+///  vector.fma %a, %a, %a : vector<8xf32>
+/// ```
+/// is converted to:
+/// ```
+///  llvm.intr.fma %va, %va, %va:
+///    (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
+///    -> !llvm<"<8 x float>">
+/// ```
+class VectorFMAOp1DConversion : public LLVMOpLowering {
+public:
+  explicit VectorFMAOp1DConversion(MLIRContext *context,
+                                   LLVMTypeConverter &typeConverter)
+      : LLVMOpLowering(vector::FMAOp::getOperationName(), context,
+                       typeConverter) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto adaptor = vector::FMAOpOperandAdaptor(operands);
+    vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
+    VectorType vType = fmaOp.getVectorType();
+    if (vType.getRank() != 1)
+      return matchFailure();
+    rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
+                                             adaptor.acc());
+    return matchSuccess();
+  }
+};
+
 class VectorInsertElementOpConversion : public LLVMOpLowering {
 public:
   explicit VectorInsertElementOpConversion(MLIRContext *context,
@@ -502,6 +537,54 @@ class VectorInsertOpConversion : public LLVMOpLowering {
   }
 };
 
+/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
+///
+/// Example:
+/// ```
+///   %d = vector.fma %a, %b, %c : vector<2x4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+///  %r = splat %f0: vector<2x4xf32>
+///  %va = vector.extractvalue %a[0] : vector<2x4xf32>
+///  %vb = vector.extractvalue %b[0] : vector<2x4xf32>
+///  %vc = vector.extractvalue %c[0] : vector<2x4xf32>
+///  %vd = vector.fma %va, %vb, %vc : vector<4xf32>
+///  %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
+///  %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
+///  %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
+///  %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
+///  %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
+///  %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
+///  // %r3 holds the final value.
+/// ```
+class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
+public:
+  using OpRewritePattern<FMAOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(FMAOp op,
+                                     PatternRewriter &rewriter) const override {
+    auto vType = op.getVectorType();
+    if (vType.getRank() < 2)
+      return matchFailure();
+
+    auto loc = op.getLoc();
+    auto elemType = vType.getElementType();
+    Value zero = rewriter.create<ConstantOp>(loc, elemType,
+                                             rewriter.getZeroAttr(elemType));
+    Value desc = rewriter.create<SplatOp>(loc, vType, zero);
+    for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
+      Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
+      Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
+      Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
+      Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
+      desc = rewriter.create<InsertOp>(loc, fma, desc, i);
+    }
+    rewriter.replaceOp(op, desc);
+    return matchSuccess();
+  }
+};
+
 // When ranks are 
diff erent, InsertStridedSlice needs to extract a properly
 // ranked vector from the destination vector into which to insert. This pattern
 // only takes care of this part and forwards the rest of the conversion to
@@ -969,14 +1052,16 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
 void mlir::populateVectorToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   MLIRContext *ctx = converter.getDialect()->getContext();
-  patterns.insert<VectorInsertStridedSliceOpDifferentRankRewritePattern,
+  patterns.insert<VectorFMAOpNDRewritePattern,
+                  VectorInsertStridedSliceOpDifferentRankRewritePattern,
                   VectorInsertStridedSliceOpSameRankRewritePattern,
                   VectorStridedSliceOpConversion>(ctx);
   patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
                   VectorExtractElementOpConversion, VectorExtractOpConversion,
-                  VectorInsertElementOpConversion, VectorInsertOpConversion,
-                  VectorOuterProductOpConversion, VectorTypeCastOpConversion,
-                  VectorPrintOpConversion>(ctx, converter);
+                  VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+                  VectorInsertOpConversion, VectorOuterProductOpConversion,
+                  VectorTypeCastOpConversion, VectorPrintOpConversion>(
+      ctx, converter);
 }
 
 namespace {

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index efbb8aa2e54b..d1535d59593c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -637,3 +637,29 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
 //      CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>">
 //      CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]">
 //      CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]">
+
+// CHECK-LABEL: llvm.func @vector_fma(
+//  CHECK-SAME: %[[A:.*]]: !llvm<"<8 x float>">, %[[B:.*]]: !llvm<"[2 x <4 x float>]">)
+//  CHECK-SAME: -> !llvm<"{ <8 x float>, [2 x <4 x float>] }"> {
+func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
+  //         CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) :
+  //    CHECK-SAME:   (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
+  %0 = vector.fma %a, %a, %a : vector<8xf32>
+  
+  //       CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
+  //       CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
+  //       CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
+  //       CHECK: %[[B0:.*]] = "llvm.intr.fma"(%[[b00]], %[[b01]], %[[b02]]) :
+  //  CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
+  //       CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm<"[2 x <4 x float>]">
+  //       CHECK: %[[b10:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
+  //       CHECK: %[[b11:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
+  //       CHECK: %[[b12:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
+  //       CHECK: %[[B1:.*]] = "llvm.intr.fma"(%[[b10]], %[[b11]], %[[b12]]) :
+  //  CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
+  //       CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]">
+  %1 = vector.fma %b, %b, %b : vector<2x4xf32>
+  
+  return %0, %1: vector<8xf32>, vector<2x4xf32>
+}
+        

diff  --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index 6c90cb6f431a..ff0078310af2 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -268,3 +268,12 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
 
   return %0, %1 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>
 }
+
+// CHECK-LABEL: @vector_fma
+func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
+  // CHECK: vector.fma %{{.*}} : vector<8xf32>
+  vector.fma %a, %a, %a : vector<8xf32>
+  // CHECK: vector.fma %{{.*}} : vector<8x4xf32>
+  vector.fma %b, %b, %b : vector<8x4xf32>
+  return
+}


        


More information about the Mlir-commits mailing list