[Mlir-commits] [mlir] 11cf2d5 - [mlir][spirv] Unify aliases of different bitwidth scalar types

Lei Zhang llvmlistbot at llvm.org
Fri Jun 10 15:09:34 PDT 2022


Author: Lei Zhang
Date: 2022-06-10T18:01:31-04:00
New Revision: 11cf2d5f62f94b14644fab3478f17af9cb015706

URL: https://github.com/llvm/llvm-project/commit/11cf2d5f62f94b14644fab3478f17af9cb015706
DIFF: https://github.com/llvm/llvm-project/commit/11cf2d5f62f94b14644fab3478f17af9cb015706.diff

LOG: [mlir][spirv] Unify aliases of different bitwidth scalar types

This commit extends the UnifyAliasedResourcePass to handle scalar
types of different bitwidths. It requires to get the smaller bitwidth
resource as the canonical resource so that we can avoid subcomponent
load/store. Instead we load/store multiple smaller bitwidth ones.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D127266

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
    mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index e2b83a8ba22fc..2dc4d737a71cd 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -26,6 +26,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include <algorithm>
+#include <iterator>
 
 #define DEBUG_TYPE "spirv-unify-aliased-resource"
 
@@ -72,20 +73,65 @@ static Type getRuntimeArrayElementType(Type type) {
   return rtArrayType.getElementType();
 }
 
-/// Returns true if all `types`, which can either be scalar or vector types,
-/// have the same bitwidth base scalar type.
-static bool hasSameBitwidthScalarType(ArrayRef<spirv::SPIRVType> types) {
-  SmallVector<int64_t> scalarTypes;
-  scalarTypes.reserve(types.size());
+/// Given a list of resource element `types`, returns the index of the canonical
+/// resource that all resources should be unified into. Returns llvm::None if
+/// unable to unify.
+static Optional<int> deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
+  SmallVector<int> scalarNumBits, totalNumBits;
+  scalarNumBits.reserve(types.size());
+  totalNumBits.reserve(types.size());
+  bool hasVector = false;
+
   for (spirv::SPIRVType type : types) {
     assert(type.isScalarOrVector());
-    if (auto vectorType = type.dyn_cast<VectorType>())
-      scalarTypes.push_back(
+    if (auto vectorType = type.dyn_cast<VectorType>()) {
+      if (vectorType.getNumElements() % 2 != 0)
+        return llvm::None; // Odd-sized vector has special layout requirements.
+
+      Optional<int64_t> numBytes = type.getSizeInBytes();
+      if (!numBytes)
+        return llvm::None;
+
+      scalarNumBits.push_back(
           vectorType.getElementType().getIntOrFloatBitWidth());
-    else
-      scalarTypes.push_back(type.getIntOrFloatBitWidth());
+      totalNumBits.push_back(*numBytes * 8);
+      hasVector = true;
+    } else {
+      scalarNumBits.push_back(type.getIntOrFloatBitWidth());
+      totalNumBits.push_back(scalarNumBits.back());
+    }
   }
-  return llvm::is_splat(scalarTypes);
+
+  if (hasVector) {
+    // If there are vector types, require all element types to be the same for
+    // now to simplify the transformation.
+    if (!llvm::is_splat(scalarNumBits))
+      return llvm::None;
+
+    // Choose the one with the largest bitwidth as the canonical resource, so
+    // that we can still keep vectorized load/store.
+    auto *maxVal = std::max_element(totalNumBits.begin(), totalNumBits.end());
+    // Make sure that the canonical resource's bitwidth is divisible by others.
+    // With out this, we cannot properly adjust the index later.
+    if (llvm::any_of(totalNumBits,
+                     [maxVal](int64_t bits) { return *maxVal % bits != 0; }))
+      return llvm::None;
+
+    return std::distance(totalNumBits.begin(), maxVal);
+  }
+
+  // All element types are scalars. Then choose the smallest bitwidth as the
+  // cannonical resource to avoid subcomponent load/store.
+  auto *minVal = std::min_element(scalarNumBits.begin(), scalarNumBits.end());
+  if (llvm::any_of(scalarNumBits,
+                   [minVal](int64_t bit) { return bit % *minVal != 0; }))
+    return llvm::None;
+  return std::distance(scalarNumBits.begin(), minVal);
+}
+
+static bool areSameBitwidthScalarType(Type a, Type b) {
+  return a.isIntOrFloat() && b.isIntOrFloat() &&
+         a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
 }
 
 //===----------------------------------------------------------------------===//
@@ -203,11 +249,8 @@ ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
 
 void ResourceAliasAnalysis::recordIfUnifiable(
     const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
-  // Collect the element types and byte counts for all resources in the
-  // current set.
+  // Collect the element types for all resources in the current set.
   SmallVector<spirv::SPIRVType> elementTypes;
-  SmallVector<int64_t> numBytes;
-
   for (spirv::GlobalVariableOp resource : resources) {
     Type elementType = getRuntimeArrayElementType(resource.type());
     if (!elementType)
@@ -217,37 +260,16 @@ void ResourceAliasAnalysis::recordIfUnifiable(
     if (!type.isScalarOrVector())
       return; // Unexpected resource element type.
 
-    if (auto vectorType = type.dyn_cast<VectorType>())
-      if (vectorType.getNumElements() % 2 != 0)
-        return; // Odd-sized vector has special layout requirements.
-
-    Optional<int64_t> count = type.getSizeInBytes();
-    if (!count)
-      return;
-
     elementTypes.push_back(type);
-    numBytes.push_back(*count);
   }
 
-  // Make sure base scalar types have the same bitwdith, so that we don't need
-  // to handle extracting components for now.
-  if (!hasSameBitwidthScalarType(elementTypes))
-    return;
-
-  // Make sure that the canonical resource's bitwidth is divisible by others.
-  // With out this, we cannot properly adjust the index later.
-  auto *maxCount = std::max_element(numBytes.begin(), numBytes.end());
-  if (llvm::any_of(numBytes, [maxCount](int64_t count) {
-        return *maxCount % count != 0;
-      }))
+  Optional<int> index = deduceCanonicalResource(elementTypes);
+  if (!index)
     return;
 
-  spirv::GlobalVariableOp canonicalResource =
-      resources[std::distance(numBytes.begin(), maxCount)];
-
   // Update internal data structures for later use.
   resourceMap[descriptor].assign(resources.begin(), resources.end());
-  canonicalResourceMap[descriptor] = canonicalResource;
+  canonicalResourceMap[descriptor] = resources[*index];
   for (const auto &resource : llvm::enumerate(resources)) {
     descriptorMap[resource.value()] = descriptor;
     elementTypeMap[resource.value()] = elementTypes[resource.index()];
@@ -316,8 +338,8 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
     spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
     spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
 
-    if ((srcElemType == dstElemType) ||
-        (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) {
+    if (srcElemType == dstElemType ||
+        areSameBitwidthScalarType(srcElemType, dstElemType)) {
       // We have the same bitwidth for source and destination element types.
       // Thie indices keep the same.
       rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
@@ -333,7 +355,10 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       // them into a buffer with vector element types. We need to scale the last
       // index for the vector as a whole, then add one level of index for inside
       // the vector.
-      int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes();
+      int srcNumBits = *srcElemType.getSizeInBytes();
+      int dstNumBits = *dstElemType.getSizeInBytes();
+      assert(dstNumBits > srcNumBits && dstNumBits % srcNumBits == 0);
+      int ratio = dstNumBits / srcNumBits;
       auto ratioValue = rewriter.create<spirv::ConstantOp>(
           loc, i32Type, rewriter.getI32IntegerAttr(ratio));
 
@@ -349,6 +374,27 @@ struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
       return success();
     }
 
+    if (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) {
+      // The source indices are for a buffer with larger bitwidth scalar element
+      // types. Rewrite them into a buffer with smaller bitwidth element types.
+      // We only need to scale the last index.
+      int srcNumBits = *srcElemType.getSizeInBytes();
+      int dstNumBits = *dstElemType.getSizeInBytes();
+      assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
+      int ratio = srcNumBits / dstNumBits;
+      auto ratioValue = rewriter.create<spirv::ConstantOp>(
+          loc, i32Type, rewriter.getI32IntegerAttr(ratio));
+
+      auto indices = llvm::to_vector<4>(acOp.indices());
+      Value oldIndex = indices.back();
+      indices.back() =
+          rewriter.create<spirv::IMulOp>(loc, i32Type, oldIndex, ratioValue);
+
+      rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
+          acOp, adaptor.base_ptr(), indices);
+      return success();
+    }
+
     return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
   }
 };
@@ -370,12 +416,56 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
     auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
     if (srcElemType == dstElemType) {
       rewriter.replaceOp(loadOp, newLoadOp->getResults());
-    } else {
+      return success();
+    }
+
+    if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
       auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
                                                       newLoadOp.value());
       rewriter.replaceOp(loadOp, castOp->getResults());
+
+      return success();
     }
 
+    // The source and destination have scalar types of 
diff erent bitwidths.
+    // For such cases, we need to load multiple smaller bitwidth values and
+    // construct a larger bitwidth one.
+
+    int srcNumBits = srcElemType.getIntOrFloatBitWidth();
+    int dstNumBits = dstElemType.getIntOrFloatBitWidth();
+    assert(srcNumBits > dstNumBits && srcNumBits % dstNumBits == 0);
+    int ratio = srcNumBits / dstNumBits;
+    if (ratio > 4)
+      return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
+
+    SmallVector<Value> components;
+    components.reserve(ratio);
+    components.push_back(newLoadOp);
+
+    auto acOp = adaptor.ptr().getDefiningOp<spirv::AccessChainOp>();
+    if (!acOp)
+      return rewriter.notifyMatchFailure(loadOp, "ptr not spv.AccessChain");
+
+    auto i32Type = rewriter.getI32Type();
+    Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
+    auto indices = llvm::to_vector<4>(acOp.indices());
+    for (int i = 1; i < ratio; ++i) {
+      // Load all subsequent components belonging to this element.
+      indices.back() = rewriter.create<spirv::IAddOp>(loc, i32Type,
+                                                      indices.back(), oneValue);
+      auto componentAcOp =
+          rewriter.create<spirv::AccessChainOp>(loc, acOp.base_ptr(), indices);
+      components.push_back(rewriter.create<spirv::LoadOp>(loc, componentAcOp));
+    }
+    std::reverse(components.begin(), components.end()); // For little endian..
+
+    // Create a vector of the components and then cast back to the larger
+    // bitwidth element type.
+    auto vectorType = VectorType::get({ratio}, dstElemType);
+    Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
+        loc, vectorType, components);
+    rewriter.replaceOpWithNewOp<spirv::BitcastOp>(loadOp, srcElemType,
+                                                  vectorValue);
     return success();
   }
 };
@@ -392,6 +482,8 @@ struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
         adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
     if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
       return rewriter.notifyMatchFailure(storeOp, "not scalar type");
+    if (!areSameBitwidthScalarType(srcElemType, dstElemType))
+      return rewriter.notifyMatchFailure(storeOp, "
diff erent bitwidth");
 
     Location loc = storeOp.getLoc();
     Value value = adaptor.value();

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
index 546fc1f93b097..0b36178783f07 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource %s -o - | FileCheck %s
+// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource -verify-diagnostics %s | FileCheck %s
 
 spv.module Logical GLSL450 {
   spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
@@ -213,3 +213,68 @@ spv.module Logical GLSL450 {
 //     CHECK:   %[[CAST2:.+]] = spv.Bitcast %[[VAL0]] : i32 to f32
 //     CHECK:   spv.Store "StorageBuffer" %[[AC]], %[[CAST2]] : f32
 //     CHECK:   spv.ReturnValue %[[CAST1]] : i32
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01s_i64 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+
+  spv.func @load_
diff erent_scalar_bitwidth(%index: i32) -> i64 "None" {
+    %c0 = spv.Constant 0 : i32
+
+    %addr0 = spv.mlir.addressof @var01s_i64 : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>
+    %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>, i32, i32
+    %val0 = spv.Load "StorageBuffer" %ac0 : i64
+
+    spv.ReturnValue %val0 : i64
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s_i64
+//     CHECK: spv.GlobalVariable @var01s_f32 bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s_i64
+
+//     CHECK: spv.func @load_
diff erent_scalar_bitwidth(%[[INDEX:.+]]: i32)
+//     CHECK:   %[[ZERO:.+]] = spv.Constant 0 : i32
+//     CHECK:   %[[ADDR:.+]] = spv.mlir.addressof @var01s_f32
+
+//     CHECK:   %[[TWO:.+]] = spv.Constant 2 : i32
+//     CHECK:   %[[BASE:.+]] = spv.IMul %[[INDEX]], %[[TWO]] : i32
+//     CHECK:   %[[AC0:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[BASE]]]
+//     CHECK:   %[[LOAD0:.+]] = spv.Load "StorageBuffer" %[[AC0]] : f32
+
+//     CHECK:   %[[ONE:.+]] = spv.Constant 1 : i32
+//     CHECK:   %[[ADD:.+]] = spv.IAdd %[[BASE]], %[[ONE]] : i32
+//     CHECK:   %[[AC1:.+]] = spv.AccessChain %[[ADDR]][%[[ZERO]], %[[ADD]]]
+//     CHECK:   %[[LOAD1:.+]] = spv.Load "StorageBuffer" %[[AC1]] : f32
+
+//     CHECK:   %[[CC:.+]] = spv.CompositeConstruct %[[LOAD1]], %[[LOAD0]]
+//     CHECK:   %[[CAST:.+]] = spv.Bitcast %[[CC]] : vector<2xf32> to i64
+//     CHECK:   spv.ReturnValue %[[CAST]]
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01s_i64 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+
+  spv.func @store_
diff erent_scalar_bitwidth(%i0: i32, %i1: i32) "None" {
+    %c0 = spv.Constant 0 : i32
+
+    %addr0 = spv.mlir.addressof @var01s_f32 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac0 = spv.AccessChain %addr0[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %f32val = spv.Load "StorageBuffer" %ac0 : f32
+    %f64val = spv.FConvert %f32val : f32 to f64
+    %i64val = spv.Bitcast %f64val : f64 to i64
+
+    %addr1 = spv.mlir.addressof @var01s_i64 : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>
+    %ac1 = spv.AccessChain %addr1[%c0, %i1] : !spv.ptr<!spv.struct<(!spv.rtarray<i64, stride=4> [0])>, StorageBuffer>, i32, i32
+    // expected-error at +1 {{failed to legalize operation 'spv.Store'}}
+    spv.Store "StorageBuffer" %ac1, %i64val : i64
+
+    spv.Return
+  }
+}


        


More information about the Mlir-commits mailing list