[Mlir-commits] [mlir] [mlir][spirv] Add folding for [I|Logical][Not]Equal (PR #74194)

Finn Plummer llvmlistbot at llvm.org
Sun Dec 10 03:33:31 PST 2023


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/74194

>From 73ff53cc8501a8fae214b6306fe6a84344c19fc2 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Fri, 24 Nov 2023 10:18:08 +0100
Subject: [PATCH 1/2] [mlir][spirv] Add folding for [I|Logical][Not]Equal

Add missing constant propogation folder for [I|Logical][N]Eq

Implement additional folding when lhs == rhs for all ops.

As well as, fix test cases in logical-ops-to-llvm that failed due to
introduced folding.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704
---
 .../mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td  |   9 +-
 .../SPIRV/IR/SPIRVCanonicalization.cpp        |  97 +++++++++-
 .../SPIRVToLLVM/logical-ops-to-llvm.mlir      |  16 +-
 .../SPIRV/Transforms/canonicalize.mlir        | 165 ++++++++++++++++++
 4 files changed, 276 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index cf38c15d20dc32..0053cd5fc9448b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -473,6 +473,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -506,6 +508,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
 
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -644,6 +648,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
     %2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -713,7 +719,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
     %2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
     ```
   }];
-  let hasFolder = true;
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af6..16efe8797f4a32 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -309,6 +309,32 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), true);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), true);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? (zero + 1) : zero;
+                                        });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.LogicalNotEqualOp
 //===----------------------------------------------------------------------===//
@@ -316,12 +342,29 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
   if (std::optional<bool> rhs =
           getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
-    // x && false = x
+    // x != false -> x
     if (!rhs.value())
       return getOperand1();
   }
 
-  return Attribute();
+  // x == x -> false
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), false);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), false);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? zero : (zero + 1);
+                                        });
 }
 
 //===----------------------------------------------------------------------===//
@@ -356,6 +399,56 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.IEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
+  // x == x -> true
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), true);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), true);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? (zero + 1) : zero;
+                                        });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
+  // x == x -> false
+  if (getOperand1() == getOperand2()) {
+    auto type = getType();
+    if (isa<IntegerType>(type)) {
+      return BoolAttr::get(getContext(), false);
+    }
+    if (isa<VectorType>(type)) {
+      auto vtType = cast<ShapedType>(type);
+      auto element = BoolAttr::get(getContext(), false);
+      return DenseElementsAttr::get(vtType, element);
+    }
+  }
+
+  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+                                        [](const APInt &a, const APInt &b) {
+                                          APInt zero = APInt::getZero(1);
+                                          return a == b ? zero : (zero + 1);
+                                        });
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
index 6d93480d3ed142..aab2dce980ca7b 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
@@ -7,14 +7,14 @@
 // CHECK-LABEL: @logical_equal_scalar
 spirv.func @logical_equal_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+  %0 = spirv.LogicalEqual %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_equal_vector
 spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalEqual %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalEqual %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
 
@@ -25,14 +25,14 @@ spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None
 // CHECK-LABEL: @logical_not_equal_scalar
 spirv.func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+  %0 = spirv.LogicalNotEqual %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_not_equal_vector
 spirv.func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalNotEqual %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
 
@@ -63,14 +63,14 @@ spirv.func @logical_not_vector(%arg0: vector<4xi1>) "None" {
 // CHECK-LABEL: @logical_and_scalar
 spirv.func @logical_and_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.and %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalAnd %arg0, %arg0 : i1
+  %0 = spirv.LogicalAnd %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_and_vector
 spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.and %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalAnd %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalAnd %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
 
@@ -81,13 +81,13 @@ spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None"
 // CHECK-LABEL: @logical_or_scalar
 spirv.func @logical_or_scalar(%arg0: i1, %arg1: i1) "None" {
   // CHECK: llvm.or %{{.*}}, %{{.*}} : i1
-  %0 = spirv.LogicalOr %arg0, %arg0 : i1
+  %0 = spirv.LogicalOr %arg0, %arg1 : i1
   spirv.Return
 }
 
 // CHECK-LABEL: @logical_or_vector
 spirv.func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
   // CHECK: llvm.or %{{.*}}, %{{.*}} : vector<4xi1>
-  %0 = spirv.LogicalOr %arg0, %arg0 : vector<4xi1>
+  %0 = spirv.LogicalOr %arg0, %arg1 : vector<4xi1>
   spirv.Return
 }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a444397..7a8e262db266a4 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -569,6 +569,48 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
   spirv.ReturnValue %3 : vector<3xi1>
 }
 
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @logical_equal_same
+func.func @logical_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+
+  %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+  %1 = spirv.LogicalEqual %arg1, %arg1 : vector<3xi1>
+  // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_equal
+func.func @const_fold_scalar_logical_equal() -> (i1, i1) {
+  %true = spirv.Constant true
+  %false = spirv.Constant false
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.LogicalEqual %true, %false : i1
+  %1 = spirv.LogicalEqual %false, %false : i1
+
+  // CHECK: return %[[CFALSE]], %[[CTRUE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_equal
+func.func @const_fold_vector_logical_equal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+  %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
+  %0 = spirv.LogicalEqual %cv0, %cv1 : vector<3xi1>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
 
 // -----
 
@@ -585,6 +627,43 @@ func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
   spirv.ReturnValue %0 : vector<4xi1>
 }
 
+// CHECK-LABEL: @logical_not_equal_same
+func.func @logical_not_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+  %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+  %1 = spirv.LogicalNotEqual %arg1, %arg1 : vector<3xi1>
+
+  // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_not_equal
+func.func @const_fold_scalar_logical_not_equal() -> (i1, i1) {
+  %true = spirv.Constant true
+  %false = spirv.Constant false
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.LogicalNotEqual %true, %false : i1
+  %1 = spirv.LogicalNotEqual %false, %false : i1
+
+  // CHECK: return %[[CTRUE]], %[[CFALSE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_not_equal
+func.func @const_fold_vector_logical_not_equal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+  %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
+  %0 = spirv.LogicalNotEqual %cv0, %cv1 : vector<3xi1>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
 // -----
 
 func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
@@ -660,6 +739,92 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iequal_same
+func.func @iequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+  %0 = spirv.IEqual %arg0, %arg0 : i32
+  %1 = spirv.IEqual %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iequal
+func.func @const_fold_scalar_iequal() -> (i1, i1) {
+  %c5 = spirv.Constant 5 : i32
+  %c6 = spirv.Constant 6 : i32
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.IEqual %c5, %c6 : i32
+  %1 = spirv.IEqual %c5, %c5 : i32
+
+  // CHECK: return %[[CFALSE]], %[[CTRUE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_iequal
+func.func @const_fold_vector_iequal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
+  %0 = spirv.IEqual %cv0, %cv1 : vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @inotequal_same
+func.func @inotequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+  %0 = spirv.INotEqual %arg0, %arg0 : i32
+  %1 = spirv.INotEqual %arg1, %arg1 : vector<3xi32>
+
+  // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+  return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_inotequal
+func.func @const_fold_scalar_inotequal() -> (i1, i1) {
+  %c5 = spirv.Constant 5 : i32
+  %c6 = spirv.Constant 6 : i32
+
+  // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+  // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+  %0 = spirv.INotEqual %c5, %c6 : i32
+  %1 = spirv.INotEqual %c5, %c5 : i32
+
+  // CHECK: return %[[CTRUE]], %[[CFALSE]]
+  return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_inotequal
+func.func @const_fold_vector_inotequal() -> vector<3xi1> {
+  %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+  %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
+  %0 = spirv.INotEqual %cv0, %cv1 : vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi1>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//

>From cb66d1ef5758ccf6cf47f07d1da85f6ae378dc93 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Sun, 10 Dec 2023 12:31:20 +0100
Subject: [PATCH 2/2] review comments:

- fix coding style
---
 .../SPIRV/IR/SPIRVCanonicalization.cpp        | 84 ++++++++-----------
 1 file changed, 36 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 16efe8797f4a32..ff96cb1715867c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -317,22 +317,19 @@ OpFoldResult
 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
   // x == x -> true
   if (getOperand1() == getOperand2()) {
-    auto type = getType();
-    if (isa<IntegerType>(type)) {
-      return BoolAttr::get(getContext(), true);
+    auto trueAttr = BoolAttr::get(getContext(), true);
+    if (isa<IntegerType>(getType())) {
+      return trueAttr;
     }
-    if (isa<VectorType>(type)) {
-      auto vtType = cast<ShapedType>(type);
-      auto element = BoolAttr::get(getContext(), true);
-      return DenseElementsAttr::get(vtType, element);
+    if (auto vecTy = dyn_cast<VectorType>(getType())) {
+      return SplatElementsAttr::get(vecTy, trueAttr);
     }
   }
 
-  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
-                                        [](const APInt &a, const APInt &b) {
-                                          APInt zero = APInt::getZero(1);
-                                          return a == b ? (zero + 1) : zero;
-                                        });
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a, const APInt &b) {
+        return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -349,22 +346,19 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
 
   // x == x -> false
   if (getOperand1() == getOperand2()) {
-    auto type = getType();
-    if (isa<IntegerType>(type)) {
-      return BoolAttr::get(getContext(), false);
+    auto falseAttr = BoolAttr::get(getContext(), false);
+    if (isa<IntegerType>(getType())) {
+      return falseAttr;
     }
-    if (isa<VectorType>(type)) {
-      auto vtType = cast<ShapedType>(type);
-      auto element = BoolAttr::get(getContext(), false);
-      return DenseElementsAttr::get(vtType, element);
+    if (auto vecTy = dyn_cast<VectorType>(getType())) {
+      return SplatElementsAttr::get(vecTy, falseAttr);
     }
   }
 
-  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
-                                        [](const APInt &a, const APInt &b) {
-                                          APInt zero = APInt::getZero(1);
-                                          return a == b ? zero : (zero + 1);
-                                        });
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [](const APInt &a, const APInt &b) {
+        return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -406,22 +400,19 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
   // x == x -> true
   if (getOperand1() == getOperand2()) {
-    auto type = getType();
-    if (isa<IntegerType>(type)) {
-      return BoolAttr::get(getContext(), true);
+    auto trueAttr = BoolAttr::get(getContext(), true);
+    if (isa<IntegerType>(getType())) {
+      return trueAttr;
     }
-    if (isa<VectorType>(type)) {
-      auto vtType = cast<ShapedType>(type);
-      auto element = BoolAttr::get(getContext(), true);
-      return DenseElementsAttr::get(vtType, element);
+    if (auto vecTy = dyn_cast<VectorType>(getType())) {
+      return SplatElementsAttr::get(vecTy, trueAttr);
     }
   }
 
-  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
-                                        [](const APInt &a, const APInt &b) {
-                                          APInt zero = APInt::getZero(1);
-                                          return a == b ? (zero + 1) : zero;
-                                        });
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+        return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -431,22 +422,19 @@ OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
   // x == x -> false
   if (getOperand1() == getOperand2()) {
-    auto type = getType();
-    if (isa<IntegerType>(type)) {
-      return BoolAttr::get(getContext(), false);
+    auto falseAttr = BoolAttr::get(getContext(), false);
+    if (isa<IntegerType>(getType())) {
+      return falseAttr;
     }
-    if (isa<VectorType>(type)) {
-      auto vtType = cast<ShapedType>(type);
-      auto element = BoolAttr::get(getContext(), false);
-      return DenseElementsAttr::get(vtType, element);
+    if (auto vecTy = dyn_cast<VectorType>(getType())) {
+      return SplatElementsAttr::get(vecTy, falseAttr);
     }
   }
 
-  return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
-                                        [](const APInt &a, const APInt &b) {
-                                          APInt zero = APInt::getZero(1);
-                                          return a == b ? zero : (zero + 1);
-                                        });
+  return constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+        return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
+      });
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list