[Mlir-commits] [mlir] b91bba8 - [mlir][spirv] Use `AttrTypeReplacer` in map-memref-storage-class. NFC. (#80055)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 17:32:30 PST 2024
Author: Jakub Kuderski
Date: 2024-01-30T20:32:26-05:00
New Revision: b91bba89edfb25d011e1f2366cda5dec605c87f6
URL: https://github.com/llvm/llvm-project/commit/b91bba89edfb25d011e1f2366cda5dec605c87f6
DIFF: https://github.com/llvm/llvm-project/commit/b91bba89edfb25d011e1f2366cda5dec605c87f6.diff
LOG: [mlir][spirv] Use `AttrTypeReplacer` in map-memref-storage-class. NFC. (#80055)
Keep the conversion target to allow for checking if the op is legal.
Added:
Modified:
mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 9463ceb4363ef..54711c8ad727f 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -58,11 +58,10 @@ class MemorySpaceToStorageClassConverter : public TypeConverter {
std::unique_ptr<ConversionTarget>
getMemorySpaceToStorageClassTarget(MLIRContext &);
-/// Appends to a pattern list additional patterns for converting numeric MemRef
-/// memory spaces into SPIR-V symbolic ones.
-void populateMemorySpaceToStorageClassPatterns(
- MemorySpaceToStorageClassConverter &typeConverter,
- RewritePatternSet &patterns);
+/// Converts all MemRef types and attributes in the op, as decided by the
+/// `typeConverter`.
+void convertMemRefTypesAndAttrs(
+ Operation *op, MemorySpaceToStorageClassConverter &typeConverter);
} // namespace spirv
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 0dd0e7e21b055..1d1db913e3df2 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/PatternMatch.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTGPUTOSPIRV
@@ -90,20 +91,25 @@ void GPUToSPIRVPass::runOnOperation() {
// Map MemRef memory space to SPIR-V storage class first if requested.
if (mapMemorySpace) {
- std::unique_ptr<ConversionTarget> target =
- spirv::getMemorySpaceToStorageClassTarget(*context);
spirv::MemorySpaceToStorageClassMap memorySpaceMap =
targetEnvSupportsKernelCapability(
dyn_cast<gpu::GPUModuleOp>(gpuModule))
? spirv::mapMemorySpaceToOpenCLStorageClass
: spirv::mapMemorySpaceToVulkanStorageClass;
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+ spirv::convertMemRefTypesAndAttrs(gpuModule, converter);
- RewritePatternSet patterns(context);
- spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
- if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
- return signalPassFailure();
+ // Check if there are any illegal ops remaining.
+ std::unique_ptr<ConversionTarget> target =
+ spirv::getMemorySpaceToStorageClassTarget(*context);
+ gpuModule->walk([&target, this](Operation *childOp) {
+ if (target->isIllegal(childOp)) {
+ childOp->emitOpError("failed to legalize memory space");
+ signalPassFailure();
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
}
std::unique_ptr<ConversionTarget> target =
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index fef1055d7b3f2..76dab8ee4ac33 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -21,11 +21,13 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
+#include <optional>
namespace mlir {
#define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
@@ -243,66 +245,17 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
return target;
}
-//===----------------------------------------------------------------------===//
-// Conversion Pattern
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Converts any op that has operands/results/attributes with numeric MemRef
-/// memory spaces.
-struct MapMemRefStoragePattern final : ConversionPattern {
- MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
- : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
-
- LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const override {
- llvm::SmallVector<NamedAttribute> newAttrs;
- newAttrs.reserve(op->getAttrs().size());
- for (NamedAttribute attr : op->getAttrs()) {
- if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
- Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
- if (!newAttr) {
- return rewriter.notifyMatchFailure(
- op, "type attribute conversion failed");
- }
- newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
- } else {
- newAttrs.push_back(attr);
- }
- }
-
- llvm::SmallVector<Type, 4> newResults;
- if (failed(
- getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
- return rewriter.notifyMatchFailure(op, "result type conversion failed");
-
- OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
- newResults, newAttrs, op->getSuccessors());
-
- for (Region ®ion : op->getRegions()) {
- Region *newRegion = state.addRegion();
- rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
- TypeConverter::SignatureConversion result(newRegion->getNumArguments());
- if (failed(getTypeConverter()->convertSignatureArgs(
- newRegion->getArgumentTypes(), result))) {
- return rewriter.notifyMatchFailure(
- op, "signature argument type conversion failed");
- }
- rewriter.applySignatureConversion(newRegion, result);
- }
-
- Operation *newOp = rewriter.create(state);
- rewriter.replaceOp(op, newOp->getResults());
- return success();
- }
-};
-} // namespace
+void spirv::convertMemRefTypesAndAttrs(
+ Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
+ AttrTypeReplacer replacer;
+ replacer.addReplacement([&typeConverter](BaseMemRefType origType)
+ -> std::optional<BaseMemRefType> {
+ return typeConverter.convertType<BaseMemRefType>(origType);
+ });
-void spirv::populateMemorySpaceToStorageClassPatterns(
- spirv::MemorySpaceToStorageClassConverter &typeConverter,
- RewritePatternSet &patterns) {
- patterns.add<MapMemRefStoragePattern>(patterns.getContext(), typeConverter);
+ replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
+ /*replaceLocs=*/false,
+ /*replaceTypes=*/true);
}
//===----------------------------------------------------------------------===//
@@ -335,24 +288,31 @@ class MapMemRefStorageClassPass final
MLIRContext *context = &getContext();
Operation *op = getOperation();
+ spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
spirv::TargetEnv targetEnv(attr);
if (targetEnv.allows(spirv::Capability::Kernel)) {
- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+ spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
} else if (targetEnv.allows(spirv::Capability::Shader)) {
- memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+ spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
}
}
+ spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
+ // Perform the replacement.
+ spirv::convertMemRefTypesAndAttrs(op, converter);
+
+ // Check if there are any illegal ops remaining.
std::unique_ptr<ConversionTarget> target =
spirv::getMemorySpaceToStorageClassTarget(*context);
- spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-
- RewritePatternSet patterns(context);
- spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
- if (failed(applyFullConversion(op, *target, std::move(patterns))))
- return signalPassFailure();
+ op->walk([&target, this](Operation *childOp) {
+ if (target->isIllegal(childOp)) {
+ childOp->emitOpError("failed to legalize memory space");
+ signalPassFailure();
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
}
private:
More information about the Mlir-commits
mailing list