[Mlir-commits] [mlir] b83d0d4 - [mlir][spirv] Make MemRef memory space mapping pass more flexible

Lei Zhang llvmlistbot at llvm.org
Tue Aug 9 11:21:56 PDT 2022


Author: Lei Zhang
Date: 2022-08-09T14:21:50-04:00
New Revision: b83d0d46c0b197c36c671493a3339d8d7a8bcd6a

URL: https://github.com/llvm/llvm-project/commit/b83d0d46c0b197c36c671493a3339d8d7a8bcd6a
DIFF: https://github.com/llvm/llvm-project/commit/b83d0d46c0b197c36c671493a3339d8d7a8bcd6a.diff

LOG: [mlir][spirv] Make MemRef memory space mapping pass more flexible

* Avoid restricting the pass to to builtin module ops. The pass
  should be able to run on any region ops.
* Avoid hardcoding func FuncOp when handling functions. Instead,
  use the function op interface.
* Assigns the default mapping in the constructor. So for cases
  where we are using the pass in a pipeline, we still have a
  meaningful default.

Along the way, dropped uncessary unrealized conversion casts and
use full conversion. The pass should be able to convert all sorts
of ops; there is really no need to have such bridages.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D131409

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
    mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
index 8b239699d8a5c..9f81d3376c5bd 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
@@ -21,7 +21,7 @@ class ModuleOp;
 
 /// Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V
 /// storage classes. The mapping is read from the command-line option.
-std::unique_ptr<OperationPass<ModuleOp>> createMapMemRefStorageClassPass();
+std::unique_ptr<OperationPass<>> createMapMemRefStorageClassPass();
 
 /// Creates a pass to convert MemRef ops to SPIR-V ops.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertMemRefToSPIRVPass();

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 88010d4d029ce..39ca0debc2ec3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -538,7 +538,7 @@ def ConvertMemRefToLLVM : Pass<"convert-memref-to-llvm", "ModuleOp"> {
 // MemRefToSPIRV
 //===----------------------------------------------------------------------===//
 
-def MapMemRefStorageClass : Pass<"map-memref-spirv-storage-class", "ModuleOp"> {
+def MapMemRefStorageClass : Pass<"map-memref-spirv-storage-class"> {
   let summary = "Map numeric MemRef memory spaces to SPIR-V storage classes";
   let constructor = "mlir::createMapMemRefStorageClassPass()";
   let dependentDialects = ["spirv::SPIRVDialect"];

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 0ba6708e027b7..535613714d53d 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -14,10 +14,10 @@
 #include "../PassDetail.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/FunctionInterfaces.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
@@ -86,15 +86,16 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
     Attribute spaceAttr = memRefType.getMemorySpace();
     if (spaceAttr && !spaceAttr.isa<IntegerAttr>()) {
       LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType
-                              << " due to non-IntegerAttr memory space");
+                              << " due to non-IntegerAttr memory space\n");
       return llvm::None;
     }
 
     unsigned space = memRefType.getMemorySpaceAsInt();
     auto it = this->memorySpaceMap.find(space);
     if (it == this->memorySpaceMap.end()) {
-      LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType
-                              << " due to unable to find memory space in map");
+      LLVM_DEBUG(llvm::dbgs()
+                 << "cannot convert " << memRefType
+                 << " due to being unable to find memory space in map\n");
       return llvm::None;
     }
 
@@ -143,10 +144,9 @@ static bool isLegalAttr(Attribute attr) {
 
 /// Returns true if the given `op` is considered as legal for SPIR-V conversion.
 static bool isLegalOp(Operation *op) {
-  if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
-    FunctionType funcType = funcOp.getFunctionType();
-    return llvm::all_of(funcType.getInputs(), isLegalType) &&
-           llvm::all_of(funcType.getResults(), isLegalType);
+  if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+    return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
+           llvm::all_of(funcOp.getResultTypes(), isLegalType);
   }
 
   auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
@@ -230,7 +230,18 @@ namespace {
 class MapMemRefStorageClassPass final
     : public MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
 public:
-  explicit MapMemRefStorageClassPass() = default;
+  explicit MapMemRefStorageClassPass() {
+    memorySpaceMap = spirv::getDefaultVulkanStorageClassMap();
+
+    LLVM_DEBUG({
+      llvm::dbgs() << "memory space to storage class mapping:\n";
+      if (memorySpaceMap.empty())
+        llvm::dbgs() << "  [empty]\n";
+      for (auto kv : memorySpaceMap)
+        llvm::dbgs() << "  " << kv.first << " -> "
+                     << spirv::stringifyStorageClass(kv.second) << "\n";
+    });
+  }
   explicit MapMemRefStorageClassPass(
       const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
       : memorySpaceMap(memorySpaceMap) {}
@@ -251,46 +262,23 @@ LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
   if (clientAPI != "vulkan")
     return failure();
 
-  memorySpaceMap = spirv::getDefaultVulkanStorageClassMap();
-
-  LLVM_DEBUG({
-    llvm::dbgs() << "memory space to storage class mapping:\n";
-    if (memorySpaceMap.empty())
-      llvm::dbgs() << "  [empty]\n";
-    for (auto kv : memorySpaceMap)
-      llvm::dbgs() << "  " << kv.first << " -> "
-                   << spirv::stringifyStorageClass(kv.second) << "\n";
-  });
-
   return success();
 }
 
 void MapMemRefStorageClassPass::runOnOperation() {
   MLIRContext *context = &getContext();
-  ModuleOp module = getOperation();
+  Operation *op = getOperation();
 
   auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
-
   spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-  // Use UnrealizedConversionCast as the bridge so that we don't need to pull in
-  // patterns for other dialects.
-  auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
-                              Location loc) {
-    auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
-    return Optional<Value>(cast.getResult(0));
-  };
-  converter.addSourceMaterialization(addUnrealizedCast);
-  converter.addTargetMaterialization(addUnrealizedCast);
-  target->addLegalOp<UnrealizedConversionCastOp>();
 
   RewritePatternSet patterns(context);
   spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
 
-  if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+  if (failed(applyFullConversion(op, *target, std::move(patterns))))
     return signalPassFailure();
 }
 
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createMapMemRefStorageClassPass() {
+std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
   return std::make_unique<MapMemRefStorageClassPass>();
 }

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
index 1e3908618e986..fa0a1723d171d 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
@@ -41,7 +41,7 @@ func.func @type_attribute() {
 
 // -----
 
-// VULKAN-LABEL: func @function_io
+// VULKAN-LABEL: func.func @function_io
 func.func @function_io
   // VULKAN-SAME: (%{{.+}}: memref<f64, #spv.storage_class<Generic>>, %{{.+}}: memref<4xi32, #spv.storage_class<Workgroup>>)
   (%arg0: memref<f64, 1>, %arg1: memref<4xi32, 3>)
@@ -52,7 +52,15 @@ func.func @function_io
 
 // -----
 
-// VULKAN: func @region
+gpu.module @kernel {
+// VULKAN-LABEL: gpu.func @function_io
+// VULKAN-SAME: memref<8xi32, #spv.storage_class<StorageBuffer>>
+gpu.func @function_io(%arg0 : memref<8xi32>) kernel { gpu.return }
+}
+
+// -----
+
+// VULKAN-LABEL: func.func @region
 func.func @region(%cond: i1, %arg0: memref<f32, 1>) {
   scf.if %cond {
     //      VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, #spv.storage_class<Workgroup>>}


        


More information about the Mlir-commits mailing list