[Mlir-commits] [mlir] fa596c6 - [mlir][Vector] Fix reordering of floating point adds during lower of `vector.contract`.
Mahesh Ravishankar
llvmlistbot at llvm.org
Mon Jun 27 22:27:13 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-06-28T05:26:39Z
New Revision: fa596c6921159af50e69cc3be189d951521a9eb9
URL: https://github.com/llvm/llvm-project/commit/fa596c6921159af50e69cc3be189d951521a9eb9
DIFF: https://github.com/llvm/llvm-project/commit/fa596c6921159af50e69cc3be189d951521a9eb9.diff
LOG: [mlir][Vector] Fix reordering of floating point adds during lower of `vector.contract`.
Adding the accumulator value after the `vector.contract` changes the
precision of the operation. This makes sure the accumulator is carried
through to `vector.reduce` (and down to LLVM).
Differential Revision: https://reviews.llvm.org/D128674
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 152e5387c2e0a..57c02c9a35ba3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -91,7 +91,7 @@ def Vector_ContractionOp :
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
Vector_AffineMapArrayAttr:$indexing_maps,
- ArrayAttr:$iterator_types,
+ ArrayAttr:$iterator_types,
DefaultValuedAttr<Vector_CombiningKindAttr,
"CombiningKind::ADD">:$kind)>,
Results<(outs AnyType)> {
@@ -280,8 +280,7 @@ def Vector_ReductionOp :
let description = [{
Reduces an 1-D vector "horizontally" into a scalar using the given
operation (add/mul/min/max for int/fp and and/or/xor for int only).
- Some reductions (add/mul for fp) also allow an optional fused
- accumulator.
+ Reductions also allow an optional fused accumulator.
Note that these operations are restricted to 1-D vectors to remain
close to the corresponding LLVM intrinsics:
@@ -1760,7 +1759,7 @@ def Vector_GatherOp :
Vector_Op<"gather">,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
+ VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$pass_thru)>,
Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1826,7 +1825,7 @@ def Vector_ScatterOp :
Vector_Op<"scatter">,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
+ VectorOfRankAndType<[1], [AnyInteger, Index]>:$index_vec,
VectorOfRankAndType<[1], [I1]>:$mask,
VectorOfRank<[1]>:$valueToStore)> {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a164c7d167dc6..fa4920486aad8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -362,6 +362,37 @@ class VectorCompressStoreOpConversion
}
};
+/// Helper method to lower a `vector.reduction` op that performs an arithmetic
+/// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use
+/// and `ScalarOp` is the scalar operation used to add the accumulation value if
+/// non-null.
+template <class VectorOp, class ScalarOp>
+static Value createIntegerReductionArithmeticOpLowering(
+ ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+ Value vectorOperand, Value accumulator) {
+ Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+ if (accumulator)
+ result = rewriter.create<ScalarOp>(loc, accumulator, result);
+ return result;
+}
+
+/// Helper method to lower a `vector.reduction` operation that performs
+/// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector
+/// intrinsic to use and `predicate` is the predicate to use to compare+combine
+/// the accumulator value if non-null.
+template <class VectorOp>
+static Value createIntegerReductionComparisonOpLowering(
+ ConversionPatternRewriter &rewriter, Location loc, Type llvmType,
+ Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) {
+ Value result = rewriter.create<VectorOp>(loc, llvmType, vectorOperand);
+ if (accumulator) {
+ Value cmp =
+ rewriter.create<LLVM::ICmpOp>(loc, predicate, accumulator, result);
+ result = rewriter.create<LLVM::SelectOp>(loc, cmp, accumulator, result);
+ }
+ return result;
+}
+
/// Conversion pattern for all vector reductions.
class VectorReductionOpConversion
: public ConvertOpToLLVMPattern<vector::ReductionOp> {
@@ -377,38 +408,68 @@ class VectorReductionOpConversion
auto kind = reductionOp.getKind();
Type eltType = reductionOp.getDest().getType();
Type llvmType = typeConverter->convertType(eltType);
- Value operand = adaptor.getOperands()[0];
+ Value operand = adaptor.getVector();
+ Value acc = adaptor.getAcc();
+ Location loc = reductionOp.getLoc();
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
- if (kind == vector::CombiningKind::ADD)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_add>(reductionOp,
- llvmType, operand);
- else if (kind == vector::CombiningKind::MUL)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
- llvmType, operand);
- else if (kind == vector::CombiningKind::MINUI)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
- reductionOp, llvmType, operand);
- else if (kind == vector::CombiningKind::MINSI)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
- reductionOp, llvmType, operand);
- else if (kind == vector::CombiningKind::MAXUI)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
- reductionOp, llvmType, operand);
- else if (kind == vector::CombiningKind::MAXSI)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
- reductionOp, llvmType, operand);
- else if (kind == vector::CombiningKind::AND)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_and>(reductionOp,
- llvmType, operand);
- else if (kind == vector::CombiningKind::OR)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_or>(reductionOp,
- llvmType, operand);
- else if (kind == vector::CombiningKind::XOR)
- rewriter.replaceOpWithNewOp<LLVM::vector_reduce_xor>(reductionOp,
- llvmType, operand);
- else
+ Value result;
+ switch (kind) {
+ case vector::CombiningKind::ADD:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_add,
+ LLVM::AddOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ case vector::CombiningKind::MUL:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_mul,
+ LLVM::MulOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ case vector::CombiningKind::MINUI:
+ result = createIntegerReductionComparisonOpLowering<
+ LLVM::vector_reduce_umin>(rewriter, loc, llvmType, operand, acc,
+ LLVM::ICmpPredicate::ule);
+ break;
+ case vector::CombiningKind::MINSI:
+ result = createIntegerReductionComparisonOpLowering<
+ LLVM::vector_reduce_smin>(rewriter, loc, llvmType, operand, acc,
+ LLVM::ICmpPredicate::sle);
+ break;
+ case vector::CombiningKind::MAXUI:
+ result = createIntegerReductionComparisonOpLowering<
+ LLVM::vector_reduce_umax>(rewriter, loc, llvmType, operand, acc,
+ LLVM::ICmpPredicate::uge);
+ break;
+ case vector::CombiningKind::MAXSI:
+ result = createIntegerReductionComparisonOpLowering<
+ LLVM::vector_reduce_smax>(rewriter, loc, llvmType, operand, acc,
+ LLVM::ICmpPredicate::sge);
+ break;
+ case vector::CombiningKind::AND:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_and,
+ LLVM::AndOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ case vector::CombiningKind::OR:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_or,
+ LLVM::OrOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ case vector::CombiningKind::XOR:
+ result =
+ createIntegerReductionArithmeticOpLowering<LLVM::vector_reduce_xor,
+ LLVM::XOrOp>(
+ rewriter, loc, llvmType, operand, acc);
+ break;
+ default:
return failure();
+ }
+ rewriter.replaceOp(reductionOp, result);
+
return success();
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ebf36627b6308..8332c0b8b260b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -403,15 +403,6 @@ LogicalResult ReductionOp::verify() {
<< eltType << "' for kind '" << stringifyCombiningKind(getKind())
<< "'";
- // Verify optional accumulator.
- if (getAcc()) {
- if (getKind() != CombiningKind::ADD && getKind() != CombiningKind::MUL)
- return emitOpError("no accumulator for reduction kind: ")
- << stringifyCombiningKind(getKind());
- if (!eltType.isa<FloatType>())
- return emitOpError("no accumulator for type: ") << eltType;
- }
-
return success();
}
@@ -1969,7 +1960,7 @@ LogicalResult InsertOp::verify() {
(static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
static_cast<unsigned>(destVectorType.getRank())))
return emitOpError("expected position attribute rank + source rank to "
- "match dest vector rank");
+ "match dest vector rank");
if (!srcVectorType &&
(positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
return emitOpError(
@@ -2302,8 +2293,7 @@ LogicalResult ReshapeOp::verify() {
int64_t numFixedVectorSizes = fixedVectorSizes.size();
if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
- return emitError("invalid input shape for vector type ")
- << inputVectorType;
+ return emitError("invalid input shape for vector type ") << inputVectorType;
if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
return emitError("invalid output shape for vector type ")
@@ -2396,24 +2386,29 @@ LogicalResult ExtractStridedSliceOp::verify() {
auto sizes = getSizesAttr();
auto strides = getStridesAttr();
if (offsets.size() != sizes.size() || offsets.size() != strides.size())
- return emitOpError("expected offsets, sizes and strides attributes of same size");
+ return emitOpError(
+ "expected offsets, sizes and strides attributes of same size");
auto shape = type.getShape();
auto offName = getOffsetsAttrName();
auto sizesName = getSizesAttrName();
auto stridesName = getStridesAttrName();
- if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
- failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
+ if (failed(
+ isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
+ failed(
+ isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
stridesName)) ||
- failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
+ failed(
+ isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
/*halfOpen=*/false,
/*min=*/1)) ||
- failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName,
+ failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
+ stridesName,
/*halfOpen=*/false)) ||
- failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape,
- offName, sizesName,
+ failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
+ shape, offName, sizesName,
/*halfOpen=*/false)))
return failure();
@@ -4223,7 +4218,7 @@ LogicalResult BitCastOp::verify() {
if (sourceVectorType.getRank() == 0) {
if (sourceElementBits != resultElementBits)
return emitOpError("source/result bitwidth of the 0-D vector element "
- "types must be equal");
+ "types must be equal");
} else if (sourceElementBits * sourceVectorType.getShape().back() !=
resultElementBits * resultVectorType.getShape().back()) {
return emitOpError(
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 67635d69ddeaf..338ffd053486a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1875,10 +1875,9 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
assert(rhsType.getRank() == 1 && "corrupt contraction");
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
auto kind = vector::CombiningKind::ADD;
- Value res = rewriter.create<vector::ReductionOp>(loc, kind, m);
if (auto acc = op.getAcc())
- res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
- return res;
+ return rewriter.create<vector::ReductionOp>(loc, kind, m, acc);
+ return rewriter.create<vector::ReductionOp>(loc, kind, m);
}
// Construct new iterator types and affine map array attribute.
std::array<AffineMap, 3> lowIndexingMaps = {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index adad1ea016ad2..0d8406c793d67 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1178,6 +1178,206 @@ func.func @reduce_i32(%arg0: vector<16xi32>) -> i32 {
// -----
+func.func @reduce_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <add>, %arg0, %arg1 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_acc_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.add"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.add %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_mul_i32(%arg0: vector<16xi32>) -> i32 {
+ %0 = vector.reduction <mul>, %arg0 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_mul_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.mul"(%[[A]])
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_mul_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <mul>, %arg0, %arg1 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_mul_acc_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.mul"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.mul %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_minui_i32(%arg0: vector<16xi32>) -> i32 {
+ %0 = vector.reduction <minui>, %arg0 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_minui_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umin"(%[[A]])
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_minui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <minui>, %arg0, %arg1 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_minui_acc_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.umin"(%[[A]])
+// CHECK: %[[S:.*]] = llvm.icmp "ule" %[[ACC]], %[[R]]
+// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_maxui_i32(%arg0: vector<16xi32>) -> i32 {
+ %0 = vector.reduction <maxui>, %arg0 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxui_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.umax"(%[[A]])
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_maxui_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <maxui>, %arg0, %arg1 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxui_acc_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.umax"(%[[A]])
+// CHECK: %[[S:.*]] = llvm.icmp "uge" %[[ACC]], %[[R]]
+// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_minsi_i32(%arg0: vector<16xi32>) -> i32 {
+ %0 = vector.reduction <minsi>, %arg0 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_minsi_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smin"(%[[A]])
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_minsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <minsi>, %arg0, %arg1 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_minsi_acc_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.smin"(%[[A]])
+// CHECK: %[[S:.*]] = llvm.icmp "sle" %[[ACC]], %[[R]]
+// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_maxsi_i32(%arg0: vector<16xi32>) -> i32 {
+ %0 = vector.reduction <maxsi>, %arg0 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxsi_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.smax"(%[[A]])
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_maxsi_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <maxsi>, %arg0, %arg1 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_maxsi_acc_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.smax"(%[[A]])
+// CHECK: %[[S:.*]] = llvm.icmp "sge" %[[ACC]], %[[R]]
+// CHECK: %[[V:.*]] = llvm.select %[[S]], %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_and_i32(%arg0: vector<16xi32>) -> i32 {
+ %0 = vector.reduction <and>, %arg0 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_and_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.and"(%[[A]])
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_and_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <and>, %arg0, %arg1 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_and_acc_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.and"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.and %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_or_i32(%arg0: vector<16xi32>) -> i32 {
+ %0 = vector.reduction <or>, %arg0 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_or_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.or"(%[[A]])
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_or_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <or>, %arg0, %arg1 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_or_acc_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.or"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.or %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_xor_i32(%arg0: vector<16xi32>) -> i32 {
+ %0 = vector.reduction <xor>, %arg0 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_xor_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>)
+// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.xor"(%[[A]])
+// CHECK: return %[[V]] : i32
+
+// -----
+
+func.func @reduce_xor_acc_i32(%arg0: vector<16xi32>, %arg1 : i32) -> i32 {
+ %0 = vector.reduction <xor>, %arg0, %arg1 : vector<16xi32> into i32
+ return %0 : i32
+}
+// CHECK-LABEL: @reduce_xor_acc_i32(
+// CHECK-SAME: %[[A:.*]]: vector<16xi32>, %[[ACC:.*]]: i32)
+// CHECK: %[[R:.*]] = "llvm.intr.vector.reduce.xor"(%[[A]])
+// CHECK: %[[V:.*]] = llvm.xor %[[ACC]], %[[R]]
+// CHECK: return %[[V]] : i32
+
+// -----
+
func.func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
%0 = vector.reduction <add>, %arg0 : vector<16xi64> into i64
return %0 : i64
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 4dd9388b5cdb6..243e83e8ceb6a 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1116,27 +1116,6 @@ func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32)
// -----
-func.func @reduce_unsupported_accumulator_kind(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
- // expected-error at +1 {{'vector.reduction' op no accumulator for reduction kind: min}}
- %0 = vector.reduction <minf>, %arg0, %arg1 : vector<16xf32> into f32
-}
-
-// -----
-
-func.func @reduce_unsupported_accumulator_type(%arg0: vector<16xi32>, %arg1: i32) -> i32 {
- // expected-error at +1 {{'vector.reduction' op no accumulator for type: 'i32'}}
- %0 = vector.reduction <add>, %arg0, %arg1 : vector<16xi32> into i32
-}
-
-// -----
-
-func.func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 {
- // expected-error at +1 {{'vector.reduction' op unsupported reduction type}}
- %0 = vector.reduction <xor>, %arg0 : vector<16xf32> into f32
-}
-
-// -----
-
func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
// expected-error at +1 {{'vector.reduction' op unsupported reduction rank: 2}}
%0 = vector.reduction <add>, %arg0 : vector<4x16xf32> into f32
diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
index 70f86fd4dc6dd..4123ef3b75135 100644
--- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir
@@ -19,9 +19,8 @@
// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
// CHECK-SAME: %[[C:.*2]]: f32
// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
-// CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]] : vector<4xf32> into f32
-// CHECK: %[[ACC:.*]] = arith.addf %[[R]], %[[C]] : f32
-// CHECK: return %[[ACC]] : f32
+// CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xf32> into f32
+// CHECK: return %[[R]] : f32
func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
%0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
@@ -34,9 +33,8 @@ func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2:
// CHECK-SAME: %[[B:.*1]]: vector<4xi32>,
// CHECK-SAME: %[[C:.*2]]: i32
// CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32>
-// CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]] : vector<4xi32> into i32
-// CHECK: %[[ACC:.*]] = arith.addi %[[R]], %[[C]] : i32
-// CHECK: return %[[ACC]] : i32
+// CHECK: %[[R:.*]] = vector.reduction <add>, %[[F]], %[[C]] : vector<4xi32> into i32
+// CHECK: return %[[R]] : i32
func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 {
%0 = vector.contract #dotp_trait %arg0, %arg1, %arg2
@@ -72,7 +70,7 @@ func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %ar
func.func @extract_contract2(%arg0: vector<2x3xf32>,
%arg1: vector<3xf32>,
- %arg2: vector<2xf32>) -> vector<2xf32> {
+ %arg2: vector<2xf32>) -> vector<2xf32> {
%0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
: vector<2x3xf32>, vector<3xf32> into vector<2xf32>
return %0 : vector<2xf32>
@@ -95,7 +93,7 @@ func.func @extract_contract2(%arg0: vector<2x3xf32>,
// CHECK: return %[[T10]] : vector<2xi32>
func.func @extract_contract2_int(%arg0: vector<2x3xi32>,
%arg1: vector<3xi32>,
- %arg2: vector<2xi32>) -> vector<2xi32> {
+ %arg2: vector<2xi32>) -> vector<2xi32> {
%0 = vector.contract #matvec_trait %arg0, %arg1, %arg2
: vector<2x3xi32>, vector<3xi32> into vector<2xi32>
return %0 : vector<2xi32>
@@ -201,18 +199,16 @@ func.func @extract_contract4(%arg0: vector<2x2xf32>,
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]] : vector<3xf32> into f32
-// CHECK: %[[T4:.*]] = arith.addf %[[T3]], %[[C]] : f32
+// CHECK: %[[T3:.*]] = vector.reduction <add>, %[[T2]], %[[C]] : vector<3xf32> into f32
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32>
-// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]] : vector<3xf32> into f32
-// CHECK: %[[T9:.*]] = arith.addf %[[T8]], %[[T4]] : f32
-// CHECK: return %[[T9]] : f32
+// CHECK: %[[T8:.*]] = vector.reduction <add>, %[[T7]], %[[T3]] : vector<3xf32> into f32
+// CHECK: return %[[T8]] : f32
func.func @full_contract1(%arg0: vector<2x3xf32>,
%arg1: vector<2x3xf32>,
- %arg2: f32) -> f32 {
+ %arg2: f32) -> f32 {
%0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2
: vector<2x3xf32>, vector<2x3xf32> into f32
return %0 : f32
@@ -241,8 +237,7 @@ func.func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32>
// CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32>
-// CHECK: %[[T11:.*]] = vector.reduction <add>, %[[T10]] : vector<3xf32> into f32
-// CHECK: %[[ACC0:.*]] = arith.addf %[[T11]], %[[C]] : f32
+// CHECK: %[[T11:.*]] = vector.reduction <add>, %[[T10]], %[[C]] : vector<3xf32> into f32
//
// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf
@@ -252,13 +247,12 @@ func.func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32>
// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32>
// CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32>
-// CHECK: %[[T23:.*]] = vector.reduction <add>, %[[T22]] : vector<3xf32> into f32
-// CHECK: %[[ACC1:.*]] = arith.addf %[[T23]], %[[ACC0]] : f32
-// CHECK: return %[[ACC1]] : f32
+// CHECK: %[[T23:.*]] = vector.reduction <add>, %[[T22]], %[[T11]] : vector<3xf32> into f32
+// CHECK: return %[[T23]] : f32
func.func @full_contract2(%arg0: vector<2x3xf32>,
%arg1: vector<3x2xf32>,
- %arg2: f32) -> f32 {
+ %arg2: f32) -> f32 {
%0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2
: vector<2x3xf32>, vector<3x2xf32> into f32
return %0 : f32
More information about the Mlir-commits
mailing list