[Mlir-commits] [mlir] 067daa5 - [mlir][spirv] Unify resources of different vector sizes

Lei Zhang llvmlistbot at llvm.org
Wed Jul 27 16:28:40 PDT 2022


Author: Lei Zhang
Date: 2022-07-27T19:22:50-04:00
New Revision: 067daa56a9040f5e785aff4ebd299a6f57eb68c6

URL: https://github.com/llvm/llvm-project/commit/067daa56a9040f5e785aff4ebd299a6f57eb68c6
DIFF: https://github.com/llvm/llvm-project/commit/067daa56a9040f5e785aff4ebd299a6f57eb68c6.diff

LOG: [mlir][spirv] Unify resources of different vector sizes

This commit extends UnifyAliasedResourcePass to handle the case
where aliased resources have different vector sizes. (It still
requires all scalar types to be of the same bitwidth.) This is
effectively reusing the code for handling different-bitwidth
scalar types.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D130671

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
    mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index e95a4cc395cee..ab99934434910 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -20,7 +20,6 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/SymbolTable.h"
-#include "mlir/Pass/AnalysisManager.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/STLExtras.h"
@@ -77,12 +76,15 @@ static Type getRuntimeArrayElementType(Type type) {
 /// resource that all resources should be unified into. Returns llvm::None if
 /// unable to unify.
 static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
-  SmallVector<int> scalarNumBits, totalNumBits;
+  // scalarNumBits: contains all resources' scalar types' bit counts.
+  // vectorNumBits: only contains resources whose element types are vectors.
+  SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
   scalarNumBits.reserve(types.size());
-  totalNumBits.reserve(types.size());
-  bool hasVector = false;
+  vectorNumBits.reserve(types.size());
+  vectorIndices.reserve(types.size());
 
-  for (spirv::SPIRVType type : types) {
+  for (const auto &indexedTypes : llvm::enumerate(types)) {
+    spirv::SPIRVType type = indexedTypes.value();
     assert(type.isScalarOrVector());
     if (auto vectorType = type.dyn_cast<VectorType>()) {
       if (vectorType.getNumElements() % 2 != 0)
@@ -94,30 +96,30 @@ static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
 
       scalarNumBits.push_back(
           vectorType.getElementType().getIntOrFloatBitWidth());
-      totalNumBits.push_back(*numBytes * 8);
-      hasVector = true;
+      vectorNumBits.push_back(*numBytes * 8);
+      vectorIndices.push_back(indexedTypes.index());
     } else {
       scalarNumBits.push_back(type.getIntOrFloatBitWidth());
-      totalNumBits.push_back(scalarNumBits.back());
     }
   }
 
-  if (hasVector) {
+  if (!vectorNumBits.empty()) {
     // If there are vector types, require all element types to be the same for
     // now to simplify the transformation.
     if (!llvm::is_splat(scalarNumBits))
       return llvm::None;
 
-    // Choose the one with the largest bitwidth as the canonical resource, so
-    // that we can still keep vectorized load/store.
-    auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end());
+    // Choose the *vector* with the smallest bitwidth as the canonical resource,
+    // so that we can still keep vectorized load/store and avoid partial updates
+    // to large vectors.
+    auto *minVal = std::min_element(vectorNumBits.begin(), vectorNumBits.end());
     // Make sure that the canonical resource's bitwidth is divisible by others.
     // With out this, we cannot properly adjust the index later.
-    if (llvm::any_of(totalNumBits,
-                     [maxVal](int64_t bits) { return *maxVal % bits != 0; }))
+    if (llvm::any_of(vectorNumBits,
+                     [minVal](int64_t bits) { return bits % *minVal != 0; }))
       return llvm::None;
 
-    return std::distance(totalNumBits.begin(), maxVal);
+    return vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
   }
 
   // All element types are scalars. Then choose the smallest bitwidth as the
@@ -374,10 +376,11 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       return success();
     }
 
-    if (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) {
-      // The source indices are for a buffer with larger bitwidth scalar element
-      // types. Rewrite them into a buffer with smaller bitwidth element types.
-      // We only need to scale the last index.
+    if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
+        (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
+      // The source indices are for a buffer with larger bitwidth scalar/vector
+      // element types. Rewrite them into a buffer with smaller bitwidth element
+      // types. We only need to scale the last index.
       int srcNumBits = *srcElemType.getSizeInBytes();
       int dstNumBits = *dstElemType.getSizeInBytes();
       assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
@@ -395,7 +398,8 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       return success();
     }
 
-    return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
+    return rewriter.notifyMatchFailure(
+        acOp, "unsupported src/dst types for spv.AccessChain");
   }
 };
 
@@ -405,12 +409,10 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
   LogicalResult
   matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto srcElemType =
-        loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
-    auto dstElemType =
-        adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
-    if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
-      return rewriter.notifyMatchFailure(loadOp, "not scalar type");
+    auto srcPtrType = loadOp.ptr().getType().cast<spirv::PointerType>();
+    auto srcElemType = srcPtrType.getPointeeType().cast<spirv::SPIRVType>();
+    auto dstPtrType = adaptor.ptr().getType().cast<spirv::PointerType>();
+    auto dstElemType = dstPtrType.getPointeeType().cast<spirv::SPIRVType>();
 
     Location loc = loadOp.getLoc();
     auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
@@ -427,48 +429,60 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
       return success();
     }
 
-    // The source and destination have scalar types of 
diff erent bitwidths.
-    // For such cases, we need to load multiple smaller bitwidth values and
-    // construct a larger bitwidth one.
+    if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
+        (srcElemType.isa<VectorType>() && dstElemType.isa<VectorType>())) {
+      // The source and destination have scalar types of 
diff erent bitwidths, or
+      // vector types of 
diff erent component counts. For such cases, we load
+      // multiple smaller bitwidth values and construct a larger bitwidth one.
 
-    int srcNumBits = srcElemType.getIntOrFloatBitWidth();
-    int dstNumBits = dstElemType.getIntOrFloatBitWidth();
-    assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
-    int ratio = srcNumBits / dstNumBits;
-    if (ratio > 4)
-      return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
+      int srcNumBits = *srcElemType.getSizeInBytes() * 8;
+      int dstNumBits = *dstElemType.getSizeInBytes() * 8;
+      assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
+      int ratio = srcNumBits / dstNumBits;
+      if (ratio > 4)
+        return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
 
-    SmallVector<Value> components;
-    components.reserve(ratio);
-    components.push_back(newLoadOp);
+      SmallVector<Value> components;
+      components.reserve(ratio);
+      components.push_back(newLoadOp);
 
-    auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
-    if (!acOp)
-      return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
+      auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
+      if (!acOp)
+        return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
 
-    auto i32Type = rewriter.getI32Type();
-    Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
-    auto indices = llvm::to_vector<4>(acOp.indices());
-    for (int i = 1; i < ratio; ++i) {
-      // Load all subsequent components belonging to this element.
-      indices.back() = rewriter.create<spirv::IAddOp>(loc, i32Type,
-                                                      indices.back(), oneValue);
-      auto componentAcOp =
-          rewriter.create<spirv::AccessChainOp>(loc, acOp.base_ptr(), indices);
-      // Assuming little endian, this reads lower-ordered bits of the number to
-      // lower-numbered components of the vector.
-      components.push_back(rewriter.create<spirv::LoadOp>(loc, componentAcOp));
+      auto i32Type = rewriter.getI32Type();
+      Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
+      auto indices = llvm::to_vector<4>(acOp.indices());
+      for (int i = 1; i < ratio; ++i) {
+        // Load all subsequent components belonging to this element.
+        indices.back() = rewriter.create<spirv::IAddOp>(
+            loc, i32Type, indices.back(), oneValue);
+        auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
+            loc, acOp.base_ptr(), indices);
+        // Assuming little endian, this reads lower-ordered bits of the number
+        // to lower-numbered components of the vector.
+        components.push_back(
+            rewriter.create<spirv::LoadOp>(loc, componentAcOp));
+      }
+
+      // Create a vector of the components and then cast back to the larger
+      // bitwidth element type. For spv.bitcast, the lower-numbered components
+      // of the vector map to lower-ordered bits of the larger bitwidth element
+      // type.
+      Type vectorType = srcElemType;
+      if (!srcElemType.isa<VectorType>())
+        vectorType = VectorType::get({ratio}, dstElemType);
+      Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
+          loc, vectorType, components);
+      if (!srcElemType.isa<VectorType>())
+        vectorValue =
+            rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
+      rewriter.replaceOp(loadOp, vectorValue);
+      return success();
     }
 
-    // Create a vector of the components and then cast back to the larger
-    // bitwidth element type. For spv.bitcast, the lower-numbered components of
-    // the vector map to lower-ordered bits of the larger bitwidth element type.
-    auto vectorType = VectorType::get({ratio}, dstElemType);
-    Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
-        loc, vectorType, components);
-    rewriter.replaceOpWithNewOp<spirv::BitcastOp>(loadOp, srcElemType,
-                                                  vectorValue);
-    return success();
+    return rewriter.notifyMatchFailure(
+        loadOp, "unsupported src/dst types for spv.Load");
   }
 };
 

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index 1d0c81d723666..d363666451f7f 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -278,3 +278,54 @@ spv.module Logical GLSL450 {
     spv.Return
   }
 }
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01_scalar bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01_vec2 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<2xf32>, stride=8> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01_vec4 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+  spv.func @load_
diff erent_vector_sizes(%i0: i32) -> vector<4xf32> "None" {
+    %c0 = spv.Constant 0 : i32
+
+    %addr0 = spv.mlir.addressof @var01_vec4 : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+    %ac0 = spv.AccessChain %addr0[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32
+    %vec4val = spv.Load "StorageBuffer" %ac0 : vector<4xf32>
+
+    %addr1 = spv.mlir.addressof @var01_scalar : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac1 = spv.AccessChain %addr1[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %scalarval = spv.Load "StorageBuffer" %ac1 : f32
+
+    %val = spv.CompositeInsert %scalarval, %vec4val[0 : i32] : f32 into vector<4xf32>
+    spv.ReturnValue %val : vector<4xf32>
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01_scalar
+// CHECK-NOT: @var01_vec4
+//     CHECK: spv.GlobalVariable @var01_vec2 bind(0, 1) : !spv.ptr<{{.+}}>
+// CHECK-NOT: @var01_scalar
+// CHECK-NOT: @var01_vec4
+
+//     CHECK: spv.func @load_
diff erent_vector_sizes(%[[IDX:.+]]: i32)
+//     CHECK:   %[[ZERO:.+]] = spv.Constant 0 : i32
+//     CHECK:   %[[ADDR:.+]] = spv.mlir.addressof @var01_vec2
+//     CHECK:   %[[TWO:.+]] = spv.Constant 2 : i32
+//     CHECK:   %[[IDX0:.+]] = spv.IMul %[[IDX]], %[[TWO]] : i32
+//     CHECK:   %[[AC0:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[IDX0]]]
+//     CHECK:   %[[LD0:.+]] = spv.Load "StorageBuffer" %[[AC0]] : vector<2xf32>
+//     CHECK:   %[[ONE:.+]] = spv.Constant 1 : i32
+//     CHECK:   %[[IDX1:.+]] = spv.IAdd %0, %[[ONE]] : i32
+//     CHECK:   %[[AC1:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[IDX1]]]
+//     CHECK:   %[[LD1:.+]] = spv.Load "StorageBuffer" %[[AC1]] : vector<2xf32>
+//     CHECK:   spv.CompositeConstruct %[[LD0]], %[[LD1]] : (vector<2xf32>, vector<2xf32>) -> vector<4xf32>
+
+//     CHECK:   %[[ADDR:.+]] = spv.mlir.addressof @var01_vec2
+//     CHECK:   %[[TWO:.+]] = spv.Constant 2 : i32
+//     CHECK:   %[[DIV:.+]] = spv.SDiv %[[IDX]], %[[TWO]] : i32
+//     CHECK:   %[[MOD:.+]] = spv.SMod %[[IDX]], %[[TWO]] : i32
+//     CHECK:   %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[DIV]], %[[MOD]]]
+//     CHECK:   %[[LD:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32


        


More information about the Mlir-commits mailing list