[Mlir-commits] [mlir] c8544d2 - [mlir] Support attributes in `matchPattern`
Jakub Kuderski
llvmlistbot at llvm.org
Tue Sep 5 11:55:20 PDT 2023
Author: Jakub Kuderski
Date: 2023-09-05T14:52:27-04:00
New Revision: c8544d280a8ef1e41efc5c036e00dc671edbb93d
URL: https://github.com/llvm/llvm-project/commit/c8544d280a8ef1e41efc5c036e00dc671edbb93d
DIFF: https://github.com/llvm/llvm-project/commit/c8544d280a8ef1e41efc5c036e00dc671edbb93d.diff
LOG: [mlir] Support attributes in `matchPattern`
The primary motivation is to we have a simple mechanism to extract
values from attributes in folders and canon patterns without having to
re-fold constants or write nested conditions over attribute types.
Matching over attributes composes especially well with fold adaptors.
Update folds in Arith and SPIRV dialects to match over attributes, where
applicable.
Reviewed By: mehdi_amini, zero9178
Differential Revision: https://reviews.llvm.org/D159437
Added:
Modified:
mlir/include/mlir/IR/Matchers.h
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 2361a541efc222b..f6417f62d09e8c0 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -15,6 +15,7 @@
#ifndef MLIR_IR_MATCHERS_H
#define MLIR_IR_MATCHERS_H
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
@@ -38,7 +39,7 @@ struct attr_value_binder {
/// Creates a matcher instance that binds the value to bv if match succeeds.
attr_value_binder(ValueType *bv) : bind_value(bv) {}
- bool match(const Attribute &attr) {
+ bool match(Attribute attr) {
if (auto intAttr = llvm::dyn_cast<AttrClass>(attr)) {
*bind_value = intAttr.getValue();
return true;
@@ -123,27 +124,33 @@ struct AttrOpBinder {
};
/// The matcher that matches a constant scalar / vector splat / tensor splat
-/// float operation and binds the constant float value.
-struct constant_float_op_binder {
+/// float Attribute or Operation and binds the constant float value.
+struct constant_float_value_binder {
FloatAttr::ValueType *bind_value;
/// Creates a matcher instance that binds the value to bv if match succeeds.
- constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
+ constant_float_value_binder(FloatAttr::ValueType *bv) : bind_value(bv) {}
+
+ bool match(Attribute attr) {
+ attr_value_binder<FloatAttr> matcher(bind_value);
+ if (matcher.match(attr))
+ return true;
+
+ if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
+ return matcher.match(splatAttr.getSplatValue<Attribute>());
+
+ return false;
+ }
bool match(Operation *op) {
Attribute attr;
if (!constant_op_binder<Attribute>(&attr).match(op))
return false;
- auto type = op->getResult(0).getType();
-
- if (llvm::isa<FloatType>(type))
- return attr_value_binder<FloatAttr>(bind_value).match(attr);
- if (llvm::isa<VectorType, RankedTensorType>(type)) {
- if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr)) {
- return attr_value_binder<FloatAttr>(bind_value)
- .match(splatAttr.getSplatValue<Attribute>());
- }
- }
+
+ Type type = op->getResult(0).getType();
+ if (isa<FloatType, VectorType, RankedTensorType>(type))
+ return match(attr);
+
return false;
}
};
@@ -153,34 +160,45 @@ struct constant_float_op_binder {
struct constant_float_predicate_matcher {
bool (*predicate)(const APFloat &);
+ bool match(Attribute attr) {
+ APFloat value(APFloat::Bogus());
+ return constant_float_value_binder(&value).match(attr) && predicate(value);
+ }
+
bool match(Operation *op) {
APFloat value(APFloat::Bogus());
- return constant_float_op_binder(&value).match(op) && predicate(value);
+ return constant_float_value_binder(&value).match(op) && predicate(value);
}
};
/// The matcher that matches a constant scalar / vector splat / tensor splat
-/// integer operation and binds the constant integer value.
-struct constant_int_op_binder {
+/// integer Attribute or Operation and binds the constant integer value.
+struct constant_int_value_binder {
IntegerAttr::ValueType *bind_value;
/// Creates a matcher instance that binds the value to bv if match succeeds.
- constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
+ constant_int_value_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {}
+
+ bool match(Attribute attr) {
+ attr_value_binder<IntegerAttr> matcher(bind_value);
+ if (matcher.match(attr))
+ return true;
+
+ if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr))
+ return matcher.match(splatAttr.getSplatValue<Attribute>());
+
+ return false;
+ }
bool match(Operation *op) {
Attribute attr;
if (!constant_op_binder<Attribute>(&attr).match(op))
return false;
- auto type = op->getResult(0).getType();
-
- if (llvm::isa<IntegerType, IndexType>(type))
- return attr_value_binder<IntegerAttr>(bind_value).match(attr);
- if (llvm::isa<VectorType, RankedTensorType>(type)) {
- if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr)) {
- return attr_value_binder<IntegerAttr>(bind_value)
- .match(splatAttr.getSplatValue<Attribute>());
- }
- }
+
+ Type type = op->getResult(0).getType();
+ if (isa<IntegerType, IndexType, VectorType, RankedTensorType>(type))
+ return match(attr);
+
return false;
}
};
@@ -190,9 +208,14 @@ struct constant_int_op_binder {
struct constant_int_predicate_matcher {
bool (*predicate)(const APInt &);
+ bool match(Attribute attr) {
+ APInt value;
+ return constant_int_value_binder(&value).match(attr) && predicate(value);
+ }
+
bool match(Operation *op) {
APInt value;
- return constant_int_op_binder(&value).match(op) && predicate(value);
+ return constant_int_value_binder(&value).match(op) && predicate(value);
}
};
@@ -203,14 +226,14 @@ struct op_matcher {
};
/// Trait to check whether T provides a 'match' method with type
-/// `OperationOrValue`.
-template <typename T, typename OperationOrValue>
-using has_operation_or_value_matcher_t =
- decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
+/// `MatchTarget` (Value, Operation, or Attribute).
+template <typename T, typename MatchTarget>
+using has_compatible_matcher_t =
+ decltype(std::declval<T>().match(std::declval<MatchTarget>()));
/// Statically switch to a Value matcher.
template <typename MatcherClass>
-std::enable_if_t<llvm::is_detected<detail::has_operation_or_value_matcher_t,
+std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
MatcherClass, Value>::value,
bool>
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
@@ -219,7 +242,7 @@ matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
/// Statically switch to an Operation matcher.
template <typename MatcherClass>
-std::enable_if_t<llvm::is_detected<detail::has_operation_or_value_matcher_t,
+std::enable_if_t<llvm::is_detected<detail::has_compatible_matcher_t,
MatcherClass, Operation *>::value,
bool>
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
@@ -376,6 +399,7 @@ inline detail::op_matcher<OpClass> m_Op() {
/// Entry point for matching a pattern over a Value.
template <typename Pattern>
inline bool matchPattern(Value value, const Pattern &pattern) {
+ assert(value);
// TODO: handle other cases
if (auto *op = value.getDefiningOp())
return const_cast<Pattern &>(pattern).match(op);
@@ -385,21 +409,34 @@ inline bool matchPattern(Value value, const Pattern &pattern) {
/// Entry point for matching a pattern over an Operation.
template <typename Pattern>
inline bool matchPattern(Operation *op, const Pattern &pattern) {
+ assert(op);
return const_cast<Pattern &>(pattern).match(op);
}
+/// Entry point for matching a pattern over an Attribute. Returns `false`
+/// when `attr` is null.
+template <typename Pattern>
+inline bool matchPattern(Attribute attr, const Pattern &pattern) {
+ static_assert(llvm::is_detected<detail::has_compatible_matcher_t, Pattern,
+ Attribute>::value,
+ "Pattern does not support matching Attributes");
+ if (!attr)
+ return false;
+ return const_cast<Pattern &>(pattern).match(attr);
+}
+
/// Matches a constant holding a scalar/vector/tensor float (splat) and
/// writes the float value to bind_value.
-inline detail::constant_float_op_binder
+inline detail::constant_float_value_binder
m_ConstantFloat(FloatAttr::ValueType *bind_value) {
- return detail::constant_float_op_binder(bind_value);
+ return detail::constant_float_value_binder(bind_value);
}
/// Matches a constant holding a scalar/vector/tensor integer (splat) and
/// writes the integer value to bind_value.
-inline detail::constant_int_op_binder
+inline detail::constant_int_value_binder
m_ConstantInt(IntegerAttr::ValueType *bind_value) {
- return detail::constant_int_op_binder(bind_value);
+ return detail::constant_int_value_binder(bind_value);
}
template <typename OpType, typename... Matchers>
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 007b105d2328c79..c87b4185722fb01 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -106,12 +106,9 @@ static int64_t getScalarOrElementWidth(Value value) {
}
static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
- if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
- return intAttr.getValue();
-
- if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
- if (llvm::isa<IntegerType>(splatAttr.getElementType()))
- return splatAttr.getSplatValue<APInt>();
+ APInt value;
+ if (matchPattern(attr, m_ConstantInt(&value)))
+ return value;
return failure();
}
@@ -258,7 +255,7 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
// addi(x, 0) -> x
- if (matchPattern(getRhs(), m_Zero()))
+ if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// addi(subi(a, b), b) -> a
@@ -349,7 +346,7 @@ OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
if (getOperand(0) == getOperand(1))
return Builder(getContext()).getZeroAttr(getType());
// subi(x,0) -> x
- if (matchPattern(getRhs(), m_Zero()))
+ if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
if (auto add = getLhs().getDefiningOp<AddIOp>()) {
@@ -379,11 +376,11 @@ void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
// muli(x, 0) -> 0
- if (matchPattern(getRhs(), m_Zero()))
+ if (matchPattern(adaptor.getRhs(), m_Zero()))
return getRhs();
// muli(x, 1) -> x
- if (matchPattern(getRhs(), m_One()))
- return getOperand(0);
+ if (matchPattern(adaptor.getRhs(), m_One()))
+ return getLhs();
// TODO: Handle the overflow case.
// default folder
@@ -412,7 +409,7 @@ LogicalResult
arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
// mulsi_extended(x, 0) -> 0, 0
- if (matchPattern(getRhs(), m_Zero())) {
+ if (matchPattern(adaptor.getRhs(), m_Zero())) {
Attribute zero = adaptor.getRhs();
results.push_back(zero);
results.push_back(zero);
@@ -460,7 +457,7 @@ LogicalResult
arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
// mului_extended(x, 0) -> 0, 0
- if (matchPattern(getRhs(), m_Zero())) {
+ if (matchPattern(adaptor.getRhs(), m_Zero())) {
Attribute zero = adaptor.getRhs();
results.push_back(zero);
results.push_back(zero);
@@ -468,7 +465,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
}
// mului_extended(x, 1) -> x, 0
- if (matchPattern(getRhs(), m_One())) {
+ if (matchPattern(adaptor.getRhs(), m_One())) {
Builder builder(getContext());
Attribute zero = builder.getZeroAttr(getLhs().getType());
results.push_back(getLhs());
@@ -508,7 +505,7 @@ void arith::MulUIExtendedOp::getCanonicalizationPatterns(
OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
// divui (x, 1) -> x.
- if (matchPattern(getRhs(), m_One()))
+ if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
// Don't fold if it would require a division by zero.
@@ -537,7 +534,7 @@ Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
// divsi (x, 1) -> x.
- if (matchPattern(getRhs(), m_One()))
+ if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
// Don't fold if it would overflow or if it requires a division by zero.
@@ -584,7 +581,7 @@ static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
// ceildivui (x, 1) -> x.
- if (matchPattern(getRhs(), m_One()))
+ if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
bool overflowOrDiv0 = false;
@@ -616,7 +613,7 @@ Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
// ceildivsi (x, 1) -> x.
- if (matchPattern(getRhs(), m_One()))
+ if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
// Don't fold if it would overflow or if it requires a division by zero.
@@ -677,7 +674,7 @@ Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
// floordivsi (x, 1) -> x.
- if (matchPattern(getRhs(), m_One()))
+ if (matchPattern(adaptor.getRhs(), m_One()))
return getLhs();
// Don't fold if it would overflow or if it requires a division by zero.
@@ -726,7 +723,7 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
// remui (x, 1) -> 0.
- if (matchPattern(getRhs(), m_One()))
+ if (matchPattern(adaptor.getRhs(), m_One()))
return Builder(getContext()).getZeroAttr(getType());
// Don't fold if it would require a division by zero.
@@ -749,7 +746,7 @@ OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
// remsi (x, 1) -> 0.
- if (matchPattern(getRhs(), m_One()))
+ if (matchPattern(adaptor.getRhs(), m_One()))
return Builder(getContext()).getZeroAttr(getType());
// Don't fold if it would require a division by zero.
@@ -789,11 +786,12 @@ static Value foldAndIofAndI(arith::AndIOp op) {
OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
/// and(x, 0) -> 0
- if (matchPattern(getRhs(), m_Zero()))
+ if (matchPattern(adaptor.getRhs(), m_Zero()))
return getRhs();
/// and(x, allOnes) -> x
APInt intValue;
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
+ if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
+ intValue.isAllOnes())
return getLhs();
/// and(x, not(x)) -> 0
if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
@@ -820,13 +818,14 @@ OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
- /// or(x, 0) -> x
- if (matchPattern(getRhs(), m_Zero()))
- return getLhs();
- /// or(x, <all ones>) -> <all ones>
- if (auto rhsAttr = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs()))
- if (rhsAttr.getValue().isAllOnes())
- return rhsAttr;
+ if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
+ /// or(x, 0) -> x
+ if (rhsVal.isZero())
+ return getLhs();
+ /// or(x, <all ones>) -> <all ones>
+ if (rhsVal.isAllOnes())
+ return adaptor.getRhs();
+ }
APInt intValue;
/// or(x, xor(x, 1)) -> 1
@@ -851,7 +850,7 @@ OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
/// xor(x, 0) -> x
- if (matchPattern(getRhs(), m_Zero()))
+ if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
/// xor(x, x) -> 0
if (getLhs() == getRhs())
@@ -901,7 +900,7 @@ OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
// addf(x, -0) -> x
- if (matchPattern(getRhs(), m_NegZeroFloat()))
+ if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
@@ -915,7 +914,7 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
// subf(x, +0) -> x
- if (matchPattern(getRhs(), m_PosZeroFloat()))
+ if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
@@ -933,7 +932,7 @@ OpFoldResult arith::MaxFOp::fold(FoldAdaptor adaptor) {
return getRhs();
// maxf(x, -inf) -> x
- if (matchPattern(getRhs(), m_NegInfFloat()))
+ if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
@@ -950,16 +949,15 @@ OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- APInt intValue;
- // maxsi(x,MAX_INT) -> MAX_INT
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
- intValue.isMaxSignedValue())
- return getRhs();
-
- // maxsi(x, MIN_INT) -> x
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
- intValue.isMinSignedValue())
- return getLhs();
+ if (APInt intValue;
+ matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
+ // maxsi(x,MAX_INT) -> MAX_INT
+ if (intValue.isMaxSignedValue())
+ return getRhs();
+ // maxsi(x, MIN_INT) -> x
+ if (intValue.isMinSignedValue())
+ return getLhs();
+ }
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
@@ -976,14 +974,15 @@ OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- APInt intValue;
- // maxui(x,MAX_INT) -> MAX_INT
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
- return getRhs();
-
- // maxui(x, MIN_INT) -> x
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
- return getLhs();
+ if (APInt intValue;
+ matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
+ // maxui(x,MAX_INT) -> MAX_INT
+ if (intValue.isMaxValue())
+ return getRhs();
+ // maxui(x, MIN_INT) -> x
+ if (intValue.isMinValue())
+ return getLhs();
+ }
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
@@ -1001,7 +1000,7 @@ OpFoldResult arith::MinFOp::fold(FoldAdaptor adaptor) {
return getRhs();
// minf(x, +inf) -> x
- if (matchPattern(getRhs(), m_PosInfFloat()))
+ if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
@@ -1018,16 +1017,15 @@ OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- APInt intValue;
- // minsi(x,MIN_INT) -> MIN_INT
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
- intValue.isMinSignedValue())
- return getRhs();
-
- // minsi(x, MAX_INT) -> x
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
- intValue.isMaxSignedValue())
- return getLhs();
+ if (APInt intValue;
+ matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
+ // minsi(x,MIN_INT) -> MIN_INT
+ if (intValue.isMinSignedValue())
+ return getRhs();
+ // minsi(x, MAX_INT) -> x
+ if (intValue.isMaxSignedValue())
+ return getLhs();
+ }
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
@@ -1044,14 +1042,15 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
if (getLhs() == getRhs())
return getRhs();
- APInt intValue;
- // minui(x,MIN_INT) -> MIN_INT
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
- return getRhs();
-
- // minui(x, MAX_INT) -> x
- if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
- return getLhs();
+ if (APInt intValue;
+ matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
+ // minui(x,MIN_INT) -> MIN_INT
+ if (intValue.isMinValue())
+ return getRhs();
+ // minui(x, MAX_INT) -> x
+ if (intValue.isMaxValue())
+ return getLhs();
+ }
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
[](const APInt &a, const APInt &b) {
@@ -1065,7 +1064,7 @@ OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
// mulf(x, 1) -> x
- if (matchPattern(getRhs(), m_OneFloat()))
+ if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
@@ -1084,7 +1083,7 @@ void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
// divf(x, 1) -> x
- if (matchPattern(getRhs(), m_OneFloat()))
+ if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();
return constFoldBinaryOp<FloatAttr>(
@@ -1685,7 +1684,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
return getBoolAttribute(getType(), getContext(), val);
}
- if (matchPattern(getRhs(), m_Zero())) {
+ if (matchPattern(adaptor.getRhs(), m_Zero())) {
if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
// extsi(%x : i1 -> iN) != 0 -> %x
std::optional<int64_t> integerWidth =
@@ -2188,11 +2187,11 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
Value condition = getCondition();
// select true, %0, %1 => %0
- if (matchPattern(condition, m_One()))
+ if (matchPattern(adaptor.getCondition(), m_One()))
return trueVal;
// select false, %0, %1 => %1
- if (matchPattern(condition, m_Zero()))
+ if (matchPattern(adaptor.getCondition(), m_Zero()))
return falseVal;
// If either operand is fully poisoned, return the other.
@@ -2203,8 +2202,8 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
return trueVal;
// select %x, true, false => %x
- if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
- matchPattern(getFalseValue(), m_Zero()))
+ if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) &&
+ matchPattern(adaptor.getFalseValue(), m_Zero()))
return condition;
if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
@@ -2313,7 +2312,7 @@ LogicalResult arith::SelectOp::verify() {
OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
// shli(x, 0) -> x
- if (matchPattern(getRhs(), m_Zero()))
+ if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
bool bounded = false;
@@ -2331,7 +2330,7 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
// shrui(x, 0) -> x
- if (matchPattern(getRhs(), m_Zero()))
+ if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
bool bounded = false;
@@ -2349,7 +2348,7 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
// shrsi(x, 0) -> x
- if (matchPattern(getRhs(), m_Zero()))
+ if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
// Don't fold if shifting more than the bit width.
bool bounded = false;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index f7ab3c0702a98b0..6ebd8515caf037d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
@@ -1966,27 +1967,12 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
// spirv.BtiwiseAndOp
//===----------------------------------------------------------------------===//
-static std::optional<APInt> extractIntConstant(Attribute attr) {
- IntegerAttr intAttr;
- if (auto splat = dyn_cast_if_present<SplatElementsAttr>(attr))
- intAttr = dyn_cast<IntegerAttr>(splat.getSplatValue<Attribute>());
- else
- intAttr = dyn_cast_if_present<IntegerAttr>(attr);
-
- if (!intAttr)
- return std::nullopt;
-
- return intAttr.getValue();
-}
-
OpFoldResult
spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
- std::optional<APInt> rhsVal = extractIntConstant(adaptor.getOperand2());
- if (!rhsVal)
+ APInt rhsMask;
+ if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
return {};
- APInt rhsMask = *rhsVal;
-
// x & 0 -> 0
if (rhsMask.isZero())
return getOperand2();
@@ -2011,12 +1997,10 @@ spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
- std::optional<APInt> rhsVal = extractIntConstant(adaptor.getOperand2());
- if (!rhsVal)
+ APInt rhsMask;
+ if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
return {};
- APInt rhsMask = *rhsVal;
-
// x | 0 -> x
if (rhsMask.isZero())
return getOperand1();
More information about the Mlir-commits
mailing list