[Mlir-commits] [mlir] [mlir][spirv] Add folding for [I|Logical][Not]Equal (PR #74194)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Dec 6 14:58:55 PST 2023
================
@@ -309,19 +309,62 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
+ // x == x -> true
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), true);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), true);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? (zero + 1) : zero;
+ });
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalNotEqualOp
//===----------------------------------------------------------------------===//
OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
if (std::optional<bool> rhs =
getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
- // x && false = x
+ // x != false -> x
if (!rhs.value())
return getOperand1();
}
- return Attribute();
+ // x == x -> false
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), false);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), false);
+ return DenseElementsAttr::get(vtType, element);
+ }
----------------
kuhar wrote:
same here
https://github.com/llvm/llvm-project/pull/74194
More information about the Mlir-commits
mailing list