[Mlir-commits] [mlir] d9728a9 - [mlir][spirv] Unify mixed scalar/vector primitive type resources

Lei Zhang llvmlistbot at llvm.org
Mon Aug 8 11:30:30 PDT 2022


Author: Lei Zhang
Date: 2022-08-08T14:30:14-04:00
New Revision: d9728a9baa49e6ddfa6c91af9fd7bf6a760f9a62

URL: https://github.com/llvm/llvm-project/commit/d9728a9baa49e6ddfa6c91af9fd7bf6a760f9a62
DIFF: https://github.com/llvm/llvm-project/commit/d9728a9baa49e6ddfa6c91af9fd7bf6a760f9a62.diff

LOG: [mlir][spirv] Unify mixed scalar/vector primitive type resources

This further relaxes the requirement to allow aliased resources
to have different primitive types and some are scalars while the
other are vectors.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D131207

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
    mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index ab99934434910..8299fd05cda96 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -78,6 +78,7 @@ static Type getRuntimeArrayElementType(Type type) {
 static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
   // scalarNumBits: contains all resources' scalar types' bit counts.
   // vectorNumBits: only contains resources whose element types are vectors.
+  // vectorIndices: each vector's original index in `types`.
   SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
   scalarNumBits.reserve(types.size());
   vectorNumBits.reserve(types.size());
@@ -104,11 +105,6 @@ static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
   }
 
   if (!vectorNumBits.empty()) {
-    // If there are vector types, require all element types to be the same for
-    // now to simplify the transformation.
-    if (!llvm::is_splat(scalarNumBits))
-      return llvm::None;
-
     // Choose the *vector* with the smallest bitwidth as the canonical resource,
     // so that we can still keep vectorized load/store and avoid partial updates
     // to large vectors.
@@ -116,10 +112,18 @@ static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
     // Make sure that the canonical resource's bitwidth is divisible by others.
     // With out this, we cannot properly adjust the index later.
     if (llvm::any_of(vectorNumBits,
-                     [minVal](int64_t bits) { return bits % *minVal != 0; }))
+                     [&](int bits) { return bits % *minVal != 0; }))
+      return llvm::None;
+
+    // Require all scalar type bit counts to be a multiple of the chosen
+    // vector's primitive type to avoid reading/writing subcomponents.
+    int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
+    int baseNumBits = scalarNumBits[index];
+    if (llvm::any_of(scalarNumBits,
+                     [&](int bits) { return bits % baseNumBits != 0; }))
       return llvm::None;
 
-    return vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
+    return index;
   }
 
   // All element types are scalars. Then choose the smallest bitwidth as the
@@ -357,10 +361,10 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       // them into a buffer with vector element types. We need to scale the last
       // index for the vector as a whole, then add one level of index for inside
       // the vector.
-      int srcNumBits = *srcElemType.getSizeInBytes();
-      int dstNumBits = *dstElemType.getSizeInBytes();
-      assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0);
-      int ratio = dstNumBits / srcNumBits;
+      int srcNumBytes = *srcElemType.getSizeInBytes();
+      int dstNumBytes = *dstElemType.getSizeInBytes();
+      assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
+      int ratio = dstNumBytes / srcNumBytes;
       auto ratioValue = rewriter.create<spirv::ConstantOp>(
           loc, i32Type, rewriter.getI32IntegerAttr(ratio));
 
@@ -381,10 +385,10 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       // The source indices are for a buffer with larger bitwidth scalar/vector
       // element types. Rewrite them into a buffer with smaller bitwidth element
       // types. We only need to scale the last index.
-      int srcNumBits = *srcElemType.getSizeInBytes();
-      int dstNumBits = *dstElemType.getSizeInBytes();
-      assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
-      int ratio = srcNumBits / dstNumBits;
+      int srcNumBytes = *srcElemType.getSizeInBytes();
+      int dstNumBytes = *dstElemType.getSizeInBytes();
+      assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
+      int ratio = srcNumBytes / dstNumBytes;
       auto ratioValue = rewriter.create<spirv::ConstantOp>(
           loc, i32Type, rewriter.getI32IntegerAttr(ratio));
 
@@ -435,10 +439,10 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
       // vector types of 
diff erent component counts. For such cases, we load
       // multiple smaller bitwidth values and construct a larger bitwidth one.
 
-      int srcNumBits = *srcElemType.getSizeInBytes() * 8;
-      int dstNumBits = *dstElemType.getSizeInBytes() * 8;
-      assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
-      int ratio = srcNumBits / dstNumBits;
+      int srcNumBytes = *srcElemType.getSizeInBytes();
+      int dstNumBytes = *dstElemType.getSizeInBytes();
+      assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
+      int ratio = srcNumBytes / dstNumBytes;
       if (ratio > 4)
         return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
 

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index d363666451f7f..c17331a161061 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -189,7 +189,7 @@ spv.module Logical GLSL450 {
   spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
   spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
 
-  spv.func @
diff erent_scalar_type(%index: i32, %val0: i32) -> i32 "None" {
+  spv.func @
diff erent_primitive_type(%index: i32, %val0: i32) -> i32 "None" {
     %c0 = spv.Constant 0 : i32
     %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
     %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
@@ -205,7 +205,7 @@ spv.module Logical GLSL450 {
 //     CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
 // CHECK-NOT: @var01s
 
-//     CHECK: spv.func @
diff erent_scalar_type(%{{.+}}: i32, %[[VAL0:.+]]: i32)
+//     CHECK: spv.func @
diff erent_primitive_type(%{{.+}}: i32, %[[VAL0:.+]]: i32)
 //     CHECK:   %[[ADDR:.+]] = spv.mlir.addressof @var01v
 //     CHECK:   %[[AC:.+]] = spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
 //     CHECK:   %[[VAL1:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32
@@ -329,3 +329,122 @@ spv.module Logical GLSL450 {
 //     CHECK:   %[[MOD:.+]] = spv.SMod %[[IDX]], %[[TWO]] : i32
 //     CHECK:   %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[DIV]], %[[MOD]]]
 //     CHECK:   %[[LD:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01_v4f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01_f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01_i64 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>
+
+  spv.func @load_mixed_scalar_vector_primitive_types(%i0: i32) -> vector<4xf32> "None" {
+    %c0 = spv.Constant 0 : i32
+
+    %addr0 = spv.mlir.addressof @var01_v4f32 : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+    %ac0 = spv.AccessChain %addr0[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
+    %vec4val = spv.Load "StorageBuffer" %ac0 : vector<4xf32>
+
+    %addr1 = spv.mlir.addressof @var01_f32 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac1 = spv.AccessChain %addr1[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %f32val = spv.Load "StorageBuffer" %ac1 : f32
+
+    %addr2 = spv.mlir.addressof @var01_i64 : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>
+    %ac2 = spv.AccessChain %addr2[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>, i32, i32
+    %i64val = spv.Load "StorageBuffer" %ac2 : i64
+    %i32val = spv.SConvert %i64val : i64 to i32
+    %castval = spv.Bitcast %i32val : i32 to f32
+
+    %val1 = spv.CompositeInsert %f32val, %vec4val[0 : i32] : f32 into vector<4xf32>
+    %val2 = spv.CompositeInsert %castval, %val1[1 : i32] : f32 into vector<4xf32>
+    spv.ReturnValue %val2 : vector<4xf32>
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01_f32
+// CHECK-NOT: @var01_i64
+//     CHECK: spv.GlobalVariable @var01_v4f32 bind(0, 1) : !spv.ptr<{{.+}}>
+// CHECK-NOT: @var01_f32
+// CHECK-NOT: @var01_i64
+
+// CHECK:  spv.func @load_mixed_scalar_vector_primitive_types(%[[IDX:.+]]: i32)
+
+// CHECK:    %[[ZERO:.+]] = spv.Constant 0 : i32
+// CHECK:    %[[ADDR0:.+]] = spv.mlir.addressof @var01_v4f32
+// CHECK:    %[[AC0:.+]] = spv.AccessChain %[[ADDR0]][%[[ZERO]], %[[IDX]]]
+// CHECK:    spv.Load "StorageBuffer" %[[AC0]] : vector<4xf32>
+
+// CHECK:    %[[ADDR1:.+]] = spv.mlir.addressof @var01_v4f32
+// CHECK:    %[[FOUR:.+]] = spv.Constant 4 : i32
+// CHECK:    %[[DIV:.+]] = spv.SDiv %[[IDX]], %[[FOUR]] : i32
+// CHECK:    %[[MOD:.+]] = spv.SMod %[[IDX]], %[[FOUR]] : i32
+// CHECK:    %[[AC1:.+]] = spv.AccessChain %[[ADDR1]][%[[ZERO]], %[[DIV]], %[[MOD]]]
+// CHECK:    spv.Load "StorageBuffer" %[[AC1]] : f32
+
+// CHECK:    %[[ADDR2:.+]] = spv.mlir.addressof @var01_v4f32
+// CHECK:    %[[TWO:.+]] = spv.Constant 2 : i32
+// CHECK:    %[[DIV0:.+]] = spv.SDiv %[[IDX]], %[[TWO]] : i32
+// CHECK:    %[[MOD0:.+]] = spv.SMod %[[IDX]], %[[TWO]] : i32
+// CHECK:    %[[AC2:.+]] = spv.AccessChain %[[ADDR2]][%[[ZERO]], %[[DIV0]], %[[MOD0]]]
+// CHECK:    %[[LD0:.+]] = spv.Load "StorageBuffer" %[[AC2]] : f32
+
+// CHECK:    %[[ONE:.+]] = spv.Constant 1 : i32
+// CHECK:    %[[MOD1:.+]] = spv.IAdd %[[MOD0]], %[[ONE]]
+// CHECK:    %[[AC3:.+]] = spv.AccessChain %[[ADDR2]][%[[ZERO]], %[[DIV0]], %[[MOD1]]]
+// CHECK:    %[[LD1:.+]] = spv.Load "StorageBuffer" %[[AC3]] : f32
+// CHECK:    %[[CC:.+]] = spv.CompositeConstruct %[[LD0]], %[[LD1]]
+// CHECK:    %[[BC:.+]] = spv.Bitcast %[[CC]] : vector<2xf32> to i64
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01_v2f2 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<2xf32>, stride=16> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01_i64 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>
+
+  spv.func @load_mixed_scalar_vector_primitive_types(%i0: i32) -> i64 "None" {
+    %c0 = spv.Constant 0 : i32
+
+    %addr = spv.mlir.addressof @var01_i64 : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>
+    %ac = spv.AccessChain %addr[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=8> [0])>, StorageBuffer>, i32, i32
+    %val = spv.Load "StorageBuffer" %ac : i64
+
+    spv.ReturnValue %val : i64
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK:  spv.func @load_mixed_scalar_vector_primitive_types(%[[IDX:.+]]: i32)
+
+// CHECK:    %[[ADDR:.+]] = spv.mlir.addressof @var01_v2f2
+// CHECK:    %[[ONE:.+]] = spv.Constant 1 : i32
+// CHECK:    %[[DIV:.+]] = spv.SDiv %[[IDX]], %[[ONE]] : i32
+// CHECK:    %[[MOD:.+]] = spv.SMod %[[IDX]], %[[ONE]] : i32
+// CHECK:    spv.AccessChain %[[ADDR]][%{{.+}}, %[[DIV]], %[[MOD]]]
+// CHECK:    spv.Load
+// CHECK:    spv.Load
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01_v2f2 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<2xf32>, stride=16> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01_i16 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i16, stride=2> [0])>, StorageBuffer>
+
+  spv.func @scalar_type_bitwidth_smaller_than_vector(%i0: i32) -> i16 "None" {
+    %c0 = spv.Constant 0 : i32
+
+    %addr = spv.mlir.addressof @var01_i16 : !spv.ptr<!spv.struct<(!spv.rtarray<i16, stride=2> [0])>, StorageBuffer>
+    %ac = spv.AccessChain %addr[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<i16, stride=2> [0])>, StorageBuffer>, i32, i32
+    %val = spv.Load "StorageBuffer" %ac : i16
+
+    spv.ReturnValue %val : i16
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK: spv.GlobalVariable @var01_v2f2 bind(0, 1) {aliased}
+// CHECK: spv.GlobalVariable @var01_i16 bind(0, 1) {aliased}
+
+// CHECK: spv.func @scalar_type_bitwidth_smaller_than_vector


        


More information about the Mlir-commits mailing list