[Mlir-commits] [mlir] [mlir][spirv] Add folding for SNegate, [Logical]Not (PR #74992)
Finn Plummer
llvmlistbot at llvm.org
Thu Dec 21 04:26:35 PST 2023
https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/74992
>From e2e231d7e931c9234df0c56ab8d4dfea07abef09 Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Fri, 24 Nov 2023 10:42:44 +0100
Subject: [PATCH 1/2] [mlir][spirv] Add folding for SNegate, [Logical]Not
Add missing constant propogation folder for SNegate, [Logical]Not.
Implement additional folding when !(!x) for all ops.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704
---
.../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 2 +
.../mlir/Dialect/SPIRV/IR/SPIRVBitOps.td | 2 +
.../mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td | 1 +
.../SPIRV/IR/SPIRVCanonicalization.cpp | 55 +++++++++
.../SPIRV/Transforms/canonicalize.mlir | 116 ++++++++++++++++++
5 files changed, 176 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 51124e141c6d46..22d5afcd773817 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -582,6 +582,8 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
%3 = spirv.SNegate %2 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index b460c8e68aa0c6..38639a175ab4db 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -462,6 +462,8 @@ def SPIRV_NotOp : SPIRV_BitUnaryOp<"Not", [UsableInSpecConstantOp]> {
%3 = spirv.Not %1 : vector<4xi32>
```
}];
+
+ let hasFolder = 1;
}
#endif // MLIR_DIALECT_SPIRV_IR_BIT_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index 47887ffb474f00..260d24b5502577 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -528,6 +528,7 @@ def SPIRV_LogicalNotOp : SPIRV_LogicalUnaryOp<"LogicalNot",
}];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9de1707dfca465..fe334d50b6faaa 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -643,6 +643,45 @@ OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
return div0 ? Attribute() : res;
}
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
+ // -(-x) = 0 - (0 - x) = x
+ auto op = getOperand();
+ if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
+ return negateOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Signed-integer subtract of Operand from zero.
+ return constFoldUnaryOp<IntegerAttr>(
+ adaptor.getOperands(), [](const APInt &a) {
+ APInt zero = APInt::getZero(a.getBitWidth());
+ return zero - a;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NotOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
+ // !(!x) = x
+ auto op = getOperand();
+ if (auto notOp = op.getDefiningOp<spirv::NotOp>())
+ return notOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Complement the bits of Operand.
+ return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
+ a.flipAllBits();
+ return a;
+ });
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
@@ -681,6 +720,22 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
// spirv.LogicalNot
//===----------------------------------------------------------------------===//
+OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
+ // !(!x) = x
+ auto op = getOperand();
+ if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
+ return notOp->getOperand(0);
+
+ // According to the SPIR-V spec:
+ //
+ // Complement the bits of Operand.
+ return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
+ [](const APInt &a) {
+ APInt zero = APInt::getZero(1);
+ return a == 1 ? zero : (zero + 1);
+ });
+}
+
void spirv::LogicalNotOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 29bea91ce461d9..7da2cf5be4e007 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1006,6 +1006,90 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.SNegate
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @snegate_twice
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @snegate_twice(%arg0 : i32) -> i32 {
+ %0 = spirv.SNegate %arg0 : i32
+ %1 = spirv.SNegate %0 : i32
+
+ // CHECK: return %[[ARG]] : i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_snegate
+func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
+ %c0 = spirv.Constant 0 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[THREE:.*]] = spirv.Constant 3 : i32
+ // CHECK-DAG: %[[NTHREE:.*]] = spirv.Constant -3 : i32
+ // CHECK-DAG: %[[ZERO:.*]] = spirv.Constant 0 : i32
+ %0 = spirv.SNegate %c0 : i32
+ %1 = spirv.SNegate %c3 : i32
+ %2 = spirv.SNegate %cn3 : i32
+
+ // CHECK: return %[[ZERO]], %[[NTHREE]], %[[THREE]]
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_snegate
+func.func @const_fold_vector_snegate() -> vector<3xi32> {
+ // CHECK: spirv.Constant dense<[0, 3, -3]>
+ %cv = spirv.Constant dense<[0, -3, 3]> : vector<3xi32>
+ %0 = spirv.SNegate %cv : vector<3xi32>
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.Not
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @not_twice
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+func.func @not_twice(%arg0 : i32) -> i32 {
+ %0 = spirv.Not %arg0 : i32
+ %1 = spirv.Not %0 : i32
+
+ // CHECK: return %[[ARG]] : i32
+ return %1 : i32
+}
+
+// CHECK-LABEL: @const_fold_scalar_not
+func.func @const_fold_scalar_not() -> (i32, i32, i32) {
+ %c0 = spirv.Constant 0 : i32
+ %c3 = spirv.Constant 3 : i32
+ %cn3 = spirv.Constant -3 : i32
+
+ // CHECK-DAG: %[[TWO:.*]] = spirv.Constant 2 : i32
+ // CHECK-DAG: %[[NFOUR:.*]] = spirv.Constant -4 : i32
+ // CHECK-DAG: %[[NONE:.*]] = spirv.Constant -1 : i32
+ %0 = spirv.Not %c0 : i32
+ %1 = spirv.Not %c3 : i32
+ %2 = spirv.Not %cn3 : i32
+
+ // CHECK: return %[[NONE]], %[[NFOUR]], %[[TWO]]
+ return %0, %1, %2 : i32, i32, i32
+}
+
+// CHECK-LABEL: @const_fold_vector_not
+func.func @const_fold_vector_not() -> vector<3xi32> {
+ %cv = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+
+ // CHECK: spirv.Constant dense<[0, 3, -3]>
+ %0 = spirv.Not %cv : vector<3xi32>
+
+ return %0 : vector<3xi32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.LogicalAnd
//===----------------------------------------------------------------------===//
@@ -1040,6 +1124,38 @@ func.func @convert_logical_and_true_false_vector(%arg: vector<3xi1>) -> (vector<
// spirv.LogicalNot
//===----------------------------------------------------------------------===//
+// CHECK-LABEL: @logical_not_twice
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+func.func @logical_not_twice(%arg0 : i1) -> i1 {
+ %0 = spirv.LogicalNot %arg0 : i1
+ %1 = spirv.LogicalNot %0 : i1
+
+ // CHECK: return %[[ARG]] : i1
+ return %1 : i1
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_not
+func.func @const_fold_scalar_logical_not() -> i1 {
+ %true = spirv.Constant true
+
+ // CHECK: spirv.Constant false
+ %0 = spirv.LogicalNot %true : i1
+
+ return %0 : i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_not
+func.func @const_fold_vector_logical_not() -> vector<2xi1> {
+ %cv = spirv.Constant dense<[true, false]> : vector<2xi1>
+
+ // CHECK: spirv.Constant dense<[false, true]>
+ %0 = spirv.LogicalNot %cv : vector<2xi1>
+
+ return %0 : vector<2xi1>
+}
+
+// -----
+
func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
// CHECK: %[[RESULT:.*]] = spirv.INotEqual {{%.*}}, {{%.*}} : vector<3xi64>
// CHECK-NEXT: spirv.ReturnValue %[[RESULT]] : vector<3xi1>
>From 92f2ff7303df1a5a3680faa5b9c44d69f6e6f6ae Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Thu, 21 Dec 2023 13:25:26 +0100
Subject: [PATCH 2/2] review comments:
- add testcase to demonstrate SNegate behaviour for INT_MIN
---
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 7da2cf5be4e007..9440654d7a176c 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1020,6 +1020,18 @@ func.func @snegate_twice(%arg0 : i32) -> i32 {
return %1 : i32
}
+// CHECK-LABEL: @snegate_min
+func.func @snegate_min() -> (i8, i8) {
+ // CHECK: %[[MIN:.*]] = spirv.Constant -128 : i8
+ %cmin = spirv.Constant -128 : i8
+
+ %0 = spirv.SNegate %cmin : i8
+ %1 = spirv.SNegate %0 : i8
+
+ // CHECK: return %[[MIN]], %[[MIN]]
+ return %0, %1 : i8, i8
+}
+
// CHECK-LABEL: @const_fold_scalar_snegate
func.func @const_fold_scalar_snegate() -> (i32, i32, i32) {
%c0 = spirv.Constant 0 : i32
More information about the Mlir-commits
mailing list