[Mlir-commits] [mlir] [mlir][spirv] Use `AttrTypeReplacer` in map-memref-storage-class. NFC. (PR #80055)

Jakub Kuderski llvmlistbot at llvm.org
Tue Jan 30 14:47:56 PST 2024


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/80055

>From 9e4c4edb361828c917da4c98b8854f261f9119b5 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 30 Jan 2024 15:22:08 -0500
Subject: [PATCH 1/2] [mlir][spirv] Use `AttrTypeReplacer` in
 map-memref-storage-class. NFC.

Keep the conversion target to allow for checking if the op is legal.
---
 .../Conversion/MemRefToSPIRV/MemRefToSPIRV.h  |  9 +-
 .../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp  | 13 +--
 .../MapMemRefStorageClassPass.cpp             | 90 +++++--------------
 3 files changed, 33 insertions(+), 79 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 9463ceb4363ef..54711c8ad727f 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -58,11 +58,10 @@ class MemorySpaceToStorageClassConverter : public TypeConverter {
 std::unique_ptr<ConversionTarget>
 getMemorySpaceToStorageClassTarget(MLIRContext &);
 
-/// Appends to a pattern list additional patterns for converting numeric MemRef
-/// memory spaces into SPIR-V symbolic ones.
-void populateMemorySpaceToStorageClassPatterns(
-    MemorySpaceToStorageClassConverter &typeConverter,
-    RewritePatternSet &patterns);
+/// Converts all MemRef types and attributes in the op, as decided by the
+/// `typeConverter`.
+void convertMemRefTypesAndAttrs(
+    Operation *op, MemorySpaceToStorageClassConverter &typeConverter);
 
 } // namespace spirv
 
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 0dd0e7e21b055..d3402fd766def 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTGPUTOSPIRV
@@ -90,19 +91,19 @@ void GPUToSPIRVPass::runOnOperation() {
 
     // Map MemRef memory space to SPIR-V storage class first if requested.
     if (mapMemorySpace) {
-      std::unique_ptr<ConversionTarget> target =
-          spirv::getMemorySpaceToStorageClassTarget(*context);
       spirv::MemorySpaceToStorageClassMap memorySpaceMap =
           targetEnvSupportsKernelCapability(
               dyn_cast<gpu::GPUModuleOp>(gpuModule))
               ? spirv::mapMemorySpaceToOpenCLStorageClass
               : spirv::mapMemorySpaceToVulkanStorageClass;
       spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+      spirv::convertMemRefTypesAndAttrs(gpuModule, converter);
 
-      RewritePatternSet patterns(context);
-      spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
-      if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
+      // Run the conversion on an empty pattern set to diagnose any illegal ops
+      // left.
+      if (failed(applyFullConversion(
+              gpuModule, *spirv::getMemorySpaceToStorageClassTarget(*context),
+              RewritePatternSet{context})))
         return signalPassFailure();
     }
 
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index fef1055d7b3f2..7ce6a4035d5a8 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -26,6 +26,7 @@
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
+#include <optional>
 
 namespace mlir {
 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
@@ -243,66 +244,17 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
   return target;
 }
 
-//===----------------------------------------------------------------------===//
-// Conversion Pattern
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Converts any op that has operands/results/attributes with numeric MemRef
-/// memory spaces.
-struct MapMemRefStoragePattern final : ConversionPattern {
-  MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
-      : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    llvm::SmallVector<NamedAttribute> newAttrs;
-    newAttrs.reserve(op->getAttrs().size());
-    for (NamedAttribute attr : op->getAttrs()) {
-      if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
-        Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
-        if (!newAttr) {
-          return rewriter.notifyMatchFailure(
-              op, "type attribute conversion failed");
-        }
-        newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
-      } else {
-        newAttrs.push_back(attr);
-      }
-    }
-
-    llvm::SmallVector<Type, 4> newResults;
-    if (failed(
-            getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
-      return rewriter.notifyMatchFailure(op, "result type conversion failed");
-
-    OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
-                         newResults, newAttrs, op->getSuccessors());
-
-    for (Region &region : op->getRegions()) {
-      Region *newRegion = state.addRegion();
-      rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
-      TypeConverter::SignatureConversion result(newRegion->getNumArguments());
-      if (failed(getTypeConverter()->convertSignatureArgs(
-              newRegion->getArgumentTypes(), result))) {
-        return rewriter.notifyMatchFailure(
-            op, "signature argument type conversion failed");
-      }
-      rewriter.applySignatureConversion(newRegion, result);
-    }
-
-    Operation *newOp = rewriter.create(state);
-    rewriter.replaceOp(op, newOp->getResults());
-    return success();
-  }
-};
-} // namespace
+void spirv::convertMemRefTypesAndAttrs(
+    Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
+  AttrTypeReplacer replacer;
+  replacer.addReplacement([&typeConverter](BaseMemRefType origType)
+                              -> std::optional<BaseMemRefType> {
+    return typeConverter.convertType<BaseMemRefType>(origType);
+  });
 
-void spirv::populateMemorySpaceToStorageClassPatterns(
-    spirv::MemorySpaceToStorageClassConverter &typeConverter,
-    RewritePatternSet &patterns) {
-  patterns.add<MapMemRefStoragePattern>(patterns.getContext(), typeConverter);
+  replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
+                                        /*replaceLocs=*/false,
+                                        /*replaceTypes=*/true);
 }
 
 //===----------------------------------------------------------------------===//
@@ -335,23 +287,25 @@ class MapMemRefStorageClassPass final
     MLIRContext *context = &getContext();
     Operation *op = getOperation();
 
+    spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
     if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
       spirv::TargetEnv targetEnv(attr);
       if (targetEnv.allows(spirv::Capability::Kernel)) {
-        memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+        spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
       } else if (targetEnv.allows(spirv::Capability::Shader)) {
-        memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+        spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
       }
     }
 
-    std::unique_ptr<ConversionTarget> target =
-        spirv::getMemorySpaceToStorageClassTarget(*context);
-    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-
-    RewritePatternSet patterns(context);
-    spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
+    spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
+    // Perform the replacement.
+    spirv::convertMemRefTypesAndAttrs(op, converter);
 
-    if (failed(applyFullConversion(op, *target, std::move(patterns))))
+    // Only perform the conversion to check that there are no illegal ops
+    // remaining. Do not attempt to convert anything.
+    if (failed(applyFullConversion(
+            op, *spirv::getMemorySpaceToStorageClassTarget(*context),
+            RewritePatternSet{context})))
       return signalPassFailure();
   }
 

>From 3daae166f5e14d6b2a6b54f572c7d9ff66279581 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Tue, 30 Jan 2024 17:47:46 -0500
Subject: [PATCH 2/2] Walk instead of doing full dialect conversion

---
 .../Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp  | 17 ++++++++++------
 .../MapMemRefStorageClassPass.cpp             | 20 ++++++++++++-------
 2 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index d3402fd766def..1d1db913e3df2 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -99,12 +99,17 @@ void GPUToSPIRVPass::runOnOperation() {
       spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
       spirv::convertMemRefTypesAndAttrs(gpuModule, converter);
 
-      // Run the conversion on an empty pattern set to diagnose any illegal ops
-      // left.
-      if (failed(applyFullConversion(
-              gpuModule, *spirv::getMemorySpaceToStorageClassTarget(*context),
-              RewritePatternSet{context})))
-        return signalPassFailure();
+      // Check if there are any illegal ops remaining.
+      std::unique_ptr<ConversionTarget> target =
+          spirv::getMemorySpaceToStorageClassTarget(*context);
+      gpuModule->walk([&target, this](Operation *childOp) {
+        if (target->isIllegal(childOp)) {
+          childOp->emitOpError("failed to legalize memory space");
+          signalPassFailure();
+          return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
+      });
     }
 
     std::unique_ptr<ConversionTarget> target =
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 7ce6a4035d5a8..76dab8ee4ac33 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -21,8 +21,9 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
@@ -301,12 +302,17 @@ class MapMemRefStorageClassPass final
     // Perform the replacement.
     spirv::convertMemRefTypesAndAttrs(op, converter);
 
-    // Only perform the conversion to check that there are no illegal ops
-    // remaining. Do not attempt to convert anything.
-    if (failed(applyFullConversion(
-            op, *spirv::getMemorySpaceToStorageClassTarget(*context),
-            RewritePatternSet{context})))
-      return signalPassFailure();
+    // Check if there are any illegal ops remaining.
+    std::unique_ptr<ConversionTarget> target =
+        spirv::getMemorySpaceToStorageClassTarget(*context);
+    op->walk([&target, this](Operation *childOp) {
+      if (target->isIllegal(childOp)) {
+        childOp->emitOpError("failed to legalize memory space");
+        signalPassFailure();
+        return WalkResult::interrupt();
+      }
+      return WalkResult::advance();
+    });
   }
 
 private:



More information about the Mlir-commits mailing list