[Mlir-commits] [mlir] [mlir][spirv] Use `AttrTypeReplacer` in map-memref-storage-class. NFC. (PR #80055)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Jan 30 12:24:24 PST 2024
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/80055
Keep the conversion target to allow for checking if the op is legal.
>From 9e4c4edb361828c917da4c98b8854f261f9119b5 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 30 Jan 2024 15:22:08 -0500
Subject: [PATCH] [mlir][spirv] Use `AttrTypeReplacer` in
map-memref-storage-class. NFC.
Keep the conversion target to allow for checking if the op is legal.
---
.../Conversion/MemRefToSPIRV/MemRefToSPIRV.h | 9 +-
.../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp | 13 +--
.../MapMemRefStorageClassPass.cpp | 90 +++++--------------
3 files changed, 33 insertions(+), 79 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 9463ceb4363ef..54711c8ad727f 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -58,11 +58,10 @@ class MemorySpaceToStorageClassConverter : public TypeConverter {
std::unique_ptr<ConversionTarget>
getMemorySpaceToStorageClassTarget(MLIRContext &);
-/// Appends to a pattern list additional patterns for converting numeric MemRef
-/// memory spaces into SPIR-V symbolic ones.
-void populateMemorySpaceToStorageClassPatterns(
- MemorySpaceToStorageClassConverter &typeConverter,
- RewritePatternSet &patterns);
+/// Converts all MemRef types and attributes in the op, as decided by the
+/// `typeConverter`.
+void convertMemRefTypesAndAttrs(
+ Operation *op, MemorySpaceToStorageClassConverter &typeConverter);
} // namespace spirv
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 0dd0e7e21b055..d3402fd766def 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/PatternMatch.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTGPUTOSPIRV
@@ -90,19 +91,19 @@ void GPUToSPIRVPass::runOnOperation() {
// Map MemRef memory space to SPIR-V storage class first if requested.
if (mapMemorySpace) {
- std::unique_ptr<ConversionTarget> target =
- spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
targetEnvSupportsKernelCapability(
dyn_cast<gpu::GPUModuleOp>(gpuModule))
? spirv::mapMemorySpaceToOpenCLStorageClass
: spirv::mapMemorySpaceToVulkanStorageClass;
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+ spirv::convertMemRefTypesAndAttrs(gpuModule, converter);
- RewritePatternSet patterns(context);
- spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
- if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
+ // Run the conversion on an empty pattern set to diagnose any illegal ops
+ // left.
+ if (failed(applyFullConversion(
+ gpuModule, *spirv::getMemorySpaceToStorageClassTarget(*context),
+ RewritePatternSet{context})))
return signalPassFailure();
}
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index fef1055d7b3f2..7ce6a4035d5a8 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -26,6 +26,7 @@
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
+#include <optional>
namespace mlir {
#define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
@@ -243,66 +244,17 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
return target;
}
-//===----------------------------------------------------------------------===//
-// Conversion Pattern
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Converts any op that has operands/results/attributes with numeric MemRef
-/// memory spaces.
-struct MapMemRefStoragePattern final : ConversionPattern {
- MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
- : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
-
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- llvm::SmallVector<NamedAttribute> newAttrs;
- newAttrs.reserve(op->getAttrs().size());
- for (NamedAttribute attr : op->getAttrs()) {
- if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
- Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
- if (!newAttr) {
- return rewriter.notifyMatchFailure(
- op, "type attribute conversion failed");
- }
- newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
- } else {
- newAttrs.push_back(attr);
- }
- }
-
- llvm::SmallVector<Type, 4> newResults;
- if (failed(
- getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
- return rewriter.notifyMatchFailure(op, "result type conversion failed");
-
- OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
- newResults, newAttrs, op->getSuccessors());
-
- for (Region ®ion : op->getRegions()) {
- Region *newRegion = state.addRegion();
- rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
- TypeConverter::SignatureConversion result(newRegion->getNumArguments());
- if (failed(getTypeConverter()->convertSignatureArgs(
- newRegion->getArgumentTypes(), result))) {
- return rewriter.notifyMatchFailure(
- op, "signature argument type conversion failed");
- }
- rewriter.applySignatureConversion(newRegion, result);
- }
-
- Operation *newOp = rewriter.create(state);
- rewriter.replaceOp(op, newOp->getResults());
- return success();
- }
-};
-} // namespace
+void spirv::convertMemRefTypesAndAttrs(
+ Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
+ AttrTypeReplacer replacer;
+ replacer.addReplacement([&typeConverter](BaseMemRefType origType)
+ -> std::optional<BaseMemRefType> {
+ return typeConverter.convertType<BaseMemRefType>(origType);
+ });
-void spirv::populateMemorySpaceToStorageClassPatterns(
- spirv::MemorySpaceToStorageClassConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<MapMemRefStoragePattern>(patterns.getContext(), typeConverter);
+ replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
+ /*replaceLocs=*/false,
+ /*replaceTypes=*/true);
}
//===----------------------------------------------------------------------===//
@@ -335,23 +287,25 @@ class MapMemRefStorageClassPass final
MLIRContext *context = &getContext();
Operation *op = getOperation();
+ spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
spirv::TargetEnv targetEnv(attr);
if (targetEnv.allows(spirv::Capability::Kernel)) {
- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+ spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
} else if (targetEnv.allows(spirv::Capability::Shader)) {
- memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+ spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
}
}
- std::unique_ptr<ConversionTarget> target =
- spirv::getMemorySpaceToStorageClassTarget(*context);
- spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-
- RewritePatternSet patterns(context);
- spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
+ spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
+ // Perform the replacement.
+ spirv::convertMemRefTypesAndAttrs(op, converter);
- if (failed(applyFullConversion(op, *target, std::move(patterns))))
+ // Only perform the conversion to check that there are no illegal ops
+ // remaining. Do not attempt to convert anything.
+ if (failed(applyFullConversion(
+ op, *spirv::getMemorySpaceToStorageClassTarget(*context),
+ RewritePatternSet{context})))
return signalPassFailure();
}
More information about the Mlir-commits
mailing list