[Mlir-commits] [mlir] [mlir][spirv] Clean up map memref-storage-class pass (PR #79937)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Jan 29 19:11:07 PST 2024
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/79937
Clean up the code before making more substantial changes. NFC modulo extra error checking and physical storage buffer storage class handling.
* Add switch case for physical storage buffer
* Handle type conversion failures
* Inline methods to reduce scrolling
* Other minor cleanups
>From b8cd6fd1f79ddc0bdd8109a24e07db88f84e41f6 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Mon, 29 Jan 2024 22:06:09 -0500
Subject: [PATCH] [mlir][spir] Clean up map memref-storage-class pass
Clean up the code before making more substantial changes. NFC modulo extra
error checking and physical storage buffer storage class handling.
* Add switch case for physical storage bufer
* Handle type conversion failures
* Inline methods to reduce scrolling
* Clean up code
---
.../MapMemRefStorageClassPass.cpp | 156 +++++++++---------
1 file changed, 78 insertions(+), 78 deletions(-)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index c6ef5be2494ad..cb969e0b5d7f3 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -18,10 +18,12 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
@@ -54,7 +56,8 @@ using namespace mlir;
MAP_FN(spirv::StorageClass::PushConstant, 7) \
MAP_FN(spirv::StorageClass::UniformConstant, 8) \
MAP_FN(spirv::StorageClass::Input, 9) \
- MAP_FN(spirv::StorageClass::Output, 10)
+ MAP_FN(spirv::StorageClass::Output, 10) \
+ MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
std::optional<spirv::StorageClass>
spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
@@ -185,13 +188,10 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
});
addConversion([this](FunctionType type) {
- SmallVector<Type> inputs, results;
- inputs.reserve(type.getNumInputs());
- results.reserve(type.getNumResults());
- for (Type input : type.getInputs())
- inputs.push_back(convertType(input));
- for (Type result : type.getResults())
- results.push_back(convertType(result));
+ auto inputs = llvm::to_vector(llvm::map_range(
+ type.getInputs(), [this](Type ty) { return convertType(ty); }));
+ auto results = llvm::to_vector(llvm::map_range(
+ type.getResults(), [this](Type ty) { return convertType(ty); }));
return FunctionType::get(type.getContext(), inputs, results);
});
}
@@ -250,49 +250,54 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
namespace {
/// Converts any op that has operands/results/attributes with numeric MemRef
/// memory spaces.
-struct MapMemRefStoragePattern final : public ConversionPattern {
+struct MapMemRefStoragePattern final : ConversionPattern {
MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult MapMemRefStoragePattern::matchAndRewrite(
- Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- llvm::SmallVector<NamedAttribute, 4> newAttrs;
- newAttrs.reserve(op->getAttrs().size());
- for (auto attr : op->getAttrs()) {
- if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
- auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
- newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
- } else {
- newAttrs.push_back(attr);
+ ConversionPatternRewriter &rewriter) const override {
+ llvm::SmallVector<NamedAttribute> newAttrs;
+ newAttrs.reserve(op->getAttrs().size());
+ for (NamedAttribute attr : op->getAttrs()) {
+ if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
+ Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
+ if (!newAttr) {
+ return rewriter.notifyMatchFailure(
+ op, "type attribute conversion failed");
+ }
+ newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
+ } else {
+ newAttrs.push_back(attr);
+ }
}
- }
- llvm::SmallVector<Type, 4> newResults;
- (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
-
- OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
- newResults, newAttrs, op->getSuccessors());
+ llvm::SmallVector<Type, 4> newResults;
+ if (failed(
+ getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
+ return rewriter.notifyMatchFailure(op, "result type conversion failed");
+
+ OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+ newResults, newAttrs, op->getSuccessors());
+
+ for (Region ®ion : op->getRegions()) {
+ Region *newRegion = state.addRegion();
+ rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+ TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+ if (failed(getTypeConverter()->convertSignatureArgs(
+ newRegion->getArgumentTypes(), result))) {
+ return rewriter.notifyMatchFailure(
+ op, "signature argument type conversion failed");
+ }
+ rewriter.applySignatureConversion(newRegion, result);
+ }
- for (Region ®ion : op->getRegions()) {
- Region *newRegion = state.addRegion();
- rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
- TypeConverter::SignatureConversion result(newRegion->getNumArguments());
- (void)getTypeConverter()->convertSignatureArgs(
- newRegion->getArgumentTypes(), result);
- rewriter.applySignatureConversion(newRegion, result);
+ Operation *newOp = rewriter.create(state);
+ rewriter.replaceOp(op, newOp->getResults());
+ return success();
}
-
- Operation *newOp = rewriter.create(state);
- rewriter.replaceOp(op, newOp->getResults());
- return success();
-}
+};
+} // namespace
void spirv::populateMemorySpaceToStorageClassPatterns(
spirv::MemorySpaceToStorageClassConverter &typeConverter,
@@ -315,51 +320,46 @@ class MapMemRefStorageClassPass final
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
: memorySpaceMap(memorySpaceMap) {}
- LogicalResult initializeOptions(StringRef options) override;
-
- void runOnOperation() override;
-
-private:
- spirv::MemorySpaceToStorageClassMap memorySpaceMap;
-};
-} // namespace
+ LogicalResult initializeOptions(StringRef options) override {
+ if (failed(Pass::initializeOptions(options)))
+ return failure();
-LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
- if (failed(Pass::initializeOptions(options)))
- return failure();
+ if (clientAPI == "opencl")
+ memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+ else if (clientAPI != "vulkan")
+ return failure();
- if (clientAPI == "opencl") {
- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+ return success();
}
- if (clientAPI != "vulkan" && clientAPI != "opencl")
- return failure();
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ Operation *op = getOperation();
+
+ if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
+ spirv::TargetEnv targetEnv(attr);
+ if (targetEnv.allows(spirv::Capability::Kernel)) {
+ memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+ } else if (targetEnv.allows(spirv::Capability::Shader)) {
+ memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+ }
+ }
- return success();
-}
+ std::unique_ptr<ConversionTarget> target =
+ spirv::getMemorySpaceToStorageClassTarget(*context);
+ spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-void MapMemRefStorageClassPass::runOnOperation() {
- MLIRContext *context = &getContext();
- Operation *op = getOperation();
+ RewritePatternSet patterns(context);
+ spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
- if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
- spirv::TargetEnv targetEnv(attr);
- if (targetEnv.allows(spirv::Capability::Kernel)) {
- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
- } else if (targetEnv.allows(spirv::Capability::Shader)) {
- memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
- }
+ if (failed(applyFullConversion(op, *target, std::move(patterns))))
+ return signalPassFailure();
}
- auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
- spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-
- RewritePatternSet patterns(context);
- spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
- if (failed(applyFullConversion(op, *target, std::move(patterns))))
- return signalPassFailure();
-}
+private:
+ spirv::MemorySpaceToStorageClassMap memorySpaceMap;
+};
+} // namespace
std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
return std::make_unique<MapMemRefStorageClassPass>();
More information about the Mlir-commits
mailing list