[Mlir-commits] [mlir] 2ba9720 - [mlir][spirv] Add folding for SPIR-V Shifting ops (#74192)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 6 15:00:48 PST 2023
Author: Finn Plummer
Date: 2023-12-06T18:00:44-05:00
New Revision: 2ba9720a76491709288467f3d51530cc6d031540
URL: https://github.com/llvm/llvm-project/commit/2ba9720a76491709288467f3d51530cc6d031540
DIFF: https://github.com/llvm/llvm-project/commit/2ba9720a76491709288467f3d51530cc6d031540.diff
LOG: [mlir][spirv] Add folding for SPIR-V Shifting ops (#74192)
Add missing constant propogation folder for LeftShiftLogical,
RightShift[Logical|Arithmetic].
Implement additional folding when Shift value is 0.
This helps for readability of lowered code into SPIR-V.
Part of work for #70704
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 91b69576cd1d3..7423bb3c68d7c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -368,6 +368,8 @@ def SPIRV_ShiftLeftLogicalOp : SPIRV_ShiftOp<"ShiftLeftLogical",
%5 = spirv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -399,6 +401,8 @@ def SPIRV_ShiftRightArithmeticOp : SPIRV_ShiftOp<"ShiftRightArithmetic",
%5 = spirv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -431,6 +435,8 @@ def SPIRV_ShiftRightLogicalOp : SPIRV_ShiftOp<"ShiftRightLogical",
%5 = spirv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
```
}];
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 22cb9bf718e36..4643fc08c08d3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -709,6 +709,108 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.ShiftLeftLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftLeftLogicalOp::fold(
+ spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
+ // x << 0 -> x
+ if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+ return getOperand1();
+ }
+
+ // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit...
+ //
+ // The result is undefined if Shift is greater than or equal to the bit width
+ // of the components of Base.
+ //
+ // So we can use the APInt << method, but don't fold if undefined behaviour.
+ bool shiftToLarge = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (shiftToLarge || b.uge(a.getBitWidth())) {
+ shiftToLarge = true;
+ return a;
+ }
+ return a << b;
+ });
+ return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightArithmetic
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightArithmeticOp::fold(
+ spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
+ // x >> 0 -> x
+ if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+ return getOperand1();
+ }
+
+ // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit...
+ //
+ // The result is undefined if Shift is greater than or equal to the bit width
+ // of the components of Base.
+ //
+ // So we can use the APInt ashr method, but don't fold if undefined behaviour.
+ bool shiftToLarge = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (shiftToLarge || b.uge(a.getBitWidth())) {
+ shiftToLarge = true;
+ return a;
+ }
+ return a.ashr(b);
+ });
+ return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightLogicalOp::fold(
+ spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
+ // x >> 0 -> x
+ if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+ return getOperand1();
+ }
+
+ // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+ // According to the SPIR-V spec:
+ //
+ // Type is a scalar or vector of integer type.
+ // Results are computed per component, and within each component, per bit...
+ //
+ // The result is undefined if Shift is greater than or equal to the bit width
+ // of the components of Base.
+ //
+ // So we can use the APInt lshr method, but don't fold if undefined behaviour.
+ bool shiftToLarge = false;
+ auto res = constFoldBinaryOp<IntegerAttr>(
+ adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+ if (shiftToLarge || b.uge(a.getBitWidth())) {
+ shiftToLarge = true;
+ return a;
+ }
+ return a.lshr(b);
+ });
+ return shiftToLarge ? Attribute() : res;
+}
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 867ddf3c80173..6be4b8d20ff36 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1139,6 +1139,184 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
// -----
+//===----------------------------------------------------------------------===//
+// spirv.LeftShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsl_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsl_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %c0 = spirv.Constant 0 : i32
+ %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+ %0 = spirv.ShiftLeftLogical %arg0, %c0 : i32, i32
+ %1 = spirv.ShiftLeftLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsl_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsl_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+ // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+ %c32 = spirv.Constant 32 : i32
+ %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+ // CHECK: %0 = spirv.ShiftLeftLogical %[[ARG0]], %[[C32]]
+ // CHECK: %1 = spirv.ShiftLeftLogical %[[ARG1]], %[[CV]]
+ %0 = spirv.ShiftLeftLogical %arg0, %c32 : i32, i32
+ %1 = spirv.ShiftLeftLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsl
+func.func @const_fold_scalar_lsl() -> i32 {
+ %c1 = spirv.Constant 65535 : i32 // 0x0000 ffff
+ %c2 = spirv.Constant 17 : i32
+
+ // CHECK: %[[RET:.*]] = spirv.Constant -131072
+ // 0x0000 ffff << 17 -> 0xfffe 0000
+ %0 = spirv.ShiftLeftLogical %c1, %c2 : i32, i32
+
+ // CHECK: return %[[RET]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsl
+func.func @const_fold_vector_lsl() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[1, -1, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[-2147483648, -65536, 1040384]>
+ %0 = spirv.ShiftLeftLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftArithmetic
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @asr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @asr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %c0 = spirv.Constant 0 : i32
+ %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+ %0 = spirv.ShiftRightArithmetic %arg0, %c0 : i32, i32
+ %1 = spirv.ShiftRightArithmetic %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @asr_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @asr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+ // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+ %c32 = spirv.Constant 32 : i32
+ %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+ // CHECK: %0 = spirv.ShiftRightArithmetic %[[ARG0]], %[[C32]]
+ // CHECK: %1 = spirv.ShiftRightArithmetic %[[ARG1]], %[[CV]]
+ %0 = spirv.ShiftRightArithmetic %arg0, %c32 : i32, i32
+ %1 = spirv.ShiftRightArithmetic %arg1, %cv : vector<3xi32>, vector<3xi32>
+
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_asr
+func.func @const_fold_scalar_asr() -> i32 {
+ %c1 = spirv.Constant -131072 : i32 // 0xfffe 0000
+ %c2 = spirv.Constant 17 : i32
+ // 0x0000 ffff ashr 17 -> 0xffff ffff
+ // CHECK: %[[RET:.*]] = spirv.Constant -1
+ %0 = spirv.ShiftRightArithmetic %c1, %c2 : i32, i32
+
+ // CHECK: return %[[RET]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_asr
+func.func @const_fold_vector_asr() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[-2147483648, 239847, 127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[-1, 3, 0]>
+ %0 = spirv.ShiftRightArithmetic %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ %c0 = spirv.Constant 0 : i32
+ %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+ %0 = spirv.ShiftRightLogical %arg0, %c0 : i32, i32
+ %1 = spirv.ShiftRightLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[ARG0]], %[[ARG1]]
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsr_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+ // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+ // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+ %c32 = spirv.Constant 32 : i32
+ %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+ // CHECK: %0 = spirv.ShiftRightLogical %[[ARG0]], %[[C32]]
+ // CHECK: %1 = spirv.ShiftRightLogical %[[ARG1]], %[[CV]]
+ %0 = spirv.ShiftRightLogical %arg0, %c32 : i32, i32
+ %1 = spirv.ShiftRightLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+ return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsr
+func.func @const_fold_scalar_lsr() -> i32 {
+ %c1 = spirv.Constant -131072 : i32 // 0xfffe 0000
+ %c2 = spirv.Constant 17 : i32
+
+ // 0x0000 ffff << 17 -> 0x0000 7fff
+ // CHECK: %[[RET:.*]] = spirv.Constant 32767
+ %0 = spirv.ShiftRightLogical %c1, %c2 : i32, i32
+
+ // CHECK: return %[[RET]]
+ return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsr
+func.func @const_fold_vector_lsr() -> vector<3xi32> {
+ %c1 = spirv.Constant dense<[-2147483648, -1, -127]> : vector<3xi32>
+ %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[1, 65535, 524287]>
+ %0 = spirv.ShiftRightLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list