[Mlir-commits] [mlir] 9b1d90e - [mlir] Move min/max ops from Std to Arith.
Alexander Belyaev
llvmlistbot at llvm.org
Mon Nov 15 04:19:49 PST 2021
Author: Alexander Belyaev
Date: 2021-11-15T13:19:17+01:00
New Revision: 9b1d90e8ac9c95ba55a0b949118377f31e6703f8
URL: https://github.com/llvm/llvm-project/commit/9b1d90e8ac9c95ba55a0b949118377f31e6703f8
DIFF: https://github.com/llvm/llvm-project/commit/9b1d90e8ac9c95ba55a0b949118377f31e6703f8.diff
LOG: [mlir] Move min/max ops from Std to Arith.
Differential Revision: https://reviews.llvm.org/D113881
Added:
Modified:
mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Analysis/AffineAnalysis.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
mlir/test/Dialect/Arithmetic/canonicalize.mlir
mlir/test/Dialect/Arithmetic/expand-ops.mlir
mlir/test/Dialect/Arithmetic/ops.mlir
mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/Dialect/Standard/canonicalize.mlir
mlir/test/Dialect/Standard/expand-ops.mlir
mlir/test/Dialect/Standard/ops.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
mlir/test/python/integration/dialects/linalg/opsrun.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 2e90455daaa1..5123215ba87a 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -620,6 +620,93 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// MaxFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf"> {
+ let summary = "floating-point maximum operation";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `arith.maxf` ssa-use `,` ssa-use `:` type
+ ```
+
+ Returns the maximum of the two arguments, treating -0.0 as less than +0.0.
+ If one of the arguments is NaN, then the result is also NaN.
+
+ Example:
+
+ ```mlir
+ // Scalar floating-point maximum.
+ %a = arith.maxf %b, %c : f64
+ ```
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// MaxSIOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MaxSIOp : Arith_IntBinaryOp<"maxsi"> {
+ let summary = "signed integer maximum operation";
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// MaxUIOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MaxUIOp : Arith_IntBinaryOp<"maxui"> {
+ let summary = "unsigned integer maximum operation";
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// MinFOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MinFOp : Arith_FloatBinaryOp<"minf"> {
+ let summary = "floating-point minimum operation";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `arith.minf` ssa-use `,` ssa-use `:` type
+ ```
+
+ Returns the minimum of the two arguments, treating -0.0 as less than +0.0.
+ If one of the arguments is NaN, then the result is also NaN.
+
+ Example:
+
+ ```mlir
+ // Scalar floating-point minimum.
+ %a = arith.minf %b, %c : f64
+ ```
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// MinSIOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MinSIOp : Arith_IntBinaryOp<"minsi"> {
+ let summary = "signed integer minimum operation";
+ let hasFolder = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// MinUIOp
+//===----------------------------------------------------------------------===//
+
+def Arith_MinUIOp : Arith_IntBinaryOp<"minui"> {
+ let summary = "unsigned integer minimum operation";
+ let hasFolder = 1;
+}
+
+
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 985bd8207ffd..9f87333ca672 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -665,156 +665,6 @@ def ConstantOp : Std_Op<"constant",
let hasFolder = 1;
}
-//===----------------------------------------------------------------------===//
-// MaxFOp
-//===----------------------------------------------------------------------===//
-
-def MaxFOp : FloatBinaryOp<"maxf"> {
- let summary = "floating-point maximum operation";
- let description = [{
- Syntax:
-
- ```
- operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type
- ```
-
- Returns the maximum of the two arguments, treating -0.0 as less than +0.0.
- If one of the arguments is NaN, then the result is also NaN.
-
- Example:
-
- ```mlir
- // Scalar floating-point maximum.
- %a = maxf %b, %c : f64
- ```
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// MaxSIOp
-//===----------------------------------------------------------------------===//
-
-def MaxSIOp : IntBinaryOp<"maxsi"> {
- let summary = "signed integer maximum operation";
- let description = [{
- Syntax:
-
- ```
- operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type
- ```
-
- Returns the larger of %a and %b comparing the values as signed integers.
-
- Example:
-
- ```mlir
- // Scalar signed integer maximum.
- %a = maxsi %b, %c : i64
- ```
- }];
- let hasFolder = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// MaxUIOp
-//===----------------------------------------------------------------------===//
-
-def MaxUIOp : IntBinaryOp<"maxui"> {
- let summary = "unsigned integer maximum operation";
- let description = [{
- Syntax:
-
- ```
- operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type
- ```
-
- Returns the larger of %a and %b comparing the values as unsigned integers.
-
- Example:
-
- ```mlir
- // Scalar unsigned integer maximum.
- %a = maxui %b, %c : i64
- ```
- }];
- let hasFolder = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// MinFOp
-//===----------------------------------------------------------------------===//
-
-def MinFOp : FloatBinaryOp<"minf"> {
- let summary = "floating-point minimum operation";
- let description = [{
- Syntax:
-
- ```
- operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type
- ```
-
- Returns the minimum of the two arguments, treating -0.0 as less than +0.0.
- If one of the arguments is NaN, then the result is also NaN.
-
- Example:
-
- ```mlir
- // Scalar floating-point minimum.
- %a = minf %b, %c : f64
- ```
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// MinSIOp
-//===----------------------------------------------------------------------===//
-
-def MinSIOp : IntBinaryOp<"minsi"> {
- let summary = "signed integer minimum operation";
- let description = [{
- Syntax:
-
- ```
- operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type
- ```
-
- Returns the smaller of %a and %b comparing the values as signed integers.
-
- Example:
-
- ```mlir
- // Scalar signed integer minimum.
- %a = minsi %b, %c : i64
- ```
- }];
- let hasFolder = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// MinUIOp
-//===----------------------------------------------------------------------===//
-
-def MinUIOp : IntBinaryOp<"minui"> {
- let summary = "unsigned integer minimum operation";
- let description = [{
- Syntax:
-
- ```
- operation ::= ssa-id `=` `minui` ssa-use `,` ssa-use `:` type
- ```
-
- Returns the smaller of %a and %b comparing the values as unsigned integers.
-
- Example:
-
- ```mlir
- // Scalar unsigned integer minimum.
- %a = minui %b, %c : i64
- ```
- }];
- let hasFolder = 1;
-}
-
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index 252b2b5fe0f9..873d9b9aa3b4 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -58,12 +58,12 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
.Case([](arith::MulFOp) { return AtomicRMWKind::mulf; })
.Case([](arith::AddIOp) { return AtomicRMWKind::addi; })
.Case([](arith::MulIOp) { return AtomicRMWKind::muli; })
- .Case([](MinFOp) { return AtomicRMWKind::minf; })
- .Case([](MaxFOp) { return AtomicRMWKind::maxf; })
- .Case([](MinSIOp) { return AtomicRMWKind::mins; })
- .Case([](MaxSIOp) { return AtomicRMWKind::maxs; })
- .Case([](MinUIOp) { return AtomicRMWKind::minu; })
- .Case([](MaxUIOp) { return AtomicRMWKind::maxu; })
+ .Case([](arith::MinFOp) { return AtomicRMWKind::minf; })
+ .Case([](arith::MaxFOp) { return AtomicRMWKind::maxf; })
+ .Case([](arith::MinSIOp) { return AtomicRMWKind::mins; })
+ .Case([](arith::MaxSIOp) { return AtomicRMWKind::maxs; })
+ .Case([](arith::MinUIOp) { return AtomicRMWKind::minu; })
+ .Case([](arith::MaxUIOp) { return AtomicRMWKind::maxu; })
.Default([](Operation *) -> Optional<AtomicRMWKind> {
// TODO: AtomicRMW supports other kinds of reductions this is
// currently not detecting, add those when the need arises.
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 87d57080ed80..fea7c7ca6169 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -230,12 +230,12 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
patterns.add<
// Unary and binary patterns
- spirv::UnaryAndBinaryOpPattern<MaxFOp, spirv::GLSLFMaxOp>,
- spirv::UnaryAndBinaryOpPattern<MaxSIOp, spirv::GLSLSMaxOp>,
- spirv::UnaryAndBinaryOpPattern<MaxUIOp, spirv::GLSLUMaxOp>,
- spirv::UnaryAndBinaryOpPattern<MinFOp, spirv::GLSLFMinOp>,
- spirv::UnaryAndBinaryOpPattern<MinSIOp, spirv::GLSLSMinOp>,
- spirv::UnaryAndBinaryOpPattern<MinUIOp, spirv::GLSLUMinOp>,
+ spirv::UnaryAndBinaryOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
+ spirv::UnaryAndBinaryOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
+ spirv::UnaryAndBinaryOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
+ spirv::UnaryAndBinaryOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
+ spirv::UnaryAndBinaryOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
+ spirv::UnaryAndBinaryOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>,
ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern,
CondBranchOpPattern>(typeConverter, context);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 18f472634480..84e9cae77dd1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -126,9 +126,9 @@ convertElementwiseOpToMMA(Operation *op) {
return gpu::MMAElementwiseOp::ADDF;
if (isa<arith::MulFOp>(op))
return gpu::MMAElementwiseOp::MULF;
- if (isa<MaxFOp>(op))
+ if (isa<arith::MaxFOp>(op))
return gpu::MMAElementwiseOp::MAXF;
- if (isa<MinFOp>(op))
+ if (isa<arith::MinFOp>(op))
return gpu::MMAElementwiseOp::MINF;
if (isa<arith::DivFOp>(op))
return gpu::MMAElementwiseOp::DIVF;
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 19b5bd05ab08..29e938964363 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -561,6 +561,106 @@ OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
operands, [](APFloat a, APFloat b) { return a - b; });
}
+//===----------------------------------------------------------------------===//
+// MaxSIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
+
+ // maxsi(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ APInt intValue;
+ // maxsi(x,MAX_INT) -> MAX_INT
+ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
+ intValue.isMaxSignedValue())
+ return getRhs();
+
+ // maxsi(x, MIN_INT) -> x
+ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
+ intValue.isMinSignedValue())
+ return getLhs();
+
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); });
+}
+
+//===----------------------------------------------------------------------===//
+// MaxUIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
+
+ // maxui(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ APInt intValue;
+ // maxui(x,MAX_INT) -> MAX_INT
+ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
+ return getRhs();
+
+ // maxui(x, MIN_INT) -> x
+ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
+ return getLhs();
+
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); });
+}
+
+//===----------------------------------------------------------------------===//
+// MinSIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
+
+ // minsi(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ APInt intValue;
+ // minsi(x,MIN_INT) -> MIN_INT
+ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
+ intValue.isMinSignedValue())
+ return getRhs();
+
+ // minsi(x, MAX_INT) -> x
+ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
+ intValue.isMaxSignedValue())
+ return getLhs();
+
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); });
+}
+
+//===----------------------------------------------------------------------===//
+// MinUIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "binary operation takes two operands");
+
+ // minui(x,x) -> x
+ if (getLhs() == getRhs())
+ return getRhs();
+
+ APInt intValue;
+ // minui(x,MIN_INT) -> MIN_INT
+ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
+ return getRhs();
+
+ // minui(x, MAX_INT) -> x
+ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
+ return getLhs();
+
+ return constFoldBinaryOp<IntegerAttr>(
+ operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); });
+}
+
//===----------------------------------------------------------------------===//
// MulFOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
index 87e41bb1c2e2..56487cf7e16b 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp
@@ -8,6 +8,7 @@
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
@@ -147,6 +148,50 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
}
};
+template <typename OpTy, arith::CmpFPredicate pred>
+struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const final {
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Location loc = op.getLoc();
+ Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
+ Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
+
+ auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
+ Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
+ lhs, rhs);
+
+ Value nan = rewriter.create<arith::ConstantFloatOp>(
+ loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
+ if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
+ nan = rewriter.create<SplatOp>(loc, vectorType, nan);
+
+ rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
+ return success();
+ }
+};
+
+template <typename OpTy, arith::CmpIPredicate pred>
+struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+ LogicalResult matchAndRewrite(OpTy op,
+ PatternRewriter &rewriter) const final {
+ Value lhs = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Location loc = op.getLoc();
+ Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
+ rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs);
+ return success();
+ }
+};
+
struct ArithmeticExpandOpsPass
: public ArithmeticExpandOpsBase<ArithmeticExpandOpsPass> {
void runOnFunction() override {
@@ -156,9 +201,19 @@ struct ArithmeticExpandOpsPass
arith::populateArithmeticExpandOpsPatterns(patterns);
target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>();
- target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
- arith::FloorDivSIOp>();
-
+ // clang-format off
+ target.addIllegalOp<
+ arith::CeilDivSIOp,
+ arith::CeilDivUIOp,
+ arith::FloorDivSIOp,
+ arith::MaxFOp,
+ arith::MaxSIOp,
+ arith::MaxUIOp,
+ arith::MinFOp,
+ arith::MinSIOp,
+ arith::MinUIOp
+ >();
+ // clang-format on
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
@@ -169,9 +224,19 @@ struct ArithmeticExpandOpsPass
void mlir::arith::populateArithmeticExpandOpsPatterns(
RewritePatternSet &patterns) {
- patterns
- .add<CeilDivUIOpConverter, CeilDivSIOpConverter, FloorDivSIOpConverter>(
- patterns.getContext());
+ // clang-format off
+ patterns.add<
+ CeilDivSIOpConverter,
+ CeilDivUIOpConverter,
+ FloorDivSIOpConverter,
+ MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::OGT>,
+ MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::OLT>,
+ MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
+ MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
+ MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
+ MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>
+ >(patterns.getContext());
+ // clang-format on
}
std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 3d2f42d174fe..e792e4109aa0 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -283,36 +283,36 @@ class RegionBuilderHelper {
Value applyfn__max(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
- return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
+ return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
- return builder.create<MaxSIOp>(lhs.getLoc(), lhs, rhs);
+ return builder.create<arith::MaxSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__max_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
- return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
+ return builder.create<arith::MaxFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
- return builder.create<MaxUIOp>(lhs.getLoc(), lhs, rhs);
+ return builder.create<arith::MaxUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__min(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
- return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
+ return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
- return builder.create<MinSIOp>(lhs.getLoc(), lhs, rhs);
+ return builder.create<arith::MinSIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
Value applyfn__min_unsigned(Value lhs, Value rhs) {
OpBuilder builder = getBuilder();
if (isFloatingPoint(lhs))
- return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
+ return builder.create<arith::MinFOp>(lhs.getLoc(), lhs, rhs);
if (isInteger(lhs))
- return builder.create<MinUIOp>(lhs.getLoc(), lhs, rhs);
+ return builder.create<arith::MinUIOp>(lhs.getLoc(), lhs, rhs);
llvm_unreachable("unsupported non numeric type");
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d8f1527a3306..0b856c2e5678 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -129,10 +129,12 @@ getKindForOp(Operation *reductionOp) {
.Case<arith::AddIOp, arith::AddFOp>(
[&](auto op) { return vector::CombiningKind::ADD; })
.Case<arith::AndIOp>([&](auto op) { return vector::CombiningKind::AND; })
- .Case<MaxSIOp>([&](auto op) { return vector::CombiningKind::MAXSI; })
- .Case<MaxFOp>([&](auto op) { return vector::CombiningKind::MAXF; })
- .Case<MinSIOp>([&](auto op) { return vector::CombiningKind::MINSI; })
- .Case<MinFOp>([&](auto op) { return vector::CombiningKind::MINF; })
+ .Case<arith::MaxSIOp>(
+ [&](auto op) { return vector::CombiningKind::MAXSI; })
+ .Case<arith::MaxFOp>([&](auto op) { return vector::CombiningKind::MAXF; })
+ .Case<arith::MinSIOp>(
+ [&](auto op) { return vector::CombiningKind::MINSI; })
+ .Case<arith::MinFOp>([&](auto op) { return vector::CombiningKind::MINF; })
.Case<arith::MulIOp, arith::MulFOp>(
[&](auto op) { return vector::CombiningKind::MUL; })
.Case<arith::OrIOp>([&](auto op) { return vector::CombiningKind::OR; })
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 6bc2d7fd436d..17a69b03671c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -251,17 +251,17 @@ Value mlir::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
case AtomicRMWKind::muli:
return builder.create<arith::MulIOp>(loc, lhs, rhs);
case AtomicRMWKind::maxf:
- return builder.create<MaxFOp>(loc, lhs, rhs);
+ return builder.create<arith::MaxFOp>(loc, lhs, rhs);
case AtomicRMWKind::minf:
- return builder.create<MinFOp>(loc, lhs, rhs);
+ return builder.create<arith::MinFOp>(loc, lhs, rhs);
case AtomicRMWKind::maxs:
- return builder.create<MaxSIOp>(loc, lhs, rhs);
+ return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
case AtomicRMWKind::mins:
- return builder.create<MinSIOp>(loc, lhs, rhs);
+ return builder.create<arith::MinSIOp>(loc, lhs, rhs);
case AtomicRMWKind::maxu:
- return builder.create<MaxUIOp>(loc, lhs, rhs);
+ return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
case AtomicRMWKind::minu:
- return builder.create<MinUIOp>(loc, lhs, rhs);
+ return builder.create<arith::MinUIOp>(loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
@@ -921,106 +921,6 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
return value.isa<UnitAttr>();
}
-//===----------------------------------------------------------------------===//
-// MaxSIOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
- // maxsi(x,x) -> x
- if (getLhs() == getRhs())
- return getRhs();
-
- APInt intValue;
- // maxsi(x,MAX_INT) -> MAX_INT
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
- intValue.isMaxSignedValue())
- return getRhs();
-
- // maxsi(x, MIN_INT) -> x
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
- intValue.isMinSignedValue())
- return getLhs();
-
- return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); });
-}
-
-//===----------------------------------------------------------------------===//
-// MaxUIOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
- // maxui(x,x) -> x
- if (getLhs() == getRhs())
- return getRhs();
-
- APInt intValue;
- // maxui(x,MAX_INT) -> MAX_INT
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
- return getRhs();
-
- // maxui(x, MIN_INT) -> x
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
- return getLhs();
-
- return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); });
-}
-
-//===----------------------------------------------------------------------===//
-// MinSIOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
- // minsi(x,x) -> x
- if (getLhs() == getRhs())
- return getRhs();
-
- APInt intValue;
- // minsi(x,MIN_INT) -> MIN_INT
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
- intValue.isMinSignedValue())
- return getRhs();
-
- // minsi(x, MAX_INT) -> x
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
- intValue.isMaxSignedValue())
- return getLhs();
-
- return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); });
-}
-
-//===----------------------------------------------------------------------===//
-// MinUIOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "binary operation takes two operands");
-
- // minui(x,x) -> x
- if (getLhs() == getRhs())
- return getRhs();
-
- APInt intValue;
- // minui(x,MIN_INT) -> MIN_INT
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
- return getRhs();
-
- // minui(x, MAX_INT) -> x
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
- return getLhs();
-
- return constFoldBinaryOp<IntegerAttr>(
- operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); });
-}
-
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index 4955b83b80bb..63ac39fef50d 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -119,64 +119,16 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
}
};
-template <typename OpTy, arith::CmpFPredicate pred>
-struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpTy op,
- PatternRewriter &rewriter) const final {
- Value lhs = op.getLhs();
- Value rhs = op.getRhs();
-
- Location loc = op.getLoc();
- Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
- Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
-
- auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
- Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
- lhs, rhs);
-
- Value nan = rewriter.create<arith::ConstantFloatOp>(
- loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
- if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
- nan = rewriter.create<SplatOp>(loc, vectorType, nan);
-
- rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
- return success();
- }
-};
-
-template <typename OpTy, arith::CmpIPredicate pred>
-struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(OpTy op,
- PatternRewriter &rewriter) const final {
- Value lhs = op.getLhs();
- Value rhs = op.getRhs();
-
- Location loc = op.getLoc();
- Value cmp = rewriter.create<arith::CmpIOp>(loc, pred, lhs, rhs);
- rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs);
- return success();
- }
-};
-
struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
void runOnFunction() override {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
populateStdExpandOpsPatterns(patterns);
- arith::populateArithmeticExpandOpsPatterns(patterns);
-
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
StandardOpsDialect>();
- target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
- arith::FloorDivSIOp>();
target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
return op.getKind() != AtomicRMWKind::maxf &&
op.getKind() != AtomicRMWKind::minf;
@@ -184,16 +136,6 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !op.shape().getType().cast<MemRefType>().hasStaticShape();
});
- // clang-format off
- target.addIllegalOp<
- MaxFOp,
- MaxSIOp,
- MaxUIOp,
- MinFOp,
- MinSIOp,
- MinUIOp
- >();
- // clang-format on
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
@@ -203,18 +145,8 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
} // namespace
void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) {
- // clang-format off
- patterns.add<
- AtomicRMWOpConverter,
- MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::OGT>,
- MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::OLT>,
- MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
- MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
- MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
- MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
- MemRefReshapeOpConverter
- >(patterns.getContext());
- // clang-format on
+ patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
+ patterns.getContext());
}
std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {
diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
index 637c8729f06f..84102f0fe2a5 100644
--- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp
@@ -255,22 +255,22 @@ struct TwoDimMultiReductionToElementWise
result = rewriter.create<arith::MulFOp>(loc, operand, result);
break;
case vector::CombiningKind::MINUI:
- result = rewriter.create<MinUIOp>(loc, operand, result);
+ result = rewriter.create<arith::MinUIOp>(loc, operand, result);
break;
case vector::CombiningKind::MINSI:
- result = rewriter.create<MinSIOp>(loc, operand, result);
+ result = rewriter.create<arith::MinSIOp>(loc, operand, result);
break;
case vector::CombiningKind::MINF:
- result = rewriter.create<MinFOp>(loc, operand, result);
+ result = rewriter.create<arith::MinFOp>(loc, operand, result);
break;
case vector::CombiningKind::MAXUI:
- result = rewriter.create<MaxUIOp>(loc, operand, result);
+ result = rewriter.create<arith::MaxUIOp>(loc, operand, result);
break;
case vector::CombiningKind::MAXSI:
- result = rewriter.create<MaxSIOp>(loc, operand, result);
+ result = rewriter.create<arith::MaxSIOp>(loc, operand, result);
break;
case vector::CombiningKind::MAXF:
- result = rewriter.create<MaxFOp>(loc, operand, result);
+ result = rewriter.create<arith::MaxFOp>(loc, operand, result);
break;
case vector::CombiningKind::AND:
result = rewriter.create<arith::AndIOp>(loc, operand, result);
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index df32b15a872a..3fb6d4c50e9b 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -862,16 +862,16 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
combinedResult = rewriter.create<arith::MulIOp>(loc, mul, acc);
break;
case CombiningKind::MINUI:
- combinedResult = rewriter.create<MinUIOp>(loc, mul, acc);
+ combinedResult = rewriter.create<arith::MinUIOp>(loc, mul, acc);
break;
case CombiningKind::MINSI:
- combinedResult = rewriter.create<MinSIOp>(loc, mul, acc);
+ combinedResult = rewriter.create<arith::MinSIOp>(loc, mul, acc);
break;
case CombiningKind::MAXUI:
- combinedResult = rewriter.create<MaxUIOp>(loc, mul, acc);
+ combinedResult = rewriter.create<arith::MaxUIOp>(loc, mul, acc);
break;
case CombiningKind::MAXSI:
- combinedResult = rewriter.create<MaxSIOp>(loc, mul, acc);
+ combinedResult = rewriter.create<arith::MaxSIOp>(loc, mul, acc);
break;
case CombiningKind::AND:
combinedResult = rewriter.create<arith::AndIOp>(loc, mul, acc);
@@ -910,10 +910,10 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
combinedResult = rewriter.create<arith::MulFOp>(loc, mul, acc);
break;
case CombiningKind::MINF:
- combinedResult = rewriter.create<MinFOp>(loc, mul, acc);
+ combinedResult = rewriter.create<arith::MinFOp>(loc, mul, acc);
break;
case CombiningKind::MAXF:
- combinedResult = rewriter.create<MaxFOp>(loc, mul, acc);
+ combinedResult = rewriter.create<arith::MaxFOp>(loc, mul, acc);
break;
case CombiningKind::ADD: // Already handled this special case above.
case CombiningKind::AND: // Only valid for integer types.
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
index 2ece9eb92bbe..933c26ad9a7b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
@@ -334,30 +334,30 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value:
def _eval_max(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return std.MaxFOp(lhs, rhs).result
+ return arith.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return std.MaxSIOp(lhs, rhs).result
+ return arith.MaxSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return std.MaxFOp(lhs, rhs).result
+ return arith.MaxFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return std.MaxUIOp(lhs, rhs).result
+ return arith.MaxUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return std.MinFOp(lhs, rhs).result
+ return arith.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return std.MinSIOp(lhs, rhs).result
+ return arith.MinSIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
if _is_floating_point_type(lhs.type):
- return std.MinFOp(lhs, rhs).result
+ return arith.MinFOp(lhs, rhs).result
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
- return std.MinUIOp(lhs, rhs).result
+ return arith.MinUIOp(lhs, rhs).result
raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index b8d9966c9a5b..db11f41c11e4 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -25,13 +25,13 @@ func @int32_scalar(%lhs: i32, %rhs: i32) {
// CHECK: spv.UMod %{{.*}}, %{{.*}}: i32
%5 = arith.remui %lhs, %rhs: i32
// CHECK: spv.GLSL.SMax %{{.*}}, %{{.*}}: i32
- %6 = maxsi %lhs, %rhs : i32
+ %6 = arith.maxsi %lhs, %rhs : i32
// CHECK: spv.GLSL.UMax %{{.*}}, %{{.*}}: i32
- %7 = maxui %lhs, %rhs : i32
+ %7 = arith.maxui %lhs, %rhs : i32
// CHECK: spv.GLSL.SMin %{{.*}}, %{{.*}}: i32
- %8 = minsi %lhs, %rhs : i32
+ %8 = arith.minsi %lhs, %rhs : i32
// CHECK: spv.GLSL.UMin %{{.*}}, %{{.*}}: i32
- %9 = minui %lhs, %rhs : i32
+ %9 = arith.minui %lhs, %rhs : i32
return
}
@@ -76,9 +76,9 @@ func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
// CHECK: spv.FRem %{{.*}}, %{{.*}}: f32
%4 = arith.remf %lhs, %rhs: f32
// CHECK: spv.GLSL.FMax %{{.*}}, %{{.*}}: f32
- %5 = maxf %lhs, %rhs: f32
+ %5 = arith.maxf %lhs, %rhs: f32
// CHECK: spv.GLSL.FMin %{{.*}}, %{{.*}}: f32
- %6 = minf %lhs, %rhs: f32
+ %6 = arith.minf %lhs, %rhs: f32
return
}
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
index f82640e5daa3..024184408ae8 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
@@ -34,7 +34,7 @@ func @vecdim_reduction_minf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
- %min = minf %red_iter, %ld : f32
+ %min = arith.minf %red_iter, %ld : f32
affine.yield %min : f32
}
affine.store %final_red, %out[%i] : memref<256xf32>
@@ -47,7 +47,7 @@ func @vecdim_reduction_minf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
// CHECK: %[[vmax:.*]] = arith.constant dense<0x7F800000> : vector<128xf32>
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmax]]) -> (vector<128xf32>) {
// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32>
-// CHECK: %[[min:.*]] = minf %[[red_iter]], %[[ld]] : vector<128xf32>
+// CHECK: %[[min:.*]] = arith.minf %[[red_iter]], %[[ld]] : vector<128xf32>
// CHECK: affine.yield %[[min]] : vector<128xf32>
// CHECK: }
// CHECK: %[[final_min:.*]] = vector.reduction "minf", %[[vred:.*]] : vector<128xf32> into f32
@@ -61,7 +61,7 @@ func @vecdim_reduction_maxf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
- %max = maxf %red_iter, %ld : f32
+ %max = arith.maxf %red_iter, %ld : f32
affine.yield %max : f32
}
affine.store %final_red, %out[%i] : memref<256xf32>
@@ -74,7 +74,7 @@ func @vecdim_reduction_maxf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
// CHECK: %[[vmin:.*]] = arith.constant dense<0xFF800000> : vector<128xf32>
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmin]]) -> (vector<128xf32>) {
// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32>
-// CHECK: %[[max:.*]] = maxf %[[red_iter]], %[[ld]] : vector<128xf32>
+// CHECK: %[[max:.*]] = arith.maxf %[[red_iter]], %[[ld]] : vector<128xf32>
// CHECK: affine.yield %[[max]] : vector<128xf32>
// CHECK: }
// CHECK: %[[final_max:.*]] = vector.reduction "maxf", %[[vred:.*]] : vector<128xf32> into f32
@@ -88,7 +88,7 @@ func @vecdim_reduction_minsi(%in: memref<256x512xi32>, %out: memref<256xi32>) {
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) {
%ld = affine.load %in[%i, %j] : memref<256x512xi32>
- %min = minsi %red_iter, %ld : i32
+ %min = arith.minsi %red_iter, %ld : i32
affine.yield %min : i32
}
affine.store %final_red, %out[%i] : memref<256xi32>
@@ -101,7 +101,7 @@ func @vecdim_reduction_minsi(%in: memref<256x512xi32>, %out: memref<256xi32>) {
// CHECK: %[[vmax:.*]] = arith.constant dense<2147483647> : vector<128xi32>
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmax]]) -> (vector<128xi32>) {
// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xi32>, vector<128xi32>
-// CHECK: %[[min:.*]] = minsi %[[red_iter]], %[[ld]] : vector<128xi32>
+// CHECK: %[[min:.*]] = arith.minsi %[[red_iter]], %[[ld]] : vector<128xi32>
// CHECK: affine.yield %[[min]] : vector<128xi32>
// CHECK: }
// CHECK: %[[final_min:.*]] = vector.reduction "minsi", %[[vred:.*]] : vector<128xi32> into i32
@@ -115,7 +115,7 @@ func @vecdim_reduction_maxsi(%in: memref<256x512xi32>, %out: memref<256xi32>) {
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) {
%ld = affine.load %in[%i, %j] : memref<256x512xi32>
- %max = maxsi %red_iter, %ld : i32
+ %max = arith.maxsi %red_iter, %ld : i32
affine.yield %max : i32
}
affine.store %final_red, %out[%i] : memref<256xi32>
@@ -128,7 +128,7 @@ func @vecdim_reduction_maxsi(%in: memref<256x512xi32>, %out: memref<256xi32>) {
// CHECK: %[[vmin:.*]] = arith.constant dense<-2147483648> : vector<128xi32>
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmin]]) -> (vector<128xi32>) {
// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xi32>, vector<128xi32>
-// CHECK: %[[max:.*]] = maxsi %[[red_iter]], %[[ld]] : vector<128xi32>
+// CHECK: %[[max:.*]] = arith.maxsi %[[red_iter]], %[[ld]] : vector<128xi32>
// CHECK: affine.yield %[[max]] : vector<128xi32>
// CHECK: }
// CHECK: %[[final_max:.*]] = vector.reduction "maxsi", %[[vred:.*]] : vector<128xi32> into i32
@@ -142,7 +142,7 @@ func @vecdim_reduction_minui(%in: memref<256x512xi32>, %out: memref<256xi32>) {
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) {
%ld = affine.load %in[%i, %j] : memref<256x512xi32>
- %min = minui %red_iter, %ld : i32
+ %min = arith.minui %red_iter, %ld : i32
affine.yield %min : i32
}
affine.store %final_red, %out[%i] : memref<256xi32>
@@ -155,7 +155,7 @@ func @vecdim_reduction_minui(%in: memref<256x512xi32>, %out: memref<256xi32>) {
// CHECK: %[[vmax:.*]] = arith.constant dense<-1> : vector<128xi32>
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmax]]) -> (vector<128xi32>) {
// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xi32>, vector<128xi32>
-// CHECK: %[[min:.*]] = minui %[[red_iter]], %[[ld]] : vector<128xi32>
+// CHECK: %[[min:.*]] = arith.minui %[[red_iter]], %[[ld]] : vector<128xi32>
// CHECK: affine.yield %[[min]] : vector<128xi32>
// CHECK: }
// CHECK: %[[final_min:.*]] = vector.reduction "minui", %[[vred:.*]] : vector<128xi32> into i32
@@ -169,7 +169,7 @@ func @vecdim_reduction_maxui(%in: memref<256x512xi32>, %out: memref<256xi32>) {
affine.for %i = 0 to 256 {
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) {
%ld = affine.load %in[%i, %j] : memref<256x512xi32>
- %max = maxui %red_iter, %ld : i32
+ %max = arith.maxui %red_iter, %ld : i32
affine.yield %max : i32
}
affine.store %final_red, %out[%i] : memref<256xi32>
@@ -182,7 +182,7 @@ func @vecdim_reduction_maxui(%in: memref<256x512xi32>, %out: memref<256xi32>) {
// CHECK: %[[vmin:.*]] = arith.constant dense<0> : vector<128xi32>
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmin]]) -> (vector<128xi32>) {
// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xi32>, vector<128xi32>
-// CHECK: %[[max:.*]] = maxui %[[red_iter]], %[[ld]] : vector<128xi32>
+// CHECK: %[[max:.*]] = arith.maxui %[[red_iter]], %[[ld]] : vector<128xi32>
// CHECK: affine.yield %[[max]] : vector<128xi32>
// CHECK: }
// CHECK: %[[final_max:.*]] = vector.reduction "maxui", %[[vred:.*]] : vector<128xi32> into i32
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 84b74df49c59..328ed1d028c7 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -386,3 +386,75 @@ func @bitcastOfBitcast(%arg : i16) -> i16 {
%res = arith.bitcast %bf : bf16 to i16
return %res : i16
}
+
+// -----
+
+// CHECK-LABEL: test_maxsi
+// CHECK: %[[C0:.+]] = arith.constant 42
+// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127
+// CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]]
+// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
+func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) {
+ %maxIntCst = arith.constant 127 : i8
+ %minIntCst = arith.constant -128 : i8
+ %c0 = arith.constant 42 : i8
+ %0 = arith.maxsi %arg0, %arg0 : i8
+ %1 = arith.maxsi %arg0, %maxIntCst : i8
+ %2 = arith.maxsi %arg0, %minIntCst : i8
+ %3 = arith.maxsi %arg0, %c0 : i8
+ return %0, %1, %2, %3: i8, i8, i8, i8
+}
+
+// -----
+
+// CHECK-LABEL: test_maxui
+// CHECK: %[[C0:.+]] = arith.constant 42
+// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1
+// CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]]
+// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
+func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) {
+ %maxIntCst = arith.constant 255 : i8
+ %minIntCst = arith.constant 0 : i8
+ %c0 = arith.constant 42 : i8
+ %0 = arith.maxui %arg0, %arg0 : i8
+ %1 = arith.maxui %arg0, %maxIntCst : i8
+ %2 = arith.maxui %arg0, %minIntCst : i8
+ %3 = arith.maxui %arg0, %c0 : i8
+ return %0, %1, %2, %3: i8, i8, i8, i8
+}
+
+// -----
+
+// CHECK-LABEL: test_minsi
+// CHECK: %[[C0:.+]] = arith.constant 42
+// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128
+// CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]]
+// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
+func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) {
+ %maxIntCst = arith.constant 127 : i8
+ %minIntCst = arith.constant -128 : i8
+ %c0 = arith.constant 42 : i8
+ %0 = arith.minsi %arg0, %arg0 : i8
+ %1 = arith.minsi %arg0, %maxIntCst : i8
+ %2 = arith.minsi %arg0, %minIntCst : i8
+ %3 = arith.minsi %arg0, %c0 : i8
+ return %0, %1, %2, %3: i8, i8, i8, i8
+}
+
+// -----
+
+// CHECK-LABEL: test_minui
+// CHECK: %[[C0:.+]] = arith.constant 42
+// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0
+// CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]]
+// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
+func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
+ %maxIntCst = arith.constant 255 : i8
+ %minIntCst = arith.constant 0 : i8
+ %c0 = arith.constant 42 : i8
+ %0 = arith.minui %arg0, %arg0 : i8
+ %1 = arith.minui %arg0, %maxIntCst : i8
+ %2 = arith.minui %arg0, %minIntCst : i8
+ %3 = arith.minui %arg0, %c0 : i8
+ return %0, %1, %2, %3: i8, i8, i8, i8
+}
diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
index a1bd39208be3..2f14178e88f2 100644
--- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir
@@ -145,3 +145,92 @@ func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index
// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : index
}
+
+// -----
+
+// CHECK-LABEL: func @maxf
+func @maxf(%a: f32, %b: f32) -> f32 {
+ %result = arith.maxf %a, %b : f32
+ return %result : f32
+}
+// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: func @maxf_vector
+func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
+ %result = arith.maxf %a, %b : vector<4xf16>
+ return %result : vector<4xf16>
+}
+// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16>
+// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]]
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16>
+// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7E00 : f16
+// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16>
+// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]]
+// CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
+
+// -----
+
+// CHECK-LABEL: func @minf
+func @minf(%a: f32, %b: f32) -> f32 {
+ %result = arith.minf %a, %b : f32
+ return %result : f32
+}
+// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32
+// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
+// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
+// CHECK-LABEL: func @maxsi
+func @maxsi(%a: i32, %b: i32) -> i32 {
+ %result = arith.maxsi %a, %b : i32
+ return %result : i32
+}
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32
+
+// -----
+
+// CHECK-LABEL: func @minsi
+func @minsi(%a: i32, %b: i32) -> i32 {
+ %result = arith.minsi %a, %b : i32
+ return %result : i32
+}
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32
+
+
+// -----
+
+// CHECK-LABEL: func @maxui
+func @maxui(%a: i32, %b: i32) -> i32 {
+ %result = arith.maxui %a, %b : i32
+ return %result : i32
+}
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32
+
+
+// -----
+
+// CHECK-LABEL: func @minui
+func @minui(%a: i32, %b: i32) -> i32 {
+ %result = arith.minui %a, %b : i32
+ return %result : i32
+}
+// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir
index 205f6ee9aedf..54a1014eb6e2 100644
--- a/mlir/test/Dialect/Arithmetic/ops.mlir
+++ b/mlir/test/Dialect/Arithmetic/ops.mlir
@@ -704,3 +704,25 @@ func @test_constant() -> () {
return
}
+
+// CHECK-LABEL: func @maximum
+func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
+ %f1: f32, %f2: f32,
+ %i1: i32, %i2: i32) {
+ %max_vector = arith.maxf %v1, %v2 : vector<4xf32>
+ %max_float = arith.maxf %f1, %f2 : f32
+ %max_signed = arith.maxsi %i1, %i2 : i32
+ %max_unsigned = arith.maxui %i1, %i2 : i32
+ return
+}
+
+// CHECK-LABEL: func @minimum
+func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
+ %f1: f32, %f2: f32,
+ %i1: i32, %i2: i32) {
+ %min_vector = arith.minf %v1, %v2 : vector<4xf32>
+ %min_float = arith.minf %f1, %f2 : f32
+ %min_signed = arith.minsi %i1, %i2 : i32
+ %min_unsigned = arith.minui %i1, %i2 : i32
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 652286f98184..fa9be9f4d89b 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -940,7 +940,7 @@ func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> ten
iterator_types = ["parallel", "reduction"]
} ins(%5 : tensor<?x?xf32>) outs(%7 : tensor<?xf32>) {
^bb0(%arg2: f32, %arg3: f32): // no predecessors
- %9 = maxf %arg2, %arg3 : f32
+ %9 = arith.maxf %arg2, %arg3 : f32
linalg.yield %9 : f32
} -> tensor<?xf32>
return %8 : tensor<?xf32>
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
index 1a1fa0729592..cccb38150a98 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -111,7 +111,7 @@ func @generalize_pooling_nhwc_max_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
// CHECK-LABEL: @generalize_pooling_nhwc_max_f32
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
-// CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT: %[[MAX:.+]] = arith.maxf %[[OUT_ARG]], %[[IN_ARG]] : f32
// CHECK-NEXT: linalg.yield %[[MAX]] : f32
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
@@ -125,7 +125,7 @@ func @generalize_pooling_nhwc_max_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
// CHECK-LABEL: @generalize_pooling_nhwc_max_i32
// Verify signed integer maximum.
-// CHECK: = maxsi
+// CHECK: = arith.maxsi
// -----
@@ -137,7 +137,7 @@ func @generalize_pooling_nhwc_max_unsigned_i32(%input : tensor<1x4x16x1xi32>, %s
// CHECK-LABEL: @generalize_pooling_nhwc_max_unsigned_i32
// Verify unsigned integer minimum.
-// CHECK: = maxui
+// CHECK: = arith.maxui
// -----
@@ -149,7 +149,7 @@ func @generalize_pooling_nhwc_min_f32(%input : tensor<1x4x16x1xf32>, %shape: ten
// CHECK-LABEL: @generalize_pooling_nhwc_min_f32
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
-// CHECK-NEXT: %[[MIN:.+]] = minf %[[OUT_ARG]], %[[IN_ARG]] : f32
+// CHECK-NEXT: %[[MIN:.+]] = arith.minf %[[OUT_ARG]], %[[IN_ARG]] : f32
// CHECK-NEXT: linalg.yield %[[MIN]] : f32
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
@@ -163,7 +163,7 @@ func @generalize_pooling_nhwc_min_i32(%input : tensor<1x4x16x1xi32>, %shape: ten
// CHECK-LABEL: @generalize_pooling_nhwc_min_i32
// Verify signed integer minimum.
-// CHECK: = minsi
+// CHECK: = arith.minsi
// -----
@@ -175,7 +175,7 @@ func @generalize_pooling_nhwc_min_unsigned_i32(%input : tensor<1x4x16x1xi32>, %s
// CHECK-LABEL: @generalize_pooling_nhwc_min_unsigned_i32
// Verify unsigned integer minimum.
-// CHECK: = minui
+// CHECK: = arith.minui
// -----
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 7e84ee92ff8b..c055ef47a36d 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -819,7 +819,7 @@ func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
iterator_types = ["parallel", "reduction"]}
ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
^bb0(%in0: f32, %out0: f32): // no predecessors
- %max = maxf %in0, %out0 : f32
+ %max = arith.maxf %in0, %out0 : f32
linalg.yield %max : f32
} -> tensor<4xf32>
return %red : tensor<4xf32>
@@ -834,7 +834,7 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
// CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
// CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind<minf>, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
- // CHECK: minf %[[R]], %[[CMAXF]] : vector<4xf32>
+ // CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
%maxf32 = arith.constant 3.40282e+38 : f32
%init = linalg.init_tensor [4] : tensor<4xf32>
@@ -844,7 +844,7 @@ func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
iterator_types = ["parallel", "reduction"]}
ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) {
^bb0(%in0: f32, %out0: f32): // no predecessors
- %min = minf %out0, %in0 : f32
+ %min = arith.minf %out0, %in0 : f32
linalg.yield %min : f32
} -> tensor<4xf32>
return %red : tensor<4xf32>
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 2ba3fe1fa600..875d9f7bc4fa 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -60,68 +60,3 @@ func @selToNot(%arg0: i1) -> i1 {
%res = select %arg0, %false, %true : i1
return %res : i1
}
-
-// CHECK-LABEL: test_maxsi
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127
-// CHECK: %[[X:.+]] = maxsi %arg0, %[[C0]]
-// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
-func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) {
- %maxIntCst = arith.constant 127 : i8
- %minIntCst = arith.constant -128 : i8
- %c0 = arith.constant 42 : i8
- %0 = maxsi %arg0, %arg0 : i8
- %1 = maxsi %arg0, %maxIntCst : i8
- %2 = maxsi %arg0, %minIntCst : i8
- %3 = maxsi %arg0, %c0 : i8
- return %0, %1, %2, %3: i8, i8, i8, i8
-}
-
-// CHECK-LABEL: test_maxui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1
-// CHECK: %[[X:.+]] = maxui %arg0, %[[C0]]
-// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]]
-func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) {
- %maxIntCst = arith.constant 255 : i8
- %minIntCst = arith.constant 0 : i8
- %c0 = arith.constant 42 : i8
- %0 = maxui %arg0, %arg0 : i8
- %1 = maxui %arg0, %maxIntCst : i8
- %2 = maxui %arg0, %minIntCst : i8
- %3 = maxui %arg0, %c0 : i8
- return %0, %1, %2, %3: i8, i8, i8, i8
-}
-
-
-// CHECK-LABEL: test_minsi
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128
-// CHECK: %[[X:.+]] = minsi %arg0, %[[C0]]
-// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
-func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) {
- %maxIntCst = arith.constant 127 : i8
- %minIntCst = arith.constant -128 : i8
- %c0 = arith.constant 42 : i8
- %0 = minsi %arg0, %arg0 : i8
- %1 = minsi %arg0, %maxIntCst : i8
- %2 = minsi %arg0, %minIntCst : i8
- %3 = minsi %arg0, %c0 : i8
- return %0, %1, %2, %3: i8, i8, i8, i8
-}
-
-// CHECK-LABEL: test_minui
-// CHECK: %[[C0:.+]] = arith.constant 42
-// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0
-// CHECK: %[[X:.+]] = minui %arg0, %[[C0]]
-// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]]
-func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) {
- %maxIntCst = arith.constant 255 : i8
- %minIntCst = arith.constant 0 : i8
- %c0 = arith.constant 42 : i8
- %0 = minui %arg0, %arg0 : i8
- %1 = minui %arg0, %maxIntCst : i8
- %2 = minui %arg0, %minIntCst : i8
- %3 = minui %arg0, %c0 : i8
- return %0, %1, %2, %3: i8, i8, i8, i8
-}
diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir
index e46132964a01..45659aee0763 100644
--- a/mlir/test/Dialect/Standard/expand-ops.mlir
+++ b/mlir/test/Dialect/Standard/expand-ops.mlir
@@ -52,92 +52,3 @@ func @memref_reshape(%input: memref<*xf32>,
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
-
-// -----
-
-// CHECK-LABEL: func @maxf
-func @maxf(%a: f32, %b: f32) -> f32 {
- %result = maxf(%a, %b): (f32, f32) -> f32
- return %result : f32
-}
-// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
-// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
-// CHECK-LABEL: func @maxf_vector
-func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
- %result = maxf(%a, %b): (vector<4xf16>, vector<4xf16>) -> vector<4xf16>
- return %result : vector<4xf16>
-}
-// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16>
-// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]]
-// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16>
-// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7E00 : f16
-// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16>
-// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]]
-// CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
-
-// -----
-
-// CHECK-LABEL: func @minf
-func @minf(%a: f32, %b: f32) -> f32 {
- %result = minf(%a, %b): (f32, f32) -> f32
- return %result : f32
-}
-// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32
-// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32
-// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-
-// -----
-
-// CHECK-LABEL: func @maxsi
-func @maxsi(%a: i32, %b: i32) -> i32 {
- %result = maxsi(%a, %b): (i32, i32) -> i32
- return %result : i32
-}
-// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32
-
-// -----
-
-// CHECK-LABEL: func @minsi
-func @minsi(%a: i32, %b: i32) -> i32 {
- %result = minsi(%a, %b): (i32, i32) -> i32
- return %result : i32
-}
-// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32
-
-
-// -----
-
-// CHECK-LABEL: func @maxui
-func @maxui(%a: i32, %b: i32) -> i32 {
- %result = maxui(%a, %b): (i32, i32) -> i32
- return %result : i32
-}
-// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32
-
-
-// -----
-
-// CHECK-LABEL: func @minui
-func @minui(%a: i32, %b: i32) -> i32 {
- %result = minui(%a, %b): (i32, i32) -> i32
- return %result : i32
-}
-// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
-// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index a5f8a717067a..c3b40be816b5 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -62,27 +62,3 @@ func @constant_complex_f64() -> complex<f64> {
%result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
return %result : complex<f64>
}
-
-// CHECK-LABEL: func @maximum
-func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>,
- %f1: f32, %f2: f32,
- %i1: i32, %i2: i32) {
- %max_vector = maxf(%v1, %v2)
- : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
- %max_float = maxf(%f1, %f2) : (f32, f32) -> f32
- %max_signed = maxsi(%i1, %i2) : (i32, i32) -> i32
- %max_unsigned = maxui(%i1, %i2) : (i32, i32) -> i32
- return
-}
-
-// CHECK-LABEL: func @minimum
-func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>,
- %f1: f32, %f2: f32,
- %i1: i32, %i2: i32) {
- %min_vector = minf(%v1, %v2)
- : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
- %min_float = minf(%f1, %f2) : (f32, f32) -> f32
- %min_signed = minsi(%i1, %i2) : (i32, i32) -> i32
- %min_unsigned = minui(%i1, %i2) : (i32, i32) -> i32
- return
-}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 119594059632..3f0e184274dc 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -27,11 +27,11 @@ func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = minf %[[V1]], %[[V0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.minf %[[V1]], %[[V0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = minf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[RV012:.+]] = arith.minf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = minf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.minf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
@@ -44,11 +44,11 @@ func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> {
// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32>
// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32>
-// CHECK: %[[RV01:.+]] = maxf %[[V1]], %[[V0]] : vector<2xf32>
+// CHECK: %[[RV01:.+]] = arith.maxf %[[V1]], %[[V0]] : vector<2xf32>
// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32>
-// CHECK: %[[RV012:.+]] = maxf %[[V2]], %[[RV01]] : vector<2xf32>
+// CHECK: %[[RV012:.+]] = arith.maxf %[[V2]], %[[RV01]] : vector<2xf32>
// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32>
-// CHECK: %[[RESULT_VEC:.+]] = maxf %[[V3]], %[[RV012]] : vector<2xf32>
+// CHECK: %[[RESULT_VEC:.+]] = arith.maxf %[[V3]], %[[RV012]] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]] : vector<2xf32>
func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> {
diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
index 7528319c52ac..7c50fbddcd22 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir
@@ -6,6 +6,7 @@
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
+// RUN: -arith-expand \
// RUN: -std-expand \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-memref-to-llvm \
diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
index 4ca1bc9029aa..3537616d839d 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir
@@ -6,6 +6,7 @@
// RUN: -convert-async-to-llvm \
// RUN: -convert-linalg-to-loops \
// RUN: -convert-scf-to-std \
+// RUN: -arith-expand \
// RUN: -std-expand \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-memref-to-llvm \
@@ -26,6 +27,7 @@
// RUN: -convert-async-to-llvm \
// RUN: -convert-linalg-to-loops \
// RUN: -convert-scf-to-std \
+// RUN: -arith-expand \
// RUN: -std-expand \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-memref-to-llvm \
diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
index e2fa535cbebd..4497fb0944e9 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -5,6 +5,7 @@
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-memref-to-llvm \
+// RUN: -arith-expand \
// RUN: -std-expand \
// RUN: -convert-std-to-llvm \
// RUN: -reconcile-unrealized-casts \
@@ -20,6 +21,7 @@
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-memref-to-llvm \
+// RUN: -arith-expand \
// RUN: -std-expand \
// RUN: -convert-std-to-llvm \
// RUN: -reconcile-unrealized-casts \
@@ -31,13 +33,14 @@
// RUN: mlir-opt %s -async-parallel-for="async-dispatch=false \
// RUN: num-workers=20 \
-// RUN: min-task-size=1" \
+// RUN: min-task-size=1" \
// RUN: -async-to-async-runtime \
// RUN: -async-runtime-ref-counting \
// RUN: -async-runtime-ref-counting-opt \
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-memref-to-llvm \
+// RUN: -arith-expand \
// RUN: -std-expand \
// RUN: -convert-std-to-llvm \
// RUN: -reconcile-unrealized-casts \
diff --git a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
index d55d39b93dbe..d9c36c807d03 100644
--- a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
+++ b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
@@ -1,4 +1,7 @@
-// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -std-expand -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std \
+// RUN: -std-expand -arith-expand -convert-vector-to-llvm \
+// RUN: -convert-memref-to-llvm -convert-std-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s
diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index 67ab176860ab..71dc8a5474aa 100644
--- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -295,7 +295,7 @@ def test_f32i32_conv(input, filter, init_result):
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32)
# CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32
- # CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
+ # CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32
# CHECK-NEXT: linalg.yield %[[MAX]] : i32
# CHECK-NEXT: -> tensor<2x4xi32>
@builtin.FuncOp.from_py_func(
@@ -307,7 +307,7 @@ def test_f32i32_max_pooling(input, shape, init_result):
# CHECK-LABEL: @test_f32i32_max_unsigned_pooling
# CHECK: = arith.fptoui
- # CHECK: = maxui
+ # CHECK: = arith.maxui
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
RankedTensorType.get((2, 4), i32))
@@ -320,7 +320,7 @@ def test_f32i32_max_unsigned_pooling(input, shape, init_result):
# CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]]
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"]
# CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32)
- # CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT]], %[[IN:.+]] : f32
+ # CHECK-NEXT: %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32
# CHECK-NEXT: linalg.yield %[[MAX]] : f32
# CHECK-NEXT: -> tensor<2x4xf32>
@builtin.FuncOp.from_py_func(
@@ -332,7 +332,7 @@ def test_f32f32_max_pooling(input, shape, init_result):
# CHECK-LABEL: @test_f32i32_min_pooling
# CHECK: = arith.fptosi
- # CHECK: = minsi
+ # CHECK: = arith.minsi
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
RankedTensorType.get((2, 4), i32))
@@ -342,7 +342,7 @@ def test_f32i32_min_pooling(input, shape, init_result):
# CHECK-LABEL: @test_f32i32_min_unsigned_pooling
# CHECK: = arith.fptoui
- # CHECK: = minui
+ # CHECK: = arith.minui
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
RankedTensorType.get((2, 4), i32))
@@ -351,7 +351,7 @@ def test_f32i32_min_unsigned_pooling(input, shape, init_result):
input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2])
# CHECK-LABEL: @test_f32f32_min_pooling
- # CHECK: = minf
+ # CHECK: = arith.minf
@builtin.FuncOp.from_py_func(
RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32),
RankedTensorType.get((2, 4), f32))
diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py
index c28e59bcdb54..d8b9d978c3ae 100644
--- a/mlir/test/python/integration/dialects/linalg/opsrun.py
+++ b/mlir/test/python/integration/dialects/linalg/opsrun.py
@@ -128,7 +128,7 @@ def transform(module, boilerplate):
boilerplate)
pm = PassManager.parse(
"builtin.func(convert-linalg-to-loops, lower-affine, " +
- "convert-scf-to-std, std-expand), convert-vector-to-llvm," +
+ "convert-scf-to-std, arith-expand, std-expand), convert-vector-to-llvm," +
"convert-memref-to-llvm, convert-std-to-llvm," +
"reconcile-unrealized-casts")
pm.run(mod)
More information about the Mlir-commits
mailing list