[Mlir-commits] [mlir] [mlir][spirv] Add folding for SPIR-V Shifting ops (PR #74192)
Finn Plummer
llvmlistbot at llvm.org
Sat Dec 2 06:10:02 PST 2023
https://github.com/inbelic created https://github.com/llvm/llvm-project/pull/74192
Add missing constant propogation folder for LeftShiftLogical, RightShift[Logical|Arithmetic].
Implement additional folding when Shift value is 0.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704
>From d270dd31f876dc64833d8d42cb271a437a63e15c Mon Sep 17 00:00:00 2001
From: inbelic <canadienfinn at gmail.com>
Date: Fri, 24 Nov 2023 10:53:36 +0100
Subject: [PATCH] [mlir][spirv] Add folding for SPIR-V Shifting ops
Add missing constant propogation folder for LeftShiftLogical,
RightShift[Logical|Arithmetic].
Implement additional folding when Shift value is 0.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704
---
.../mlir/Dialect/SPIRV/IR/SPIRVBitOps.td | 6 +
.../SPIRV/IR/SPIRVCanonicalization.cpp | 102 ++++++++++
.../SPIRV/Transforms/canonicalize.mlir | 176 ++++++++++++++++++
3 files changed, 284 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 286f4de6f90f6..e19bd640075c1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -457,6 +457,8 @@ def SPIRV_ShiftLeftLogicalOp : SPIRV_ShiftOp<"ShiftLeftLogical",
%5 = spirv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -499,6 +501,8 @@ def SPIRV_ShiftRightArithmeticOp : SPIRV_ShiftOp<"ShiftRightArithmetic",
%5 = spirv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -542,6 +546,8 @@ def SPIRV_ShiftRightLogicalOp : SPIRV_ShiftOp<"ShiftRightLogical",
%5 = spirv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af..528d6a5d483aa 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -356,6 +356,108 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.ShiftLeftLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftLeftLogicalOp::fold(
+ spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
+ // x << 0 -> x
+ if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+ return getOperand1();
+ }
+
+ // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit...
+ //
+ // The result is undefined if Shift is greater than or equal to the bit width
+ // of the components of Base.
+ //
+ // So we can use the APInt << method, but don't fold if undefined behaviour.
+ bool shiftToLarge = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (shiftToLarge || b.uge(a.getBitWidth())) {
+ shiftToLarge = true;
+ return a;
+ }
+ return a << b;
+ });
+ return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightArithmetic
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightArithmeticOp::fold(
+ spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
+ // x >> 0 -> x
+ if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+ return getOperand1();
+ }
+
+ // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit...
+ //
+ // The result is undefined if Shift is greater than or equal to the bit width
+ // of the components of Base.
+ //
+ // So we can use the APInt ashr method, but don't fold if undefined behaviour.
+ bool shiftToLarge = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (shiftToLarge || b.uge(a.getBitWidth())) {
+ shiftToLarge = true;
+ return a;
+ }
+ return a.ashr(b);
+ });
+ return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightLogicalOp::fold(
+ spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
+ // x >> 0 -> x
+ if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+ return getOperand1();
+ }
+
+ // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit...
+ //
+ // The result is undefined if Shift is greater than or equal to the bit width
+ // of the components of Base.
+ //
+ // So we can use the APInt lshr method, but don't fold if undefined behaviour.
+ bool shiftToLarge = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (shiftToLarge || b.uge(a.getBitWidth())) {
+ shiftToLarge = true;
+ return a;
+ }
+ return a.lshr(b);
+ });
+ return shiftToLarge ? Attribute() : res;
+}
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a44439..96b707b4398d3 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -660,6 +660,182 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
// -----
+//===----------------------------------------------------------------------===//
+// spirv.LeftShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsl_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsl_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %c0 = spirv.Constant 0 : i32
+ %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+ %0 = spirv.ShiftLeftLogical %arg0, %c0 : i32, i32
+ %1 = spirv.ShiftLeftLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsl_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsl_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+ // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+ %c32 = spirv.Constant 32 : i32
+ %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+ // CHECK: %0 = spirv.ShiftLeftLogical %[[ARG0]], %[[C32]]
+ // CHECK: %1 = spirv.ShiftLeftLogical %[[ARG1]], %[[CV]]
+ %0 = spirv.ShiftLeftLogical %arg0, %c32 : i32, i32
+ %1 = spirv.ShiftLeftLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsl
+func.func @const_fold_scalar_lsl() -> i32 {
+ %c1 = spirv.Constant 65535 : i32 // 0x0000 ffff
+ %c2 = spirv.Constant 17 : i32
+
+ // CHECK: %[[RET:.*]] = spirv.Constant -131072
+ // 0x0000 ffff << 17 -> 0xfffe 0000
+ %0 = spirv.ShiftLeftLogical %c1, %c2 : i32, i32
+
+ // CHECK: return %[[RET]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsl
+func.func @const_fold_vector_lsl() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[1, -1, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[-2147483648, -65536, 1040384]>
+ %0 = spirv.ShiftLeftLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftArithmetic
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @asr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @asr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %c0 = spirv.Constant 0 : i32
+ %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+ %0 = spirv.ShiftRightArithmetic %arg0, %c0 : i32, i32
+ %1 = spirv.ShiftRightArithmetic %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @asr_shift_overflow
+func.func @asr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+ // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+ %c32 = spirv.Constant 32 : i32
+ %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+ // CHECK: %0 = spirv.ShiftRightArithmetic %[[ARG0]], %[[C32]]
+ // CHECK: %1 = spirv.ShiftRightArithmetic %[[ARG1]], %[[CV]]
+ %0 = spirv.ShiftRightArithmetic %arg0, %c32 : i32, i32
+ %1 = spirv.ShiftRightArithmetic %arg1, %cv : vector<3xi32>, vector<3xi32>
+
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_asr
+func.func @const_fold_scalar_asr() -> i32 {
+ %c1 = spirv.Constant -131072 : i32 // 0xfffe 0000
+ %c2 = spirv.Constant 17 : i32
+ // 0x0000 ffff ashr 17 -> 0xffff ffff
+ // CHECK: %[[RET:.*]] = spirv.Constant -1
+ %0 = spirv.ShiftRightArithmetic %c1, %c2 : i32, i32
+
+ // CHECK: return %[[RET]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_asr
+func.func @const_fold_vector_asr() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[-2147483648, 239847, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[-1, 3, 0]>
+ %0 = spirv.ShiftRightArithmetic %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %c0 = spirv.Constant 0 : i32
+ %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+ %0 = spirv.ShiftRightLogical %arg0, %c0 : i32, i32
+ %1 = spirv.ShiftRightLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsr_shift_overflow
+func.func @lsr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+ // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+ %c32 = spirv.Constant 32 : i32
+ %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+ // CHECK: %0 = spirv.ShiftRightLogical %[[ARG0]], %[[C32]]
+ // CHECK: %1 = spirv.ShiftRightLogical %[[ARG1]], %[[CV]]
+ %0 = spirv.ShiftRightLogical %arg0, %c32 : i32, i32
+ %1 = spirv.ShiftRightLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsr
+func.func @const_fold_scalar_lsr() -> i32 {
+ %c1 = spirv.Constant -131072 : i32 // 0xfffe 0000
+ %c2 = spirv.Constant 17 : i32
+
+ // 0x0000 ffff << 17 -> 0x0000 7fff
+ // CHECK: %[[RET:.*]] = spirv.Constant 32767
+ %0 = spirv.ShiftRightLogical %c1, %c2 : i32, i32
+
+ // CHECK: return %[[RET]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsr
+func.func @const_fold_vector_lsr() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[-2147483648, -1, -127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[1, 65535, 524287]>
+ %0 = spirv.ShiftRightLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list