[Mlir-commits] [mlir] [mlir][spirv] Add basic arithmetic folds (PR #71414)

Finn Plummer llvmlistbot at llvm.org
Tue Nov 21 10:16:41 PST 2023


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/71414

>From 5c42f595be08f4f30ba485180e0ccd3c241e3546 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Tue, 21 Nov 2023 18:51:36 +0100
Subject: [PATCH 1/3] [mlir][spirv] Add missed arith op folds

We have missing basic constant folds for SPIR-V arithmetic operations which
negatively impacts readability of lowered or otherwise generated code. This
commit works to implementing them to improve the mentioned hinderences.
Also corrects some folds that were found to be incorrect during testing.

Resolves #70704
---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    |  17 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 372 +++++++++++++-
 .../SPIRV/Transforms/canonicalize.mlir        | 459 +++++++++++++++++-
 3 files changed, 844 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index c4d1e01f9feef831..a73989c41c04cfb5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -379,6 +379,8 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
     %2 = spirv.IAddCarry %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -534,6 +536,8 @@ def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -573,6 +577,8 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -607,6 +613,8 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
     %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -634,6 +642,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
     %3 = spirv.SNegate %2 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -673,6 +683,8 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -707,6 +719,8 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -742,6 +756,8 @@ def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
     %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
     ```
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 // -----
@@ -811,6 +827,7 @@ def SPIRV_UModOp : SPIRV_ArithmeticBinaryOp<"UMod",
     ```
   }];
 
+  let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6dc..fdfdc08a89abd617 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -69,6 +69,14 @@ static Attribute extractCompositeElement(Attribute composite,
   return {};
 }
 
+static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
+  bool div0 = b.isZero();
+
+  bool overflow = a.isMinSignedValue() && b.isAllOnes();
+
+  return div0 || overflow;
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'erated canonicalizers
 //===----------------------------------------------------------------------===//
@@ -115,6 +123,196 @@ void spirv::AccessChainOp::getCanonicalizationPatterns(
   results.add<CombineChainedAccessChain>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // iaddcarry (x, 0) = <0, x>
+    if (matchPattern(operands[1], m_Zero())) {
+      constituents.push_back(operands[1]);
+      constituents.push_back(operands[0]);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    //  Result Type must be from OpTypeStruct.  The struct must have two
+    //  members...
+    //
+    //  Member 0 of the result gets the low-order bits (full component width) of
+    //  the addition.
+    //
+    //  Member 1 of the result gets the high-order (carry) bit of the result of
+    //  the addition. That is, it gets the value 1 if the addition overflowed
+    //  the component width, and 0 otherwise.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto adds = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a + b; });
+    if (!adds)
+      return failure();
+
+    auto carrys = constFoldBinaryOp<IntegerAttr>(
+        ArrayRef{adds, lhs}, [](const APInt &a, const APInt &b) {
+          APInt zero = APInt::getZero(a.getBitWidth());
+          return a.ult(b) ? (zero + 1) : zero;
+        });
+
+    if (!carrys)
+      return failure();
+
+    Value addsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
+    constituents.push_back(addsVal);
+
+    Value carrysVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
+    constituents.push_back(carrysVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+void spirv::IAddCarryOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<IAddCarryFold>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.[S|U]MulExtended
+//===----------------------------------------------------------------------===//
+
+// We are required to use CompositeConstructOp to create a constant struct as
+// they are not yet implemented as constant, hence we can not do so in a fold.
+template <typename MulOp, bool IsSigned>
+struct MulExtendedFold final : OpRewritePattern<MulOp> {
+  using OpRewritePattern<MulOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MulOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // [su]mulextended (x, 0) = <0, 0>
+    if (matchPattern(operands[1], m_Zero())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(zero);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    // According to the SPIR-V spec:
+    //
+    // Result Type must be from OpTypeStruct.  The struct must have two
+    // members...
+    //
+    // Member 0 of the result gets the low-order bits of the multiplication.
+    //
+    // Member 1 of the result gets the high-order bits of the multiplication.
+    Attribute lhs;
+    Attribute rhs;
+    if (!matchPattern(operands[0], m_Constant(&lhs)) ||
+        !matchPattern(operands[1], m_Constant(&rhs)))
+      return failure();
+
+    auto lowBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) { return a * b; });
+
+    if (!lowBits)
+      return failure();
+
+    auto highBits = constFoldBinaryOp<IntegerAttr>(
+        {lhs, rhs}, [](const APInt &a, const APInt &b) {
+          unsigned bitWidth = a.getBitWidth();
+          APInt c;
+          if (IsSigned) {
+            c = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
+          } else {
+            c = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
+          }
+          return c.extractBits(bitWidth, bitWidth); // Extract high result
+        });
+
+    if (!highBits)
+      return failure();
+
+    Value lowBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
+    constituents.push_back(lowBitsVal);
+
+    Value highBitsVal =
+        rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
+    constituents.push_back(highBitsVal);
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                             constituents);
+    return success();
+  }
+};
+
+using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
+void spirv::SMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<SMulExtendedOpFold>(context);
+}
+
+struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto operands = op.getOperands();
+
+    SmallVector<Value> constituents;
+    Type constituentType = operands[0].getType();
+
+    // umulextended (x, 1) = <x, 0>
+    if (matchPattern(operands[1], m_One())) {
+      Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
+      constituents.push_back(operands[0]);
+      constituents.push_back(zero);
+      rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
+                                                               constituents);
+      return success();
+    }
+
+    return failure();
+  }
+};
+
+using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
+void spirv::UMulExtendedOp::getCanonicalizationPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
@@ -278,7 +476,7 @@ OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
 OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
   // x - x = 0
   if (getOperand1() == getOperand2())
-    return Builder(getContext()).getIntegerAttr(getType(), 0);
+    return Builder(getContext()).getZeroAttr(getType());
 
   // According to the SPIR-V spec:
   //
@@ -290,6 +488,178 @@ OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
       [](APInt a, const APInt &b) { return std::move(a) - b; });
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
+  // sdiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer division of Operand 1 divided by Operand 2.
+  // Results are computed per component. Behavior is undefined if Operand 2 is
+  // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
+  // representable value for the operands' type, causing signed overflow.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.sdiv(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
+  // smod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 2.
+  //
+  // So don't fold during undefined behaviour
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        APInt c = a.abs().urem(b.abs());
+        if (c.isZero())
+          return c;
+        if (b.isNegative()) {
+          APInt zero = APInt::getZero(c.getBitWidth());
+          return a.isNegative() ? (zero - c) : (b + c);
+        }
+        return a.isNegative() ? (b - c) : c;
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+  // -(-x) = 0 - (0 - x) = x
+  auto op = getOperand();
+  if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+    return negateOp->getOperand(0);
+
+  // According to the SPIR-V spec:
+  //
+  // Signed-integer subtract of Operand from zero.
+  return constFoldUnaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a) {
+        APInt zero = APInt::getZero(a.getBitWidth());
+        return zero - a;
+      });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
+  // x % 1 = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to SPIR-V spec:
+  //
+  // Signed remainder operation for the remainder whose sign matches the sign
+  // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
+  // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
+  // value for the operands' type, causing signed overflow. Otherwise, the
+  // result is the remainder r of Operand 1 divided by Operand 2 where if
+  // r ≠ 0, the sign of r is the same as the sign of Operand 1.
+
+  // Don't fold if it would do undefined behaviour.
+  bool div0OrOverflow = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](APInt a, const APInt &b) {
+        if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
+          div0OrOverflow = true;
+          return a;
+        }
+        return a.srem(b);
+      });
+  return div0OrOverflow ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
+  // udiv (x, 1) = x
+  if (matchPattern(getOperand2(), m_One()))
+    return getOperand1();
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.udiv(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UMod
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
+  // umod (x, 1) = 0
+  if (matchPattern(getOperand2(), m_One()))
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to the SPIR-V spec:
+  //
+  // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
+  // undefined if Operand 2 is 0.
+  //
+  // So don't fold during undefined behaviour.
+  bool div0 = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (div0 || b.isZero()) {
+          div0 = true;
+          return a;
+        }
+        return a.urem(b);
+      });
+  return div0 ? Attribute() : res;
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalAnd
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397a4..2db4d94453a55592 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -336,6 +336,52 @@ func.func @iadd_poison(%arg0: i32) -> i32 {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IAddCarry
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iaddcarry_x_0
+func.func @iaddcarry_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iaddcarry
+func.func @const_fold_scalar_iaddcarry() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant -3
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 1
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.IAddCarry %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_iaddcarry
+func.func @const_fold_vector_iaddcarry() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[5, -3, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[-8, -8, 1]> : vector<3xi32>
+
+  // CHECK-DAG: spirv.Constant dense<[0, 1, 1]>
+  // CHECK-DAG: spirv.Constant dense<[-3, -11, 0]>
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.IAddCarry %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IMul
 //===----------------------------------------------------------------------===//
@@ -400,15 +446,119 @@ func.func @const_fold_vector_imul() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smulextended_x_0
+func.func @smulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_smulextended
+func.func @const_fold_scalar_smulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant -1
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.SMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_smulextended
+func.func @const_fold_vector_smulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, 0, -1]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.SMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @umulextended_x_0
+func.func @umulextended_x_0(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 0 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @umulextended_x_1
+func.func @umulextended_x_1(%arg0 : i32) -> !spirv.struct<(i32, i32)> {
+  %c0 = spirv.Constant 1 : i32
+
+  // CHECK: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %arg0, %c0 : !spirv.struct<(i32, i32)>
+  return %0 : !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_scalar_umulextended
+func.func @const_fold_scalar_umulextended() -> (!spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>) {
+  %c5 = spirv.Constant 5 : i32
+  %cn5 = spirv.Constant -5 : i32
+  %cn8 = spirv.Constant -8 : i32
+
+  // CHECK-DAG: spirv.Constant 40
+  // CHECK-DAG: spirv.Constant -13
+  // CHECK-DAG: spirv.CompositeConstruct
+  // CHECK-DAG: spirv.Constant -40
+  // CHECK-DAG: spirv.Constant 4
+  // CHECK-DAG: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %c5, %cn8 : !spirv.struct<(i32, i32)>
+  %1 = spirv.UMulExtended %cn5, %cn8 : !spirv.struct<(i32, i32)>
+
+  return %0, %1 : !spirv.struct<(i32, i32)>, !spirv.struct<(i32, i32)>
+}
+
+// CHECK-LABEL: @const_fold_vector_umulextended
+func.func @const_fold_vector_umulextended() -> !spirv.struct<(vector<3xi32>, vector<3xi32>)> {
+  %v0 = spirv.Constant dense<[2147483647, -5, -1]> : vector<3xi32>
+  %v1 = spirv.Constant dense<[5, -8, 1]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[2147483643, 40, -1]>
+  // CHECK-NEXT: spirv.Constant dense<[2, -13, 0]>
+  // CHECK-NEXT: spirv.CompositeConstruct
+  %0 = spirv.UMulExtended %v0, %v1 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+  return %0 : !spirv.struct<(vector<3xi32>, vector<3xi32>)>
+
+}
+
+// -----
+
+
 //===----------------------------------------------------------------------===//
 // spirv.ISub
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @isub_x_x
-func.func @isub_x_x(%arg0: i32) -> i32 {
-  // CHECK: spirv.Constant 0
+func.func @isub_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant dense<0>
   %0 = spirv.ISub %arg0, %arg0: i32
-  return %0: i32
+  %1 = spirv.ISub %arg1, %arg1: vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
 }
 
 // CHECK-LABEL: @const_fold_scalar_isub_normal
@@ -462,10 +612,313 @@ func.func @const_fold_vector_isub() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.SDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @sdiv_x_1
+func.func @sdiv_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.SDiv %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @sdiv_div_0_or_overflow
+func.func @sdiv_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SDiv
+  // CHECK: spirv.SDiv
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SDiv %cn1, %c0 : i32
+  %1 = spirv.SDiv %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_sdiv
+func.func @const_fold_scalar_sdiv() -> (i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: spirv.Constant -18
+  // CHECK-DAG: spirv.Constant -2
+  // CHECK-DAG: spirv.Constant -7
+  // CHECK-DAG: spirv.Constant 8
+  %0 = spirv.SDiv %c56, %c7 : i32
+  %1 = spirv.SDiv %c56, %cn8 : i32
+  %2 = spirv.SDiv %cn8, %c3 : i32
+  %3 = spirv.SDiv %c56, %cn3 : i32
+  return %0, %1, %2, %3: i32, i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_sdiv
+func.func @const_fold_vector_sdiv() -> vector<3xi32> {
+  // CHECK: spirv.Constant dense<[0, -1, -3]>
+
+  %cv_num = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+  %cv_denom = spirv.Constant dense<[76, -24, 5]> : vector<3xi32>
+  %0 = spirv.SDiv %cv_num, %cv_denom : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SMod
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @smod_x_1
+func.func @smod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant dense<0>
+  %c1 = spirv.Constant 1 : i32
+  %cv1 = spirv.Constant dense<1> : vector<3xi32>
+  %0 = spirv.SMod %arg0, %c1: i32
+  %1 = spirv.SMod %arg1, %cv1: vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @smod_div_0_or_overflow
+func.func @smod_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SMod
+  // CHECK: spirv.SMod
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SMod %cn1, %c0 : i32
+  %1 = spirv.SMod %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_smod
+func.func @const_fold_scalar_smod() -> (i32, i32, i32, i32, i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %cn56 = spirv.Constant -56 : i32
+  %c59 = spirv.Constant 59 : i32
+  %cn59 = spirv.Constant -59 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[FIFTYTHREE:.*]] = spirv.Constant 53 : i32
+  // CHECK-DAG: %[[NFIFTYTHREE:.*]] = spirv.Constant -53 : i32
+  // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  %0 = spirv.SMod %c56, %c7 : i32
+  %1 = spirv.SMod %c56, %cn8 : i32
+  %2 = spirv.SMod %c56, %c3 : i32
+  %3 = spirv.SMod %cn3, %c56 : i32
+  %4 = spirv.SMod %cn3, %cn56 : i32
+  %5 = spirv.SMod %c59, %c56 : i32
+  %6 = spirv.SMod %c59, %cn56 : i32
+  %7 = spirv.SMod %cn59, %cn56 : i32
+
+  // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[FIFTYTHREE]], %[[NTHREE]], %[[THREE]], %[[NFIFTYTHREE]], %[[NTHREE]]
+  return %0, %1, %2, %3, %4, %5, %6, %7 : i32, i32, i32, i32, i32, i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_smod
+func.func @const_fold_vector_smod() -> vector<3xi32> {
+  // CHECK: spirv.Constant dense<[42, -4, 4]>
+
+  %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+  %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32>
+  %0 = spirv.SMod %cv, %cv_mod : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @snegate_twice
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @snegate_twice(%arg0 : i32) -> i32 {
+  %0 = spirv.SNegate %arg0 : i32
+  %1 = spirv.SNegate %0 : i32
+
+  // CHECK: return %[[ARG]] : i32
+  return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_snegate
+func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
+  %c0 = spirv.Constant 0 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  %0 = spirv.SNegate %c0 : i32
+  %1 = spirv.SNegate %c3 : i32
+  %2 = spirv.SNegate %cn3 : i32
+
+  // CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]]
+  return %0, %1, %2  : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_snegate
+func.func @const_fold_vector_snegate() -> vector<3xi32> {
+  // CHECK: spirv.Constant dense<[0, 3, -3]>
+  %cv = spirv.Constant dense<[0, -3, 3]> : vector<3xi32>
+  %0 = spirv.SNegate %cv : vector<3xi32>
+  return %0  : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SRem
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @srem_x_1
+func.func @srem_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant dense<0>
+  %c1 = spirv.Constant 1 : i32
+  %cv1 = spirv.Constant dense<1> : vector<3xi32>
+  %0 = spirv.SRem %arg0, %c1: i32
+  %1 = spirv.SRem %arg1, %cv1: vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @srem_div_0_or_overflow
+func.func @srem_div_0_or_overflow() -> (i32, i32) {
+  // CHECK: spirv.SRem
+  // CHECK: spirv.SRem
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %min_i32 = spirv.Constant -2147483648 : i32
+
+  %0 = spirv.SRem %cn1, %c0 : i32
+  %1 = spirv.SRem %min_i32, %cn1 : i32
+  return %0, %1 : i32, i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_srem
+func.func @const_fold_scalar_srem() -> (i32, i32, i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[ONE:.*]] = spirv.Constant 1 : i32
+  // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+  %0 = spirv.SRem %c56, %c7 : i32
+  %1 = spirv.SRem %c56, %cn8 : i32
+  %2 = spirv.SRem %c56, %c3 : i32
+  %3 = spirv.SRem %cn3, %c56 : i32
+  %4 = spirv.SRem %c7, %cn3 : i32
+  // CHECK: return %[[ZERO]], %[[ZERO]], %[[TWO]], %[[NTHREE]], %[[ONE]]
+  return %0, %1, %2, %3, %4 : i32, i32, i32, i32, i32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDiv
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @udiv_x_1
+func.func @udiv_x_1(%arg0 : i32) -> i32 {
+  // CHECK-NEXT: return %arg0 : i32
+  %c1 = spirv.Constant 1  : i32
+  %2 = spirv.UDiv %arg0, %c1: i32
+  return %2 : i32
+}
+
+// CHECK-LABEL: @udiv_div_0
+func.func @udiv_div_0() -> i32 {
+  // CHECK: spirv.UDiv
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %0 = spirv.UDiv %cn1, %c0 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_udiv
+func.func @const_fold_scalar_udiv() -> (i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant 1431655762
+  // CHECK-DAG: spirv.Constant 8
+  %0 = spirv.UDiv %c56, %c7 : i32
+  %1 = spirv.UDiv %cn8, %c3 : i32
+  %2 = spirv.UDiv %c56, %cn8 : i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.UMod
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: @umod_x_1
+func.func @umod_x_1(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant dense<0>
+  %c1 = spirv.Constant 1 : i32
+  %cv1 = spirv.Constant dense<1> : vector<3xi32>
+  %0 = spirv.UMod %arg0, %c1: i32
+  %1 = spirv.UMod %arg1, %cv1: vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @umod_div_0
+func.func @umod_div_0() -> i32 {
+  // CHECK: spirv.UMod
+  %c0 = spirv.Constant 0 : i32
+  %cn1 = spirv.Constant -1 : i32
+  %0 = spirv.UMod %cn1, %c0 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_umod
+func.func @const_fold_scalar_umod() -> (i32, i32, i32) {
+  %c56 = spirv.Constant 56 : i32
+  %c7 = spirv.Constant 7 : i32
+  %cn8 = spirv.Constant -8 : i32
+  %c3 = spirv.Constant 3 : i32
+
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant 2
+  // CHECK-DAG: spirv.Constant 56
+  %0 = spirv.UMod %c56, %c7 : i32
+  %1 = spirv.UMod %cn8, %c3 : i32
+  %2 = spirv.UMod %c56, %cn8 : i32
+  return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_umod
+func.func @const_fold_vector_umod() -> vector<3xi32> {
+  // CHECK: spirv.Constant dense<[42, 24, 0]>
+
+  %cv = spirv.Constant dense<[42, 24, -16]> : vector<3xi32>
+  %cv_mod = spirv.Constant dense<[76, -7, 5]> : vector<3xi32>
+  %0 = spirv.UMod %cv, %cv_mod : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
 // CHECK-LABEL: @umod_fold
 // CHECK-SAME: (%[[ARG:.*]]: i32)
 func.func @umod_fold(%arg0: i32) -> (i32, i32) {

>From d90ff6ddafbe4d1324b92982730c0f497d6560bb Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Tue, 21 Nov 2023 19:04:27 +0100
Subject: [PATCH 2/3] [mlir][spirv] Add missed bit op folds

We have missing basic constant folds for SPIR-V bit operations which
negatively impacts readability of lowered or otherwise generated code. This
commit works to implementing them to improve the mentioned hinderences.

Resolves #70704
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBitOps.td      |  10 +
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 207 +++++++++++
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |  49 ---
 mlir/test/Dialect/SPIRV/IR/bit-ops.mlir       |   6 +-
 .../SPIRV/Transforms/canonicalize.mlir        | 337 ++++++++++++++++++
 5 files changed, 558 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 286f4de6f90f6217..dbba4f7ec6cff764 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -412,6 +412,8 @@ def SPIRV_BitwiseXorOp : SPIRV_BitBinaryOp<"BitwiseXor",
     %2 = spirv.BitwiseXor %0, %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -457,6 +459,8 @@ def SPIRV_ShiftLeftLogicalOp : SPIRV_ShiftOp<"ShiftLeftLogical",
     %5 = spirv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -499,6 +503,8 @@ def SPIRV_ShiftRightArithmeticOp : SPIRV_ShiftOp<"ShiftRightArithmetic",
     %5 = spirv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -542,6 +548,8 @@ def SPIRV_ShiftRightLogicalOp : SPIRV_ShiftOp<"ShiftRightLogical",
     %5 = spirv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -573,6 +581,8 @@ def SPIRV_NotOp : SPIRV_BitUnaryOp<"Not", [UsableInSpecConstantOp]> {
     %3 = spirv.Not %1 : vector<4xi32>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 #endif // MLIR_DIALECT_SPIRV_IR_BIT_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index fdfdc08a89abd617..8584e63eb52a8bcb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -726,6 +726,213 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAndOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
+  APInt rhsMask;
+  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+    // x & 0 -> 0
+    if (rhsMask.isZero())
+      return getOperand2();
+
+    // x & <all ones> -> x
+    if (rhsMask.isAllOnes())
+      return getOperand1();
+
+    // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
+    if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
+      int valueBits =
+          getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
+      if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
+        return getOperand1();
+    }
+  }
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt & method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a & b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOrOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
+  APInt rhsMask;
+  if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
+    // x | 0 -> x
+    if (rhsMask.isZero())
+      return getOperand1();
+
+    // x | <all ones> -> <all ones>
+    if (rhsMask.isAllOnes())
+      return getOperand2();
+  }
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt | method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a | b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXorOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
+  // x ^ 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // x ^ x -> 0
+  if (getOperand1() == getOperand2())
+    return Builder(getContext()).getZeroAttr(getType());
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit.
+  // So we can use the APInt ^ method.
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(),
+      [](const APInt &a, const APInt &b) { return a ^ b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftLeftLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftLeftLogicalOp::fold(
+    spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
+  // x << 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit...
+  //
+  // The result is undefined if Shift is greater than or equal to the bit width
+  // of the components of Base.
+  //
+  // So we can use the APInt << method, but don't fold if undefined behaviour.
+  bool shiftToLarge = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (shiftToLarge || b.uge(a.getBitWidth())) {
+          shiftToLarge = true;
+          return a;
+        }
+        return a << b;
+      });
+  return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightArithmetic
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightArithmeticOp::fold(
+    spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
+  // x >> 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit...
+  //
+  // The result is undefined if Shift is greater than or equal to the bit width
+  // of the components of Base.
+  //
+  // So we can use the APInt ashr method, but don't fold if undefined behaviour.
+  bool shiftToLarge = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (shiftToLarge || b.uge(a.getBitWidth())) {
+          shiftToLarge = true;
+          return a;
+        }
+        return a.ashr(b);
+      });
+  return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightLogicalOp::fold(
+    spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
+  // x >> 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit...
+  //
+  // The result is undefined if Shift is greater than or equal to the bit width
+  // of the components of Base.
+  //
+  // So we can use the APInt lshr method, but don't fold if undefined behaviour.
+  bool shiftToLarge = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (shiftToLarge || b.uge(a.getBitWidth())) {
+          shiftToLarge = true;
+          return a;
+        }
+        return a.lshr(b);
+      });
+  return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NotOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
+  // !(!x) = x
+  auto op = getOperand();
+  if (auto notOp = op.getDefiningOp<spirv::NotOp>())
+    return notOp->getOperand(0);
+
+  // According to the SPIR-V spec:
+  //
+  // Complement the bits of Operand.
+  return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
+    a.flipAllBits();
+    return a;
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3906bf74ea72235b..2a1d083308282a8f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1968,55 +1968,6 @@ LogicalResult spirv::ShiftRightLogicalOp::verify() {
   return verifyShiftOp(*this);
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseAndOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult
-spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
-  APInt rhsMask;
-  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
-    return {};
-
-  // x & 0 -> 0
-  if (rhsMask.isZero())
-    return getOperand2();
-
-  // x & <all ones> -> x
-  if (rhsMask.isAllOnes())
-    return getOperand1();
-
-  // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
-  if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
-    int valueBits =
-        getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
-    if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
-      return getOperand1();
-  }
-
-  return {};
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.BtiwiseOrOp
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
-  APInt rhsMask;
-  if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask)))
-    return {};
-
-  // x | 0 -> x
-  if (rhsMask.isZero())
-    return getOperand1();
-
-  // x | <all ones> -> <all ones>
-  if (rhsMask.isAllOnes())
-    return getOperand2();
-
-  return {};
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.ImageQuerySize
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
index 82a2316f6c784fbe..f3f0ebf60f468e6c 100644
--- a/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
@@ -149,14 +149,16 @@ func.func @bitwise_or_float(%arg0: f16, %arg1: f16) -> f16 {
 //===----------------------------------------------------------------------===//
 
 func.func @bitwise_xor_scalar(%arg: i32) -> i32 {
+  %c1 = spirv.Constant 1 : i32 // using constant to avoid folding
   // CHECK: spirv.BitwiseXor
-  %0 = spirv.BitwiseXor %arg, %arg : i32
+  %0 = spirv.BitwiseXor %c1, %arg : i32
   return %0 : i32
 }
 
 func.func @bitwise_xor_vector(%arg: vector<4xi32>) -> vector<4xi32> {
+  %c1 = spirv.Constant dense<1> : vector<4xi32> // using constant to avoid folding
   // CHECK: spirv.BitwiseXor
-  %0 = spirv.BitwiseXor %arg, %arg : vector<4xi32>
+  %0 = spirv.BitwiseXor %c1, %arg : vector<4xi32>
   return %0 : vector<4xi32>
 }
 
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 2db4d94453a55592..e4c85a542462052f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1113,6 +1113,343 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseAnd
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_and_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_and_x_0(%arg0 : i32) -> i32 {
+  // CHECK: %[[C0:.*]] = spirv.Constant 0 : i32
+  %c1 = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+
+  // CHECK: return %[[C0]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_and_x_n1
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_and_x_n1(%arg0 : i32) -> i32 {
+  %c1 = spirv.Constant -1 : i32
+  %0 = spirv.BitwiseAnd %arg0, %c1 : i32
+
+  // CHECK: return %[[ARG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_band
+func.func @const_fold_scalar_band() -> i32 {
+  %c1 = spirv.Constant -268464129 : i32   // 0xefff 8fff
+  %c2 = spirv.Constant 268464128: i32     // 0x1000 7000
+
+  // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+  // CHECK: spirv.Constant 0
+  %0 = spirv.BitwiseAnd %c1, %c2 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_band
+func.func @const_fold_vector_band() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[40, -63, 28]>
+  %0 = spirv.BitwiseAnd %c1, %c2 : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseOr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_or_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_or_x_0(%arg0 : i32) -> i32 {
+  %c1 = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseOr %arg0, %c1 : i32
+
+  // CHECK: return %[[ARG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_or_x_n1
+func.func @bitwise_or_x_n1(%arg0 : i32) -> i32 {
+  // CHECK: %[[CN1:.*]] = spirv.Constant -1 : i32
+  %c1 = spirv.Constant -1 : i32
+  %0 = spirv.BitwiseOr %arg0, %c1 : i32
+
+  // CHECK: return %[[CN1]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_bor
+func.func @const_fold_scalar_bor() -> i32 {
+  %c1 = spirv.Constant -268464129 : i32   // 0xefff 8fff
+  %c2 = spirv.Constant 268464128: i32     // 0x1000 7000
+
+  // 0xefff 8fff | 0x1000 7000 = 0xffff ffff = -1
+  // CHECK: spirv.Constant -1
+  %0 = spirv.BitwiseOr %c1, %c2 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bor
+func.func @const_fold_vector_bor() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[-1, -7, 127]>
+  %0 = spirv.BitwiseOr %c1, %c2 : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.BitwiseXor
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @bitwise_xor_x_0
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @bitwise_xor_x_0(%arg0 : i32) -> i32 {
+  %c1 = spirv.Constant 0 : i32
+  %0 = spirv.BitwiseXor %arg0, %c1 : i32
+
+  // CHECK: return %[[ARG]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @bitwise_xor_x_x
+func.func @bitwise_xor_x_x(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: spirv.Constant 0
+  // CHECK-DAG: spirv.Constant dense<0>
+  %0 = spirv.BitwiseXor %arg0, %arg0 : i32
+  %1 = spirv.BitwiseXor %arg1, %arg1 : vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_bxor
+func.func @const_fold_scalar_bxor() -> i32 {
+  %c1 = spirv.Constant 4294967295 : i32  // 2^32 - 1: 0xffff ffff
+  %c2 = spirv.Constant -2147483648 : i32 // -2^31   : 0x8000 0000
+
+  // 0x8000 0000 ^ 0xffff fffe = 0xefff ffff
+  // CHECK: spirv.Constant 2147483647
+  %0 = spirv.BitwiseXor %c1, %c2 : i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_bxor
+func.func @const_fold_vector_bxor() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[42, -55, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[-3, -15, 28]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[-41, 56, 99]>
+  %0 = spirv.BitwiseXor %c1, %c2 : vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.LeftShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsl_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsl_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.ShiftLeftLogical %arg0, %c0 : i32, i32
+  %1 = spirv.ShiftLeftLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsl_shift_overflow
+func.func @lsl_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK: spirv.ShiftLeftLogical
+  // CHECK: spirv.ShiftLeftLogical
+  %c32 = spirv.Constant 32 : i32
+  %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+  %0 = spirv.ShiftLeftLogical %arg0, %c32 : i32, i32
+  %1 = spirv.ShiftLeftLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsl
+func.func @const_fold_scalar_lsl() -> i32 {
+  %c1 = spirv.Constant 65535 : i32  // 0x0000 ffff
+  %c2 = spirv.Constant 17 : i32
+  // 0x0000 ffff << 17 -> 0xfffe 0000
+  // CHECK: spirv.Constant -131072
+  %0 = spirv.ShiftLeftLogical %c1, %c2 : i32, i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsl
+func.func @const_fold_vector_lsl() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[1, -1, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[-2147483648, -65536, 1040384]>
+  %0 = spirv.ShiftLeftLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftArithmetic
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @asr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @asr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.ShiftRightArithmetic %arg0, %c0 : i32, i32
+  %1 = spirv.ShiftRightArithmetic %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @asr_shift_overflow
+func.func @asr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK: spirv.ShiftRightArithmetic
+  // CHECK: spirv.ShiftRightArithmetic
+  %c32 = spirv.Constant 32 : i32
+  %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+  %0 = spirv.ShiftRightArithmetic %arg0, %c32 : i32, i32
+  %1 = spirv.ShiftRightArithmetic %arg1, %cv : vector<3xi32>, vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_asr
+func.func @const_fold_scalar_asr() -> i32 {
+  %c1 = spirv.Constant -131072 : i32  // 0xfffe 0000
+  %c2 = spirv.Constant 17 : i32
+  // 0x0000 ffff ashr 17 -> 0xffff ffff
+  // CHECK: spirv.Constant -1
+  %0 = spirv.ShiftRightArithmetic %c1, %c2 : i32, i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_asr
+func.func @const_fold_vector_asr() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[-2147483648, 239847, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[-1, 3, 0]>
+  %0 = spirv.ShiftRightArithmetic %c1, %c2 : vector<3xi32>, vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.ShiftRightLogical %arg0, %c0 : i32, i32
+  %1 = spirv.ShiftRightLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsr_shift_overflow
+func.func @lsr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK: spirv.ShiftRightLogical
+  // CHECK: spirv.ShiftRightLogical
+  %c32 = spirv.Constant 32 : i32
+  %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+  %0 = spirv.ShiftRightLogical %arg0, %c32 : i32, i32
+  %1 = spirv.ShiftRightLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsr
+func.func @const_fold_scalar_lsr() -> i32 {
+  %c1 = spirv.Constant -131072 : i32  // 0xfffe 0000
+  %c2 = spirv.Constant 17 : i32
+  // 0x0000 ffff << 17 -> 0x0000 7fff
+  // CHECK: spirv.Constant 32767
+  %0 = spirv.ShiftRightLogical %c1, %c2 : i32, i32
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsr
+func.func @const_fold_vector_lsr() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[-2147483648, -1, -127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[1, 65535, 524287]>
+  %0 = spirv.ShiftRightLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.Not
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @not_twice
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @not_twice(%arg0 : i32) -> i32 {
+  %0 = spirv.Not %arg0 : i32
+  %1 = spirv.Not %0 : i32
+
+  // CHECK: return %[[ARG]] : i32
+  return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_not
+func.func @const_fold_scalar_not() -> (i32, i32, i32) {
+  %c0 = spirv.Constant 0 : i32
+  %c3 = spirv.Constant 3 : i32
+  %cn3 = spirv.Constant -3 : i32
+
+  // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+  // CHECK-DAG: %[[NFOUR:.*]] = spirv.Constant -4 : i32
+  // CHECK-DAG: %[[NONE:.*]] = spirv.Constant -1 : i32
+  %0 = spirv.Not %c0 : i32
+  %1 = spirv.Not %c3 : i32
+  %2 = spirv.Not %cn3 : i32
+
+  // CHECK: return %[[NONE]], %[[NFOUR]], %[[TWO]]
+  return %0, %1, %2  : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_not
+func.func @const_fold_vector_not() -> vector<3xi32> {
+  %cv = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[0, 3, -3]>
+  %0 = spirv.Not %cv : vector<3xi32>
+
+  return %0 : vector<3xi32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//

>From f252ccf5269dad395ffd7a5fea98aed35db0203d Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Tue, 21 Nov 2023 19:07:32 +0100
Subject: [PATCH 3/3] [mlir][spirv] Add missed logical op folds

We have missing basic constant folds for SPIR-V logical operations which
negatively impacts readability of lowered or otherwise generated code. This
commit works to implementing them to improve the mentioned hinderences.
Corrects some testcases in logical-ops-to-llvm as required.

Resolves #70704
---
 .../mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td  |   9 +-
 .../SPIRV/IR/SPIRVCanonicalization.cpp        |  97 ++++++++++-
 .../SPIRVToLLVM/logical-ops-to-llvm.mlir      |  16 +-
 .../SPIRV/Transforms/canonicalize.mlir        | 153 ++++++++++++++++++
 4 files changed, 264 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index cf38c15d20dc3267..0053cd5fc9448b54 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -473,6 +473,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -506,6 +508,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -644,6 +648,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
     %2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -713,7 +719,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
     %2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
     ```
   }];
-  let hasFolder = true;
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 8584e63eb52a8bcb..ba2281d30bdb5893 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -679,6 +679,32 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), true);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), true);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? (zero + 1) : zero;
+                                        });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalNotEqualOp
 //===----------------------------------------------------------------------===//
@@ -686,12 +712,29 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
   if (std::optional<bool> rhs =
           getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
-    // x && false = x
+    // x != false -> x
     if (!rhs.value())
       return getOperand1();
   }
 
-  return Attribute();
+  // x == x -> false
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), false);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), false);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? zero : (zero + 1);
+                                        });
 }
 
 //===----------------------------------------------------------------------===//
@@ -933,6 +976,56 @@ OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
   });
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), true);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), true);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? (zero + 1) : zero;
+                                        });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
+  // x == x -> false
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), false);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), false);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? zero : (zero + 1);
+                                        });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
index 6d93480d3ed142e5..aab2dce980ca7bf6 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
@@ -7,14 +7,14 @@
 // CHECK-LABEL: @logical_equal_scalar
 spirv.func @logical_equal_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+  %0 = spirv.LogicalEqual %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_equal_vector
 spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalEqual %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalEqual %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
 
@@ -25,14 +25,14 @@ spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None
 // CHECK-LABEL: @logical_not_equal_scalar
 spirv.func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+  %0 = spirv.LogicalNotEqual %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_not_equal_vector
 spirv.func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalNotEqual %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
 
@@ -63,14 +63,14 @@ spirv.func @logical_not_vector(%arg0: vector<4xi1>) "None" {
 // CHECK-LABEL: @logical_and_scalar
 spirv.func @logical_and_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.and %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalAnd %arg0, %arg0 : i1
+  %0 = spirv.LogicalAnd %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_and_vector
 spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.and %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalAnd %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalAnd %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
 
@@ -81,13 +81,13 @@ spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None"
 // CHECK-LABEL: @logical_or_scalar
 spirv.func @logical_or_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.or %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalOr %arg0, %arg0 : i1
+  %0 = spirv.LogicalOr %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_or_vector
 spirv.func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.or %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalOr %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalOr %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index e4c85a542462052f..13370fd693f7d559 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1022,6 +1022,45 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
   spirv.ReturnValue %3 : vector<3xi1>
 }
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lequal_same
+func.func @lequal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+  %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+  %1 = spirv.LogicalEqual %arg1, %arg1 : vector<3xi1>
+
+  // CHECK-DAG: spirv.Constant true
+  // CHECK-DAG: spirv.Constant dense<true>
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lequal
+func.func @const_fold_scalar_lequal() -> (i1, i1) {
+  %true = spirv.Constant true
+  %false = spirv.Constant false
+
+  // CHECK-DAG: spirv.Constant true
+  // CHECK-DAG: spirv.Constant false
+  %0 = spirv.LogicalEqual %true, %false : i1
+  %1 = spirv.LogicalEqual %false, %false : i1
+
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_lequal
+func.func @const_fold_vector_lequal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+  %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+  // CHECK: spirv.Constant dense<[true, true, false]>
+  %0 = spirv.LogicalEqual %cv0, %cv1 : vector<3xi1>
+
+  return %0 : vector<3xi1>
+}
 
 // -----
 
@@ -1038,6 +1077,40 @@ func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
   spirv.ReturnValue %0 : vector<4xi1>
 }
 
+// CHECK-LABEL: @lnotequal_same
+func.func @lnotequal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+  %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+  %1 = spirv.LogicalNotEqual %arg1, %arg1 : vector<3xi1>
+
+  // CHECK-DAG: spirv.Constant false
+  // CHECK-DAG: spirv.Constant dense<false>
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lnotequal
+func.func @const_fold_scalar_lnotequal() -> (i1, i1) {
+  %true = spirv.Constant true
+  %false = spirv.Constant false
+
+  // CHECK-DAG: spirv.Constant true
+  // CHECK-DAG: spirv.Constant false
+  %0 = spirv.LogicalNotEqual %true, %false : i1
+  %1 = spirv.LogicalNotEqual %false, %false : i1
+
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_lnotequal
+func.func @const_fold_vector_lnotequal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+  %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+  // CHECK: spirv.Constant dense<[false, false, true]>
+  %0 = spirv.LogicalNotEqual %cv0, %cv1 : vector<3xi1>
+
+  return %0 : vector<3xi1>
+}
+
 // -----
 
 func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
@@ -1450,6 +1523,86 @@ func.func @const_fold_vector_not() -> vector<3xi32> {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iequal_same
+func.func @iequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  %0 = spirv.IEqual %arg0, %arg0 : i32
+  %1 = spirv.IEqual %arg1, %arg1 : vector<3xi32>
+
+  // CHECK-DAG: spirv.Constant true
+  // CHECK-DAG: spirv.Constant dense<true>
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iequal
+func.func @const_fold_scalar_iequal() -> (i1, i1) {
+  %c5 = spirv.Constant 5 : i32
+  %c6 = spirv.Constant 6 : i32
+
+  // CHECK-DAG: spirv.Constant true
+  // CHECK-DAG: spirv.Constant false
+  %0 = spirv.IEqual %c5, %c6 : i32
+  %1 = spirv.IEqual %c5, %c5 : i32
+
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_iequal
+func.func @const_fold_vector_iequal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[true, false, true]>
+  %0 = spirv.IEqual %cv0, %cv1 : vector<3xi32>
+
+  return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @inotequal_same
+func.func @inotequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  %0 = spirv.INotEqual %arg0, %arg0 : i32
+  %1 = spirv.INotEqual %arg1, %arg1 : vector<3xi32>
+
+  // CHECK-DAG: spirv.Constant false
+  // CHECK-DAG: spirv.Constant dense<false>
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_inotequal
+func.func @const_fold_scalar_inotequal() -> (i1, i1) {
+  %c5 = spirv.Constant 5 : i32
+  %c6 = spirv.Constant 6 : i32
+
+  // CHECK-DAG: spirv.Constant true
+  // CHECK-DAG: spirv.Constant false
+  %0 = spirv.INotEqual %c5, %c6 : i32
+  %1 = spirv.INotEqual %c5, %c5 : i32
+
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_inotequal
+func.func @const_fold_vector_inotequal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: spirv.Constant dense<[false, true, false]>
+  %0 = spirv.INotEqual %cv0, %cv1 : vector<3xi32>
+
+  return %0 : vector<3xi1>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list