[Mlir-commits] [mlir] [mlir][spirv] Use `AttrTypeReplacer` in map-memref-storage-class. NFC. (PR #80055)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 30 12:24:51 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

<details>
<summary>Changes</summary>

Keep the conversion target to allow for checking if the op is legal.

---
Full diff: https://github.com/llvm/llvm-project/pull/80055.diff


3 Files Affected:

- (modified) mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h (+4-5) 
- (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp (+7-6) 
- (modified) mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp (+22-68) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 9463ceb4363ef..54711c8ad727f 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -58,11 +58,10 @@ class MemorySpaceToStorageClassConverter : public TypeConverter {
 std::unique_ptr<ConversionTarget>
 getMemorySpaceToStorageClassTarget(MLIRContext &);
 
-/// Appends to a pattern list additional patterns for converting numeric MemRef
-/// memory spaces into SPIR-V symbolic ones.
-void populateMemorySpaceToStorageClassPatterns(
-    MemorySpaceToStorageClassConverter &typeConverter,
-    RewritePatternSet &patterns);
+/// Converts all MemRef types and attributes in the op, as decided by the
+/// `typeConverter`.
+void convertMemRefTypesAndAttrs(
+    Operation *op, MemorySpaceToStorageClassConverter &typeConverter);
 
 } // namespace spirv
 
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 0dd0e7e21b055..d3402fd766def 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTGPUTOSPIRV
@@ -90,19 +91,19 @@ void GPUToSPIRVPass::runOnOperation() {
 
     // Map MemRef memory space to SPIR-V storage class first if requested.
     if (mapMemorySpace) {
-      std::unique_ptr<ConversionTarget> target =
-          spirv::getMemorySpaceToStorageClassTarget(*context);
       spirv::MemorySpaceToStorageClassMap memorySpaceMap =
           targetEnvSupportsKernelCapability(
               dyn_cast<gpu::GPUModuleOp>(gpuModule))
               ? spirv::mapMemorySpaceToOpenCLStorageClass
               : spirv::mapMemorySpaceToVulkanStorageClass;
       spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+      spirv::convertMemRefTypesAndAttrs(gpuModule, converter);
 
-      RewritePatternSet patterns(context);
-      spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
-      if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
+      // Run the conversion on an empty pattern set to diagnose any illegal ops
+      // left.
+      if (failed(applyFullConversion(
+              gpuModule, *spirv::getMemorySpaceToStorageClassTarget(*context),
+              RewritePatternSet{context})))
         return signalPassFailure();
     }
 
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index fef1055d7b3f2..7ce6a4035d5a8 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -26,6 +26,7 @@
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
+#include <optional>
 
 namespace mlir {
 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
@@ -243,66 +244,17 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
   return target;
 }
 
-//===----------------------------------------------------------------------===//
-// Conversion Pattern
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Converts any op that has operands/results/attributes with numeric MemRef
-/// memory spaces.
-struct MapMemRefStoragePattern final : ConversionPattern {
-  MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
-      : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    llvm::SmallVector<NamedAttribute> newAttrs;
-    newAttrs.reserve(op->getAttrs().size());
-    for (NamedAttribute attr : op->getAttrs()) {
-      if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
-        Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
-        if (!newAttr) {
-          return rewriter.notifyMatchFailure(
-              op, "type attribute conversion failed");
-        }
-        newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
-      } else {
-        newAttrs.push_back(attr);
-      }
-    }
-
-    llvm::SmallVector<Type, 4> newResults;
-    if (failed(
-            getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
-      return rewriter.notifyMatchFailure(op, "result type conversion failed");
-
-    OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
-                         newResults, newAttrs, op->getSuccessors());
-
-    for (Region &region : op->getRegions()) {
-      Region *newRegion = state.addRegion();
-      rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
-      TypeConverter::SignatureConversion result(newRegion->getNumArguments());
-      if (failed(getTypeConverter()->convertSignatureArgs(
-              newRegion->getArgumentTypes(), result))) {
-        return rewriter.notifyMatchFailure(
-            op, "signature argument type conversion failed");
-      }
-      rewriter.applySignatureConversion(newRegion, result);
-    }
-
-    Operation *newOp = rewriter.create(state);
-    rewriter.replaceOp(op, newOp->getResults());
-    return success();
-  }
-};
-} // namespace
+void spirv::convertMemRefTypesAndAttrs(
+    Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
+  AttrTypeReplacer replacer;
+  replacer.addReplacement([&typeConverter](BaseMemRefType origType)
+                              -> std::optional<BaseMemRefType> {
+    return typeConverter.convertType<BaseMemRefType>(origType);
+  });
 
-void spirv::populateMemorySpaceToStorageClassPatterns(
-    spirv::MemorySpaceToStorageClassConverter &typeConverter,
-    RewritePatternSet &patterns) {
-  patterns.add<MapMemRefStoragePattern>(patterns.getContext(), typeConverter);
+  replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
+                                        /*replaceLocs=*/false,
+                                        /*replaceTypes=*/true);
 }
 
 //===----------------------------------------------------------------------===//
@@ -335,23 +287,25 @@ class MapMemRefStorageClassPass final
     MLIRContext *context = &getContext();
     Operation *op = getOperation();
 
+    spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
     if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
       spirv::TargetEnv targetEnv(attr);
       if (targetEnv.allows(spirv::Capability::Kernel)) {
-        memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+        spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
       } else if (targetEnv.allows(spirv::Capability::Shader)) {
-        memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+        spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
       }
     }
 
-    std::unique_ptr<ConversionTarget> target =
-        spirv::getMemorySpaceToStorageClassTarget(*context);
-    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-
-    RewritePatternSet patterns(context);
-    spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
+    spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
+    // Perform the replacement.
+    spirv::convertMemRefTypesAndAttrs(op, converter);
 
-    if (failed(applyFullConversion(op, *target, std::move(patterns))))
+    // Only perform the conversion to check that there are no illegal ops
+    // remaining. Do not attempt to convert anything.
+    if (failed(applyFullConversion(
+            op, *spirv::getMemorySpaceToStorageClassTarget(*context),
+            RewritePatternSet{context})))
       return signalPassFailure();
   }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/80055


More information about the Mlir-commits mailing list