[Mlir-commits] [mlir] [mlir][spirv] Add folding for SPIR-V Shifting ops (PR #74192)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 3 10:51:43 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Finn Plummer (inbelic)

<details>
<summary>Changes</summary>

Add missing constant propogation folder for LeftShiftLogical, RightShift[Logical|Arithmetic].

Implement additional folding when Shift value is 0.

This helps for readability of lowered code into SPIR-V.

Part of work for #<!-- -->70704

---
Full diff: https://github.com/llvm/llvm-project/pull/74192.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td (+6) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (+102) 
- (modified) mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir (+178) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 286f4de6f90f6..e19bd640075c1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -457,6 +457,8 @@ def SPIRV_ShiftLeftLogicalOp : SPIRV_ShiftOp<"ShiftLeftLogical",
     %5 = spirv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -499,6 +501,8 @@ def SPIRV_ShiftRightArithmeticOp : SPIRV_ShiftOp<"ShiftRightArithmetic",
     %5 = spirv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -542,6 +546,8 @@ def SPIRV_ShiftRightLogicalOp : SPIRV_ShiftOp<"ShiftRightLogical",
     %5 = spirv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af..528d6a5d483aa 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -356,6 +356,108 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.ShiftLeftLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftLeftLogicalOp::fold(
+    spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
+  // x << 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit...
+  //
+  // The result is undefined if Shift is greater than or equal to the bit width
+  // of the components of Base.
+  //
+  // So we can use the APInt << method, but don't fold if undefined behaviour.
+  bool shiftToLarge = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (shiftToLarge || b.uge(a.getBitWidth())) {
+          shiftToLarge = true;
+          return a;
+        }
+        return a << b;
+      });
+  return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightArithmetic
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightArithmeticOp::fold(
+    spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
+  // x >> 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit...
+  //
+  // The result is undefined if Shift is greater than or equal to the bit width
+  // of the components of Base.
+  //
+  // So we can use the APInt ashr method, but don't fold if undefined behaviour.
+  bool shiftToLarge = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (shiftToLarge || b.uge(a.getBitWidth())) {
+          shiftToLarge = true;
+          return a;
+        }
+        return a.ashr(b);
+      });
+  return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightLogicalOp::fold(
+    spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
+  // x >> 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit...
+  //
+  // The result is undefined if Shift is greater than or equal to the bit width
+  // of the components of Base.
+  //
+  // So we can use the APInt lshr method, but don't fold if undefined behaviour.
+  bool shiftToLarge = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (shiftToLarge || b.uge(a.getBitWidth())) {
+          shiftToLarge = true;
+          return a;
+        }
+        return a.lshr(b);
+      });
+  return shiftToLarge ? Attribute() : res;
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a44439..3919a051fc875 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -660,6 +660,184 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.LeftShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsl_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsl_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.ShiftLeftLogical %arg0, %c0 : i32, i32
+  %1 = spirv.ShiftLeftLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsl_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsl_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+  // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+  %c32 = spirv.Constant 32 : i32
+  %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+  // CHECK: %0 = spirv.ShiftLeftLogical %[[ARG0]], %[[C32]]
+  // CHECK: %1 = spirv.ShiftLeftLogical %[[ARG1]], %[[CV]]
+  %0 = spirv.ShiftLeftLogical %arg0, %c32 : i32, i32
+  %1 = spirv.ShiftLeftLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsl
+func.func @const_fold_scalar_lsl() -> i32 {
+  %c1 = spirv.Constant 65535 : i32  // 0x0000 ffff
+  %c2 = spirv.Constant 17 : i32
+
+  // CHECK: %[[RET:.*]] = spirv.Constant -131072
+  // 0x0000 ffff << 17 -> 0xfffe 0000
+  %0 = spirv.ShiftLeftLogical %c1, %c2 : i32, i32
+
+  // CHECK: return %[[RET]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsl
+func.func @const_fold_vector_lsl() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[1, -1, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[-2147483648, -65536, 1040384]>
+  %0 = spirv.ShiftLeftLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftArithmetic
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @asr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @asr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.ShiftRightArithmetic %arg0, %c0 : i32, i32
+  %1 = spirv.ShiftRightArithmetic %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @asr_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @asr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+  // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+  %c32 = spirv.Constant 32 : i32
+  %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+  // CHECK: %0 = spirv.ShiftRightArithmetic %[[ARG0]], %[[C32]]
+  // CHECK: %1 = spirv.ShiftRightArithmetic %[[ARG1]], %[[CV]]
+  %0 = spirv.ShiftRightArithmetic %arg0, %c32 : i32, i32
+  %1 = spirv.ShiftRightArithmetic %arg1, %cv : vector<3xi32>, vector<3xi32>
+
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_asr
+func.func @const_fold_scalar_asr() -> i32 {
+  %c1 = spirv.Constant -131072 : i32  // 0xfffe 0000
+  %c2 = spirv.Constant 17 : i32
+  // 0x0000 ffff ashr 17 -> 0xffff ffff
+  // CHECK: %[[RET:.*]] = spirv.Constant -1
+  %0 = spirv.ShiftRightArithmetic %c1, %c2 : i32, i32
+
+  // CHECK: return %[[RET]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_asr
+func.func @const_fold_vector_asr() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[-2147483648, 239847, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[-1, 3, 0]>
+  %0 = spirv.ShiftRightArithmetic %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.ShiftRightLogical %arg0, %c0 : i32, i32
+  %1 = spirv.ShiftRightLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsr_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+  // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+  %c32 = spirv.Constant 32 : i32
+  %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+  // CHECK: %0 = spirv.ShiftRightLogical %[[ARG0]], %[[C32]]
+  // CHECK: %1 = spirv.ShiftRightLogical %[[ARG1]], %[[CV]]
+  %0 = spirv.ShiftRightLogical %arg0, %c32 : i32, i32
+  %1 = spirv.ShiftRightLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsr
+func.func @const_fold_scalar_lsr() -> i32 {
+  %c1 = spirv.Constant -131072 : i32  // 0xfffe 0000
+  %c2 = spirv.Constant 17 : i32
+
+  // 0x0000 ffff << 17 -> 0x0000 7fff
+  // CHECK: %[[RET:.*]] = spirv.Constant 32767
+  %0 = spirv.ShiftRightLogical %c1, %c2 : i32, i32
+
+  // CHECK: return %[[RET]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsr
+func.func @const_fold_vector_lsr() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[-2147483648, -1, -127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[1, 65535, 524287]>
+  %0 = spirv.ShiftRightLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/74192


More information about the Mlir-commits mailing list