[Mlir-commits] [mlir] e83b7b9 - [mlir] [VectorOps] Implement vector.reduce operation
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 11 11:32:09 PST 2020
Author: aartbik
Date: 2020-02-11T11:31:59-08:00
New Revision: e83b7b99da2e0385c567cd3883cad66fb5ce271c
URL: https://github.com/llvm/llvm-project/commit/e83b7b99da2e0385c567cd3883cad66fb5ce271c
DIFF: https://github.com/llvm/llvm-project/commit/e83b7b99da2e0385c567cd3883cad66fb5ce271c.diff
LOG: [mlir] [VectorOps] Implement vector.reduce operation
Summary:
This new operation operates on 1-D vectors and
forms the bridge between vector.contract and
llvm intrinsics for vector reductions.
Reviewers: nicolasvasilache, andydavis1, ftynse
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74370
Added:
Modified:
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index de7007e3b509..074a6d005376 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -183,6 +183,39 @@ def Vector_ContractionOp :
}];
}
+def Vector_ReductionOp :
+ Vector_Op<"reduction", [NoSideEffect,
+ PredOpTrait<"source operand and result have same element type",
+ TCresVTEtIsSameAsOpBase<0, 0>>]>,
+ Arguments<(ins StrAttr:$kind, AnyVector:$vector)>,
+ Results<(outs AnyType:$dest)> {
+ let summary = "reduction operation";
+ let description = [{
+ Reduces an 1-D vector "horizontally" into a scalar using the given
+ operation (add/mul/min/max for int/fp and and/or/xor for int only).
+ Note that these operations are restricted to 1-D vectors to remain
+ close to the corresponding LLVM intrinsics:
+
+ http://llvm.org/docs/LangRef.html#experimental-vector-reduction-intrinsics
+
+ Examples:
+ ```
+ %1 = vector.reduction "add", %0 : vector<16xf32> into f32
+
+ %3 = vector.reduction "xor", %2 : vector<4xi32> into i32
+ ```
+ }];
+ let verifier = [{ return ::verify(*this); }];
+ let assemblyFormat = [{
+ $kind `,` $vector attr-dict `:` type($vector) `into` type($dest)
+ }];
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return vector().getType().cast<VectorType>();
+ }
+ }];
+}
+
def Vector_BroadcastOp :
Vector_Op<"broadcast", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a3d724bea834..9fcad2f5063a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -124,6 +124,7 @@ static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
}
namespace {
+
class VectorBroadcastOpConversion : public LLVMOpLowering {
public:
explicit VectorBroadcastOpConversion(MLIRContext *context,
@@ -272,6 +273,73 @@ class VectorBroadcastOpConversion : public LLVMOpLowering {
}
};
+class VectorReductionOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorReductionOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::ReductionOp::getOperationName(), context,
+ typeConverter) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto reductionOp = cast<vector::ReductionOp>(op);
+ auto kind = reductionOp.kind();
+ Type eltType = reductionOp.dest().getType();
+ Type llvmType = lowering.convertType(eltType);
+ if (eltType.isInteger(32) || eltType.isInteger(64)) {
+ // Integer reductions: add/mul/min/max/and/or/xor.
+ if (kind == "add")
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_add>(
+ op, llvmType, operands[0]);
+ else if (kind == "mul")
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_mul>(
+ op, llvmType, operands[0]);
+ else if (kind == "min")
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smin>(
+ op, llvmType, operands[0]);
+ else if (kind == "max")
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_smax>(
+ op, llvmType, operands[0]);
+ else if (kind == "and")
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_and>(
+ op, llvmType, operands[0]);
+ else if (kind == "or")
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_or>(
+ op, llvmType, operands[0]);
+ else if (kind == "xor")
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_xor>(
+ op, llvmType, operands[0]);
+ else
+ return matchFailure();
+ return matchSuccess();
+
+ } else if (eltType.isF32() || eltType.isF64()) {
+ // Floating-point reductions: add/mul/min/max
+ if (kind == "add") {
+ Value zero = rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmType, rewriter.getZeroAttr(eltType));
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
+ op, llvmType, zero, operands[0]);
+ } else if (kind == "mul") {
+ Value one = rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0));
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
+ op, llvmType, one, operands[0]);
+ } else if (kind == "min")
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
+ op, llvmType, operands[0]);
+ else if (kind == "max")
+ rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmax>(
+ op, llvmType, operands[0]);
+ else
+ return matchFailure();
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+
class VectorShuffleOpConversion : public LLVMOpLowering {
public:
explicit VectorShuffleOpConversion(MLIRContext *context,
@@ -1056,12 +1124,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorStridedSliceOpConversion>(ctx);
- patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
- VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
- VectorInsertOpConversion, VectorOuterProductOpConversion,
- VectorTypeCastOpConversion, VectorPrintOpConversion>(
- ctx, converter);
+ patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
+ VectorShuffleOpConversion, VectorExtractElementOpConversion,
+ VectorExtractOpConversion, VectorFMAOp1DConversion,
+ VectorInsertElementOpConversion, VectorInsertOpConversion,
+ VectorOuterProductOpConversion, VectorTypeCastOpConversion,
+ VectorPrintOpConversion>(ctx, converter);
}
namespace {
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index a987a54f5ea1..b4d7aee70b17 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -60,6 +60,33 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
return builder.getI64ArrayAttr(values);
}
+//===----------------------------------------------------------------------===//
+// ReductionOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(ReductionOp op) {
+ // Verify for 1-D vector.
+ int64_t rank = op.getVectorType().getRank();
+ if (rank != 1)
+ return op.emitOpError("unsupported reduction rank: ") << rank;
+
+ // Verify supported reduction kind.
+ auto kind = op.kind();
+ Type eltType = op.dest().getType();
+ if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
+ if (eltType.isF32() || eltType.isF64() || eltType.isInteger(32) ||
+ eltType.isInteger(64))
+ return success();
+ return op.emitOpError("unsupported reduction type");
+ }
+ if (kind == "and" || kind == "or" || kind == "xor") {
+ if (eltType.isInteger(32) || eltType.isInteger(64))
+ return success();
+ return op.emitOpError("unsupported reduction type");
+ }
+ return op.emitOpError("unknown reduction kind: ") << kind;
+}
+
//===----------------------------------------------------------------------===//
// ContractionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d1535d59593c..5159031339aa 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -645,7 +645,7 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
// CHECK: "llvm.intr.fma"(%[[A]], %[[A]], %[[A]]) :
// CHECK-SAME: (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>">
%0 = vector.fma %a, %a, %a : vector<8xf32>
-
+
// CHECK: %[[b00:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[b01:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
// CHECK: %[[b02:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
@@ -659,7 +659,45 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>) -> (vector<8xf32>, vect
// CHECK-SAME: (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
// CHECK: llvm.insertvalue %[[B1]], {{.*}}[1] : !llvm<"[2 x <4 x float>]">
%1 = vector.fma %b, %b, %b : vector<2x4xf32>
-
+
return %0, %1: vector<8xf32>, vector<2x4xf32>
}
-
+
+func @reduce_f32(%arg0: vector<16xf32>) -> f32 {
+ %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32
+ return %0 : f32
+}
+// CHECK-LABEL: llvm.func @reduce_f32
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
+// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
+// CHECK: llvm.return %[[V]] : !llvm.float
+
+func @reduce_f64(%arg0: vector<16xf64>) -> f64 {
+ %0 = vector.reduction "add", %arg0 : vector<16xf64> into f64
+ return %0 : f64
+}
+// CHECK-LABEL: llvm.func @reduce_f64
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>">
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double
+// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]])
+// CHECK: llvm.return %[[V]] : !llvm.double
+
+func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
+ %0 = vector.reduction "add", %arg0 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: llvm.func @reduce_i32
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i32>">
+// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
+// CHECK: llvm.return %[[V]] : !llvm.i32
+
+func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
+ %0 = vector.reduction "add", %arg0 : vector<16xi64> into i64
+ return %0 : i64
+}
+// CHECK-LABEL: llvm.func @reduce_i64
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x i64>">
+// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
+// CHECK: llvm.return %[[V]] : !llvm.i64
+
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index a1fee2cf9a62..2a45820be7b0 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -990,3 +990,31 @@ func @shape_cast_
diff erent_tuple_sizes(
%1 = vector.shape_cast %arg1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
tuple<vector<20x2xf32>>
}
+
+// -----
+
+func @reduce_unknown_kind(%arg0: vector<16xf32>) -> f32 {
+ // expected-error at +1 {{'vector.reduction' op unknown reduction kind: joho}}
+ %0 = vector.reduction "joho", %arg0 : vector<16xf32> into f32
+}
+
+// -----
+
+func @reduce_elt_type_mismatch(%arg0: vector<16xf32>) -> i32 {
+ // expected-error at +1 {{'vector.reduction' op failed to verify that source operand and result have same element type}}
+ %0 = vector.reduction "add", %arg0 : vector<16xf32> into i32
+}
+
+// -----
+
+func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 {
+ // expected-error at +1 {{'vector.reduction' op unsupported reduction type}}
+ %0 = vector.reduction "xor", %arg0 : vector<16xf32> into f32
+}
+
+// -----
+
+func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
+ // expected-error at +1 {{'vector.reduction' op unsupported reduction rank: 2}}
+ %0 = vector.reduction "add", %arg0 : vector<4x16xf32> into f32
+}
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index ff0078310af2..bb5ca6eb8538 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -277,3 +277,37 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
vector.fma %b, %b, %b : vector<8x4xf32>
return
}
+
+// CHECK-LABEL: reduce_fp
+func @reduce_fp(%arg0: vector<16xf32>) -> f32 {
+ // CHECK: vector.reduction "add", %{{.*}} : vector<16xf32> into f32
+ vector.reduction "add", %arg0 : vector<16xf32> into f32
+ // CHECK: vector.reduction "mul", %{{.*}} : vector<16xf32> into f32
+ vector.reduction "mul", %arg0 : vector<16xf32> into f32
+ // CHECK: vector.reduction "min", %{{.*}} : vector<16xf32> into f32
+ vector.reduction "min", %arg0 : vector<16xf32> into f32
+ // CHECK: %[[X:.*]] = vector.reduction "max", %{{.*}} : vector<16xf32> into f32
+ %0 = vector.reduction "max", %arg0 : vector<16xf32> into f32
+ // CHECK: return %[[X]] : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: reduce_int
+func @reduce_int(%arg0: vector<16xi32>) -> i32 {
+ // CHECK: vector.reduction "add", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "add", %arg0 : vector<16xi32> into i32
+ // CHECK: vector.reduction "mul", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "mul", %arg0 : vector<16xi32> into i32
+ // CHECK: vector.reduction "min", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "min", %arg0 : vector<16xi32> into i32
+ // CHECK: vector.reduction "max", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "max", %arg0 : vector<16xi32> into i32
+ // CHECK: vector.reduction "and", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "and", %arg0 : vector<16xi32> into i32
+ // CHECK: vector.reduction "or", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "or", %arg0 : vector<16xi32> into i32
+ // CHECK: %[[X:.*]] = vector.reduction "xor", %{{.*}} : vector<16xi32> into i32
+ %0 = vector.reduction "xor", %arg0 : vector<16xi32> into i32
+ // CHECK: return %[[X]] : i32
+ return %0 : i32
+}
More information about the Mlir-commits
mailing list