[Mlir-commits] [mlir] [mlir][spirv] Add basic arithmetic folds (PR #71414)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 21 10:26:36 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Finn Plummer (inbelic)

<details>
<summary>Changes</summary>

We have missing basic constant folds for SPIR-V arithmetic operations which negatively impacts readability of lowered or otherwise generated code. This commit works to implementing them to improve the mentioned hindrances.

Resolves #<!-- -->70704

---

Patch is 69.34 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71414.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+17) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td (+10) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td (+8-1) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+673-3) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (-49) 
- (modified) mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir (+8-8) 
- (modified) mlir/test/Dialect/SPIRV/IR/bit-ops.mlir (+4-2) 
- (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+946-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef83..a73989c41c04cfb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
     %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -534,6 +536,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -573,6 +577,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -607,6 +613,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
     %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -634,6 +642,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
     %3 = spirv.SNegate %2 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -673,6 +683,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -707,6 +719,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -742,6 +756,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
     %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -811,6 +827,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
     ```
   }];
 
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 286f4de6f90f621..dbba4f7ec6cff76 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -412,6 +412,8 @@ def SPIRV_BitwiseXorOp : SPIRV_BitBinaryOp<"BitwiseXor",
     %2 = spirv.BitwiseXor %0, %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -457,6 +459,8 @@ def SPIRV_ShiftLeftLogicalOp : SPIRV_ShiftOp<"ShiftLeftLogical",
     %5 = spirv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -499,6 +503,8 @@ def SPIRV_ShiftRightArithmeticOp : SPIRV_ShiftOp<"ShiftRightArithmetic",
     %5 = spirv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -542,6 +548,8 @@ def SPIRV_ShiftRightLogicalOp : SPIRV_ShiftOp<"ShiftRightLogical",
     %5 = spirv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -573,6 +581,8 @@ def SPIRV_NotOp : SPIRV_BitUnaryOp<"Not", [UsableInSpecConstantOp]> {
     %3 = spirv.Not %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 #endif // MLIR_DIALECT_SPIRV_IR_BIT_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index cf38c15d20dc326..0053cd5fc9448b5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -473,6 +473,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -506,6 +508,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -644,6 +648,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
     %2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -713,7 +719,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
     %2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
     ```
   }];
-  let hasFolder = true;
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6d..ba2281d30bdb589 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
   return {};
 }
 
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+  bool div0 = b.isZero();
+
+  bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+  return div0 || overflow;
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'erated canonicalizers
 //===----------------------------------------------------------------------===//
@@ -115,6 +123,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
   results.add<CombineChainedAccessChain>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // iaddcarry (x, 0) = <0, x>
+    if (matchPattern(operands[1], m_Zero())) {
+      constituents.push_back(operands[1]);
+      constituents.push_back(operands[0]);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    //  Result Type must be from OpTypeStruct.  The struct must have two
+    //  members...
+    //
+    //  Member 0 of the result gets the low-order bits (full component width) of
+    //  the addition.
+    //
+    //  Member 1 of the result gets the high-order (carry) bit of the result of
+    //  the addition. That is, it gets the value 1 if the addition overflowed
+    //  the component width, and 0 otherwise.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto adds = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+    if (!adds)
+      return failure();
+
+    auto carrys = constFoldBinaryOp<IntegerAttr>(
+        ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+          APInt zero = APInt::getZero(a.getBitWidth());
+          return a.ult(b) ? (zero + 1) : zero;
+        });
+
+    if (!carrys)
+      return failure();
+
+    Value addsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+    constituents.push_back(addsVal);
+
+    Value carrysVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+    constituents.push_back(carrysVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+  using OpRewritePattern<MulOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // [su]mulextended (x, 0) = <0, 0>
+    if (matchPattern(operands[1], m_Zero())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(zero);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    // Result Type must be from OpTypeStruct.  The struct must have two
+    // members...
+    //
+    // Member 0 of the result gets the low-order bits of the multiplication.
+    //
+    // Member 1 of the result gets the high-order bits of the multiplication.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto lowBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+    if (!lowBits)
+      return failure();
+
+    auto highBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) {
+          unsigned bitWidth = a.getBitWidth();
+          APInt c;
+          if (IsSigned) {
+            c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+          } else {
+            c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+          }
+          return c.extractBits(bitWidth, bitWidth); // Extract high result
+        });
+
+    if (!highBits)
+      return failure();
+
+    Value lowBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+    constituents.push_back(lowBitsVal);
+
+    Value highBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+    constituents.push_back(highBitsVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // umulextended (x, 1) = <x, 0>
+    if (matchPattern(operands[1], m_One())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(operands[0]);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
@@ -278,7 +476,7 @@ OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
 OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
   // x - x = 0
   if (getOperand1() == getOperand2())
-    return Builder(getContext()).getIntegerAttr(getType(), 0);
+    return Builder(getContext()).getZeroAttr(getType());
 
   // According to the SPIR-V spec:
   //
@@ -290,6 +488,178 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
       [](APInt a, const APInt &b) { return std::move(a) - b; });
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+  // sdiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer division of Operand 1 divided by Operand 2.
+  // Results are computed per component. Behavior is undefined if Operand 2 is
+  // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+  // representable value for the operands' type, causing signed overflow.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.sdiv(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+  // smod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+  //
+  // So don't fold during undefined behaviour
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        APInt c = a.abs().urem(b.abs());
+        if (c.isZero())
+          return c;
+        if (b.isNegative()) {
+          APInt zero = APInt::getZero(c.getBitWidth());
+          return a.isNegative() ? (zero - c) : (b + c);
+        }
+        return a.isNegative() ? (b - c) : c;
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+  // -(-x) = 0 - (0 - x) = x
+  auto op = getOperand();
+  if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+    return negateOp->getOperand(0);
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer subtract of Operand from zero.
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        APInt zero = APInt::getZero(a.getBitWidth());
+        return zero - a;
+      });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+  // x % 1 = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+  // Don't fold if it would do undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](APInt a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.srem(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+  // udiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.udiv(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+  // umod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.urem(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalAnd
 //===----------------------------------------------------------------------===//
@@ -309,6 +679,32 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), true);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), true);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? (zero + 1) : zero;
+                                        });
+}
+
 //===-----------------------------------...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/71414


More information about the Mlir-commits mailing list