[Mlir-commits] [mlir] c9b3680 - [mlir][spirv] Add a pass to unify aliased resource variables

Lei Zhang llvmlistbot at llvm.org
Thu Feb 17 06:18:28 PST 2022


Author: Lei Zhang
Date: 2022-02-17T09:08:58-05:00
New Revision: c9b36807beaf120f4e06d9da3b7df7625e440825

URL: https://github.com/llvm/llvm-project/commit/c9b36807beaf120f4e06d9da3b7df7625e440825
DIFF: https://github.com/llvm/llvm-project/commit/c9b36807beaf120f4e06d9da3b7df7625e440825.diff

LOG: [mlir][spirv] Add a pass to unify aliased resource variables

In SPIR-V, resources are represented as global variables that
are bound to certain descriptor. SPIR-V requires those global
variables to be declared as aliased if multiple ones are bound
to the same slot. Such aliased decorations can cause issues
for transcompilers like SPIRV-Cross when converting to source
shading languages like MSL.

So this commit adds a pass to perform analysis of aliased
resources and see if we can unify them into one.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D119872

Added: 
    mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
    mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
    mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h
    mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
    mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index ebefa3167a249..4201e0ee09333 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -385,7 +385,7 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
     OptionalAttr<FlatSymbolRefAttr>:$initializer,
     OptionalAttr<I32Attr>:$location,
     OptionalAttr<I32Attr>:$binding,
-    OptionalAttr<I32Attr>:$descriptorSet,
+    OptionalAttr<I32Attr>:$descriptor_set,
     OptionalAttr<StrAttr>:$builtin
   );
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h
index 38548fee32682..116a37dc0b534 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h
@@ -55,6 +55,11 @@ std::unique_ptr<OperationPass<spirv::ModuleOp>> createLowerABIAttributesPass();
 /// spv.CompositeInsert into spv.CompositeConstruct.
 std::unique_ptr<OperationPass<spirv::ModuleOp>> createRewriteInsertsPass();
 
+/// Creates an operation pass that unifies access of multiple aliased resources
+/// into access of one single resource.
+std::unique_ptr<OperationPass<spirv::ModuleOp>>
+createUnifyAliasedResourcePass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 575bb0898faad..32abca53f8a59 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -28,6 +28,13 @@ def SPIRVRewriteInsertsPass : Pass<"spirv-rewrite-inserts", "spirv::ModuleOp"> {
   let constructor = "mlir::spirv::createRewriteInsertsPass()";
 }
 
+def SPIRVUnifyAliasedResourcePass
+    : Pass<"spirv-unify-aliased-resource", "spirv::ModuleOp"> {
+  let summary = "Unify access of multiple aliased resources into access of one "
+                "single resource";
+  let constructor = "mlir::spirv::createUnifyAliasedResourcePass()";
+}
+
 def SPIRVUpdateVCE : Pass<"spirv-update-vce", "spirv::ModuleOp"> {
   let summary = "Deduce and attach minimal (version, capabilities, extensions) "
                 "requirements to spv.module ops";

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index db274088bdf22..affceebcfd3d4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ set(LLVM_OPTIONAL_SOURCES
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
   SPIRVConversion.cpp
+  UnifyAliasedResourcePass.cpp
   UpdateVCEPass.cpp
 )
 
@@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVTransforms
   DecorateCompositeTypeLayoutPass.cpp
   LowerABIAttributesPass.cpp
   RewriteInsertsPass.cpp
+  UnifyAliasedResourcePass.cpp
   UpdateVCEPass.cpp
 
   ADDITIONAL_HEADER_DIRS

diff  --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
new file mode 100644
index 0000000000000..fa0e551f5d53e
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -0,0 +1,452 @@
+//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that unifies access of multiple aliased resources
+// into access of one single resource.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/AnalysisManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+#include <algorithm>
+
+#define DEBUG_TYPE "spirv-unify-aliased-resource"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
+using AliasedResourceMap =
+    DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
+
+/// Collects all aliased resources in the given SPIR-V `moduleOp`.
+static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
+  AliasedResourceMap aliasedResoruces;
+  moduleOp->walk([&aliasedResoruces](spirv::GlobalVariableOp varOp) {
+    if (varOp->getAttrOfType<UnitAttr>("aliased")) {
+      Optional<uint32_t> set = varOp.descriptor_set();
+      Optional<uint32_t> binding = varOp.binding();
+      if (set && binding)
+        aliasedResoruces[{*set, *binding}].push_back(varOp);
+    }
+  });
+  return aliasedResoruces;
+}
+
+/// Returns the element type if the given `type` is a runtime array resource:
+/// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise.
+static Type getRuntimeArrayElementType(Type type) {
+  auto ptrType = type.dyn_cast<spirv::PointerType>();
+  if (!ptrType)
+    return {};
+
+  auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
+  if (!structType || structType.getNumElements() != 1)
+    return {};
+
+  auto rtArrayType =
+      structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
+  if (!rtArrayType)
+    return {};
+
+  return rtArrayType.getElementType();
+}
+
+/// Returns true if all `types`, which can either be scalar or vector types,
+/// have the same bitwidth base scalar type.
+static bool hasSameBitwidthScalarType(ArrayRef<spirv::SPIRVType> types) {
+  SmallVector<int64_t> scalarTypes;
+  scalarTypes.reserve(types.size());
+  for (spirv::SPIRVType type : types) {
+    assert(type.isScalarOrVector());
+    if (auto vectorType = type.dyn_cast<VectorType>())
+      scalarTypes.push_back(
+          vectorType.getElementType().getIntOrFloatBitWidth());
+    else
+      scalarTypes.push_back(type.getIntOrFloatBitWidth());
+  }
+  return llvm::is_splat(scalarTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// Analysis
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A class for analyzing aliased resources.
+///
+/// Resources are expected to be spv.GlobalVarible that has a descriptor set and
+/// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>`
+/// per Vulkan requirements.
+///
+/// Right now, we only support the case that there is a single runtime array
+/// inside the struct.
+class ResourceAliasAnalysis {
+public:
+  explicit ResourceAliasAnalysis(Operation *);
+
+  /// Returns true if the given `op` can be rewritten to use a canonical
+  /// resource.
+  bool shouldUnify(Operation *op) const;
+
+  /// Returns all descriptors and their corresponding aliased resources.
+  const AliasedResourceMap &getResourceMap() const { return resourceMap; }
+
+  /// Returns the canonical resource for the given descriptor/variable.
+  spirv::GlobalVariableOp
+  getCanonicalResource(const Descriptor &descriptor) const;
+  spirv::GlobalVariableOp
+  getCanonicalResource(spirv::GlobalVariableOp varOp) const;
+
+  /// Returns the element type for the given variable.
+  spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
+
+private:
+  /// Given the descriptor and aliased resources bound to it, analyze whether we
+  /// can unify them and record if so.
+  void recordIfUnifiable(const Descriptor &descriptor,
+                         ArrayRef<spirv::GlobalVariableOp> resources);
+
+  /// Mapping from a descriptor to all aliased resources bound to it.
+  AliasedResourceMap resourceMap;
+
+  /// Mapping from a descriptor to the chosen canonical resource.
+  DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
+
+  /// Mapping from an aliased resource to its descriptor.
+  DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
+
+  /// Mapping from an aliased resource to its element (scalar/vector) type.
+  DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
+};
+} // namespace
+
+ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
+  // Collect all aliased resources first and put them into 
diff erent sets
+  // according to the descriptor.
+  AliasedResourceMap aliasedResoruces =
+      collectAliasedResources(cast<spirv::ModuleOp>(root));
+
+  // For each resource set, analyze whether we can unify; if so, try to identify
+  // a canonical resource, whose element type has the largest bitwidth.
+  for (const auto &descriptorResoruce : aliasedResoruces) {
+    recordIfUnifiable(descriptorResoruce.first, descriptorResoruce.second);
+  }
+}
+
+bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
+  if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
+    auto canonicalOp = getCanonicalResource(varOp);
+    return canonicalOp && varOp != canonicalOp;
+  }
+  if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
+    auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
+    auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
+    return shouldUnify(varOp);
+  }
+
+  if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
+    return shouldUnify(acOp.base_ptr().getDefiningOp());
+  if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
+    return shouldUnify(loadOp.ptr().getDefiningOp());
+  if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
+    return shouldUnify(storeOp.ptr().getDefiningOp());
+
+  return false;
+}
+
+spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
+    const Descriptor &descriptor) const {
+  auto varIt = canonicalResourceMap.find(descriptor);
+  if (varIt == canonicalResourceMap.end())
+    return {};
+  return varIt->second;
+}
+
+spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
+    spirv::GlobalVariableOp varOp) const {
+  auto descriptorIt = descriptorMap.find(varOp);
+  if (descriptorIt == descriptorMap.end())
+    return {};
+  return getCanonicalResource(descriptorIt->second);
+}
+
+spirv::SPIRVType
+ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
+  auto it = elementTypeMap.find(varOp);
+  if (it == elementTypeMap.end())
+    return {};
+  return it->second;
+}
+
+void ResourceAliasAnalysis::recordIfUnifiable(
+    const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
+  // Collect the element types and byte counts for all resources in the
+  // current set.
+  SmallVector<spirv::SPIRVType> elementTypes;
+  SmallVector<int64_t> numBytes;
+
+  for (spirv::GlobalVariableOp resource : resources) {
+    Type elementType = getRuntimeArrayElementType(resource.type());
+    if (!elementType)
+      return; // Unexpected resource variable type.
+
+    auto type = elementType.cast<spirv::SPIRVType>();
+    if (!type.isScalarOrVector())
+      return; // Unexpected resource element type.
+
+    if (auto vectorType = type.dyn_cast<VectorType>())
+      if (vectorType.getNumElements() % 2 != 0)
+        return; // Odd-sized vector has special layout requirements.
+
+    Optional<int64_t> count = type.getSizeInBytes();
+    if (!count)
+      return;
+
+    elementTypes.push_back(type);
+    numBytes.push_back(*count);
+  }
+
+  // Make sure base scalar types have the same bitwdith, so that we don't need
+  // to handle extracting components for now.
+  if (!hasSameBitwidthScalarType(elementTypes))
+    return;
+
+  // Make sure that the canonical resource's bitwidth is divisible by others.
+  // With out this, we cannot properly adjust the index later.
+  auto *maxCount = std::max_element(numBytes.begin(), numBytes.end());
+  if (llvm::any_of(numBytes, [maxCount](int64_t count) {
+        return *maxCount % count != 0;
+      }))
+    return;
+
+  spirv::GlobalVariableOp canonicalResource =
+      resources[std::distance(numBytes.begin(), maxCount)];
+
+  // Update internal data structures for later use.
+  resourceMap[descriptor].assign(resources.begin(), resources.end());
+  canonicalResourceMap[descriptor] = canonicalResource;
+  for (const auto &resource : llvm::enumerate(resources)) {
+    descriptorMap[resource.value()] = descriptor;
+    elementTypeMap[resource.value()] = elementTypes[resource.index()];
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+class ConvertAliasResoruce : public OpConversionPattern<OpTy> {
+public:
+  ConvertAliasResoruce(const ResourceAliasAnalysis &analysis,
+                       MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
+
+protected:
+  const ResourceAliasAnalysis &analysis;
+};
+
+struct ConvertVariable : public ConvertAliasResoruce<spirv::GlobalVariableOp> {
+  using ConvertAliasResoruce::ConvertAliasResoruce;
+
+  LogicalResult
+  matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Just remove the aliased resource. Users will be rewritten to use the
+    // canonical one.
+    rewriter.eraseOp(varOp);
+    return success();
+  }
+};
+
+struct ConvertAddressOf : public ConvertAliasResoruce<spirv::AddressOfOp> {
+  using ConvertAliasResoruce::ConvertAliasResoruce;
+
+  LogicalResult
+  matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Rewrite the AddressOf op to get the address of the canoncical resource.
+    auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
+    auto srcVarOp = cast<spirv::GlobalVariableOp>(
+        SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
+    auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
+    rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
+    return success();
+  }
+};
+
+struct ConvertAccessChain : public ConvertAliasResoruce<spirv::AccessChainOp> {
+  using ConvertAliasResoruce::ConvertAliasResoruce;
+
+  LogicalResult
+  matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
+    if (!addressOp)
+      return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
+
+    auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
+    auto srcVarOp = cast<spirv::GlobalVariableOp>(
+        SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
+    auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
+
+    spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
+    spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
+
+    if ((srcElemType == dstElemType) ||
+        (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) {
+      // We have the same bitwidth for source and destination element types.
+      // Thie indices keep the same.
+      rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
+          acOp, adaptor.base_ptr(), adaptor.indices());
+      return success();
+    }
+
+    Location loc = acOp.getLoc();
+    auto i32Type = rewriter.getI32Type();
+
+    if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
+      // The source indices are for a buffer with scalar element types. Rewrite
+      // them into a buffer with vector element types. We need to scale the last
+      // index for the vector as a whole, then add one level of index for inside
+      // the vector.
+      int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes();
+      auto ratioValue = rewriter.create<spirv::ConstantOp>(
+          loc, i32Type, rewriter.getI32IntegerAttr(ratio));
+
+      auto indices = llvm::to_vector<4>(acOp.indices());
+      Value oldIndex = indices.back();
+      indices.back() =
+          rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
+      indices.push_back(
+          rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
+
+      rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
+          acOp, adaptor.base_ptr(), indices);
+      return success();
+    }
+
+    return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
+  }
+};
+
+struct ConvertLoad : public ConvertAliasResoruce<spirv::LoadOp> {
+  using ConvertAliasResoruce::ConvertAliasResoruce;
+
+  LogicalResult
+  matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcElemType =
+        loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+    auto dstElemType =
+        adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+    if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
+      return rewriter.notifyMatchFailure(loadOp, "not scalar type");
+
+    Location loc = loadOp.getLoc();
+    auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
+    if (srcElemType == dstElemType) {
+      rewriter.replaceOp(loadOp, newLoadOp->getResults());
+    } else {
+      auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
+                                                      newLoadOp.value());
+      rewriter.replaceOp(loadOp, castOp->getResults());
+    }
+
+    return success();
+  }
+};
+
+struct ConvertStore : public ConvertAliasResoruce<spirv::StoreOp> {
+  using ConvertAliasResoruce::ConvertAliasResoruce;
+
+  LogicalResult
+  matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcElemType =
+        storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+    auto dstElemType =
+        adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+    if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
+      return rewriter.notifyMatchFailure(storeOp, "not scalar type");
+
+    Location loc = storeOp.getLoc();
+    Value value = adaptor.value();
+    if (srcElemType != dstElemType)
+      value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
+    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
+                                                storeOp->getAttrs());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+class UnifyAliasedResourcePass final
+    : public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> {
+public:
+  void runOnOperation() override;
+};
+} // namespace
+
+void UnifyAliasedResourcePass::runOnOperation() {
+  spirv::ModuleOp moduleOp = getOperation();
+  MLIRContext *context = &getContext();
+
+  // Analyze aliased resources first.
+  ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
+
+  ConversionTarget target(*context);
+  target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
+                               spirv::AccessChainOp, spirv::LoadOp,
+                               spirv::StoreOp>(
+      [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
+  target.addLegalDialect<spirv::SPIRVDialect>();
+
+  // Run patterns to rewrite usages of non-canonical resources.
+  RewritePatternSet patterns(context);
+  patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
+               ConvertLoad, ConvertStore>(analysis, context);
+  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+    return signalPassFailure();
+
+  // Drop aliased attribute if we only have one single bound resource for a
+  // descriptor. We need to re-collect the map here given in the above the
+  // conversion is best effort; certain sets may not be converted.
+  AliasedResourceMap resourceMap =
+      collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
+  for (const auto &dr : resourceMap) {
+    const auto &resources = dr.second;
+    if (resources.size() == 1)
+      resources.front()->removeAttr("aliased");
+  }
+}
+
+std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
+spirv::createUnifyAliasedResourcePass() {
+  return std::make_unique<UnifyAliasedResourcePass>();
+}

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
new file mode 100644
index 0000000000000..546fc1f93b097
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -0,0 +1,215 @@
+// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource %s -o - | FileCheck %s
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+  spv.func @load_store_scalar(%index: i32) -> f32 "None" {
+    %c0 = spv.Constant 0 : i32
+    %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %value = spv.Load "StorageBuffer" %ac : f32
+    spv.Store "StorageBuffer" %ac, %value : f32
+    spv.ReturnValue %value : f32
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s
+//     CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+//     CHECK: spv.func @load_store_scalar(%[[INDEX:.+]]: i32)
+// CHECK-DAG:   %[[C0:.+]] = spv.Constant 0 : i32
+// CHECK-DAG:   %[[C4:.+]] = spv.Constant 4 : i32
+// CHECK-DAG:   %[[ADDR:.+]] = spv.mlir.addressof @var01v
+//     CHECK:   %[[DIV:.+]] = spv.SDiv %[[INDEX]], %[[C4]] : i32
+//     CHECK:   %[[MOD:.+]] = spv.SMod %[[INDEX]], %[[C4]] : i32
+//     CHECK:   %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[C0]], %[[DIV]], %[[MOD]]]
+//     CHECK:   spv.Load "StorageBuffer" %[[AC]]
+//     CHECK:   spv.Store "StorageBuffer" %[[AC]]
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+  spv.func @multiple_uses(%i0: i32, %i1: i32) -> f32 "None" {
+    %c0 = spv.Constant 0 : i32
+    %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac0 = spv.AccessChain %addr[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %val0 = spv.Load "StorageBuffer" %ac0 : f32
+    %ac1 = spv.AccessChain %addr[%c0, %i1] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %val1 = spv.Load "StorageBuffer" %ac1 : f32
+    %value = spv.FAdd %val0, %val1 : f32
+    spv.ReturnValue %value : f32
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s
+//     CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+//     CHECK: spv.func @multiple_uses
+//     CHECK:   %[[ADDR:.+]] = spv.mlir.addressof @var01v
+//     CHECK:   spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
+//     CHECK:   spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<3xf32>, stride=16> [0])>, StorageBuffer>
+
+  spv.func @vector3(%index: i32) -> f32 "None" {
+    %c0 = spv.Constant 0 : i32
+    %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %value = spv.Load "StorageBuffer" %ac : f32
+    spv.ReturnValue %value : f32
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK: spv.GlobalVariable @var01s bind(0, 1) {aliased}
+// CHECK: spv.GlobalVariable @var01v bind(0, 1) {aliased}
+// CHECK: spv.func @vector3
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01v bind(1, 0) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+  spv.func @not_aliased(%index: i32) -> f32 "None" {
+    %c0 = spv.Constant 0 : i32
+    %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %value = spv.Load "StorageBuffer" %ac : f32
+    spv.Store "StorageBuffer" %ac, %value : f32
+    spv.ReturnValue %value : f32
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK: spv.GlobalVariable @var01s bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+// CHECK: spv.GlobalVariable @var01v bind(1, 0) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK: spv.func @not_aliased
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01s_1 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01v_1 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+  spv.func @multiple_aliases(%index: i32) -> f32 "None" {
+    %c0 = spv.Constant 0 : i32
+
+    %addr0 = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %val0 = spv.Load "StorageBuffer" %ac0 : f32
+
+    %addr1 = spv.mlir.addressof @var01s_1 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac1 = spv.AccessChain %addr1[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %val1 = spv.Load "StorageBuffer" %ac1 : f32
+
+    %addr2 = spv.mlir.addressof @var01v_1 : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+    %ac2 = spv.AccessChain %addr2[%c0, %index, %c0] : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32, i32
+    %val2 = spv.Load "StorageBuffer" %ac2 : f32
+
+    %add0 = spv.FAdd %val0, %val1 : f32
+    %add1 = spv.FAdd %add0, %val2 : f32
+    spv.ReturnValue %add1 : f32
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s
+//     CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01v_1
+
+//     CHECK: spv.func @multiple_aliases
+//     CHECK:   %[[ADDR0:.+]] = spv.mlir.addressof @var01v :
+//     CHECK:   spv.AccessChain %[[ADDR0]][%{{.+}}, %{{.+}}, %{{.+}}]
+//     CHECK:   %[[ADDR1:.+]] = spv.mlir.addressof @var01v :
+//     CHECK:   spv.AccessChain %[[ADDR1]][%{{.+}}, %{{.+}}, %{{.+}}]
+//     CHECK:   %[[ADDR2:.+]] = spv.mlir.addressof @var01v :
+//     CHECK:   spv.AccessChain %[[ADDR2]][%{{.+}}, %{{.+}}, %{{.+}}]
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01s_i32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+
+  spv.func @
diff erent_scalar_type(%index: i32, %val1: f32) -> i32 "None" {
+    %c0 = spv.Constant 0 : i32
+
+    %addr0 = spv.mlir.addressof @var01s_i32 : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+    %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %val0 = spv.Load "StorageBuffer" %ac0 : i32
+
+    %addr1 = spv.mlir.addressof @var01s_f32 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+    %ac1 = spv.AccessChain %addr1[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+    spv.Store "StorageBuffer" %ac1, %val1 : f32
+
+    spv.ReturnValue %val0 : i32
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s_f32
+//     CHECK: spv.GlobalVariable @var01s_i32 bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s_f32
+
+//     CHECK: spv.func @
diff erent_scalar_type(%[[INDEX:.+]]: i32, %[[VAL1:.+]]: f32)
+
+//     CHECK:   %[[IADDR:.+]] = spv.mlir.addressof @var01s_i32
+//     CHECK:   %[[IAC:.+]] = spv.AccessChain %[[IADDR]][%{{.+}}, %[[INDEX]]]
+//     CHECK:   spv.Load "StorageBuffer" %[[IAC]] : i32
+
+//     CHECK:   %[[FADDR:.+]] = spv.mlir.addressof @var01s_i32
+//     CHECK:   %[[FAC:.+]] = spv.AccessChain %[[FADDR]][%cst0_i32, %[[INDEX]]]
+//     CHECK:   %[[CAST:.+]] = spv.Bitcast %[[VAL1]] : f32 to i32
+//     CHECK:   spv.Store "StorageBuffer" %[[FAC]], %[[CAST]] : i32
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+  spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+  spv.func @
diff erent_scalar_type(%index: i32, %val0: i32) -> i32 "None" {
+    %c0 = spv.Constant 0 : i32
+    %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+    %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
+    %val1 = spv.Load "StorageBuffer" %ac : i32
+    spv.Store "StorageBuffer" %ac, %val0 : i32
+    spv.ReturnValue %val1 : i32
+  }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s
+//     CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+//     CHECK: spv.func @
diff erent_scalar_type(%{{.+}}: i32, %[[VAL0:.+]]: i32)
+//     CHECK:   %[[ADDR:.+]] = spv.mlir.addressof @var01v
+//     CHECK:   %[[AC:.+]] = spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
+//     CHECK:   %[[VAL1:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32
+//     CHECK:   %[[CAST1:.+]] = spv.Bitcast %[[VAL1]] : f32 to i32
+//     CHECK:   %[[CAST2:.+]] = spv.Bitcast %[[VAL0]] : i32 to f32
+//     CHECK:   spv.Store "StorageBuffer" %[[AC]], %[[CAST2]] : f32
+//     CHECK:   spv.ReturnValue %[[CAST1]] : i32


        


More information about the Mlir-commits mailing list