[Mlir-commits] [mlir] 0d92470 - [mlir] [VectorOps] Merge VectorReduction/VectorReductionV2 into one Op
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 5 13:07:51 PST 2020
Author: aartbik
Date: 2020-03-05T13:07:31-08:00
New Revision: 0d924700a610c59555ede59325d3c51515679903
URL: https://github.com/llvm/llvm-project/commit/0d924700a610c59555ede59325d3c51515679903
DIFF: https://github.com/llvm/llvm-project/commit/0d924700a610c59555ede59325d3c51515679903.diff
LOG: [mlir] [VectorOps] Merge VectorReduction/VectorReductionV2 into one Op
Summary:
Paying off some technical debt in VectorOps, where I introduced a special
op for a fused accumulator into reduction to avoid some issues around
printing and parsing an optional accumulator. This CL merges the two
into one op again and does things the right way (still would be nice
to have "assemblyFormat" for optional operands though....).
Reviewers: nicolasvasilache, andydavis1, ftynse, rriddle
Reviewed By: nicolasvasilache
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D75699
Added:
Modified:
mlir/include/mlir/Dialect/VectorOps/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/VectorOps/VectorOps.cpp
mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
mlir/test/Dialect/VectorOps/invalid.mlir
mlir/test/Dialect/VectorOps/ops.mlir
mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index 7ab4ac17045a..aee269555bd3 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -187,12 +187,15 @@ def Vector_ReductionOp :
Vector_Op<"reduction", [NoSideEffect,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
- Arguments<(ins StrAttr:$kind, AnyVector:$vector)>,
+ Arguments<(ins StrAttr:$kind, AnyVector:$vector, Variadic<AnyType>:$acc)>,
Results<(outs AnyType:$dest)> {
let summary = "reduction operation";
let description = [{
Reduces an 1-D vector "horizontally" into a scalar using the given
operation (add/mul/min/max for int/fp and and/or/xor for int only).
+ Some reductions (add/mul for fp) also allow an optional fused
+ accumulator.
+
Note that these operations are restricted to 1-D vectors to remain
close to the corresponding LLVM intrinsics:
@@ -203,34 +206,9 @@ def Vector_ReductionOp :
%1 = vector.reduction "add", %0 : vector<16xf32> into f32
%3 = vector.reduction "xor", %2 : vector<4xi32> into i32
- ```
- }];
- let verifier = [{ return ::verify(*this); }];
- let assemblyFormat = [{
- $kind `,` $vector attr-dict `:` type($vector) `into` type($dest)
- }];
- let extraClassDeclaration = [{
- VectorType getVectorType() {
- return vector().getType().cast<VectorType>();
- }
- }];
-}
-// TODO(ajcbik): quick version with "fused" accumulator; next step
-// will merge Reduction/ReductionV2 into one with
-// an optional accumulator instead
-def Vector_ReductionV2Op :
- Vector_Op<"reductionv2", [NoSideEffect]>,
- Arguments<(ins StrAttr:$kind, VectorOf<[F32, F64]>:$vector, AnyType:$acc)>,
- Results<(outs AnyType:$dest)> {
- let summary = "reduction operation";
- let description = [{
- As vector.reduction, but with a fused accumulator (add/mul for fp only).
- }];
- let verifier = ?;
- let assemblyFormat = [{
- $kind `,` $vector `,` $acc attr-dict `:`
- type($vector) `,` type($acc) `into` type($dest)
+ %4 = vector.reduction "mul", %0, %1 : vector<16xf32> into f32
+ ```
}];
let extraClassDeclaration = [{
VectorType getVectorType() {
@@ -469,7 +447,7 @@ def Vector_FMAOp :
to the `llvm.fma.*` intrinsic.
Example:
-
+
```
%3 = vector.fma %0, %1, %2: vector<8x16xf32>
```
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index b9182ea1118f..a075848a9ac7 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -319,15 +319,22 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
} else if (eltType.isF32() || eltType.isF64()) {
// Floating-point reductions: add/mul/min/max
if (kind == "add") {
- Value zero = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmType, rewriter.getZeroAttr(eltType));
+ // Optional accumulator (or zero).
+ Value acc = operands.size() > 1 ? operands[1]
+ : rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmType,
+ rewriter.getZeroAttr(eltType));
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
- op, llvmType, zero, operands[0]);
+ op, llvmType, acc, operands[0]);
} else if (kind == "mul") {
- Value one = rewriter.create<LLVM::ConstantOp>(
- op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0));
+ // Optional accumulator (or one).
+ Value acc = operands.size() > 1
+ ? operands[1]
+ : rewriter.create<LLVM::ConstantOp>(
+ op->getLoc(), llvmType,
+ rewriter.getFloatAttr(eltType, 1.0));
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
- op, llvmType, one, operands[0]);
+ op, llvmType, acc, operands[0]);
} else if (kind == "min")
rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_fmin>(
op, llvmType, operands[0]);
@@ -342,33 +349,6 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern {
}
};
-// TODO(ajcbik): merge Reduction and ReductionV2
-class VectorReductionV2OpConversion : public ConvertToLLVMPattern {
-public:
- explicit VectorReductionV2OpConversion(MLIRContext *context,
- LLVMTypeConverter &typeConverter)
- : ConvertToLLVMPattern(vector::ReductionV2Op::getOperationName(), context,
- typeConverter) {}
- PatternMatchResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- auto reductionOp = cast<vector::ReductionV2Op>(op);
- auto kind = reductionOp.kind();
- Type eltType = reductionOp.dest().getType();
- Type llvmType = typeConverter.convertType(eltType);
- if (kind == "add") {
- rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fadd>(
- op, llvmType, operands[1], operands[0]);
- return matchSuccess();
- } else if (kind == "mul") {
- rewriter.replaceOpWithNewOp<LLVM::experimental_vector_reduce_v2_fmul>(
- op, llvmType, operands[1], operands[0]);
- return matchSuccess();
- }
- return matchFailure();
- }
-};
-
class VectorShuffleOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorShuffleOpConversion(MLIRContext *context,
@@ -1154,12 +1134,11 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorStridedSliceOpConversion>(ctx);
patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
- VectorReductionV2OpConversion, VectorShuffleOpConversion,
- VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
- VectorInsertOpConversion, VectorOuterProductOpConversion,
- VectorTypeCastOpConversion, VectorPrintOpConversion>(
- ctx, converter);
+ VectorShuffleOpConversion, VectorExtractElementOpConversion,
+ VectorExtractOpConversion, VectorFMAOp1DConversion,
+ VectorInsertElementOpConversion, VectorInsertOpConversion,
+ VectorOuterProductOpConversion, VectorTypeCastOpConversion,
+ VectorPrintOpConversion>(ctx, converter);
}
namespace {
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 345ee9a9dfe5..ca7163997cc9 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -74,17 +74,54 @@ static LogicalResult verify(ReductionOp op) {
auto kind = op.kind();
Type eltType = op.dest().getType();
if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
- if (eltType.isF32() || eltType.isF64() || eltType.isSignlessInteger(32) ||
- eltType.isSignlessInteger(64))
- return success();
- return op.emitOpError("unsupported reduction type");
+ if (!eltType.isF32() && !eltType.isF64() &&
+ !eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64))
+ return op.emitOpError("unsupported reduction type");
+ } else if (kind == "and" || kind == "or" || kind == "xor") {
+ if (!eltType.isSignlessInteger(32) && !eltType.isSignlessInteger(64))
+ return op.emitOpError("unsupported reduction type");
+ } else {
+ return op.emitOpError("unknown reduction kind: ") << kind;
}
- if (kind == "and" || kind == "or" || kind == "xor") {
- if (eltType.isSignlessInteger(32) || eltType.isSignlessInteger(64))
- return success();
- return op.emitOpError("unsupported reduction type");
+
+ // Verify optional accumulator.
+ if (!op.acc().empty()) {
+ if (kind != "add" && kind != "mul")
+ return op.emitOpError("no accumulator for reduction kind: ") << kind;
+ if (!eltType.isF32() && !eltType.isF64())
+ return op.emitOpError("no accumulator for type: ") << eltType;
}
- return op.emitOpError("unknown reduction kind: ") << kind;
+
+ return success();
+}
+
+static ParseResult parseReductionOp(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 2> operandsInfo;
+ Type redType;
+ Type resType;
+ Attribute attr;
+ if (parser.parseAttribute(attr, "kind", result.attributes) ||
+ parser.parseComma() || parser.parseOperandList(operandsInfo) ||
+ parser.parseColonType(redType) ||
+ parser.parseKeywordType("into", resType) ||
+ (operandsInfo.size() > 0 &&
+ parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
+ (operandsInfo.size() > 1 &&
+ parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
+ parser.addTypeToList(resType, result.types))
+ return failure();
+ if (operandsInfo.size() < 1 || operandsInfo.size() > 2)
+ return parser.emitError(parser.getNameLoc(),
+ "unsupported number of operands");
+ return success();
+}
+
+static void print(OpAsmPrinter &p, ReductionOp op) {
+ p << op.getOperationName() << " \"" << op.kind() << "\", " << op.vector();
+ if (!op.acc().empty())
+ p << ", " << op.acc();
+ p << " : " << op.vector().getType() << " into " << op.dest().getType();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 4c626b068aa6..8764d487dfb9 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -864,7 +864,7 @@ class InsertSlicesOpLowering : public OpRewritePattern<vector::InsertSlicesOp> {
}
};
-/// Progressive lowering of ConstractionOp.
+/// Progressive lowering of ContractionOp.
/// One:
/// %x = vector.contract with at least one free/batch dimension
/// is replaced by:
@@ -1017,8 +1017,8 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
Value zero = zeroVector(loc, lhsType, rewriter);
Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
StringAttr kind = rewriter.getStringAttr("add");
- return rewriter.create<vector::ReductionV2Op>(loc, resType, kind, fma,
- op.acc());
+ return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
+ op.acc());
}
// Construct new iterator types and affine map array attribute.
SmallVector<AffineMap, 4> lowIndexingMaps;
@@ -1067,9 +1067,8 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
SmallVector<Attribute, 4> results;
for (auto it : llvm::enumerate(iteratorTypes)) {
int64_t idx = it.index();
- if (idx == index) {
+ if (idx == index)
continue;
- }
results.push_back(it.value());
}
return results;
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index 3743a5cfcf3c..91f6850779a9 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -1007,6 +1007,34 @@ func @reduce_elt_type_mismatch(%arg0: vector<16xf32>) -> i32 {
// -----
+func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
+ // expected-error at +1 {{'vector.reduction' op attribute 'kind' failed to satisfy constraint: string attribute}}
+ %0 = vector.reduction 1234, %arg0 : vector<16xf32> into i32
+}
+
+// -----
+
+func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
+ // expected-error at +1 {{'vector.reduction' unsupported number of operands}}
+ %0 = vector.reduction "add", %arg0, %arg1, %arg1 : vector<16xf32> into f32
+}
+
+// -----
+
+func @reduce_unsupported_accumulator_kind(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
+ // expected-error at +1 {{'vector.reduction' op no accumulator for reduction kind: min}}
+ %0 = vector.reduction "min", %arg0, %arg1 : vector<16xf32> into f32
+}
+
+// -----
+
+func @reduce_unsupported_accumulator_type(%arg0: vector<16xi32>, %arg1: i32) -> i32 {
+ // expected-error at +1 {{'vector.reduction' op no accumulator for type: 'i32'}}
+ %0 = vector.reduction "add", %arg0, %arg1 : vector<16xi32> into i32
+}
+
+// -----
+
func @reduce_unsupported_type(%arg0: vector<16xf32>) -> f32 {
// expected-error at +1 {{'vector.reduction' op unsupported reduction type}}
%0 = vector.reduction "xor", %arg0 : vector<16xf32> into f32
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index bb5ca6eb8538..f286b932a472 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -279,11 +279,15 @@ func @vector_fma(%a: vector<8xf32>, %b: vector<8x4xf32>) {
}
// CHECK-LABEL: reduce_fp
-func @reduce_fp(%arg0: vector<16xf32>) -> f32 {
+func @reduce_fp(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
// CHECK: vector.reduction "add", %{{.*}} : vector<16xf32> into f32
vector.reduction "add", %arg0 : vector<16xf32> into f32
+ // CHECK: vector.reduction "add", %{{.*}}, %{{.*}} : vector<16xf32> into f32
+ vector.reduction "add", %arg0, %arg1 : vector<16xf32> into f32
// CHECK: vector.reduction "mul", %{{.*}} : vector<16xf32> into f32
vector.reduction "mul", %arg0 : vector<16xf32> into f32
+ // CHECK: vector.reduction "mul", %{{.*}}, %{{.*}} : vector<16xf32> into f32
+ vector.reduction "mul", %arg0, %arg1 : vector<16xf32> into f32
// CHECK: vector.reduction "min", %{{.*}} : vector<16xf32> into f32
vector.reduction "min", %arg0 : vector<16xf32> into f32
// CHECK: %[[X:.*]] = vector.reduction "max", %{{.*}} : vector<16xf32> into f32
diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
index 362c85a38d09..275fd0841a60 100644
--- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
+++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir
@@ -16,7 +16,7 @@
// CHECK-SAME: %[[C:.*2]]: f32
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
-// CHECK: %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]]
+// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32
// CHECK: return %[[R]] : f32
func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 {
@@ -44,12 +44,12 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32>
-// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: return %[[T9]] : vector<2xf32>
@@ -80,12 +80,12 @@ func @extract_contract2(%arg0: vector<2x3xf32>,
// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
// CHECK: %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32>
-// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: return %[[T9]] : vector<2xf32>
@@ -123,7 +123,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32>
-// CHECK: %[[T10:.*]] = vector.reductionv2 "add", %[[T9]], %[[T8]] : vector<2xf32>, f32 into f32
+// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
// CHECK: %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32>
@@ -133,7 +133,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
// CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32>
-// CHECK: %[[T20:.*]] = vector.reductionv2 "add", %[[T19]], %[[T18]] : vector<2xf32>, f32 into f32
+// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32>
@@ -146,7 +146,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
// CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
// CHECK: %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32>
-// CHECK: %[[T33:.*]] = vector.reductionv2 "add", %[[T32]], %[[T31]] : vector<2xf32>, f32 into f32
+// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T31]] : vector<2xf32> into f32
// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
// CHECK: %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32>
@@ -156,7 +156,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
// CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
// CHECK: %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32>
-// CHECK: %[[T43:.*]] = vector.reductionv2 "add", %[[T42]], %[[T41]] : vector<2xf32>, f32 into f32
+// CHECK: %[[T43:.*]] = vector.reduction "add", %[[T42]], %[[T41]] : vector<2xf32> into f32
// CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
// CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
// CHECK: return %[[T45]] : vector<2x2xf32>
@@ -187,11 +187,11 @@ func @extract_contract4(%arg0: vector<2x2xf32>,
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[T1]], %[[Z]] : vector<3xf32>
-// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[C]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.fma %[[T4]], %[[T5]], %[[Z]] : vector<3xf32>
-// CHECK: %[[T7:.*]] = vector.reductionv2 "add", %[[T6]], %[[T3]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32
// CHECK: return %[[T7]] : f32
func @full_contract1(%arg0: vector<2x3xf32>,
@@ -228,7 +228,7 @@ func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32>
// CHECK: %[[T10:.*]] = vector.fma %[[T0]], %[[T9]], %[[Z]] : vector<3xf32>
-// CHECK: %[[T11:.*]] = vector.reductionv2 "add", %[[T10]], %[[C]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32
// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
// CHECK: %[[T14:.*]] = vector.extract %[[T13]][1] : vector<2xf32>
@@ -240,7 +240,7 @@ func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32>
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32>
// CHECK: %[[T22:.*]] = vector.fma %[[T12]], %[[T21]], %[[Z]] : vector<3xf32>
-// CHECK: %[[T23:.*]] = vector.reductionv2 "add", %[[T22]], %[[T11]] : vector<3xf32>, f32 into f32
+// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32
// CHECK: return %[[T23]] : f32
func @full_contract2(%arg0: vector<2x3xf32>,
More information about the Mlir-commits
mailing list