[Mlir-commits] [mlir] 681f929 - [mlir][VectorOps] Introduce a `vector.fma` op that works on n-D vectors and lowers to `llvm.intrin.fmuladd`
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Feb 7 12:48:33 PST 2020
Author: Nicolas Vasilache
Date: 2020-02-07T15:44:53-05:00
New Revision: 681f929f591616ca048aa470d030d985b6719216
URL: https://github.com/llvm/llvm-project/commit/681f929f591616ca048aa470d030d985b6719216
DIFF: https://github.com/llvm/llvm-project/commit/681f929f591616ca048aa470d030d985b6719216.diff
LOG: [mlir][VectorOps] Introduce a `vector.fma` op that works on n-D vectors and lowers to `llvm.intrin.fmuladd`
Summary:
The `vector.fma` operation is portable enough across targets that we do not want
to keep it wrapped under `vector.outerproduct` and `llvm.intrin.fmuladd`.
This revision lifts the op into the vector dialect and implements the lowering to LLVM by using two patterns:
1. a pattern that lowers from n-D to (n-1)-D by unrolling when n > 2
2. a pattern that converts from 1-D to the proper LLVM representation
Reviewers: ftynse, stellaraccident, aartbik, dcaballe, jsetoain, tetuante
Reviewed By: aartbik
Subscribers: fhahn, dcaballe, merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74075
Added:
Modified:
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/VectorOps/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 6dfb1dbea75f..e7e165aa381f 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -388,6 +388,38 @@ def Vector_ExtractSlicesOp :
}];
}
+def Vector_FMAOp :
+ Op<Vector_Dialect, "fma", [NoSideEffect,
+ AllTypesMatch<["lhs", "rhs", "acc", "result"]>]>,
+ Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc)>,
+ Results<(outs AnyVector:$result)> {
+ let summary = "vector fused multiply-add";
+ let description = [{
+ Multiply-add expressions operate on n-D vectors and compute a fused
+ pointwise multiply-and-accumulate: `$result = `$lhs * $rhs + $acc`.
+ All operands and result have the same vector type. The semantics
+ of the operation correspond to those of the `llvm.fma`
+ [intrinsic](https://llvm.org/docs/LangRef.html#int-fma). In the
+ particular case of lowering to LLVM, this is guaranteed to lower
+ to the `llvm.fma.*` intrinsic.
+
+ Example:
+
+ ```
+ %3 = vector.fma %0, %1, %2: vector<8x16xf32>
+ ```
+ }];
+ // Fully specified by traits.
+ let verifier = ?;
+ let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)";
+ let builders = [OpBuilder<
+ "Builder *b, OperationState &result, Value lhs, Value rhs, Value acc",
+ "build(b, result, lhs.getType(), lhs, rhs, acc);">];
+ let extraClassDeclaration = [{
+ VectorType getVectorType() { return lhs().getType().cast<VectorType>(); }
+ }];
+}
+
def Vector_InsertElementOp :
Vector_Op<"insertelement", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b08ba1d9587e..a3d724bea834 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -410,6 +410,41 @@ class VectorExtractOpConversion : public LLVMOpLowering {
}
};
+/// Conversion pattern that turns a vector.fma on a 1-D vector
+/// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion.
+/// This does not match vectors of n >= 2 rank.
+///
+/// Example:
+/// ```
+/// vector.fma %a, %a, %a : vector<8xf32>
+/// ```
+/// is converted to:
+/// ```
+/// llvm.intr.fma %va, %va, %va:
+/// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">)
+/// -> !llvm<"<8 x float>">
+/// ```
+class VectorFMAOp1DConversion : public LLVMOpLowering {
+public:
+ explicit VectorFMAOp1DConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::FMAOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto adaptor = vector::FMAOpOperandAdaptor(operands);
+ vector::FMAOp fmaOp = cast<vector::FMAOp>(op);
+ VectorType vType = fmaOp.getVectorType();
+ if (vType.getRank() != 1)
+ return matchFailure();
+ rewriter.replaceOpWithNewOp<LLVM::FMAOp>(op, adaptor.lhs(), adaptor.rhs(),
+ adaptor.acc());
+ return matchSuccess();
+ }
+};
+
class VectorInsertElementOpConversion : public LLVMOpLowering {
public:
explicit VectorInsertElementOpConversion(MLIRContext *context,
@@ -502,6 +537,54 @@ class VectorInsertOpConversion : public LLVMOpLowering {
}
};
+/// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1.
+///
+/// Example:
+/// ```
+/// %d = vector.fma %a, %b, %c : vector<2x4xf32>
+/// ```
+/// is rewritten into:
+/// ```
+/// %r = splat %f0: vector<2x4xf32>
+/// %va = vector.extractvalue %a[0] : vector<2x4xf32>
+/// %vb = vector.extractvalue %b[0] : vector<2x4xf32>
+/// %vc = vector.extractvalue %c[0] : vector<2x4xf32>
+/// %vd = vector.fma %va, %vb, %vc : vector<4xf32>
+/// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32>
+/// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32>
+/// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32>
+/// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32>
+/// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32>
+/// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32>
+/// // %r3 holds the final value.
+/// ```
+class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
+public:
+ using OpRewritePattern<FMAOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(FMAOp op,
+ PatternRewriter &rewriter) const override {
+ auto vType = op.getVectorType();
+ if (vType.getRank() < 2)
+ return matchFailure();
+
+ auto loc = op.getLoc();
+ auto elemType = vType.getElementType();
+ Value zero = rewriter.create<ConstantOp>(loc, elemType,
+ rewriter.getZeroAttr(elemType));
+ Value desc = rewriter.create<SplatOp>(loc, vType, zero);
+ for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) {
+ Value extrLHS = rewriter.create<ExtractOp>(loc, op.lhs(), i);
+ Value extrRHS = rewriter.create<ExtractOp>(loc, op.rhs(), i);
+ Value extrACC = rewriter.create<ExtractOp>(loc, op.acc(), i);
+ Value fma = rewriter.create<FMAOp>(loc, extrLHS, extrRHS, extrACC);
+ desc = rewriter.create<InsertOp>(loc, fma, desc, i);
+ }
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
// When ranks are
diff erent, InsertStridedSlice needs to extract a properly
// ranked vector from the destination vector into which to insert. This pattern
// only takes care of this part and forwards the rest of the conversion to
@@ -969,14 +1052,16 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
- patterns.insert<VectorInsertStridedSliceOpDifferentRankRewritePattern,
+ patterns.insert<VectorFMAOpNDRewritePattern,
+ VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorStridedSliceOpConversion>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorInsertElementOpConversion, VectorInsertOpConversion,
- VectorOuterProductOpConversion, VectorTypeCastOpConversion,
- VectorPrintOpConversion>(ctx, converter);
+ VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorInsertOpConversion, VectorOuterProductOpConversion,
+ VectorTypeCastOpConversion, VectorPrintOpConversion>(
+ ctx, converter);
}
namespace {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index efbb8aa2e54b..d1535d59593c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -637,3 +637,29 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
// CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>">
// CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]">
// CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]">
+
+// CHECK-LABEL: llvm.func @vector_fma(
+// CHECK-SAME: %[[A:.*]]: !llvm<"<8 x float>">, %[[B:.*]]: !llvm<"[2 x <4 x float>]">)
+// CHECK-SAME: -> !llvm<"{ <8 x float>, [2 x <4 x float>] }"> {
+func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vector<2x4xf32>) {
+ // CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) :
+ // CHECK-SAME: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
+ %0 = vector.fma %a, %a, %a : vector<8xf32>
+
+ // CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
+ // CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
+ // CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
+ // CHECK: %[[B0:.*]] = "llvm.intr.fma"(%[[b00]], %[[b01]], %[[b02]]) :
+ // CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
+ // CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm<"[2 x <4 x float>]">
+ // CHECK: %[[b10:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
+ // CHECK: %[[b11:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
+ // CHECK: %[[b12:.*]] = llvm.extractvalue %[[B]][1] : !llvm<"[2 x <4 x float>]">
+ // CHECK: %[[B1:.*]] = "llvm.intr.fma"(%[[b10]], %[[b11]], %[[b12]]) :
+ // CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
+ // CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]">
+ %1 = vector.fma %b, %b, %b : vector<2x4xf32>
+
+ return %0, %1: vector<8xf32>, vector<2x4xf32>
+}
+
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index 6c90cb6f431a..ff0078310af2 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -268,3 +268,12 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
return %0, %1 : vector<15x2xf32>, tuple<vector<20x2xf32>, vector<12x2xf32>>
}
+
+// CHECK-LABEL: @vector_fma
+func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
+ // CHECK: vector.fma %{{.*}} : vector<8xf32>
+ vector.fma %a, %a, %a : vector<8xf32>
+ // CHECK: vector.fma %{{.*}} : vector<8x4xf32>
+ vector.fma %b, %b, %b : vector<8x4xf32>
+ return
+}
More information about the Mlir-commits
mailing list