[Mlir-commits] [mlir] b83d0d4 - [mlir][spirv] Make MemRef memory space mapping pass more flexible
Lei Zhang
llvmlistbot at llvm.org
Tue Aug 9 11:21:56 PDT 2022
Author: Lei Zhang
Date: 2022-08-09T14:21:50-04:00
New Revision: b83d0d46c0b197c36c671493a3339d8d7a8bcd6a
URL: https://github.com/llvm/llvm-project/commit/b83d0d46c0b197c36c671493a3339d8d7a8bcd6a
DIFF: https://github.com/llvm/llvm-project/commit/b83d0d46c0b197c36c671493a3339d8d7a8bcd6a.diff
LOG: [mlir][spirv] Make MemRef memory space mapping pass more flexible
* Avoid restricting the pass to to builtin module ops. The pass
should be able to run on any region ops.
* Avoid hardcoding func FuncOp when handling functions. Instead,
use the function op interface.
* Assigns the default mapping in the constructor. So for cases
where we are using the pass in a pipeline, we still have a
meaningful default.
Along the way, dropped uncessary unrealized conversion casts and
use full conversion. The pass should be able to convert all sorts
of ops; there is really no need to have such bridages.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D131409
Added:
Modified:
mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
index 8b239699d8a5c..9f81d3376c5bd 100644
--- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
+++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h
@@ -21,7 +21,7 @@ class ModuleOp;
/// Creates a pass to map numeric MemRef memory spaces to symbolic SPIR-V
/// storage classes. The mapping is read from the command-line option.
-std::unique_ptr<OperationPass<ModuleOp>> createMapMemRefStorageClassPass();
+std::unique_ptr<OperationPass<>> createMapMemRefStorageClassPass();
/// Creates a pass to convert MemRef ops to SPIR-V ops.
std::unique_ptr<OperationPass<ModuleOp>> createConvertMemRefToSPIRVPass();
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 88010d4d029ce..39ca0debc2ec3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -538,7 +538,7 @@ def ConvertMemRefToLLVM : Pass<"convert-memref-to-llvm", "ModuleOp"> {
// MemRefToSPIRV
//===----------------------------------------------------------------------===//
-def MapMemRefStorageClass : Pass<"map-memref-spirv-storage-class", "ModuleOp"> {
+def MapMemRefStorageClass : Pass<"map-memref-spirv-storage-class"> {
let summary = "Map numeric MemRef memory spaces to SPIR-V storage classes";
let constructor = "mlir::createMapMemRefStorageClassPass()";
let dependentDialects = ["spirv::SPIRVDialect"];
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 0ba6708e027b7..535613714d53d 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -14,10 +14,10 @@
#include "../PassDetail.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
@@ -86,15 +86,16 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
Attribute spaceAttr = memRefType.getMemorySpace();
if (spaceAttr && !spaceAttr.isa<IntegerAttr>()) {
LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType
- << " due to non-IntegerAttr memory space");
+ << " due to non-IntegerAttr memory space\n");
return llvm::None;
}
unsigned space = memRefType.getMemorySpaceAsInt();
auto it = this->memorySpaceMap.find(space);
if (it == this->memorySpaceMap.end()) {
- LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType
- << " due to unable to find memory space in map");
+ LLVM_DEBUG(llvm::dbgs()
+ << "cannot convert " << memRefType
+ << " due to being unable to find memory space in map\n");
return llvm::None;
}
@@ -143,10 +144,9 @@ static bool isLegalAttr(Attribute attr) {
/// Returns true if the given `op` is considered as legal for SPIR-V conversion.
static bool isLegalOp(Operation *op) {
- if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
- FunctionType funcType = funcOp.getFunctionType();
- return llvm::all_of(funcType.getInputs(), isLegalType) &&
- llvm::all_of(funcType.getResults(), isLegalType);
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
+ return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
+ llvm::all_of(funcOp.getResultTypes(), isLegalType);
}
auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
@@ -230,7 +230,18 @@ namespace {
class MapMemRefStorageClassPass final
: public MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
public:
- explicit MapMemRefStorageClassPass() = default;
+ explicit MapMemRefStorageClassPass() {
+ memorySpaceMap = spirv::getDefaultVulkanStorageClassMap();
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "memory space to storage class mapping:\n";
+ if (memorySpaceMap.empty())
+ llvm::dbgs() << " [empty]\n";
+ for (auto kv : memorySpaceMap)
+ llvm::dbgs() << " " << kv.first << " -> "
+ << spirv::stringifyStorageClass(kv.second) << "\n";
+ });
+ }
explicit MapMemRefStorageClassPass(
const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
: memorySpaceMap(memorySpaceMap) {}
@@ -251,46 +262,23 @@ LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
if (clientAPI != "vulkan")
return failure();
- memorySpaceMap = spirv::getDefaultVulkanStorageClassMap();
-
- LLVM_DEBUG({
- llvm::dbgs() << "memory space to storage class mapping:\n";
- if (memorySpaceMap.empty())
- llvm::dbgs() << " [empty]\n";
- for (auto kv : memorySpaceMap)
- llvm::dbgs() << " " << kv.first << " -> "
- << spirv::stringifyStorageClass(kv.second) << "\n";
- });
-
return success();
}
void MapMemRefStorageClassPass::runOnOperation() {
MLIRContext *context = &getContext();
- ModuleOp module = getOperation();
+ Operation *op = getOperation();
auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
-
spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
- // Use UnrealizedConversionCast as the bridge so that we don't need to pull in
- // patterns for other dialects.
- auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs,
- Location loc) {
- auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
- return Optional<Value>(cast.getResult(0));
- };
- converter.addSourceMaterialization(addUnrealizedCast);
- converter.addTargetMaterialization(addUnrealizedCast);
- target->addLegalOp<UnrealizedConversionCastOp>();
RewritePatternSet patterns(context);
spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
- if (failed(applyPartialConversion(module, *target, std::move(patterns))))
+ if (failed(applyFullConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
-std::unique_ptr<OperationPass<ModuleOp>>
-mlir::createMapMemRefStorageClassPass() {
+std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
return std::make_unique<MapMemRefStorageClassPass>();
}
diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
index 1e3908618e986..fa0a1723d171d 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
@@ -41,7 +41,7 @@ func.func @type_attribute() {
// -----
-// VULKAN-LABEL: func @function_io
+// VULKAN-LABEL: func.func @function_io
func.func @function_io
// VULKAN-SAME: (%{{.+}}: memref<f64, #spv.storage_class<Generic>>, %{{.+}}: memref<4xi32, #spv.storage_class<Workgroup>>)
(%arg0: memref<f64, 1>, %arg1: memref<4xi32, 3>)
@@ -52,7 +52,15 @@ func.func @function_io
// -----
-// VULKAN: func @region
+gpu.module @kernel {
+// VULKAN-LABEL: gpu.func @function_io
+// VULKAN-SAME: memref<8xi32, #spv.storage_class<StorageBuffer>>
+gpu.func @function_io(%arg0 : memref<8xi32>) kernel { gpu.return }
+}
+
+// -----
+
+// VULKAN-LABEL: func.func @region
func.func @region(%cond: i1, %arg0: memref<f32, 1>) {
scf.if %cond {
// VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, #spv.storage_class<Workgroup>>}
More information about the Mlir-commits
mailing list