[Mlir-commits] [mlir] eaf2588 - [mlir][Linalg] Add support for min/max reduction vectorization in linalg.generic
Diego Caballero
llvmlistbot at llvm.org
Tue Oct 5 15:51:25 PDT 2021
Author: Diego Caballero
Date: 2021-10-05T22:47:20Z
New Revision: eaf2588a51bf2d36a6aec573cd92c351062fa7d5
URL: https://github.com/llvm/llvm-project/commit/eaf2588a51bf2d36a6aec573cd92c351062fa7d5
DIFF: https://github.com/llvm/llvm-project/commit/eaf2588a51bf2d36a6aec573cd92c351062fa7d5.diff
LOG: [mlir][Linalg] Add support for min/max reduction vectorization in linalg.generic
This patch extends Linalg core vectorization with support for min/max reductions
in linalg.generic ops. It enables the reduction detection for min/max combiner ops.
It also renames MIN/MAX combining kinds to MINS/MAXS to make the sign explicit for
floating point and signed integer types. MINU/MAXU should be introduce din the future
for unsigned integer types.
Reviewed By: pifon2a, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D110854
Added:
Modified:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i32.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i4.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i64.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-reductions-si4.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-reductions-ui4.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index f24da79b01673..ea86336fd787e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -38,20 +38,25 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
}
// The "kind" of combining function for contractions and reductions.
-def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">;
-def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">;
-def COMBINING_KIND_MIN : BitEnumAttrCase<"MIN", 0x4, "min">;
-def COMBINING_KIND_MAX : BitEnumAttrCase<"MAX", 0x8, "max">;
-def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x10, "and">;
-def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x20, "or">;
-def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x40, "xor">;
+def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">;
+def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">;
+def COMBINING_KIND_MINUI : BitEnumAttrCase<"MINUI", 0x4, "minui">;
+def COMBINING_KIND_MINSI : BitEnumAttrCase<"MINSI", 0x8, "minsi">;
+def COMBINING_KIND_MINF : BitEnumAttrCase<"MINF", 0x10, "minf">;
+def COMBINING_KIND_MAXUI : BitEnumAttrCase<"MAXUI", 0x20, "maxui">;
+def COMBINING_KIND_MAXSI : BitEnumAttrCase<"MAXSI", 0x40, "maxsi">;
+def COMBINING_KIND_MAXF : BitEnumAttrCase<"MAXF", 0x80, "maxf">;
+def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x100, "and">;
+def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x200, "or">;
+def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x400, "xor">;
def CombiningKind : BitEnumAttr<
"CombiningKind",
"Kind of combining function for contractions and reductions",
- [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MIN,
- COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR,
- COMBINING_KIND_XOR]> {
+ [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
+ COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
+ COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
+ COMBINING_KIND_OR, COMBINING_KIND_XOR]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}
@@ -337,7 +342,7 @@ def Vector_MultiDimReductionOp :
static SmallVector<int64_t> inferDestShape(
ArrayRef<int64_t> shape, ArrayRef<bool> reducedDimsMask) {
- assert(shape.size() == reducedDimsMask.size() &&
+ assert(shape.size() == reducedDimsMask.size() &&
"shape and maks of
diff erent sizes");
SmallVector<int64_t> res;
for (auto it : llvm::zip(reducedDimsMask, shape))
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index ba942bbc1846d..f5ba71726af7f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -434,18 +434,16 @@ class VectorReductionOpConversion
else if (kind == "mul")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_mul>(reductionOp,
llvmType, operand);
- else if (kind == "min" &&
- (eltType.isIndex() || eltType.isUnsignedInteger()))
+ else if (kind == "minui")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umin>(
reductionOp, llvmType, operand);
- else if (kind == "min")
+ else if (kind == "minsi")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smin>(
reductionOp, llvmType, operand);
- else if (kind == "max" &&
- (eltType.isIndex() || eltType.isUnsignedInteger()))
+ else if (kind == "maxui")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_umax>(
reductionOp, llvmType, operand);
- else if (kind == "max")
+ else if (kind == "maxsi")
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_smax>(
reductionOp, llvmType, operand);
else if (kind == "and")
@@ -486,10 +484,14 @@ class VectorReductionOpConversion
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmul>(
reductionOp, llvmType, acc, operand,
rewriter.getBoolAttr(reassociateFPReductions));
- } else if (kind == "min")
+ } else if (kind == "minf")
+ // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle
+ // NaNs/-0.0/+0.0 in the same way.
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmin>(reductionOp,
llvmType, operand);
- else if (kind == "max")
+ else if (kind == "maxf")
+ // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle
+ // NaNs/-0.0/+0.0 in the same way.
rewriter.replaceOpWithNewOp<LLVM::vector_reduce_fmax>(reductionOp,
llvmType, operand);
else
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 921f8efb55363..76279ef2d3ca8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -111,24 +111,40 @@ static VectorType extractVectorTypeFromShapedValue(Value v) {
return VectorType::get(st.getShape(), st.getElementType());
}
+static llvm::Optional<vector::CombiningKind>
+getKindForOp(Operation *reductionOp) {
+ if (!reductionOp)
+ return llvm::None;
+ return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
+ reductionOp)
+ .Case<AddIOp, AddFOp>([&](auto op) { return vector::CombiningKind::ADD; })
+ .Case<MaxSIOp>([&](auto op) { return vector::CombiningKind::MAXSI; })
+ .Case<MaxFOp>([&](auto op) { return vector::CombiningKind::MAXF; })
+ .Case<MinSIOp>([&](auto op) { return vector::CombiningKind::MINSI; })
+ .Case<MinFOp>([&](auto op) { return vector::CombiningKind::MINF; })
+ .Default([&](auto op) { return llvm::None; });
+}
+
/// Check whether `outputOperand` is a reduction with a single combiner
-/// operation. Return the combiner operation of the reduction, which is assumed
-/// to be a binary operation. Multiple reduction operations would impose an
-/// ordering between reduction dimensions and is currently unsupported in
-/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
+/// operation. Return the combiner operation kind of the reduction, if
+/// supported. Return llvm::None, otherwise. Multiple reduction operations would
+/// impose an ordering between reduction dimensions and is currently unsupported
+/// in Linalg. This limitation is motivated by the fact that e.g. min(max(X)) !=
/// max(min(X))
// TODO: use in LinalgOp verification, there is a circular dependency atm.
-static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) {
+static llvm::Optional<vector::CombiningKind>
+matchLinalgReduction(OpOperand *outputOperand) {
auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
unsigned outputPos =
outputOperand->getOperandNumber() - linalgOp.getNumInputs();
+ // Only single combiner operatios are supported for now.
SmallVector<Operation *, 4> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) ||
combinerOps.size() != 1)
- return nullptr;
+ return llvm::None;
- // TODO: also assert no other subsequent ops break the reduction.
- return combinerOps[0];
+ // Return the combiner operation kind, if supported.
+ return getKindForOp(combinerOps[0]);
}
/// If `value` of assumed VectorType has a shape
diff erent than `shape`, try to
@@ -151,19 +167,6 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
newVecType, value);
}
-static llvm::Optional<vector::CombiningKind>
-getKindForOp(Operation *reductionOp) {
- if (!reductionOp)
- return llvm::None;
- return llvm::TypeSwitch<Operation *, llvm::Optional<vector::CombiningKind>>(
- reductionOp)
- .Case<AddIOp, AddFOp>([&](auto op) {
- return llvm::Optional<vector::CombiningKind>{
- vector::CombiningKind::ADD};
- })
- .Default([&](auto op) { return llvm::None; });
-}
-
/// If value of assumed VectorType has a shape
diff erent than `shape`, build and
/// return a new vector.broadcast to `shape`.
/// Otherwise, just return value.
@@ -173,9 +176,7 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
auto vecType = value.getType().dyn_cast<VectorType>();
if (!vecType || vecType.getShape() == targetVectorType.getShape())
return value;
- // At this point, we know we need to reduce. Detect the reduction operator.
- // TODO: Use the generic reduction detection util.
- Operation *reductionOp = getSingleBinaryOpAssumedReduction(outputOperand);
+
unsigned pos = 0;
MLIRContext *ctx = b.getContext();
SmallVector<AffineExpr> exprs;
@@ -183,8 +184,9 @@ static Value reduceIfNeeded(OpBuilder &b, VectorType targetVectorType,
if (isParallelIterator(s))
exprs.push_back(getAffineDimExpr(pos++, ctx));
auto loc = value.getLoc();
- // TODO: reuse common CombiningKing logic and support more than add.
- auto maybeKind = getKindForOp(reductionOp);
+
+ // At this point, we know we need to reduce. Detect the reduction operator.
+ auto maybeKind = matchLinalgReduction(outputOperand);
assert(maybeKind && "Failed precondition: could not get reduction kind");
unsigned idx = 0;
SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
@@ -597,8 +599,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
if (llvm::none_of(op.iterator_types(), isReductionIterator))
return failure();
for (OpOperand *opOperand : op.getOutputOperands()) {
- Operation *reductionOp = getSingleBinaryOpAssumedReduction(opOperand);
- if (!getKindForOp(reductionOp))
+ if (!matchLinalgReduction(opOperand))
return failure();
}
return success();
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 4857bebe0a65d..757e1a3362f0d 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -92,13 +92,18 @@ static bool isSupportedCombiningKind(CombiningKind combiningKind,
switch (combiningKind) {
case CombiningKind::ADD:
case CombiningKind::MUL:
- case CombiningKind::MIN:
- case CombiningKind::MAX:
return elementType.isIntOrIndexOrFloat();
+ case CombiningKind::MINUI:
+ case CombiningKind::MINSI:
+ case CombiningKind::MAXUI:
+ case CombiningKind::MAXSI:
case CombiningKind::AND:
case CombiningKind::OR:
case CombiningKind::XOR:
return elementType.isIntOrIndex();
+ case CombiningKind::MINF:
+ case CombiningKind::MAXF:
+ return elementType.isa<FloatType>();
}
return false;
}
@@ -151,8 +156,12 @@ static constexpr const CombiningKind combiningKindsList[] = {
// clang-format off
CombiningKind::ADD,
CombiningKind::MUL,
- CombiningKind::MIN,
- CombiningKind::MAX,
+ CombiningKind::MINUI,
+ CombiningKind::MINSI,
+ CombiningKind::MINF,
+ CombiningKind::MAXUI,
+ CombiningKind::MAXSI,
+ CombiningKind::MAXF,
CombiningKind::AND,
CombiningKind::OR,
CombiningKind::XOR,
@@ -291,22 +300,20 @@ static LogicalResult verify(ReductionOp op) {
return op.emitOpError("unsupported reduction rank: ") << rank;
// Verify supported reduction kind.
- auto kind = op.kind();
+ StringRef strKind = op.kind();
+ auto maybeKind = symbolizeCombiningKind(strKind);
+ if (!maybeKind)
+ return op.emitOpError("unknown reduction kind: ") << strKind;
+
Type eltType = op.dest().getType();
- if (kind == "add" || kind == "mul" || kind == "min" || kind == "max") {
- if (!eltType.isIntOrIndexOrFloat())
- return op.emitOpError("unsupported reduction type");
- } else if (kind == "and" || kind == "or" || kind == "xor") {
- if (!eltType.isIntOrIndex())
- return op.emitOpError("unsupported reduction type");
- } else {
- return op.emitOpError("unknown reduction kind: ") << kind;
- }
+ if (!isSupportedCombiningKind(*maybeKind, eltType))
+ return op.emitOpError("unsupported reduction type '")
+ << eltType << "' for kind '" << op.kind() << "'";
// Verify optional accumulator.
if (!op.acc().empty()) {
- if (kind != "add" && kind != "mul")
- return op.emitOpError("no accumulator for reduction kind: ") << kind;
+ if (strKind != "add" && strKind != "mul")
+ return op.emitOpError("no accumulator for reduction kind: ") << strKind;
if (!eltType.isa<FloatType>())
return op.emitOpError("no accumulator for type: ") << eltType;
}
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index ba97f99972a06..8f1d9fc08593e 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -821,15 +821,17 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
case CombiningKind::MUL:
combinedResult = rewriter.create<MulIOp>(loc, mul, acc);
break;
- case CombiningKind::MIN:
- combinedResult = rewriter.create<SelectOp>(
- loc, rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, mul, acc), mul,
- acc);
+ case CombiningKind::MINUI:
+ combinedResult = rewriter.create<MinUIOp>(loc, mul, acc);
break;
- case CombiningKind::MAX:
- combinedResult = rewriter.create<SelectOp>(
- loc, rewriter.create<CmpIOp>(loc, CmpIPredicate::sge, mul, acc), mul,
- acc);
+ case CombiningKind::MINSI:
+ combinedResult = rewriter.create<MinSIOp>(loc, mul, acc);
+ break;
+ case CombiningKind::MAXUI:
+ combinedResult = rewriter.create<MaxUIOp>(loc, mul, acc);
+ break;
+ case CombiningKind::MAXSI:
+ combinedResult = rewriter.create<MaxSIOp>(loc, mul, acc);
break;
case CombiningKind::AND:
combinedResult = rewriter.create<AndOp>(loc, mul, acc);
@@ -840,6 +842,9 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
case CombiningKind::XOR:
combinedResult = rewriter.create<XOrOp>(loc, mul, acc);
break;
+ case CombiningKind::MINF: // Only valid for floating point types.
+ case CombiningKind::MAXF: // Only valid for floating point types.
+ return Optional<Value>();
}
return Optional<Value>(combinedResult);
}
@@ -864,18 +869,18 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
case CombiningKind::MUL:
combinedResult = rewriter.create<MulFOp>(loc, mul, acc);
break;
- case CombiningKind::MIN:
- combinedResult = rewriter.create<SelectOp>(
- loc, rewriter.create<CmpFOp>(loc, CmpFPredicate::OLE, mul, acc), mul,
- acc);
+ case CombiningKind::MINF:
+ combinedResult = rewriter.create<MinFOp>(loc, mul, acc);
break;
- case CombiningKind::MAX:
- combinedResult = rewriter.create<SelectOp>(
- loc, rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, mul, acc), mul,
- acc);
+ case CombiningKind::MAXF:
+ combinedResult = rewriter.create<MaxFOp>(loc, mul, acc);
break;
case CombiningKind::ADD: // Already handled this special case above.
case CombiningKind::AND: // Only valid for integer types.
+ case CombiningKind::MINUI: // Only valid for integer types.
+ case CombiningKind::MINSI: // Only valid for integer types.
+ case CombiningKind::MAXUI: // Only valid for integer types.
+ case CombiningKind::MAXSI: // Only valid for integer types.
case CombiningKind::OR: // Only valid for integer types.
case CombiningKind::XOR: // Only valid for integer types.
return Optional<Value>();
@@ -3697,23 +3702,23 @@ struct UnrollOuterMultiReduction
else
result = rewriter.create<MulFOp>(loc, operand, result);
break;
- case vector::CombiningKind::MIN:
- if (elementType.isIntOrIndex())
- condition =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, operand, result);
- else
- condition =
- rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, result);
- result = rewriter.create<SelectOp>(loc, condition, operand, result);
+ case vector::CombiningKind::MINUI:
+ result = rewriter.create<MinUIOp>(loc, operand, result);
break;
- case vector::CombiningKind::MAX:
- if (elementType.isIntOrIndex())
- condition =
- rewriter.create<CmpIOp>(loc, CmpIPredicate::sge, operand, result);
- else
- condition =
- rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, operand, result);
- result = rewriter.create<SelectOp>(loc, condition, operand, result);
+ case vector::CombiningKind::MINSI:
+ result = rewriter.create<MinSIOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MINF:
+ result = rewriter.create<MinFOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MAXUI:
+ result = rewriter.create<MaxUIOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MAXSI:
+ result = rewriter.create<MaxSIOp>(loc, operand, result);
+ break;
+ case vector::CombiningKind::MAXF:
+ result = rewriter.create<MaxFOp>(loc, operand, result);
break;
case vector::CombiningKind::AND:
result = rewriter.create<AndOp>(loc, operand, result);
@@ -3771,10 +3776,18 @@ struct TwoDimMultiReductionToReduction
return "add";
case vector::CombiningKind::MUL:
return "mul";
- case vector::CombiningKind::MIN:
- return "min";
- case vector::CombiningKind::MAX:
- return "max";
+ case vector::CombiningKind::MINUI:
+ return "minui";
+ case vector::CombiningKind::MINSI:
+ return "minsi";
+ case vector::CombiningKind::MINF:
+ return "minf";
+ case vector::CombiningKind::MAXUI:
+ return "maxui";
+ case vector::CombiningKind::MAXSI:
+ return "maxsi";
+ case vector::CombiningKind::MAXF:
+ return "maxf";
case vector::CombiningKind::AND:
return "and";
case vector::CombiningKind::OR:
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 823f3060caf4b..d56214fb2fdf7 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -806,3 +806,54 @@ func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: ten
} -> tensor<5x2xf32>
return %0 : tensor<5x2xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @red_max_2d(
+func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
+ // CHECK: linalg.init_tensor [4] : tensor<4xf32>
+ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
+ // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
+ // CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32>
+ // CHECK: maxf {{.*}} : vector<4x4xf32>
+ // CHECK: vector.multi_reduction #vector.kind<maxf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
+ %minf32 = constant -3.40282e+38 : f32
+ %init = linalg.init_tensor [4] : tensor<4xf32>
+ %fill = linalg.fill(%minf32, %init) : f32, tensor<4xf32> -> tensor<4xf32>
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
+ ^bb0(%in0: f32, %out0: f32): // no predecessors
+ %max = maxf %in0, %out0 : f32
+ linalg.yield %max : f32
+ } -> tensor<4xf32>
+ return %red : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @red_min_2d(
+func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
+ // CHECK: linalg.init_tensor [4] : tensor<4xf32>
+ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
+ // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
+ // CHECK: vector.transfer_read {{.*}} : tensor<4xf32>, vector<4x4xf32>
+ // CHECK: minf {{.*}} : vector<4x4xf32>
+ // CHECK: vector.multi_reduction #vector.kind<minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
+ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
+ %maxf32 = constant 3.40282e+38 : f32
+ %init = linalg.init_tensor [4] : tensor<4xf32>
+ %fill = linalg.fill(%maxf32, %init) : f32, tensor<4xf32> -> tensor<4xf32>
+ %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"]}
+ ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
+ ^bb0(%in0: f32, %out0: f32): // no predecessors
+ %min = minf %in0, %out0 : f32
+ linalg.yield %min : f32
+ } -> tensor<4xf32>
+ return %red : tensor<4xf32>
+}
+
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 7fb4ecb5b0d3a..53c244716759c 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1019,7 +1019,7 @@ func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f3
func @reduce_unsupported_accumulator_kind(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
// expected-error at +1 {{'vector.reduction' op no accumulator for reduction kind: min}}
- %0 = vector.reduction "min", %arg0, %arg1 : vector<16xf32> into f32
+ %0 = vector.reduction "minf", %arg0, %arg1 : vector<16xf32> into f32
}
// -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 9ac671c632459..d5afa674274d8 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -243,13 +243,13 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
#contraction_to_scalar_max_trait = {
indexing_maps = #contraction_to_scalar_max_accesses,
iterator_types = ["reduction"],
- kind = #vector.kind<max>
+ kind = #vector.kind<maxf>
}
// CHECK-LABEL: @contraction_to_scalar_with_max
func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
// CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32
%f0 = constant 0.0: f32
- // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<max>} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
+ // CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<maxf>} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
%0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0
: vector<10xf32>, vector<10xf32> into f32
// CHECK: return %[[X]] : f32
@@ -281,7 +281,7 @@ func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32
#contraction_trait2 = {
indexing_maps = #contraction_accesses1,
iterator_types = #iterator_types1,
- kind = #vector.kind<max>
+ kind = #vector.kind<maxf>
}
// CHECK-LABEL: @contraction
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
@@ -309,7 +309,7 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
: vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
// Test contraction with "max" instead of "add".
- // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<max>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
+ // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<maxf>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
%4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
return
@@ -432,10 +432,10 @@ func @reduce_fp(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
vector.reduction "mul", %arg0 : vector<16xf32> into f32
// CHECK: vector.reduction "mul", %{{.*}}, %{{.*}} : vector<16xf32> into f32
vector.reduction "mul", %arg0, %arg1 : vector<16xf32> into f32
- // CHECK: vector.reduction "min", %{{.*}} : vector<16xf32> into f32
- vector.reduction "min", %arg0 : vector<16xf32> into f32
- // CHECK: %[[X:.*]] = vector.reduction "max", %{{.*}} : vector<16xf32> into f32
- %0 = vector.reduction "max", %arg0 : vector<16xf32> into f32
+ // CHECK: vector.reduction "minf", %{{.*}} : vector<16xf32> into f32
+ vector.reduction "minf", %arg0 : vector<16xf32> into f32
+ // CHECK: %[[X:.*]] = vector.reduction "maxf", %{{.*}} : vector<16xf32> into f32
+ %0 = vector.reduction "maxf", %arg0 : vector<16xf32> into f32
// CHECK: return %[[X]] : f32
return %0 : f32
}
@@ -446,10 +446,14 @@ func @reduce_int(%arg0: vector<16xi32>) -> i32 {
vector.reduction "add", %arg0 : vector<16xi32> into i32
// CHECK: vector.reduction "mul", %{{.*}} : vector<16xi32> into i32
vector.reduction "mul", %arg0 : vector<16xi32> into i32
- // CHECK: vector.reduction "min", %{{.*}} : vector<16xi32> into i32
- vector.reduction "min", %arg0 : vector<16xi32> into i32
- // CHECK: vector.reduction "max", %{{.*}} : vector<16xi32> into i32
- vector.reduction "max", %arg0 : vector<16xi32> into i32
+ // CHECK: vector.reduction "minui", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "minui", %arg0 : vector<16xi32> into i32
+ // CHECK: vector.reduction "minsi", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "minsi", %arg0 : vector<16xi32> into i32
+ // CHECK: vector.reduction "maxui", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "maxui", %arg0 : vector<16xi32> into i32
+ // CHECK: vector.reduction "maxsi", %{{.*}} : vector<16xi32> into i32
+ vector.reduction "maxsi", %arg0 : vector<16xi32> into i32
// CHECK: vector.reduction "and", %{{.*}} : vector<16xi32> into i32
vector.reduction "and", %arg0 : vector<16xi32> into i32
// CHECK: vector.reduction "or", %{{.*}} : vector<16xi32> into i32
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
index 9c163453ee8b8..50e5ce8aaf9cf 100644
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
@@ -12,7 +12,7 @@
#matvecmax_trait = {
indexing_maps = #matvec_accesses,
iterator_types = ["parallel", "reduction"],
- kind = #vector.kind<max>
+ kind = #vector.kind<maxf>
}
#mattransvec_accesses = [
@@ -91,10 +91,10 @@ func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<max>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<max>} : vector<2xf32>, f32
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
// CHECK: memref.store %[[T9]], %[[C]][] : memref<vector<2xf32>>
// CHECK: return
func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 91dcc2e0172f7..5f7d19fca61d4 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -18,7 +18,7 @@ func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction #vector.kind<min>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+ %0 = vector.multi_reduction #vector.kind<minf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
@@ -27,18 +27,15 @@ func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-// CHECK: %[[C0:.+]] = cmpf olt, %[[V1]], %[[V0]] : vector<2xf32>
-// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32>
+// CHECK: %[[RV01:.+]] = minf %[[V1]], %[[V0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
-// CHECK: %[[C1:.+]] = cmpf olt, %[[V2]], %[[RV01]] : vector<2xf32>
-// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32>
+// CHECK: %[[RV012:.+]] = minf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
-// CHECK: %[[C2:.+]] = cmpf olt, %[[V3]], %[[RV012]] : vector<2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = minf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
- %0 = vector.multi_reduction #vector.kind<max>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
+ %0 = vector.multi_reduction #vector.kind<maxf>, %arg0 [1] : vector<2x4xf32> to vector<2xf32>
return %0 : vector<2xf32>
}
@@ -47,14 +44,11 @@ func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-// CHECK: %[[C0:.+]] = cmpf oge, %[[V1]], %[[V0]] : vector<2xf32>
-// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32>
+// CHECK: %[[RV01:.+]] = maxf %[[V1]], %[[V0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
-// CHECK: %[[C1:.+]] = cmpf oge, %[[V2]], %[[RV01]] : vector<2xf32>
-// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32>
+// CHECK: %[[RV012:.+]] = maxf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
-// CHECK: %[[C2:.+]] = cmpf oge, %[[V3]], %[[RV012]] : vector<2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = maxf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir
index ddd3131c4e86c..cd3d9708aa9ee 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32-reassoc.mlir
@@ -27,10 +27,10 @@ func @entry() {
%1 = vector.reduction "mul", %v2 : vector<64xf32> into f32
vector.print %1 : f32
// CHECK: 6
- %2 = vector.reduction "min", %v2 : vector<64xf32> into f32
+ %2 = vector.reduction "minf", %v2 : vector<64xf32> into f32
vector.print %2 : f32
// CHECK: 1
- %3 = vector.reduction "max", %v2 : vector<64xf32> into f32
+ %3 = vector.reduction "maxf", %v2 : vector<64xf32> into f32
vector.print %3 : f32
// CHECK: 3
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir
index 4fb0d6cd8d5b9..a5befd0f5a7c3 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f32.mlir
@@ -39,10 +39,10 @@ func @entry() {
%1 = vector.reduction "mul", %v9 : vector<10xf32> into f32
vector.print %1 : f32
// CHECK: -5760
- %2 = vector.reduction "min", %v9 : vector<10xf32> into f32
+ %2 = vector.reduction "minf", %v9 : vector<10xf32> into f32
vector.print %2 : f32
// CHECK: -16
- %3 = vector.reduction "max", %v9 : vector<10xf32> into f32
+ %3 = vector.reduction "maxf", %v9 : vector<10xf32> into f32
vector.print %3 : f32
// CHECK: 5
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir
index 89f74ae875c60..144c8c7c8a3f7 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64-reassoc.mlir
@@ -27,10 +27,10 @@ func @entry() {
%1 = vector.reduction "mul", %v2 : vector<64xf64> into f64
vector.print %1 : f64
// CHECK: 6
- %2 = vector.reduction "min", %v2 : vector<64xf64> into f64
+ %2 = vector.reduction "minf", %v2 : vector<64xf64> into f64
vector.print %2 : f64
// CHECK: 1
- %3 = vector.reduction "max", %v2 : vector<64xf64> into f64
+ %3 = vector.reduction "maxf", %v2 : vector<64xf64> into f64
vector.print %3 : f64
// CHECK: 3
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir
index b83629f702a3f..2d79d4ab392bc 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-f64.mlir
@@ -39,10 +39,10 @@ func @entry() {
%1 = vector.reduction "mul", %v9 : vector<10xf64> into f64
vector.print %1 : f64
// CHECK: -5760
- %2 = vector.reduction "min", %v9 : vector<10xf64> into f64
+ %2 = vector.reduction "minf", %v9 : vector<10xf64> into f64
vector.print %2 : f64
// CHECK: -16
- %3 = vector.reduction "max", %v9 : vector<10xf64> into f64
+ %3 = vector.reduction "maxf", %v9 : vector<10xf64> into f64
vector.print %3 : f64
// CHECK: 5
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i32.mlir
index e3e3567e4ed3b..d1432d7f689b1 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i32.mlir
@@ -39,10 +39,10 @@ func @entry() {
%1 = vector.reduction "mul", %v9 : vector<10xi32> into i32
vector.print %1 : i32
// CHECK: -1228800
- %2 = vector.reduction "min", %v9 : vector<10xi32> into i32
+ %2 = vector.reduction "minsi", %v9 : vector<10xi32> into i32
vector.print %2 : i32
// CHECK: -80
- %3 = vector.reduction "max", %v9 : vector<10xi32> into i32
+ %3 = vector.reduction "maxsi", %v9 : vector<10xi32> into i32
vector.print %3 : i32
// CHECK: 5
%4 = vector.reduction "and", %v9 : vector<10xi32> into i32
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i4.mlir
index fafdb4fab7444..75e61e7a7e346 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i4.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i4.mlir
@@ -20,11 +20,11 @@ func @entry() {
vector.print %1 : i4
// CHECK: 0
- %2 = vector.reduction "min", %v : vector<24xi4> into i4
+ %2 = vector.reduction "minsi", %v : vector<24xi4> into i4
vector.print %2 : i4
// CHECK: -8
- %3 = vector.reduction "max", %v : vector<24xi4> into i4
+ %3 = vector.reduction "maxsi", %v : vector<24xi4> into i4
vector.print %3 : i4
// CHECK: 7
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i64.mlir
index 4bb148f13a25a..162707d6cab6b 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-i64.mlir
@@ -39,10 +39,10 @@ func @entry() {
%1 = vector.reduction "mul", %v9 : vector<10xi64> into i64
vector.print %1 : i64
// CHECK: -1228800
- %2 = vector.reduction "min", %v9 : vector<10xi64> into i64
+ %2 = vector.reduction "minsi", %v9 : vector<10xi64> into i64
vector.print %2 : i64
// CHECK: -80
- %3 = vector.reduction "max", %v9 : vector<10xi64> into i64
+ %3 = vector.reduction "maxsi", %v9 : vector<10xi64> into i64
vector.print %3 : i64
// CHECK: 5
%4 = vector.reduction "and", %v9 : vector<10xi64> into i64
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-si4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-si4.mlir
index 48144ed54a174..e3d4c48a90c19 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-si4.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-si4.mlir
@@ -19,11 +19,11 @@ func @entry() {
vector.print %1 : si4
// CHECK: 0
- %2 = vector.reduction "min", %v : vector<16xsi4> into si4
+ %2 = vector.reduction "minsi", %v : vector<16xsi4> into si4
vector.print %2 : si4
// CHECK: -8
- %3 = vector.reduction "max", %v : vector<16xsi4> into si4
+ %3 = vector.reduction "maxsi", %v : vector<16xsi4> into si4
vector.print %3 : si4
// CHECK: 7
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-ui4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-ui4.mlir
index a4d5fbe6621ee..0b048d01849c4 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-ui4.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/test-reductions-ui4.mlir
@@ -19,11 +19,11 @@ func @entry() {
vector.print %1 : ui4
// CHECK: 0
- %2 = vector.reduction "min", %v : vector<16xui4> into ui4
+ %2 = vector.reduction "minui", %v : vector<16xui4> into ui4
vector.print %2 : ui4
// CHECK: 0
- %3 = vector.reduction "max", %v : vector<16xui4> into ui4
+ %3 = vector.reduction "maxui", %v : vector<16xui4> into ui4
vector.print %3 : ui4
// CHECK: 15
More information about the Mlir-commits
mailing list