[Mlir-commits] [mlir] b91bba8 - [mlir][spirv] Use `AttrTypeReplacer` in map-memref-storage-class. NFC. (#80055)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 30 17:32:30 PST 2024


Author: Jakub Kuderski
Date: 2024-01-30T20:32:26-05:00
New Revision: b91bba89edfb25d011e1f2366cda5dec605c87f6

URL: https://github.com/llvm/llvm-project/commit/b91bba89edfb25d011e1f2366cda5dec605c87f6
DIFF: https://github.com/llvm/llvm-project/commit/b91bba89edfb25d011e1f2366cda5dec605c87f6.diff

LOG: [mlir][spirv] Use `AttrTypeReplacer` in map-memref-storage-class. NFC. (#80055)

Keep the conversion target to allow for checking if the op is legal.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
    mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
    mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
index 9463ceb4363ef..54711c8ad727f 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h
@@ -58,11 +58,10 @@ class MemorySpaceToStorageClassConverter : public TypeConverter {
 std::unique_ptr<ConversionTarget>
 getMemorySpaceToStorageClassTarget(MLIRContext &);
 
-/// Appends to a pattern list additional patterns for converting numeric MemRef
-/// memory spaces into SPIR-V symbolic ones.
-void populateMemorySpaceToStorageClassPatterns(
-    MemorySpaceToStorageClassConverter &typeConverter,
-    RewritePatternSet &patterns);
+/// Converts all MemRef types and attributes in the op, as decided by the
+/// `typeConverter`.
+void convertMemRefTypesAndAttrs(
+    Operation *op, MemorySpaceToStorageClassConverter &typeConverter);
 
 } // namespace spirv
 

diff  --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 0dd0e7e21b055..1d1db913e3df2 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTGPUTOSPIRV
@@ -90,20 +91,25 @@ void GPUToSPIRVPass::runOnOperation() {
 
     // Map MemRef memory space to SPIR-V storage class first if requested.
     if (mapMemorySpace) {
-      std::unique_ptr<ConversionTarget> target =
-          spirv::getMemorySpaceToStorageClassTarget(*context);
       spirv::MemorySpaceToStorageClassMap memorySpaceMap =
           targetEnvSupportsKernelCapability(
               dyn_cast<gpu::GPUModuleOp>(gpuModule))
               ? spirv::mapMemorySpaceToOpenCLStorageClass
               : spirv::mapMemorySpaceToVulkanStorageClass;
       spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
+      spirv::convertMemRefTypesAndAttrs(gpuModule, converter);
 
-      RewritePatternSet patterns(context);
-      spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
-      if (failed(applyFullConversion(gpuModule, *target, std::move(patterns))))
-        return signalPassFailure();
+      // Check if there are any illegal ops remaining.
+      std::unique_ptr<ConversionTarget> target =
+          spirv::getMemorySpaceToStorageClassTarget(*context);
+      gpuModule->walk([&target, this](Operation *childOp) {
+        if (target->isIllegal(childOp)) {
+          childOp->emitOpError("failed to legalize memory space");
+          signalPassFailure();
+          return WalkResult::interrupt();
+        }
+        return WalkResult::advance();
+      });
     }
 
     std::unique_ptr<ConversionTarget> target =

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index fef1055d7b3f2..76dab8ee4ac33 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -21,11 +21,13 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
+#include <optional>
 
 namespace mlir {
 #define GEN_PASS_DEF_MAPMEMREFSTORAGECLASS
@@ -243,66 +245,17 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
   return target;
 }
 
-//===----------------------------------------------------------------------===//
-// Conversion Pattern
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Converts any op that has operands/results/attributes with numeric MemRef
-/// memory spaces.
-struct MapMemRefStoragePattern final : ConversionPattern {
-  MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
-      : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    llvm::SmallVector<NamedAttribute> newAttrs;
-    newAttrs.reserve(op->getAttrs().size());
-    for (NamedAttribute attr : op->getAttrs()) {
-      if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
-        Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
-        if (!newAttr) {
-          return rewriter.notifyMatchFailure(
-              op, "type attribute conversion failed");
-        }
-        newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
-      } else {
-        newAttrs.push_back(attr);
-      }
-    }
-
-    llvm::SmallVector<Type, 4> newResults;
-    if (failed(
-            getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
-      return rewriter.notifyMatchFailure(op, "result type conversion failed");
-
-    OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
-                         newResults, newAttrs, op->getSuccessors());
-
-    for (Region &region : op->getRegions()) {
-      Region *newRegion = state.addRegion();
-      rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
-      TypeConverter::SignatureConversion result(newRegion->getNumArguments());
-      if (failed(getTypeConverter()->convertSignatureArgs(
-              newRegion->getArgumentTypes(), result))) {
-        return rewriter.notifyMatchFailure(
-            op, "signature argument type conversion failed");
-      }
-      rewriter.applySignatureConversion(newRegion, result);
-    }
-
-    Operation *newOp = rewriter.create(state);
-    rewriter.replaceOp(op, newOp->getResults());
-    return success();
-  }
-};
-} // namespace
+void spirv::convertMemRefTypesAndAttrs(
+    Operation *op, MemorySpaceToStorageClassConverter &typeConverter) {
+  AttrTypeReplacer replacer;
+  replacer.addReplacement([&typeConverter](BaseMemRefType origType)
+                              -> std::optional<BaseMemRefType> {
+    return typeConverter.convertType<BaseMemRefType>(origType);
+  });
 
-void spirv::populateMemorySpaceToStorageClassPatterns(
-    spirv::MemorySpaceToStorageClassConverter &typeConverter,
-    RewritePatternSet &patterns) {
-  patterns.add<MapMemRefStoragePattern>(patterns.getContext(), typeConverter);
+  replacer.recursivelyReplaceElementsIn(op, /*replaceAttrs=*/true,
+                                        /*replaceLocs=*/false,
+                                        /*replaceTypes=*/true);
 }
 
 //===----------------------------------------------------------------------===//
@@ -335,24 +288,31 @@ class MapMemRefStorageClassPass final
     MLIRContext *context = &getContext();
     Operation *op = getOperation();
 
+    spirv::MemorySpaceToStorageClassMap spaceToStorage = memorySpaceMap;
     if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
       spirv::TargetEnv targetEnv(attr);
       if (targetEnv.allows(spirv::Capability::Kernel)) {
-        memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+        spaceToStorage = spirv::mapMemorySpaceToOpenCLStorageClass;
       } else if (targetEnv.allows(spirv::Capability::Shader)) {
-        memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+        spaceToStorage = spirv::mapMemorySpaceToVulkanStorageClass;
       }
     }
 
+    spirv::MemorySpaceToStorageClassConverter converter(spaceToStorage);
+    // Perform the replacement.
+    spirv::convertMemRefTypesAndAttrs(op, converter);
+
+    // Check if there are any illegal ops remaining.
     std::unique_ptr<ConversionTarget> target =
         spirv::getMemorySpaceToStorageClassTarget(*context);
-    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-
-    RewritePatternSet patterns(context);
-    spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
-    if (failed(applyFullConversion(op, *target, std::move(patterns))))
-      return signalPassFailure();
+    op->walk([&target, this](Operation *childOp) {
+      if (target->isIllegal(childOp)) {
+        childOp->emitOpError("failed to legalize memory space");
+        signalPassFailure();
+        return WalkResult::interrupt();
+      }
+      return WalkResult::advance();
+    });
   }
 
 private:


        


More information about the Mlir-commits mailing list