[Mlir-commits] [mlir] [mlir][spirv] Add folding for [I|Logical][Not]Equal (PR #74194)

Jakub Kuderski llvmlistbot at llvm.org
Wed Dec 6 14:58:56 PST 2023


================
@@ -309,19 +309,62 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), true);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), true);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? (zero + 1) : zero;
+                                        });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalNotEqualOp
 //===----------------------------------------------------------------------===//
 
 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
   if (std::optional<bool> rhs =
           getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
-    // x && false = x
+    // x != false -> x
     if (!rhs.value())
       return getOperand1();
   }
 
-  return Attribute();
+  // x == x -> false
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), false);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), false);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? zero : (zero + 1);
----------------
kuhar wrote:

same here

https://github.com/llvm/llvm-project/pull/74194


More information about the Mlir-commits mailing list