[Mlir-commits] [mlir] c9b3680 - [mlir][spirv] Add a pass to unify aliased resource variables
Lei Zhang
llvmlistbot at llvm.org
Thu Feb 17 06:18:28 PST 2022
Author: Lei Zhang
Date: 2022-02-17T09:08:58-05:00
New Revision: c9b36807beaf120f4e06d9da3b7df7625e440825
URL: https://github.com/llvm/llvm-project/commit/c9b36807beaf120f4e06d9da3b7df7625e440825
DIFF: https://github.com/llvm/llvm-project/commit/c9b36807beaf120f4e06d9da3b7df7625e440825.diff
LOG: [mlir][spirv] Add a pass to unify aliased resource variables
In SPIR-V, resources are represented as global variables that
are bound to certain descriptor. SPIR-V requires those global
variables to be declared as aliased if multiple ones are bound
to the same slot. Such aliased decorations can cause issues
for transcompilers like SPIRV-Cross when converting to source
shading languages like MSL.
So this commit adds a pass to perform analysis of aliased
resources and see if we can unify them into one.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D119872
Added:
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h
mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
index ebefa3167a249..4201e0ee09333 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
@@ -385,7 +385,7 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
OptionalAttr<FlatSymbolRefAttr>:$initializer,
OptionalAttr<I32Attr>:$location,
OptionalAttr<I32Attr>:$binding,
- OptionalAttr<I32Attr>:$descriptorSet,
+ OptionalAttr<I32Attr>:$descriptor_set,
OptionalAttr<StrAttr>:$builtin
);
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h
index 38548fee32682..116a37dc0b534 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.h
@@ -55,6 +55,11 @@ std::unique_ptr<OperationPass<spirv::ModuleOp>> createLowerABIAttributesPass();
/// spv.CompositeInsert into spv.CompositeConstruct.
std::unique_ptr<OperationPass<spirv::ModuleOp>> createRewriteInsertsPass();
+/// Creates an operation pass that unifies access of multiple aliased resources
+/// into access of one single resource.
+std::unique_ptr<OperationPass<spirv::ModuleOp>>
+createUnifyAliasedResourcePass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
index 575bb0898faad..32abca53f8a59 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
@@ -28,6 +28,13 @@ def SPIRVRewriteInsertsPass : Pass<"spirv-rewrite-inserts", "spirv::ModuleOp"> {
let constructor = "mlir::spirv::createRewriteInsertsPass()";
}
+def SPIRVUnifyAliasedResourcePass
+ : Pass<"spirv-unify-aliased-resource", "spirv::ModuleOp"> {
+ let summary = "Unify access of multiple aliased resources into access of one "
+ "single resource";
+ let constructor = "mlir::spirv::createUnifyAliasedResourcePass()";
+}
+
def SPIRVUpdateVCE : Pass<"spirv-update-vce", "spirv::ModuleOp"> {
let summary = "Deduce and attach minimal (version, capabilities, extensions) "
"requirements to spv.module ops";
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
index db274088bdf22..affceebcfd3d4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ set(LLVM_OPTIONAL_SOURCES
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
SPIRVConversion.cpp
+ UnifyAliasedResourcePass.cpp
UpdateVCEPass.cpp
)
@@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVTransforms
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
+ UnifyAliasedResourcePass.cpp
UpdateVCEPass.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
new file mode 100644
index 0000000000000..fa0e551f5d53e
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -0,0 +1,452 @@
+//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that unifies access of multiple aliased resources
+// into access of one single resource.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Pass/AnalysisManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+#include <algorithm>
+
+#define DEBUG_TYPE "spirv-unify-aliased-resource"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
+using AliasedResourceMap =
+ DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
+
+/// Collects all aliased resources in the given SPIR-V `moduleOp`.
+static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
+ AliasedResourceMap aliasedResoruces;
+ moduleOp->walk([&aliasedResoruces](spirv::GlobalVariableOp varOp) {
+ if (varOp->getAttrOfType<UnitAttr>("aliased")) {
+ Optional<uint32_t> set = varOp.descriptor_set();
+ Optional<uint32_t> binding = varOp.binding();
+ if (set && binding)
+ aliasedResoruces[{*set, *binding}].push_back(varOp);
+ }
+ });
+ return aliasedResoruces;
+}
+
+/// Returns the element type if the given `type` is a runtime array resource:
+/// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise.
+static Type getRuntimeArrayElementType(Type type) {
+ auto ptrType = type.dyn_cast<spirv::PointerType>();
+ if (!ptrType)
+ return {};
+
+ auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
+ if (!structType || structType.getNumElements() != 1)
+ return {};
+
+ auto rtArrayType =
+ structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
+ if (!rtArrayType)
+ return {};
+
+ return rtArrayType.getElementType();
+}
+
+/// Returns true if all `types`, which can either be scalar or vector types,
+/// have the same bitwidth base scalar type.
+static bool hasSameBitwidthScalarType(ArrayRef<spirv::SPIRVType> types) {
+ SmallVector<int64_t> scalarTypes;
+ scalarTypes.reserve(types.size());
+ for (spirv::SPIRVType type : types) {
+ assert(type.isScalarOrVector());
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ scalarTypes.push_back(
+ vectorType.getElementType().getIntOrFloatBitWidth());
+ else
+ scalarTypes.push_back(type.getIntOrFloatBitWidth());
+ }
+ return llvm::is_splat(scalarTypes);
+}
+
+//===----------------------------------------------------------------------===//
+// Analysis
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A class for analyzing aliased resources.
+///
+/// Resources are expected to be spv.GlobalVarible that has a descriptor set and
+/// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>`
+/// per Vulkan requirements.
+///
+/// Right now, we only support the case that there is a single runtime array
+/// inside the struct.
+class ResourceAliasAnalysis {
+public:
+ explicit ResourceAliasAnalysis(Operation *);
+
+ /// Returns true if the given `op` can be rewritten to use a canonical
+ /// resource.
+ bool shouldUnify(Operation *op) const;
+
+ /// Returns all descriptors and their corresponding aliased resources.
+ const AliasedResourceMap &getResourceMap() const { return resourceMap; }
+
+ /// Returns the canonical resource for the given descriptor/variable.
+ spirv::GlobalVariableOp
+ getCanonicalResource(const Descriptor &descriptor) const;
+ spirv::GlobalVariableOp
+ getCanonicalResource(spirv::GlobalVariableOp varOp) const;
+
+ /// Returns the element type for the given variable.
+ spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
+
+private:
+ /// Given the descriptor and aliased resources bound to it, analyze whether we
+ /// can unify them and record if so.
+ void recordIfUnifiable(const Descriptor &descriptor,
+ ArrayRef<spirv::GlobalVariableOp> resources);
+
+ /// Mapping from a descriptor to all aliased resources bound to it.
+ AliasedResourceMap resourceMap;
+
+ /// Mapping from a descriptor to the chosen canonical resource.
+ DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
+
+ /// Mapping from an aliased resource to its descriptor.
+ DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
+
+ /// Mapping from an aliased resource to its element (scalar/vector) type.
+ DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
+};
+} // namespace
+
+ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
+ // Collect all aliased resources first and put them into
diff erent sets
+ // according to the descriptor.
+ AliasedResourceMap aliasedResoruces =
+ collectAliasedResources(cast<spirv::ModuleOp>(root));
+
+ // For each resource set, analyze whether we can unify; if so, try to identify
+ // a canonical resource, whose element type has the largest bitwidth.
+ for (const auto &descriptorResoruce : aliasedResoruces) {
+ recordIfUnifiable(descriptorResoruce.first, descriptorResoruce.second);
+ }
+}
+
+bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
+ if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
+ auto canonicalOp = getCanonicalResource(varOp);
+ return canonicalOp && varOp != canonicalOp;
+ }
+ if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
+ auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
+ auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
+ return shouldUnify(varOp);
+ }
+
+ if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
+ return shouldUnify(acOp.base_ptr().getDefiningOp());
+ if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
+ return shouldUnify(loadOp.ptr().getDefiningOp());
+ if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
+ return shouldUnify(storeOp.ptr().getDefiningOp());
+
+ return false;
+}
+
+spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
+ const Descriptor &descriptor) const {
+ auto varIt = canonicalResourceMap.find(descriptor);
+ if (varIt == canonicalResourceMap.end())
+ return {};
+ return varIt->second;
+}
+
+spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
+ spirv::GlobalVariableOp varOp) const {
+ auto descriptorIt = descriptorMap.find(varOp);
+ if (descriptorIt == descriptorMap.end())
+ return {};
+ return getCanonicalResource(descriptorIt->second);
+}
+
+spirv::SPIRVType
+ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
+ auto it = elementTypeMap.find(varOp);
+ if (it == elementTypeMap.end())
+ return {};
+ return it->second;
+}
+
+void ResourceAliasAnalysis::recordIfUnifiable(
+ const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
+ // Collect the element types and byte counts for all resources in the
+ // current set.
+ SmallVector<spirv::SPIRVType> elementTypes;
+ SmallVector<int64_t> numBytes;
+
+ for (spirv::GlobalVariableOp resource : resources) {
+ Type elementType = getRuntimeArrayElementType(resource.type());
+ if (!elementType)
+ return; // Unexpected resource variable type.
+
+ auto type = elementType.cast<spirv::SPIRVType>();
+ if (!type.isScalarOrVector())
+ return; // Unexpected resource element type.
+
+ if (auto vectorType = type.dyn_cast<VectorType>())
+ if (vectorType.getNumElements() % 2 != 0)
+ return; // Odd-sized vector has special layout requirements.
+
+ Optional<int64_t> count = type.getSizeInBytes();
+ if (!count)
+ return;
+
+ elementTypes.push_back(type);
+ numBytes.push_back(*count);
+ }
+
+ // Make sure base scalar types have the same bitwdith, so that we don't need
+ // to handle extracting components for now.
+ if (!hasSameBitwidthScalarType(elementTypes))
+ return;
+
+ // Make sure that the canonical resource's bitwidth is divisible by others.
+ // With out this, we cannot properly adjust the index later.
+ auto *maxCount = std::max_element(numBytes.begin(), numBytes.end());
+ if (llvm::any_of(numBytes, [maxCount](int64_t count) {
+ return *maxCount % count != 0;
+ }))
+ return;
+
+ spirv::GlobalVariableOp canonicalResource =
+ resources[std::distance(numBytes.begin(), maxCount)];
+
+ // Update internal data structures for later use.
+ resourceMap[descriptor].assign(resources.begin(), resources.end());
+ canonicalResourceMap[descriptor] = canonicalResource;
+ for (const auto &resource : llvm::enumerate(resources)) {
+ descriptorMap[resource.value()] = descriptor;
+ elementTypeMap[resource.value()] = elementTypes[resource.index()];
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+class ConvertAliasResoruce : public OpConversionPattern<OpTy> {
+public:
+ ConvertAliasResoruce(const ResourceAliasAnalysis &analysis,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
+
+protected:
+ const ResourceAliasAnalysis &analysis;
+};
+
+struct ConvertVariable : public ConvertAliasResoruce<spirv::GlobalVariableOp> {
+ using ConvertAliasResoruce::ConvertAliasResoruce;
+
+ LogicalResult
+ matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Just remove the aliased resource. Users will be rewritten to use the
+ // canonical one.
+ rewriter.eraseOp(varOp);
+ return success();
+ }
+};
+
+struct ConvertAddressOf : public ConvertAliasResoruce<spirv::AddressOfOp> {
+ using ConvertAliasResoruce::ConvertAliasResoruce;
+
+ LogicalResult
+ matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Rewrite the AddressOf op to get the address of the canoncical resource.
+ auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
+ auto srcVarOp = cast<spirv::GlobalVariableOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
+ auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
+ rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
+ return success();
+ }
+};
+
+struct ConvertAccessChain : public ConvertAliasResoruce<spirv::AccessChainOp> {
+ using ConvertAliasResoruce::ConvertAliasResoruce;
+
+ LogicalResult
+ matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
+ if (!addressOp)
+ return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
+
+ auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
+ auto srcVarOp = cast<spirv::GlobalVariableOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
+ auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
+
+ spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
+ spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
+
+ if ((srcElemType == dstElemType) ||
+ (srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) {
+ // We have the same bitwidth for source and destination element types.
+ // Thie indices keep the same.
+ rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
+ acOp, adaptor.base_ptr(), adaptor.indices());
+ return success();
+ }
+
+ Location loc = acOp.getLoc();
+ auto i32Type = rewriter.getI32Type();
+
+ if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
+ // The source indices are for a buffer with scalar element types. Rewrite
+ // them into a buffer with vector element types. We need to scale the last
+ // index for the vector as a whole, then add one level of index for inside
+ // the vector.
+ int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes();
+ auto ratioValue = rewriter.create<spirv::ConstantOp>(
+ loc, i32Type, rewriter.getI32IntegerAttr(ratio));
+
+ auto indices = llvm::to_vector<4>(acOp.indices());
+ Value oldIndex = indices.back();
+ indices.back() =
+ rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
+ indices.push_back(
+ rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
+
+ rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
+ acOp, adaptor.base_ptr(), indices);
+ return success();
+ }
+
+ return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
+ }
+};
+
+struct ConvertLoad : public ConvertAliasResoruce<spirv::LoadOp> {
+ using ConvertAliasResoruce::ConvertAliasResoruce;
+
+ LogicalResult
+ matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcElemType =
+ loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+ auto dstElemType =
+ adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+ if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
+ return rewriter.notifyMatchFailure(loadOp, "not scalar type");
+
+ Location loc = loadOp.getLoc();
+ auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
+ if (srcElemType == dstElemType) {
+ rewriter.replaceOp(loadOp, newLoadOp->getResults());
+ } else {
+ auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
+ newLoadOp.value());
+ rewriter.replaceOp(loadOp, castOp->getResults());
+ }
+
+ return success();
+ }
+};
+
+struct ConvertStore : public ConvertAliasResoruce<spirv::StoreOp> {
+ using ConvertAliasResoruce::ConvertAliasResoruce;
+
+ LogicalResult
+ matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcElemType =
+ storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+ auto dstElemType =
+ adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
+ if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
+ return rewriter.notifyMatchFailure(storeOp, "not scalar type");
+
+ Location loc = storeOp.getLoc();
+ Value value = adaptor.value();
+ if (srcElemType != dstElemType)
+ value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
+ rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
+ storeOp->getAttrs());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass
+//===----------------------------------------------------------------------===//
+
+namespace {
+class UnifyAliasedResourcePass final
+ : public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> {
+public:
+ void runOnOperation() override;
+};
+} // namespace
+
+void UnifyAliasedResourcePass::runOnOperation() {
+ spirv::ModuleOp moduleOp = getOperation();
+ MLIRContext *context = &getContext();
+
+ // Analyze aliased resources first.
+ ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
+
+ ConversionTarget target(*context);
+ target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
+ spirv::AccessChainOp, spirv::LoadOp,
+ spirv::StoreOp>(
+ [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
+ target.addLegalDialect<spirv::SPIRVDialect>();
+
+ // Run patterns to rewrite usages of non-canonical resources.
+ RewritePatternSet patterns(context);
+ patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
+ ConvertLoad, ConvertStore>(analysis, context);
+ if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+ return signalPassFailure();
+
+ // Drop aliased attribute if we only have one single bound resource for a
+ // descriptor. We need to re-collect the map here given in the above the
+ // conversion is best effort; certain sets may not be converted.
+ AliasedResourceMap resourceMap =
+ collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
+ for (const auto &dr : resourceMap) {
+ const auto &resources = dr.second;
+ if (resources.size() == 1)
+ resources.front()->removeAttr("aliased");
+ }
+}
+
+std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
+spirv::createUnifyAliasedResourcePass() {
+ return std::make_unique<UnifyAliasedResourcePass>();
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
new file mode 100644
index 0000000000000..546fc1f93b097
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
@@ -0,0 +1,215 @@
+// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource %s -o - | FileCheck %s
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+ spv.func @load_store_scalar(%index: i32) -> f32 "None" {
+ %c0 = spv.Constant 0 : i32
+ %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %value = spv.Load "StorageBuffer" %ac : f32
+ spv.Store "StorageBuffer" %ac, %value : f32
+ spv.ReturnValue %value : f32
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s
+// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+// CHECK: spv.func @load_store_scalar(%[[INDEX:.+]]: i32)
+// CHECK-DAG: %[[C0:.+]] = spv.Constant 0 : i32
+// CHECK-DAG: %[[C4:.+]] = spv.Constant 4 : i32
+// CHECK-DAG: %[[ADDR:.+]] = spv.mlir.addressof @var01v
+// CHECK: %[[DIV:.+]] = spv.SDiv %[[INDEX]], %[[C4]] : i32
+// CHECK: %[[MOD:.+]] = spv.SMod %[[INDEX]], %[[C4]] : i32
+// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[C0]], %[[DIV]], %[[MOD]]]
+// CHECK: spv.Load "StorageBuffer" %[[AC]]
+// CHECK: spv.Store "StorageBuffer" %[[AC]]
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+ spv.func @multiple_uses(%i0: i32, %i1: i32) -> f32 "None" {
+ %c0 = spv.Constant 0 : i32
+ %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac0 = spv.AccessChain %addr[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %val0 = spv.Load "StorageBuffer" %ac0 : f32
+ %ac1 = spv.AccessChain %addr[%c0, %i1] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %val1 = spv.Load "StorageBuffer" %ac1 : f32
+ %value = spv.FAdd %val0, %val1 : f32
+ spv.ReturnValue %value : f32
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s
+// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+// CHECK: spv.func @multiple_uses
+// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01v
+// CHECK: spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
+// CHECK: spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<3xf32>, stride=16> [0])>, StorageBuffer>
+
+ spv.func @vector3(%index: i32) -> f32 "None" {
+ %c0 = spv.Constant 0 : i32
+ %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %value = spv.Load "StorageBuffer" %ac : f32
+ spv.ReturnValue %value : f32
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK: spv.GlobalVariable @var01s bind(0, 1) {aliased}
+// CHECK: spv.GlobalVariable @var01v bind(0, 1) {aliased}
+// CHECK: spv.func @vector3
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01v bind(1, 0) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+ spv.func @not_aliased(%index: i32) -> f32 "None" {
+ %c0 = spv.Constant 0 : i32
+ %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %value = spv.Load "StorageBuffer" %ac : f32
+ spv.Store "StorageBuffer" %ac, %value : f32
+ spv.ReturnValue %value : f32
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK: spv.GlobalVariable @var01s bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+// CHECK: spv.GlobalVariable @var01v bind(1, 0) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK: spv.func @not_aliased
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01s_1 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01v_1 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+ spv.func @multiple_aliases(%index: i32) -> f32 "None" {
+ %c0 = spv.Constant 0 : i32
+
+ %addr0 = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %val0 = spv.Load "StorageBuffer" %ac0 : f32
+
+ %addr1 = spv.mlir.addressof @var01s_1 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac1 = spv.AccessChain %addr1[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %val1 = spv.Load "StorageBuffer" %ac1 : f32
+
+ %addr2 = spv.mlir.addressof @var01v_1 : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+ %ac2 = spv.AccessChain %addr2[%c0, %index, %c0] : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32, i32
+ %val2 = spv.Load "StorageBuffer" %ac2 : f32
+
+ %add0 = spv.FAdd %val0, %val1 : f32
+ %add1 = spv.FAdd %add0, %val2 : f32
+ spv.ReturnValue %add1 : f32
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s
+// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01v_1
+
+// CHECK: spv.func @multiple_aliases
+// CHECK: %[[ADDR0:.+]] = spv.mlir.addressof @var01v :
+// CHECK: spv.AccessChain %[[ADDR0]][%{{.+}}, %{{.+}}, %{{.+}}]
+// CHECK: %[[ADDR1:.+]] = spv.mlir.addressof @var01v :
+// CHECK: spv.AccessChain %[[ADDR1]][%{{.+}}, %{{.+}}, %{{.+}}]
+// CHECK: %[[ADDR2:.+]] = spv.mlir.addressof @var01v :
+// CHECK: spv.AccessChain %[[ADDR2]][%{{.+}}, %{{.+}}, %{{.+}}]
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01s_i32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+
+ spv.func @
diff erent_scalar_type(%index: i32, %val1: f32) -> i32 "None" {
+ %c0 = spv.Constant 0 : i32
+
+ %addr0 = spv.mlir.addressof @var01s_i32 : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+ %ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %val0 = spv.Load "StorageBuffer" %ac0 : i32
+
+ %addr1 = spv.mlir.addressof @var01s_f32 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
+ %ac1 = spv.AccessChain %addr1[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
+ spv.Store "StorageBuffer" %ac1, %val1 : f32
+
+ spv.ReturnValue %val0 : i32
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s_f32
+// CHECK: spv.GlobalVariable @var01s_i32 bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s_f32
+
+// CHECK: spv.func @
diff erent_scalar_type(%[[INDEX:.+]]: i32, %[[VAL1:.+]]: f32)
+
+// CHECK: %[[IADDR:.+]] = spv.mlir.addressof @var01s_i32
+// CHECK: %[[IAC:.+]] = spv.AccessChain %[[IADDR]][%{{.+}}, %[[INDEX]]]
+// CHECK: spv.Load "StorageBuffer" %[[IAC]] : i32
+
+// CHECK: %[[FADDR:.+]] = spv.mlir.addressof @var01s_i32
+// CHECK: %[[FAC:.+]] = spv.AccessChain %[[FADDR]][%cst0_i32, %[[INDEX]]]
+// CHECK: %[[CAST:.+]] = spv.Bitcast %[[VAL1]] : f32 to i32
+// CHECK: spv.Store "StorageBuffer" %[[FAC]], %[[CAST]] : i32
+
+// -----
+
+spv.module Logical GLSL450 {
+ spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+ spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+
+ spv.func @
diff erent_scalar_type(%index: i32, %val0: i32) -> i32 "None" {
+ %c0 = spv.Constant 0 : i32
+ %addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
+ %ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
+ %val1 = spv.Load "StorageBuffer" %ac : i32
+ spv.Store "StorageBuffer" %ac, %val0 : i32
+ spv.ReturnValue %val1 : i32
+ }
+}
+
+// CHECK-LABEL: spv.module
+
+// CHECK-NOT: @var01s
+// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
+// CHECK-NOT: @var01s
+
+// CHECK: spv.func @
diff erent_scalar_type(%{{.+}}: i32, %[[VAL0:.+]]: i32)
+// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01v
+// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
+// CHECK: %[[VAL1:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32
+// CHECK: %[[CAST1:.+]] = spv.Bitcast %[[VAL1]] : f32 to i32
+// CHECK: %[[CAST2:.+]] = spv.Bitcast %[[VAL0]] : i32 to f32
+// CHECK: spv.Store "StorageBuffer" %[[AC]], %[[CAST2]] : f32
+// CHECK: spv.ReturnValue %[[CAST1]] : i32
More information about the Mlir-commits
mailing list