[Mlir-commits] [mlir] f772dcb - [mlir][spirv] Support sub-byte integer types in type conversion
Lei Zhang
llvmlistbot at llvm.org
Thu May 11 15:18:04 PDT 2023
Author: Lei Zhang
Date: 2023-05-11T22:17:56Z
New Revision: f772dcbb5104bc83548e2454909f0a870dfadde5
URL: https://github.com/llvm/llvm-project/commit/f772dcbb5104bc83548e2454909f0a870dfadde5
DIFF: https://github.com/llvm/llvm-project/commit/f772dcbb5104bc83548e2454909f0a870dfadde5.diff
LOG: [mlir][spirv] Support sub-byte integer types in type conversion
Typically GPUs cannot access memory in sub-byte manner. So for
sub-byte integer type values, we need to either expand them to
full bytes or tightly pack them. This commit adds support for
tightly packed power-of-two sub-byte types.
Sub-byte types aren't allowed in SPIR-V spec, so there are no
compute/storage capability for them like other supported integer
types. So we don't recognize sub-byte types in `spirv::ScalarType`.
We just special case them in type converter and always convert
to use i32 under the hood.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D150395
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 7d362526cc22f..e3b5e24ae5681 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -26,10 +26,19 @@ namespace mlir {
// Type Converter
//===----------------------------------------------------------------------===//
+/// How sub-byte values are storaged in memory.
+enum class SPIRVSubByteTypeStorage {
+ /// Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
+ Packed,
+};
+
struct SPIRVConversionOptions {
/// The number of bits to store a boolean value.
unsigned boolNumBits{8};
+ /// How sub-byte values are storaged in memory.
+ SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
+
/// Whether to emulate narrower scalar types with 32-bit scalar types if not
/// supported by the target.
///
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index fe860b7fb7404..5a5cdfe341942 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -20,6 +20,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include <functional>
#include <optional>
@@ -256,6 +257,31 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
intType.getSignedness());
}
+/// Converts a sub-byte integer `type` to i32 regardless of target environment.
+///
+/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
+/// the above given that these sub-byte types are not supported at all in
+/// SPIR-V; there are no compute/storage capability for them like other
+/// supported integer types.
+static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
+ IntegerType type) {
+ if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
+ LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
+ return nullptr;
+ }
+
+ if (!llvm::isPowerOf2_32(type.getWidth())) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "unsupported non-power-of-two bitwidth in sub-byte" << type
+ << "\n");
+ return nullptr;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
+ return IntegerType::get(type.getContext(), /*width=*/32,
+ type.getSignedness());
+}
+
/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
@@ -417,8 +443,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
return wrapInStructAndGetPointer(arrayType, storageClass);
}
- int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
- auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
+ int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
+ int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
if (targetEnv.allows(spirv::Capability::Kernel))
@@ -426,6 +452,38 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
return wrapInStructAndGetPointer(arrayType, storageClass);
}
+static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
+ const SPIRVConversionOptions &options,
+ MemRefType type,
+ spirv::StorageClass storageClass) {
+ IntegerType elementType = cast<IntegerType>(type.getElementType());
+ Type arrayElemType = convertSubByteIntegerType(options, elementType);
+ if (!arrayElemType)
+ return nullptr;
+ int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
+
+ if (!type.hasStaticShape()) {
+ // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
+ // to the element.
+ if (targetEnv.allows(spirv::Capability::Kernel))
+ return spirv::PointerType::get(arrayElemType, storageClass);
+ int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
+ auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
+ // For Vulkan we need extra wrapping struct and array to satisfy interface
+ // needs.
+ return wrapInStructAndGetPointer(arrayType, storageClass);
+ }
+
+ int64_t memrefSize =
+ llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
+ int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
+ int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
+ auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
+ if (targetEnv.allows(spirv::Capability::Kernel))
+ return spirv::PointerType::get(arrayType, storageClass);
+ return wrapInStructAndGetPointer(arrayType, storageClass);
+}
+
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options,
MemRefType type) {
@@ -441,9 +499,11 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
}
spirv::StorageClass storageClass = attr.getValue();
- if (type.getElementType().isa<IntegerType>() &&
- type.getElementTypeBitWidth() == 1) {
- return convertBoolMemrefType(targetEnv, options, type, storageClass);
+ if (type.getElementType().isa<IntegerType>()) {
+ if (type.getElementTypeBitWidth() == 1)
+ return convertBoolMemrefType(targetEnv, options, type, storageClass);
+ if (type.getElementTypeBitWidth() < 8)
+ return convertSubByteMemrefType(targetEnv, options, type, storageClass);
}
Type arrayElemType;
@@ -497,7 +557,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
+ int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
if (targetEnv.allows(spirv::Capability::Kernel))
@@ -514,10 +574,10 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
// adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
// were tried before.
//
- // TODO: this assumes that the SPIR-V types are valid to use in
- // the given target environment, which should be the case if the whole
- // pipeline is driven by the same target environment. Still, we probably still
- // want to validate and convert to be safe.
+ // TODO: This assumes that the SPIR-V types are valid to use in the given
+ // target environment, which should be the case if the whole pipeline is
+ // driven by the same target environment. Still, we probably still want to
+ // validate and convert to be safe.
addConversion([](spirv::SPIRVType type) { return type; });
addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
@@ -525,6 +585,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](IntegerType intType) -> std::optional<Type> {
if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
return convertScalarType(this->targetEnv, this->options, scalarType);
+ if (intType.getWidth() < 8)
+ return convertSubByteIntegerType(this->options, intType);
return Type();
});
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index fe6edc132e0e3..ef1ee00b709fd 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -94,14 +94,32 @@ func.func @integer64(%arg0: i64, %arg1: si64, %arg2: ui64) { return }
// -----
-// Check that weird bitwidths are not supported.
+// Check that power-of-two sub-byte bitwidths are converted to i32.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
} {
-// CHECK-NOT: spirv.func @integer4
+// CHECK: spirv.func @integer2(%{{.+}}: i32)
+func.func @integer2(%arg0: i8) { return }
+
+// CHECK: spirv.func @integer4(%{{.+}}: i32)
func.func @integer4(%arg0: i4) { return }
+} // end module
+
+// -----
+
+// Check that other bitwidths are not supported.
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-NOT: spirv.func @integer3
+func.func @integer3(%arg0: i3) { return }
+
+// CHECK-NOT: spirv.func @integer13
+func.func @integer4(%arg0: i13) { return }
+
// CHECK-NOT: spirv.func @integer128
func.func @integer128(%arg0: i128) { return }
@@ -109,6 +127,7 @@ func.func @integer128(%arg0: i128) { return }
func.func @integer42(%arg0: i42) { return }
} // end module
+
// -----
//===----------------------------------------------------------------------===//
@@ -421,6 +440,16 @@ module attributes {
// NOEMU-SAME: memref<5xi1, #spirv.storage_class<StorageBuffer>>
func.func @memref_1bit_type(%arg0: memref<5xi1, #spirv.storage_class<StorageBuffer>>) { return }
+// 16 i2 values are tightly packed into one i32 value; so 33 i2 values takes 3 i32 value.
+// CHECK-LABEL: spirv.func @memref_2bit_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<3 x i32, stride=4> [0])>, StorageBuffer>
+func.func @memref_2bit_type(%arg0: memref<33xi2, #spirv.storage_class<StorageBuffer>>) { return }
+
+// 8 i4 values are tightly packed into one i32 value; so 16 i4 values takes 2 i32 value.
+// CHECK-LABEL: spirv.func @memref_4bit_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<2 x i32, stride=4> [0])>, StorageBuffer>
+func.func @memref_4bit_type(%arg0: memref<16xi4, #spirv.storage_class<StorageBuffer>>) { return }
+
// CHECK-LABEL: spirv.func @memref_8bit_StorageBuffer
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<4 x i32, stride=4> [0])>, StorageBuffer>
// NOEMU-LABEL: func @memref_8bit_StorageBuffer
@@ -725,6 +754,14 @@ func.func @unranked_memref(%arg0: memref<*xi32>) { return }
// NOEMU-SAME: memref<?xi1, #spirv.storage_class<StorageBuffer>>
func.func @memref_1bit_type(%arg0: memref<?xi1, #spirv.storage_class<StorageBuffer>>) { return }
+// CHECK-LABEL: spirv.func @memref_2bit_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+func.func @memref_2bit_type(%arg0: memref<?xi2, #spirv.storage_class<StorageBuffer>>) { return }
+
+// CHECK-LABEL: spirv.func @memref_4bit_type
+// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+func.func @memref_4bit_type(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>) { return }
+
// CHECK-LABEL: func @dynamic_dim_memref
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.rtarray<i32, stride=4> [0])>, StorageBuffer>
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32, stride=4> [0])>, StorageBuffer>
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 2552503fc3043..97d1a3add9c18 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -419,3 +419,63 @@ func.func @store_f32(%arg0: memref<f32, #spirv.storage_class<StorageBuffer>>, %v
}
} // end module
+
+// -----
+
+// Check that access chain indices are properly adjusted if sub-byte types are
+// emulated via 32-bit types.
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @load_i4
+func.func @load_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %i: index) -> i4 {
+ // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+ // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
+ // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32
+ // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
+ // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
+ // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]]
+ // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32
+ // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32
+ // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32
+ // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
+ // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32
+ // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32
+ // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32
+ // CHECK: %[[C28:.+]] = spirv.Constant 28 : i32
+ // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[AND]], %[[C28]] : i32, i32
+ // CHECK: spirv.ShiftRightArithmetic %[[SL]], %[[C28]] : i32, i32
+ %0 = memref.load %arg0[%i] : memref<?xi4, #spirv.storage_class<StorageBuffer>>
+ return %0 : i4
+}
+
+// CHECK-LABEL: @store_i4
+func.func @store_i4(%arg0: memref<?xi4, #spirv.storage_class<StorageBuffer>>, %value: i4, %i: index) {
+ // CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32
+ // CHECK: %[[INDEX:.+]] = builtin.unrealized_conversion_cast %{{.+}} : index to i32
+ // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32
+ // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32
+ // CHECK: %[[MUL:.+]] = spirv.IMul %[[ONE]], %[[INDEX]] : i32
+ // CHECK: %[[OFFSET:.+]] = spirv.IAdd %[[ZERO]], %[[MUL]] : i32
+ // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32
+ // CHECK: %[[FOUR:.+]] = spirv.Constant [[OFFSET]] : i32
+ // CHECK: %[[IDX:.+]] = spirv.UMod %[[OFFSET]], %[[EIGHT]] : i32
+ // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[FOUR]] : i32
+ // CHECK: %[[MASK1:.+]] = spirv.Constant 15 : i32
+ // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[BITS]] : i32, i32
+ // CHECK: %[[MASK2:.+]] = spirv.Not %[[SL]] : i32
+ // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[VAL]], %[[MASK1]] : i32
+ // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[BITS]] : i32, i32
+ // CHECK: %[[ACCESS_INDEX:.+]] = spirv.SDiv %[[OFFSET]], %[[EIGHT]] : i32
+ // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[ACCESS_INDEX]]]
+ // CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK2]]
+ // CHECK: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]]
+ memref.store %value, %arg0[%i] : memref<?xi4, #spirv.storage_class<StorageBuffer>>
+ return
+}
+
+} // end module
More information about the Mlir-commits
mailing list