[Mlir-commits] [mlir] 2ba9720 - [mlir][spirv] Add folding for SPIR-V Shifting ops (#74192)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 6 15:00:48 PST 2023


Author: Finn Plummer
Date: 2023-12-06T18:00:44-05:00
New Revision: 2ba9720a76491709288467f3d51530cc6d031540

URL: https://github.com/llvm/llvm-project/commit/2ba9720a76491709288467f3d51530cc6d031540
DIFF: https://github.com/llvm/llvm-project/commit/2ba9720a76491709288467f3d51530cc6d031540.diff

LOG: [mlir][spirv] Add folding for SPIR-V Shifting ops (#74192)

Add missing constant propogation folder for LeftShiftLogical,
RightShift[Logical|Arithmetic].

Implement additional folding when Shift value is 0.

This helps for readability of lowered code into SPIR-V.

Part of work for #70704

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
    mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
index 91b69576cd1d3..7423bb3c68d7c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
@@ -368,6 +368,8 @@ def SPIRV_ShiftLeftLogicalOp : SPIRV_ShiftOp<"ShiftLeftLogical",
     %5 = spirv.ShiftLeftLogical %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -399,6 +401,8 @@ def SPIRV_ShiftRightArithmeticOp : SPIRV_ShiftOp<"ShiftRightArithmetic",
     %5 = spirv.ShiftRightArithmetic %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----
@@ -431,6 +435,8 @@ def SPIRV_ShiftRightLogicalOp : SPIRV_ShiftOp<"ShiftRightLogical",
     %5 = spirv.ShiftRightLogical %3, %4 : vector<3xi32>, vector<3xi16>
     ```
   }];
+
+  let hasFolder = 1;
 }
 
 // -----

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 22cb9bf718e36..4643fc08c08d3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -709,6 +709,108 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   return Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.ShiftLeftLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftLeftLogicalOp::fold(
+    spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
+  // x << 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit...
+  //
+  // The result is undefined if Shift is greater than or equal to the bit width
+  // of the components of Base.
+  //
+  // So we can use the APInt << method, but don't fold if undefined behaviour.
+  bool shiftToLarge = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (shiftToLarge || b.uge(a.getBitWidth())) {
+          shiftToLarge = true;
+          return a;
+        }
+        return a << b;
+      });
+  return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightArithmetic
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightArithmeticOp::fold(
+    spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
+  // x >> 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit...
+  //
+  // The result is undefined if Shift is greater than or equal to the bit width
+  // of the components of Base.
+  //
+  // So we can use the APInt ashr method, but don't fold if undefined behaviour.
+  bool shiftToLarge = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (shiftToLarge || b.uge(a.getBitWidth())) {
+          shiftToLarge = true;
+          return a;
+        }
+        return a.ashr(b);
+      });
+  return shiftToLarge ? Attribute() : res;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ShiftRightLogical
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ShiftRightLogicalOp::fold(
+    spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
+  // x >> 0 -> x
+  if (matchPattern(adaptor.getOperand2(), m_Zero())) {
+    return getOperand1();
+  }
+
+  // Unfortunately due to below undefined behaviour can't fold 0 for Base.
+
+  // According to the SPIR-V spec:
+  //
+  // Type is a scalar or vector of integer type.
+  // Results are computed per component, and within each component, per bit...
+  //
+  // The result is undefined if Shift is greater than or equal to the bit width
+  // of the components of Base.
+  //
+  // So we can use the APInt lshr method, but don't fold if undefined behaviour.
+  bool shiftToLarge = false;
+  auto res = constFoldBinaryOp<IntegerAttr>(
+      adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
+        if (shiftToLarge || b.uge(a.getBitWidth())) {
+          shiftToLarge = true;
+          return a;
+        }
+        return a.lshr(b);
+      });
+  return shiftToLarge ? Attribute() : res;
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 867ddf3c80173..6be4b8d20ff36 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -1139,6 +1139,184 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.LeftShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsl_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsl_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.ShiftLeftLogical %arg0, %c0 : i32, i32
+  %1 = spirv.ShiftLeftLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsl_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsl_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+  // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+  %c32 = spirv.Constant 32 : i32
+  %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+  // CHECK: %0 = spirv.ShiftLeftLogical %[[ARG0]], %[[C32]]
+  // CHECK: %1 = spirv.ShiftLeftLogical %[[ARG1]], %[[CV]]
+  %0 = spirv.ShiftLeftLogical %arg0, %c32 : i32, i32
+  %1 = spirv.ShiftLeftLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsl
+func.func @const_fold_scalar_lsl() -> i32 {
+  %c1 = spirv.Constant 65535 : i32  // 0x0000 ffff
+  %c2 = spirv.Constant 17 : i32
+
+  // CHECK: %[[RET:.*]] = spirv.Constant -131072
+  // 0x0000 ffff << 17 -> 0xfffe 0000
+  %0 = spirv.ShiftLeftLogical %c1, %c2 : i32, i32
+
+  // CHECK: return %[[RET]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsl
+func.func @const_fold_vector_lsl() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[1, -1, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[-2147483648, -65536, 1040384]>
+  %0 = spirv.ShiftLeftLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftArithmetic
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @asr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @asr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.ShiftRightArithmetic %arg0, %c0 : i32, i32
+  %1 = spirv.ShiftRightArithmetic %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @asr_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @asr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+  // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+  %c32 = spirv.Constant 32 : i32
+  %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+  // CHECK: %0 = spirv.ShiftRightArithmetic %[[ARG0]], %[[C32]]
+  // CHECK: %1 = spirv.ShiftRightArithmetic %[[ARG1]], %[[CV]]
+  %0 = spirv.ShiftRightArithmetic %arg0, %c32 : i32, i32
+  %1 = spirv.ShiftRightArithmetic %arg1, %cv : vector<3xi32>, vector<3xi32>
+
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_asr
+func.func @const_fold_scalar_asr() -> i32 {
+  %c1 = spirv.Constant -131072 : i32  // 0xfffe 0000
+  %c2 = spirv.Constant 17 : i32
+  // 0x0000 ffff ashr 17 -> 0xffff ffff
+  // CHECK: %[[RET:.*]] = spirv.Constant -1
+  %0 = spirv.ShiftRightArithmetic %c1, %c2 : i32, i32
+
+  // CHECK: return %[[RET]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_asr
+func.func @const_fold_vector_asr() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[-2147483648, 239847, 127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[-1, 3, 0]>
+  %0 = spirv.ShiftRightArithmetic %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.RightShiftLogical
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @lsr_x_0
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsr_x_0(%arg0 : i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  %c0 = spirv.Constant 0 : i32
+  %cv0 = spirv.Constant dense<0> : vector<3xi32>
+
+  %0 = spirv.ShiftRightLogical %arg0, %c0 : i32, i32
+  %1 = spirv.ShiftRightLogical %arg1, %cv0 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[ARG0]], %[[ARG1]]
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @lsr_shift_overflow
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: vector<3xi32>)
+func.func @lsr_shift_overflow(%arg0: i32, %arg1: vector<3xi32>) -> (i32, vector<3xi32>) {
+  // CHECK-DAG: %[[C32:.*]] = spirv.Constant 32
+  // CHECK-DAG: %[[CV:.*]] = spirv.Constant dense<[6, 18, 128]>
+  %c32 = spirv.Constant 32 : i32
+  %cv = spirv.Constant dense<[6, 18, 128]> : vector<3xi32>
+
+  // CHECK: %0 = spirv.ShiftRightLogical %[[ARG0]], %[[C32]]
+  // CHECK: %1 = spirv.ShiftRightLogical %[[ARG1]], %[[CV]]
+  %0 = spirv.ShiftRightLogical %arg0, %c32 : i32, i32
+  %1 = spirv.ShiftRightLogical %arg1, %cv : vector<3xi32>, vector<3xi32>
+  return %0, %1 : i32, vector<3xi32>
+}
+
+// CHECK-LABEL: @const_fold_scalar_lsr
+func.func @const_fold_scalar_lsr() -> i32 {
+  %c1 = spirv.Constant -131072 : i32  // 0xfffe 0000
+  %c2 = spirv.Constant 17 : i32
+
+  // 0x0000 ffff << 17 -> 0x0000 7fff
+  // CHECK: %[[RET:.*]] = spirv.Constant 32767
+  %0 = spirv.ShiftRightLogical %c1, %c2 : i32, i32
+
+  // CHECK: return %[[RET]]
+  return %0 : i32
+}
+
+// CHECK-LABEL: @const_fold_vector_lsr
+func.func @const_fold_vector_lsr() -> vector<3xi32> {
+  %c1 = spirv.Constant dense<[-2147483648, -1, -127]> : vector<3xi32>
+  %c2 = spirv.Constant dense<[31, 16, 13]> : vector<3xi32>
+
+  // CHECK: %[[RET:.*]] = spirv.Constant dense<[1, 65535, 524287]>
+  %0 = spirv.ShiftRightLogical %c1, %c2 : vector<3xi32>, vector<3xi32>
+
+  // CHECK: return %[[RET]]
+  return %0 : vector<3xi32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.selection
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list