[Mlir-commits] [mlir] [mlir][spirv] Add folding for [I|Logical][Not]Equal (PR #74194)
Finn Plummer
llvmlistbot at llvm.org
Wed Dec 20 00:14:23 PST 2023
https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/74194
>From d7d302668a9238e01399e5db7e12428dea1ff9b3 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Fri, 24 Nov 2023 10:18:08 +0100
Subject: [PATCH 1/3] [mlir][spirv] Add folding for [I|Logical][Not]Equal
Add missing constant propogation folder for [I|Logical][N]Eq
Implement additional folding when lhs == rhs for all ops.
As well as, fix test cases in logical-ops-to-llvm that failed due to
introduced folding.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704
---
.../mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td | 9 +-
.../SPIRV/IR/SPIRVCanonicalization.cpp | 97 +++++++++-
.../SPIRVToLLVM/logical-ops-to-llvm.mlir | 16 +-
.../SPIRV/Transforms/canonicalize.mlir | 165 ++++++++++++++++++
4 files changed, 276 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 47887ffb474f00..2e26c44de281a0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -369,6 +369,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
%5 = spirv.IEqual %2, %3 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -395,6 +397,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -501,6 +505,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
%2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -557,7 +563,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
%2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
```
}];
- let hasFolder = true;
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9de1707dfca465..421d2bc343358d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -662,6 +662,32 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
+ // x == x -> true
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), true);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), true);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? (zero + 1) : zero;
+ });
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalNotEqualOp
//===----------------------------------------------------------------------===//
@@ -669,12 +695,29 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
if (std::optional<bool> rhs =
getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
- // x && false = x
+ // x != false -> x
if (!rhs.value())
return getOperand1();
}
- return Attribute();
+ // x == x -> false
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), false);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), false);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? zero : (zero + 1);
+ });
}
//===----------------------------------------------------------------------===//
@@ -709,6 +752,56 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.IEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
+ // x == x -> true
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), true);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), true);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? (zero + 1) : zero;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
+ // x == x -> false
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), false);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), false);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? zero : (zero + 1);
+ });
+}
+
//===----------------------------------------------------------------------===//
// spirv.ShiftLeftLogical
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
index 6d93480d3ed142..aab2dce980ca7b 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
@@ -7,14 +7,14 @@
// CHECK-LABEL: @logical_equal_scalar
spirv.func @logical_equal_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : i1
- %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+ %0 = spirv.LogicalEqual %arg0, %arg1 : i1
spirv.Return
}
// CHECK-LABEL: @logical_equal_vector
spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : vector<4xi1>
- %0 = spirv.LogicalEqual %arg0, %arg0 : vector<4xi1>
+ %0 = spirv.LogicalEqual %arg0, %arg1 : vector<4xi1>
spirv.Return
}
@@ -25,14 +25,14 @@ spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None
// CHECK-LABEL: @logical_not_equal_scalar
spirv.func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i1
- %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+ %0 = spirv.LogicalNotEqual %arg0, %arg1 : i1
spirv.Return
}
// CHECK-LABEL: @logical_not_equal_vector
spirv.func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : vector<4xi1>
- %0 = spirv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
+ %0 = spirv.LogicalNotEqual %arg0, %arg1 : vector<4xi1>
spirv.Return
}
@@ -63,14 +63,14 @@ spirv.func @logical_not_vector(%arg0: vector<4xi1>) "None" {
// CHECK-LABEL: @logical_and_scalar
spirv.func @logical_and_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.and %{{.*}}, %{{.*}} : i1
- %0 = spirv.LogicalAnd %arg0, %arg0 : i1
+ %0 = spirv.LogicalAnd %arg0, %arg1 : i1
spirv.Return
}
// CHECK-LABEL: @logical_and_vector
spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.and %{{.*}}, %{{.*}} : vector<4xi1>
- %0 = spirv.LogicalAnd %arg0, %arg0 : vector<4xi1>
+ %0 = spirv.LogicalAnd %arg0, %arg1 : vector<4xi1>
spirv.Return
}
@@ -81,13 +81,13 @@ spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None"
// CHECK-LABEL: @logical_or_scalar
spirv.func @logical_or_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.or %{{.*}}, %{{.*}} : i1
- %0 = spirv.LogicalOr %arg0, %arg0 : i1
+ %0 = spirv.LogicalOr %arg0, %arg1 : i1
spirv.Return
}
// CHECK-LABEL: @logical_or_vector
spirv.func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.or %{{.*}}, %{{.*}} : vector<4xi1>
- %0 = spirv.LogicalOr %arg0, %arg0 : vector<4xi1>
+ %0 = spirv.LogicalOr %arg0, %arg1 : vector<4xi1>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 29bea91ce461d9..871ecd4f28b12e 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1048,6 +1048,48 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
spirv.ReturnValue %3 : vector<3xi1>
}
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @logical_equal_same
+func.func @logical_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+
+ %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+ %1 = spirv.LogicalEqual %arg1, %arg1 : vector<3xi1>
+ // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_equal
+func.func @const_fold_scalar_logical_equal() -> (i1, i1) {
+ %true = spirv.Constant true
+ %false = spirv.Constant false
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.LogicalEqual %true, %false : i1
+ %1 = spirv.LogicalEqual %false, %false : i1
+
+ // CHECK: return %[[CFALSE]], %[[CTRUE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_equal
+func.func @const_fold_vector_logical_equal() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+ %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
+ %0 = spirv.LogicalEqual %cv0, %cv1 : vector<3xi1>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
// -----
@@ -1064,6 +1106,43 @@ func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
spirv.ReturnValue %0 : vector<4xi1>
}
+// CHECK-LABEL: @logical_not_equal_same
+func.func @logical_not_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+ %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+ %1 = spirv.LogicalNotEqual %arg1, %arg1 : vector<3xi1>
+
+ // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_not_equal
+func.func @const_fold_scalar_logical_not_equal() -> (i1, i1) {
+ %true = spirv.Constant true
+ %false = spirv.Constant false
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.LogicalNotEqual %true, %false : i1
+ %1 = spirv.LogicalNotEqual %false, %false : i1
+
+ // CHECK: return %[[CTRUE]], %[[CFALSE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_not_equal
+func.func @const_fold_vector_logical_not_equal() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+ %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
+ %0 = spirv.LogicalNotEqual %cv0, %cv1 : vector<3xi1>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
// -----
func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
@@ -1139,6 +1218,92 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
// -----
+//===----------------------------------------------------------------------===//
+// spirv.IEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iequal_same
+func.func @iequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+ %0 = spirv.IEqual %arg0, %arg0 : i32
+ %1 = spirv.IEqual %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iequal
+func.func @const_fold_scalar_iequal() -> (i1, i1) {
+ %c5 = spirv.Constant 5 : i32
+ %c6 = spirv.Constant 6 : i32
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.IEqual %c5, %c6 : i32
+ %1 = spirv.IEqual %c5, %c5 : i32
+
+ // CHECK: return %[[CFALSE]], %[[CTRUE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_iequal
+func.func @const_fold_vector_iequal() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+ %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
+ %0 = spirv.IEqual %cv0, %cv1 : vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @inotequal_same
+func.func @inotequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+ %0 = spirv.INotEqual %arg0, %arg0 : i32
+ %1 = spirv.INotEqual %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_inotequal
+func.func @const_fold_scalar_inotequal() -> (i1, i1) {
+ %c5 = spirv.Constant 5 : i32
+ %c6 = spirv.Constant 6 : i32
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.INotEqual %c5, %c6 : i32
+ %1 = spirv.INotEqual %c5, %c5 : i32
+
+ // CHECK: return %[[CTRUE]], %[[CFALSE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_inotequal
+func.func @const_fold_vector_inotequal() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+ %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
+ %0 = spirv.INotEqual %cv0, %cv1 : vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.LeftShiftLogical
//===----------------------------------------------------------------------===//
>From 880f50cb5bafffa68d4e10abcfe84f933b492d44 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Sun, 10 Dec 2023 12:31:20 +0100
Subject: [PATCH 2/3] review comments:
- fix coding style
---
.../SPIRV/IR/SPIRVCanonicalization.cpp | 84 ++++++++-----------
1 file changed, 36 insertions(+), 48 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 421d2bc343358d..2690a3de518f66 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -670,22 +670,19 @@ OpFoldResult
spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
// x == x -> true
if (getOperand1() == getOperand2()) {
- auto type = getType();
- if (isa<IntegerType>(type)) {
- return BoolAttr::get(getContext(), true);
+ auto trueAttr = BoolAttr::get(getContext(), true);
+ if (isa<IntegerType>(getType())) {
+ return trueAttr;
}
- if (isa<VectorType>(type)) {
- auto vtType = cast<ShapedType>(type);
- auto element = BoolAttr::get(getContext(), true);
- return DenseElementsAttr::get(vtType, element);
+ if (auto vecTy = dyn_cast<VectorType>(getType())) {
+ return SplatElementsAttr::get(vecTy, trueAttr);
}
}
- return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
- [](const APInt &a, const APInt &b) {
- APInt zero = APInt::getZero(1);
- return a == b ? (zero + 1) : zero;
- });
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [](const APInt &a, const APInt &b) {
+ return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
+ });
}
//===----------------------------------------------------------------------===//
@@ -702,22 +699,19 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
// x == x -> false
if (getOperand1() == getOperand2()) {
- auto type = getType();
- if (isa<IntegerType>(type)) {
- return BoolAttr::get(getContext(), false);
+ auto falseAttr = BoolAttr::get(getContext(), false);
+ if (isa<IntegerType>(getType())) {
+ return falseAttr;
}
- if (isa<VectorType>(type)) {
- auto vtType = cast<ShapedType>(type);
- auto element = BoolAttr::get(getContext(), false);
- return DenseElementsAttr::get(vtType, element);
+ if (auto vecTy = dyn_cast<VectorType>(getType())) {
+ return SplatElementsAttr::get(vecTy, falseAttr);
}
}
- return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
- [](const APInt &a, const APInt &b) {
- APInt zero = APInt::getZero(1);
- return a == b ? zero : (zero + 1);
- });
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [](const APInt &a, const APInt &b) {
+ return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
+ });
}
//===----------------------------------------------------------------------===//
@@ -759,22 +753,19 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
// x == x -> true
if (getOperand1() == getOperand2()) {
- auto type = getType();
- if (isa<IntegerType>(type)) {
- return BoolAttr::get(getContext(), true);
+ auto trueAttr = BoolAttr::get(getContext(), true);
+ if (isa<IntegerType>(getType())) {
+ return trueAttr;
}
- if (isa<VectorType>(type)) {
- auto vtType = cast<ShapedType>(type);
- auto element = BoolAttr::get(getContext(), true);
- return DenseElementsAttr::get(vtType, element);
+ if (auto vecTy = dyn_cast<VectorType>(getType())) {
+ return SplatElementsAttr::get(vecTy, trueAttr);
}
}
- return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
- [](const APInt &a, const APInt &b) {
- APInt zero = APInt::getZero(1);
- return a == b ? (zero + 1) : zero;
- });
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+ return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
+ });
}
//===----------------------------------------------------------------------===//
@@ -784,22 +775,19 @@ OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
// x == x -> false
if (getOperand1() == getOperand2()) {
- auto type = getType();
- if (isa<IntegerType>(type)) {
- return BoolAttr::get(getContext(), false);
+ auto falseAttr = BoolAttr::get(getContext(), false);
+ if (isa<IntegerType>(getType())) {
+ return falseAttr;
}
- if (isa<VectorType>(type)) {
- auto vtType = cast<ShapedType>(type);
- auto element = BoolAttr::get(getContext(), false);
- return DenseElementsAttr::get(vtType, element);
+ if (auto vecTy = dyn_cast<VectorType>(getType())) {
+ return SplatElementsAttr::get(vecTy, falseAttr);
}
}
- return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
- [](const APInt &a, const APInt &b) {
- APInt zero = APInt::getZero(1);
- return a == b ? zero : (zero + 1);
- });
+ return constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
+ return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
+ });
}
//===----------------------------------------------------------------------===//
>From 250c6922f1800e18e5bab42efc85b7af4ad21dc9 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Wed, 20 Dec 2023 09:13:50 +0100
Subject: [PATCH 3/3] review comments: fix formatting
---
.../SPIRV/IR/SPIRVCanonicalization.cpp | 24 +++++++------------
1 file changed, 8 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 2690a3de518f66..08ddc7c25aa9e5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -671,12 +671,10 @@ spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
// x == x -> true
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
- if (isa<IntegerType>(getType())) {
+ if (isa<IntegerType>(getType()))
return trueAttr;
- }
- if (auto vecTy = dyn_cast<VectorType>(getType())) {
+ if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
- }
}
return constFoldBinaryOp<IntegerAttr>(
@@ -700,12 +698,10 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
// x == x -> false
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
- if (isa<IntegerType>(getType())) {
+ if (isa<IntegerType>(getType()))
return falseAttr;
- }
- if (auto vecTy = dyn_cast<VectorType>(getType())) {
+ if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
- }
}
return constFoldBinaryOp<IntegerAttr>(
@@ -754,12 +750,10 @@ OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
// x == x -> true
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
- if (isa<IntegerType>(getType())) {
+ if (isa<IntegerType>(getType()))
return trueAttr;
- }
- if (auto vecTy = dyn_cast<VectorType>(getType())) {
+ if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
- }
}
return constFoldBinaryOp<IntegerAttr>(
@@ -776,12 +770,10 @@ OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
// x == x -> false
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
- if (isa<IntegerType>(getType())) {
+ if (isa<IntegerType>(getType()))
return falseAttr;
- }
- if (auto vecTy = dyn_cast<VectorType>(getType())) {
+ if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
- }
}
return constFoldBinaryOp<IntegerAttr>(
More information about the Mlir-commits
mailing list