[Mlir-commits] [mlir] d9728a9 - [mlir][spirv] Unify mixed scalar/vector primitive type resources
Lei Zhang
llvmlistbot at llvm.org
Mon Aug 8 11:30:30 PDT 2022
Author: Lei Zhang
Date: 2022-08-08T14:30:14-04:00
New Revision: d9728a9baa49e6ddfa6c91af9fd7bf6a760f9a62
URL: https://github.com/llvm/llvm-project/commit/d9728a9baa49e6ddfa6c91af9fd7bf6a760f9a62
DIFF: https://github.com/llvm/llvm-project/commit/d9728a9baa49e6ddfa6c91af9fd7bf6a760f9a62.diff
LOG: [mlir][spirv] Unify mixed scalar/vector primitive type resources
This further relaxes the requirement to allow aliased resources
to have different primitive types and some are scalars while the
other are vectors.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D131207
Added:
Modified:
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index ab99934434910..8299fd05cda96 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -78,6 +78,7 @@ static Type getRuntimeArrayElementType(Type type) {
static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
// scalarNumBits: contains all resources' scalar types' bit counts.
// vectorNumBits: only contains resources whose element types are vectors.
+ // vectorIndices: each vector's original index in `types`.
SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
scalarNumBits.reserve(types.size());
vectorNumBits.reserve(types.size());
@@ -104,11 +105,6 @@ static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
}
if (!vectorNumBits.empty()) {
- // If there are vector types, require all element types to be the same for
- // now to simplify the transformation.
- if (!llvm::is_splat(scalarNumBits))
- return llvm::None;
-
// Choose the *vector* with the smallest bitwidth as the canonical resource,
// so that we can still keep vectorized load/store and avoid partial updates
// to large vectors.
@@ -116,10 +112,18 @@ static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
// Make sure that the canonical resource's bitwidth is divisible by others.
// With out this, we cannot properly adjust the index later.
if (llvm::any_of(vectorNumBits,
- [minVal](int64_t bits) { return bits % *minVal != 0; }))
+ [&](int bits) { return bits % *minVal != 0; }))
+ return llvm::None;
+
+ // Require all scalar type bit counts to be a multiple of the chosen
+ // vector's primitive type to avoid reading/writing subcomponents.
+ int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
+ int baseNumBits = scalarNumBits[index];
+ if (llvm::any_of(scalarNumBits,
+ [&](int bits) { return bits % baseNumBits != 0; }))
return llvm::None;
- return vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
+ return index;
}
// All element types are scalars. Then choose the smallest bitwidth as the
@@ -357,10 +361,10 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
// them into a buffer with vector element types. We need to scale the last
// index for the vector as a whole, then add one level of index for inside
// the vector.
- int srcNumBits = *srcElemType.getSizeInBytes();
- int dstNumBits = *dstElemType.getSizeInBytes();
- assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0);
- int ratio = dstNumBits / srcNumBits;
+ int srcNumBytes = *srcElemType.getSizeInBytes();
+ int dstNumBytes = *dstElemType.getSizeInBytes();
+ assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
+ int ratio = dstNumBytes / srcNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
@@ -381,10 +385,10 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
// The source indices are for a buffer with larger bitwidth scalar/vector
// element types. Rewrite them into a buffer with smaller bitwidth element
// types. We only need to scale the last index.
- int srcNumBits = *srcElemType.getSizeInBytes();
- int dstNumBits = *dstElemType.getSizeInBytes();
- assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
- int ratio = srcNumBits / dstNumBits;
+ int srcNumBytes = *srcElemType.getSizeInBytes();
+ int dstNumBytes = *dstElemType.getSizeInBytes();
+ assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
+ int ratio = srcNumBytes / dstNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
@@ -435,10 +439,10 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
// vector types of
diff erent component counts. For such cases, we load
// multiple smaller bitwidth values and construct a larger bitwidth one.
- int srcNumBits = *srcElemType.getSizeInBytes() * 8;
- int dstNumBits = *dstElemType.getSizeInBytes() * 8;
- assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
- int ratio = srcNumBits / dstNumBits;
+ int srcNumBytes = *srcElemType.getSizeInBytes();
+ int dstNumBytes = *dstElemType.getSizeInBytes();
+ assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
+ int ratio = srcNumBytes / dstNumBytes;
if (ratio > 4)
return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index d363666451f7f..c17331a161061 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -189,7 +189,7 @@ spv.module Logical GLSL450 {
spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
- spv.func @
diff erent_scalar_type(%index: i32, %val0: i32) -> i32 "None" {
+ spv.func @
diff erent_primitive_type(%index: i32, %val0: i32) -> i32 "None" {
%c0 = spv.Constant 0 : i32
%addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
%ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
@@ -205,7 +205,7 @@ spv.module Logical GLSL450 {
// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
// CHECK-NOT: @var01s
-// CHECK: spv.func @
diff erent_scalar_type(%{{.+}}: i32, %[[VAL0:.+]]: i32)
+// CHECK: spv.func @
diff erent_primitive_type(%{{.+}}: i32, %[[VAL0:.+]]: i32)
// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01v
// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
// CHECK: %[[VAL1:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32
@@ -329,3 +329,122 @@ spv.module Logical GLSL450 {
// CHECK: %[[MOD:.+]] = spv.SMod %[[IDX]], %[[TWO]] : i32
// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[DIV]], %[[MOD]]]
// CHECK: %[[LD:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01_v4f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01_f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01_i64 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>
+
+ spv.func @load_mixed_scalar_vector_primitive_types(%i0: i32) -> vector<4xf32> "None" {
+ %c0 = spv.Constant 0 : i32
+
+ %addr0 = spv.mlir.addressof @var01_v4f32 : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+ %ac0 = spv.AccessChain %addr0[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
+ %vec4val = spv.Load "StorageBuffer" %ac0 : vector<4xf32>
+
+ %addr1 = spv.mlir.addressof @var01_f32 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac1 = spv.AccessChain %addr1[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %f32val = spv.Load "StorageBuffer" %ac1 : f32
+
+ %addr2 = spv.mlir.addressof @var01_i64 : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>
+ %ac2 = spv.AccessChain %addr2[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>, i32, i32
+ %i64val = spv.Load "StorageBuffer" %ac2 : i64
+ %i32val = spv.SConvert %i64val : i64 to i32
+ %castval = spv.Bitcast %i32val : i32 to f32
+
+ %val1 = spv.CompositeInsert %f32val, %vec4val[0 : i32] : f32 into vector<4xf32>
+ %val2 = spv.CompositeInsert %castval, %val1[1 : i32] : f32 into vector<4xf32>
+ spv.ReturnValue %val2 : vector<4xf32>
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01_f32
+// CHECK-NOT: @var01_i64
+// CHECK: spv.GlobalVariable @var01_v4f32 bind(0, 1) : !spv.ptr<{{.+}}>
+// CHECK-NOT: @var01_f32
+// CHECK-NOT: @var01_i64
+
+// CHECK: spv.func @load_mixed_scalar_vector_primitive_types(%[[IDX:.+]]: i32)
+
+// CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32
+// CHECK: %[[ADDR0:.+]] = spv.mlir.addressof @var01_v4f32
+// CHECK: %[[AC0:.+]] = spv.AccessChain %[[ADDR0]][%[[ZERO]], %[[IDX]]]
+// CHECK: spv.Load "StorageBuffer" %[[AC0]] : vector<4xf32>
+
+// CHECK: %[[ADDR1:.+]] = spv.mlir.addressof @var01_v4f32
+// CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32
+// CHECK: %[[DIV:.+]] = spv.SDiv %[[IDX]], %[[FOUR]] : i32
+// CHECK: %[[MOD:.+]] = spv.SMod %[[IDX]], %[[FOUR]] : i32
+// CHECK: %[[AC1:.+]] = spv.AccessChain %[[ADDR1]][%[[ZERO]], %[[DIV]], %[[MOD]]]
+// CHECK: spv.Load "StorageBuffer" %[[AC1]] : f32
+
+// CHECK: %[[ADDR2:.+]] = spv.mlir.addressof @var01_v4f32
+// CHECK: %[[TWO:.+]] = spv.Constant 2 : i32
+// CHECK: %[[DIV0:.+]] = spv.SDiv %[[IDX]], %[[TWO]] : i32
+// CHECK: %[[MOD0:.+]] = spv.SMod %[[IDX]], %[[TWO]] : i32
+// CHECK: %[[AC2:.+]] = spv.AccessChain %[[ADDR2]][%[[ZERO]], %[[DIV0]], %[[MOD0]]]
+// CHECK: %[[LD0:.+]] = spv.Load "StorageBuffer" %[[AC2]] : f32
+
+// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+// CHECK: %[[MOD1:.+]] = spv.IAdd %[[MOD0]], %[[ONE]]
+// CHECK: %[[AC3:.+]] = spv.AccessChain %[[ADDR2]][%[[ZERO]], %[[DIV0]], %[[MOD1]]]
+// CHECK: %[[LD1:.+]] = spv.Load "StorageBuffer" %[[AC3]] : f32
+// CHECK: %[[CC:.+]] = spv.CompositeConstruct %[[LD0]], %[[LD1]]
+// CHECK: %[[BC:.+]] = spv.Bitcast %[[CC]] : vector<2xf32> to i64
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01_v2f2 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<2xf32>, stride=16> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01_i64 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>
+
+ spv.func @load_mixed_scalar_vector_primitive_types(%i0: i32) -> i64 "None" {
+ %c0 = spv.Constant 0 : i32
+
+ %addr = spv.mlir.addressof @var01_i64 : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>
+ %ac = spv.AccessChain %addr[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>, i32, i32
+ %val = spv.Load "StorageBuffer" %ac : i64
+
+ spv.ReturnValue %val : i64
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK: spv.func @load_mixed_scalar_vector_primitive_types(%[[IDX:.+]]: i32)
+
+// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01_v2f2
+// CHECK: %[[ONE:.+]] = spv.Constant 1 : i32
+// CHECK: %[[DIV:.+]] = spv.SDiv %[[IDX]], %[[ONE]] : i32
+// CHECK: %[[MOD:.+]] = spv.SMod %[[IDX]], %[[ONE]] : i32
+// CHECK: spv.AccessChain %[[ADDR]][%{{.+}}, %[[DIV]], %[[MOD]]]
+// CHECK: spv.Load
+// CHECK: spv.Load
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01_v2f2 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<2xf32>, stride=16> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01_i16 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i16, stride=2> [0])>, StorageBuffer>
+
+ spv.func @scalar_type_bitwidth_smaller_than_vector(%i0: i32) -> i16 "None" {
+ %c0 = spv.Constant 0 : i32
+
+ %addr = spv.mlir.addressof @var01_i16 : !spv.ptr<!spv.struct<(!spv.rtarray<i16, stride=2> [0])>, StorageBuffer>
+ %ac = spv.AccessChain %addr[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<i16, stride=2> [0])>, StorageBuffer>, i32, i32
+ %val = spv.Load "StorageBuffer" %ac : i16
+
+ spv.ReturnValue %val : i16
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK: spv.GlobalVariable @var01_v2f2 bind(0, 1) {aliased}
+// CHECK: spv.GlobalVariable @var01_i16 bind(0, 1) {aliased}
+
+// CHECK: spv.func @scalar_type_bitwidth_smaller_than_vector
More information about the Mlir-commits
mailing list