[Mlir-commits] [mlir] f772dcb - [mlir][spirv] Support sub-byte integer types in type conversion

Lei Zhang llvmlistbot at llvm.org
Thu May 11 15:18:04 PDT 2023


Author: Lei Zhang
Date: 2023-05-11T22:17:56Z
New Revision: f772dcbb5104bc83548e2454909f0a870dfadde5

URL: https://github.com/llvm/llvm-project/commit/f772dcbb5104bc83548e2454909f0a870dfadde5
DIFF: https://github.com/llvm/llvm-project/commit/f772dcbb5104bc83548e2454909f0a870dfadde5.diff

LOG: [mlir][spirv] Support sub-byte integer types in type conversion

Typically GPUs cannot access memory in sub-byte manner. So for
sub-byte integer type values, we need to either expand them to
full bytes or tightly pack them. This commit adds support for
tightly packed power-of-two sub-byte types.

Sub-byte types aren't allowed in SPIR-V spec, so there are no
compute/storage capability for them like other supported integer
types. So we don't recognize sub-byte types in `spirv::ScalarType`.
We just special case them in type converter and always convert
to use i32 under the hood.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D150395

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
    mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
    mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
    mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 7d362526cc22f..e3b5e24ae5681 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -26,10 +26,19 @@ namespace mlir {
 // Type Converter
 //===----------------------------------------------------------------------===//
 
+/// How sub-byte values are storaged in memory.
+enum class SPIRVSubByteTypeStorage {
+  /// Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
+  Packed,
+};
+
 struct SPIRVConversionOptions {
   /// The number of bits to store a boolean value.
   unsigned boolNumBits{8};
 
+  /// How sub-byte values are storaged in memory.
+  SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
+
   /// Whether to emulate narrower scalar types with 32-bit scalar types if not
   /// supported by the target.
   ///

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index fe860b7fb7404..5a5cdfe341942 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 
 #include <functional>
 #include <optional>
@@ -256,6 +257,31 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
                           intType.getSignedness());
 }
 
+/// Converts a sub-byte integer `type` to i32 regardless of target environment.
+///
+/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
+/// the above given that these sub-byte types are not supported at all in
+/// SPIR-V; there are no compute/storage capability for them like other
+/// supported integer types.
+static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
+                                      IntegerType type) {
+  if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
+    LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
+    return nullptr;
+  }
+
+  if (!llvm::isPowerOf2_32(type.getWidth())) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "unsupported non-power-of-two bitwidth in sub-byte" << type
+               << "\n");
+    return nullptr;
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
+  return IntegerType::get(type.getContext(), /*width=*/32,
+                          type.getSignedness());
+}
+
 /// Returns a type with the same shape but with any index element type converted
 /// to the matching integer type. This is a noop when the element type is not
 /// the index type.
@@ -417,8 +443,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
     return wrapInStructAndGetPointer(arrayType, storageClass);
   }
 
-  int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
-  auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
+  int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
+  int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
   if (targetEnv.allows(spirv::Capability::Kernel))
@@ -426,6 +452,38 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
   return wrapInStructAndGetPointer(arrayType, storageClass);
 }
 
+static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
+                                     const SPIRVConversionOptions &options,
+                                     MemRefType type,
+                                     spirv::StorageClass storageClass) {
+  IntegerType elementType = cast<IntegerType>(type.getElementType());
+  Type arrayElemType = convertSubByteIntegerType(options, elementType);
+  if (!arrayElemType)
+    return nullptr;
+  int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
+
+  if (!type.hasStaticShape()) {
+    // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
+    // to the element.
+    if (targetEnv.allows(spirv::Capability::Kernel))
+      return spirv::PointerType::get(arrayElemType, storageClass);
+    int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
+    auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
+    // For Vulkan we need extra wrapping struct and array to satisfy interface
+    // needs.
+    return wrapInStructAndGetPointer(arrayType, storageClass);
+  }
+
+  int64_t memrefSize =
+      llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
+  int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
+  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
+  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
+  if (targetEnv.allows(spirv::Capability::Kernel))
+    return spirv::PointerType::get(arrayType, storageClass);
+  return wrapInStructAndGetPointer(arrayType, storageClass);
+}
+
 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
                               const SPIRVConversionOptions &options,
                               MemRefType type) {
@@ -441,9 +499,11 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   }
   spirv::StorageClass storageClass = attr.getValue();
 
-  if (type.getElementType().isa<IntegerType>() &&
-      type.getElementTypeBitWidth() == 1) {
-    return convertBoolMemrefType(targetEnv, options, type, storageClass);
+  if (type.getElementType().isa<IntegerType>()) {
+    if (type.getElementTypeBitWidth() == 1)
+      return convertBoolMemrefType(targetEnv, options, type, storageClass);
+    if (type.getElementTypeBitWidth() < 8)
+      return convertSubByteMemrefType(targetEnv, options, type, storageClass);
   }
 
   Type arrayElemType;
@@ -497,7 +557,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-  auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
+  int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
   int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
   if (targetEnv.allows(spirv::Capability::Kernel))
@@ -514,10 +574,10 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
   // were tried before.
   //
-  // TODO: this assumes that the SPIR-V types are valid to use in
-  // the given target environment, which should be the case if the whole
-  // pipeline is driven by the same target environment. Still, we probably still
-  // want to validate and convert to be safe.
+  // TODO: This assumes that the SPIR-V types are valid to use in the given
+  // target environment, which should be the case if the whole pipeline is
+  // driven by the same target environment. Still, we probably still want to
+  // validate and convert to be safe.
   addConversion([](spirv::SPIRVType type) { return type; });
 
   addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
@@ -525,6 +585,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   addConversion([this](IntegerType intType) -> std::optional<Type> {
     if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
       return convertScalarType(this->targetEnv, this->options, scalarType);
+    if (intType.getWidth() < 8)
+      return convertSubByteIntegerType(this->options, intType);
     return Type();
   });
 

diff  --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index fe6edc132e0e3..ef1ee00b709fd 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -94,14 +94,32 @@ func.func @integer64(%arg0: i64, %arg1: si64, %arg2: ui64) { return }
 
 // -----
 
-// Check that weird bitwidths are not supported.
+// Check that power-of-two sub-byte bitwidths are converted to i32.
 module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
 } {
 
-// CHECK-NOT: spirv.func @integer4
+// CHECK: spirv.func @integer2(%{{.+}}: i32)
+func.func @integer2(%arg0: i8) { return }
+
+// CHECK: spirv.func @integer4(%{{.+}}: i32)
 func.func @integer4(%arg0: i4) { return }
 
+} // end module
+
+// -----
+
+// Check that other bitwidths are not supported.
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-NOT: spirv.func @integer3
+func.func @integer3(%arg0: i3) { return }
+
+// CHECK-NOT: spirv.func @integer13
+func.func @integer4(%arg0: i13) { return }
+
 // CHECK-NOT: spirv.func @integer128
 func.func @integer128(%arg0: i128) { return }
 
@@ -109,6 +127,7 @@ func.func @integer128(%arg0: i128) { return }
 func.func @integer42(%arg0: i42) { return }
 
 } // end module
+
 // -----
 
 //===----------------------------------------------------------------------===//
@@ -421,6 +440,16 @@ module attributes {
 // NOEMU-SAME: memref<5xi1, #spirv.storage_class<StorageBuffer>>
 func.func @memref_1bit_type(%arg0: memref<5xi1, #spirv.storage_class<StorageBuffer>>) { return }
 
+// 16 i2 values are tightly packed into one i32 value; so 33 i2 values takes 3 i32 value.
+// CHECK-LABEL: spirv.func @memref_2bit_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<3 x i32, stride=4> [0])>, StorageBuffer>
+func.func @memref_2bit_type(%arg0: memref<33xi2, #spirv.storage_class<StorageBuffer>>) { return }
+
+// 8 i4 values are tightly packed into one i32 value; so 16 i4 values takes 2 i32 value.
+// CHECK-LABEL: spirv.func @memref_4bit_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<2 x i32, stride=4> [0])>, StorageBuffer>
+func.func @memref_4bit_type(%arg0: memref<16xi4, #spirv.storage_class<StorageBuffer>>) { return }
+
 // CHECK-LABEL: spirv.func @memref_8bit_StorageBuffer
 // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>
 // NOEMU-LABEL: func @memref_8bit_StorageBuffer
@@ -725,6 +754,14 @@ func.func @unranked_memref(%arg0: memref<*xi32>) { return }
 // NOEMU-SAME: memref<?xi1, #spirv.storage_class<StorageBuffer>>
 func.func @memref_1bit_type(%arg0: memref<?xi1, #spirv.storage_class<StorageBuffer>>) { return }
 
+// CHECK-LABEL: spirv.func @memref_2bit_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+func.func @memref_2bit_type(%arg0: memref<?xi2, #spirv.storage_class<StorageBuffer>>) { return }
+
+// CHECK-LABEL: spirv.func @memref_4bit_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+func.func @memref_4bit_type(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>) { return }
+
 // CHECK-LABEL: func @dynamic_dim_memref
 // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>
 // CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 2552503fc3043..97d1a3add9c18 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -419,3 +419,63 @@ func.func @store_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>, %v
 }
 
 } // end module
+
+// -----
+
+// Check that access chain indices are properly adjusted if sub-byte types are
+// emulated via 32-bit types.
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_i4
+func.func @load_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %i: index) -> i4 {
+  // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
+  // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32
+  // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
+  // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
+  // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32
+  // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+  // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
+  // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
+  // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32
+  // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
+  // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+  // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
+  // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+  // CHECK: %[[C28:.+]] = spirv.Constant 28 : i32
+  // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[AND]], %[[C28]] : i32, i32
+  // CHECK: spirv.ShiftRightArithmetic %[[SL]], %[[C28]] : i32, i32
+  %0 = memref.load %arg0[%i] : memref<?xi4, #spirv.storage_class<StorageBuffer>>
+  return %0 : i4
+}
+
+// CHECK-LABEL: @store_i4
+func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %value: i4, %i: index) {
+  // CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32
+  // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
+  // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+  // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
+  // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32
+  // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
+  // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
+  // CHECK: %[[FOUR:.+]] = spirv.Constant [[OFFSET]] : i32
+  // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32
+  // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
+  // CHECK: %[[MASK1:.+]] = spirv.Constant 15 : i32
+  // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[BITS]] : i32, i32
+  // CHECK: %[[MASK2:.+]] = spirv.Not %[[SL]] : i32
+  // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[VAL]], %[[MASK1]] : i32
+  // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[BITS]] : i32, i32
+  // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32
+  // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ACCESS_INDEX]]]
+  // CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK2]]
+  // CHECK: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+  memref.store %value, %arg0[%i] : memref<?xi4, #spirv.storage_class<StorageBuffer>>
+  return
+}
+
+} // end module


        


More information about the Mlir-commits mailing list