[Mlir-commits] [mlir] 520a570 - [mlir][StandardToSPIRV] Fix signedness issue in bitwidth emulation.

Hanhan Wang llvmlistbot at llvm.org
Tue May 19 11:00:21 PDT 2020


Author: Hanhan Wang
Date: 2020-05-19T11:00:01-07:00
New Revision: 520a5702680ea0b5059193a0d4ad52c217da7325

URL: https://github.com/llvm/llvm-project/commit/520a5702680ea0b5059193a0d4ad52c217da7325
DIFF: https://github.com/llvm/llvm-project/commit/520a5702680ea0b5059193a0d4ad52c217da7325.diff

LOG: [mlir][StandardToSPIRV] Fix signedness issue in bitwidth emulation.

Summary:
Previously, after applying the mask, a negative number would convert to a
positive number because the sign flag was forgotten. This patch adds two more
shift operations to do the sign extension. This assumes that we're using two's
complement.

This patch applies sign extension unconditionally when loading a unspported integer width, and it relies the pattern to do the casting because the signedness semantic is carried by operator itself.

Differential Revision: https://reviews.llvm.org/D79753

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index fbe02560008a..560bc4acf436 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -147,14 +147,44 @@ static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
 }
 
 /// Returns the shifted `targetBits`-bit value with the given offset.
-Value shiftValue(Location loc, Value value, Value offset, Value mask,
-                 int targetBits, OpBuilder &builder) {
+static Value shiftValue(Location loc, Value value, Value offset, Value mask,
+                        int targetBits, OpBuilder &builder) {
   Type targetType = builder.getIntegerType(targetBits);
   Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
   return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
                                                    offset);
 }
 
+/// Returns true if the operator is operating on unsigned integers.
+/// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information
+/// to the ops themselves.
+template <typename SPIRVOp>
+bool isUnsignedOp() {
+  return false;
+}
+
+#define CHECK_UNSIGNED_OP(SPIRVOp)                                             \
+  template <>                                                                  \
+  bool isUnsignedOp<SPIRVOp>() {                                               \
+    return true;                                                               \
+  }
+
+CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp);
+CHECK_UNSIGNED_OP(spirv::AtomicUMinOp);
+CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp);
+CHECK_UNSIGNED_OP(spirv::ConvertUToFOp);
+CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp);
+CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp);
+CHECK_UNSIGNED_OP(spirv::UConvertOp);
+CHECK_UNSIGNED_OP(spirv::UDivOp);
+CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp);
+CHECK_UNSIGNED_OP(spirv::UGreaterThanOp);
+CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp);
+CHECK_UNSIGNED_OP(spirv::ULessThanOp);
+CHECK_UNSIGNED_OP(spirv::UModOp);
+
+#undef CHECK_UNSIGNED_OP
+
 //===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
@@ -178,6 +208,10 @@ class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
     auto dstType = this->typeConverter.convertType(operation.getType());
     if (!dstType)
       return failure();
+    if (isUnsignedOp<SPIRVOp>() && dstType != operation.getType()) {
+      return operation.emitError(
+          "bitwidth emulation is not implemented yet on unsigned op");
+    }
     rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands,
                                                   ArrayRef<NamedAttribute>());
     return success();
@@ -581,6 +615,11 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
   switch (cmpIOp.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
+    if (isUnsignedOp<spirvOp>() &&                                             \
+        operandType != this->typeConverter.convertType(operandType)) {         \
+      return cmpIOp.emitError(                                                 \
+          "bitwidth emulation is not implemented yet on unsigned op");         \
+    }                                                                          \
     rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
                                          cmpIOpOperands.lhs(),                 \
                                          cmpIOpOperands.rhs());                \
@@ -661,6 +700,18 @@ IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
   Value mask = rewriter.create<spirv::ConstantOp>(
       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
   result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
+
+  // Apply sign extension on the loading value unconditionally. The signedness
+  // semantic is carried in the operator itself, we relies other pattern to
+  // handle the casting.
+  IntegerAttr shiftValueAttr =
+      rewriter.getIntegerAttr(dstType, dstBits - srcBits);
+  Value shiftValue =
+      rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
+  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
+                                                      shiftValue);
+  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
+                                                          shiftValue);
   rewriter.replaceOp(loadOp, result);
 
   assert(accessChainOp.use_empty());

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index 16633664c4ea..bf54dbaadb18 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv %s -o - | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-std-to-spirv -verify-diagnostics %s -o - | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // std arithmetic ops
@@ -128,14 +128,12 @@ module attributes {
      max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
 } {
 
-// CHECK-LABEL: @int_vector234
-func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<3xi16>, %arg2: vector<4xi64>) {
+// CHECK-LABEL: @int_vector23
+func @int_vector23(%arg0: vector<2xi8>, %arg1: vector<3xi16>) {
   // CHECK: spv.SDiv %{{.*}}, %{{.*}}: vector<2xi32>
   %0 = divi_signed %arg0, %arg0: vector<2xi8>
   // CHECK: spv.SRem %{{.*}}, %{{.*}}: vector<3xi32>
   %1 = remi_signed %arg1, %arg1: vector<3xi16>
-  // CHECK: spv.UDiv %{{.*}}, %{{.*}}: vector<4xi32>
-  %2 = divi_unsigned %arg2, %arg2: vector<4xi64>
   return
 }
 
@@ -152,6 +150,27 @@ func @float_scalar(%arg0: f16, %arg1: f64) {
 
 // -----
 
+// Check that types are converted to 32-bit when no special capabilities that
+// are not supported.
+module attributes {
+  spv.target_env = #spv.target_env<
+    #spv.vce<v1.0, [], []>,
+    {max_compute_workgroup_invocations = 128 : i32,
+     max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+
+// CHECK-LEBEL: @int_vector4_invalid
+func @int_vector4_invalid(%arg0: vector<4xi64>) {
+  // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}}
+  // expected-error @+1 {{op requires the same type for all operands and results}}
+  %0 = divi_unsigned %arg0, %arg0: vector<4xi64>
+  return
+}
+
+} // end module
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // std bit ops
 //===----------------------------------------------------------------------===//
@@ -717,7 +736,10 @@ func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spv.constant 255 : i32
-  //     CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T2:.+]] = spv.constant 24 : i32
+  //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   %0 = load %arg0[] : memref<i8>
   return
 }
@@ -738,7 +760,10 @@ func @load_i16(%arg0: memref<10xi16>, %index : index) {
   //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32
   //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spv.constant 65535 : i32
-  //     CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T2:.+]] = spv.constant 16 : i32
+  //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   %0 = load %arg0[%index] : memref<10xi16>
   return
 }
@@ -852,7 +877,10 @@ func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
   //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
   //     CHECK: %[[MASK:.+]] = spv.constant 255 : i32
-  //     CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T2:.+]] = spv.constant 24 : i32
+  //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //     CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
   %0 = load %arg0[] : memref<i8>
   return
 }


        


More information about the Mlir-commits mailing list