[Mlir-commits] [mlir] c361435 - [mlir][StandardToSPIRV] Handle i1 case for lowering memref.load/store op

Hanhan Wang llvmlistbot at llvm.org
Thu Apr 8 12:15:51 PDT 2021


Author: Hanhan Wang
Date: 2021-04-08T12:15:25-07:00
New Revision: c3614358452e5050b5b191fd3df3fad8b2664221

URL: https://github.com/llvm/llvm-project/commit/c3614358452e5050b5b191fd3df3fad8b2664221
DIFF: https://github.com/llvm/llvm-project/commit/c3614358452e5050b5b191fd3df3fad8b2664221.diff

LOG: [mlir][StandardToSPIRV] Handle i1 case for lowering memref.load/store op

This patch unconditionally converts i1 types to i8 types on memrefs. If the
extensions or capabilities are not met, they will be converted to i32. Hence the
logic in IntLoadPattern and IntStorePattern are also updated.

Also added the implementation of SPIRVTypeConverter::getOptions().

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D99724

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
    mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 911a030d4a7e5..2186107851d91 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -49,9 +49,13 @@ class SPIRVTypeConverter : public TypeConverter {
     /// values will be packed into one 32-bit value to be memory efficient.
     bool emulateNon32BitScalarTypes;
 
+    /// The number of bits to store a boolean value. It is eight bits by
+    /// default.
+    unsigned boolNumBits;
+
     // Note: we need this instead of inline initializers becuase of
     // https://bugs.llvm.org/show_bug.cgi?id=36684
-    Options() : emulateNon32BitScalarTypes(true) {}
+    Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {}
   };
 
   explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,

diff  --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index ed66252e20ae5..397d26b0499d8 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -991,6 +991,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
                            loadOperands.indices(), loc, rewriter);
 
   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+  bool isBool = srcBits == 1;
+  if (isBool)
+    srcBits = typeConverter.getOptions().boolNumBits;
   auto dstType = typeConverter.convertType(memrefType)
                      .cast<spirv::PointerType>()
                      .getPointeeType()
@@ -1044,6 +1047,18 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp,
                                                       shiftValue);
   result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
                                                           shiftValue);
+
+  if (isBool) {
+    dstType = typeConverter.convertType(loadOp.getType());
+    mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
+    Value isOne = rewriter.create<spirv::IEqualOp>(loc, result, mask);
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    result = rewriter.create<spirv::SelectOp>(loc, dstType, isOne, one, zero);
+  } else if (result.getType().getIntOrFloatBitWidth() !=
+             static_cast<unsigned>(dstBits)) {
+    result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
+  }
   rewriter.replaceOp(loadOp, result);
 
   assert(accessChainOp.use_empty());
@@ -1117,6 +1132,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
       spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
                            storeOperands.indices(), loc, rewriter);
   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
+
+  bool isBool = srcBits == 1;
+  if (isBool)
+    srcBits = typeConverter.getOptions().boolNumBits;
   auto dstType = typeConverter.convertType(memrefType)
                      .cast<spirv::PointerType>()
                      .getPointeeType()
@@ -1156,8 +1175,14 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp,
       rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
   clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
 
-  Value storeVal =
-      shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
+  Value storeVal = storeOperands.value();
+  if (isBool) {
+    Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
+    Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
+    storeVal =
+        rewriter.create<spirv::SelectOp>(loc, dstType, storeVal, one, zero);
+  }
+  storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                    srcBits, dstBits, rewriter);
   Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 4e2dc01108b2d..d26299fa82a9c 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -153,6 +153,10 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
 #undef STORAGE_SPACE_MAP_FN
 }
 
+const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
+  return options;
+}
+
 #undef STORAGE_SPACE_MAP_LIST
 
 // TODO: This is a utility function that should probably be exposed by the
@@ -342,9 +346,66 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
   return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
 }
 
+static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
+                                  const SPIRVTypeConverter::Options &options,
+                                  MemRefType type) {
+  if (!type.hasStaticShape()) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " dynamic shape on i1 is not supported yet\n");
+    return nullptr;
+  }
+
+  Optional<spirv::StorageClass> storageClass =
+      SPIRVTypeConverter::getStorageClassForMemorySpace(
+          type.getMemorySpaceAsInt());
+  if (!storageClass) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot convert memory space\n");
+    return nullptr;
+  }
+
+  unsigned numBoolBits = options.boolNumBits;
+  if (numBoolBits != 8) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "using non-8-bit storage for bool types unimplemented");
+    return nullptr;
+  }
+  auto elementType = IntegerType::get(type.getContext(), numBoolBits)
+                         .dyn_cast<spirv::ScalarType>();
+  if (!elementType)
+    return nullptr;
+  Type arrayElemType =
+      convertScalarType(targetEnv, options, elementType, storageClass);
+  if (!arrayElemType)
+    return nullptr;
+  Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
+  if (!arrayElemSize) {
+    LLVM_DEBUG(llvm::dbgs()
+               << type << " illegal: cannot deduce converted element size\n");
+    return nullptr;
+  }
+
+  int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
+  auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
+  auto arrayType =
+      spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
+
+  // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
+  // workgroup storage class do not need the struct to be laid out explicitly.
+  auto structType = *storageClass == spirv::StorageClass::Workgroup
+                        ? spirv::StructType::get(arrayType)
+                        : spirv::StructType::get(arrayType, 0);
+  return spirv::PointerType::get(structType, *storageClass);
+}
+
 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
                               const SPIRVTypeConverter::Options &options,
                               MemRefType type) {
+  if (type.getElementType().isa<IntegerType>() &&
+      type.getElementTypeBitWidth() == 1) {
+    return convertBoolMemrefType(targetEnv, options, type);
+  }
+
   Optional<spirv::StorageClass> storageClass =
       SPIRVTypeConverter::getStorageClassForMemorySpace(
           type.getMemorySpaceAsInt());

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
index d074969febff1..86d390a2ce703 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir
@@ -911,12 +911,40 @@ func @load_store_zero_rank_int(%arg0: memref<i32>, %arg1: memref<i32>) {
 
 // Check that access chain indices are properly adjusted if non-32-bit types are
 // emulated via 32-bit types.
-// TODO: Test i1 and i64 types.
+// TODO: Test i64 types.
 module attributes {
   spv.target_env = #spv.target_env<
     #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, {}>
 } {
 
+// CHECK-LABEL: @load_i1
+func @load_i1(%arg0: memref<i1>) -> i1 {
+  //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
+  //     CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32
+  //     CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32
+  //     CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  //     CHECK: %[[LOAD:.+]] = spv.Load  "StorageBuffer" %[[PTR]]
+  //     CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32
+  //     CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32
+  //     CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32
+  //     CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+  //     CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spv.Constant 255 : i32
+  //     CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  //     CHECK: %[[T2:.+]] = spv.Constant 24 : i32
+  //     CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32
+  //     CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32
+  // Convert to i1 type.
+  //     CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+  //     CHECK: %[[ISONE:.+]]  = spv.IEqual %[[T4]], %[[ONE]] : i32
+  //     CHECK: %[[FALSE:.+]] = spv.Constant false
+  //     CHECK: %[[TRUE:.+]] = spv.Constant true
+  //     CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1
+  //     CHECK: spv.ReturnValue %[[RES]] : i1
+  %0 = memref.load %arg0[] : memref<i1>
+  return %0 : i1
+}
+
 // CHECK-LABEL: @load_i8
 func @load_i8(%arg0: memref<i8>) {
   //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
@@ -982,6 +1010,31 @@ func @load_f32(%arg0: memref<f32>) {
   return
 }
 
+// CHECK-LABEL: @store_i1
+//       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1)
+func @store_i1(%arg0: memref<i1>, %value: i1) {
+  //     CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
+  //     CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32
+  //     CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32
+  //     CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32
+  //     CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32
+  //     CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32
+  //     CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32
+  //     CHECK: %[[ZERO1:.+]] = spv.Constant 0 : i32
+  //     CHECK: %[[ONE1:.+]] = spv.Constant 1 : i32
+  //     CHECK: %[[CASTED_ARG1:.+]] = spv.Select %[[ARG1]], %[[ONE1]], %[[ZERO1]] : i1, i32
+  //     CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32
+  //     CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32
+  //     CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32
+  //     CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32
+  //     CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]]
+  //     CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]]
+  //     CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+  memref.store %value, %arg0[] : memref<i1>
+  return
+}
+
 // CHECK-LABEL: @store_i8
 //       CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32)
 func @store_i8(%arg0: memref<i8>, %value: i8) {

diff  --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
index cacc3e762c149..58513124907a5 100644
--- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
+++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir
@@ -286,20 +286,6 @@ func @memref_mem_space(
 
 // -----
 
-// Check that boolean memref is not supported at the moment.
-module attributes {
-  spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
-} {
-
-// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>)
-func @memref_type(%arg0: memref<3xi1>) {
-  return
-}
-
-} // end module
-
-// -----
-
 // Check that using non-32-bit scalar types in interface storage classes
 // requires special capability and extension: convert them to 32-bit if not
 // satisfied.
@@ -307,6 +293,11 @@ module attributes {
   spv.target_env = #spv.target_env<#spv.vce<v1.0, [], []>, {}>
 } {
 
+// An i1 is store in 8-bit, so 5xi1 has 40 bits, which is stored in 2xi32.
+// CHECK-LABEL: spv.func @memref_1bit_type
+// CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<2 x i32, stride=4> [0])>, StorageBuffer>
+func @memref_1bit_type(%arg0: memref<5xi1>) { return }
+
 // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer
 // CHECK-SAME: !spv.ptr<!spv.struct<(!spv.array<16 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_8bit_StorageBuffer


        


More information about the Mlir-commits mailing list