[Mlir-commits] [mlir] 7f6d445 - [mlir][spirv] Clean up map-memref-storage-class pass (#79937)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 29 19:58:31 PST 2024


Author: Jakub Kuderski
Date: 2024-01-29T22:58:27-05:00
New Revision: 7f6d4455231d458a6925ca5c957bb2814c196ea2

URL: https://github.com/llvm/llvm-project/commit/7f6d4455231d458a6925ca5c957bb2814c196ea2
DIFF: https://github.com/llvm/llvm-project/commit/7f6d4455231d458a6925ca5c957bb2814c196ea2.diff

LOG: [mlir][spirv] Clean up map-memref-storage-class pass (#79937)

Clean up the code before making more substantial changes. NFC modulo
extra error checking and physical storage buffer storage class handling.

* Add switch case for physical storage buffer
* Handle type conversion failures
* Inline methods to reduce scrolling
* Other minor cleanups

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index c6ef5be2494ad..fef1055d7b3f2 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -18,10 +18,12 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
 
@@ -54,7 +56,8 @@ using namespace mlir;
   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
   MAP_FN(spirv::StorageClass::Input, 9)                                        \
-  MAP_FN(spirv::StorageClass::Output, 10)
+  MAP_FN(spirv::StorageClass::Output, 10)                                      \
+  MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11)
 
 std::optional<spirv::StorageClass>
 spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
@@ -185,13 +188,10 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
   });
 
   addConversion([this](FunctionType type) {
-    SmallVector<Type> inputs, results;
-    inputs.reserve(type.getNumInputs());
-    results.reserve(type.getNumResults());
-    for (Type input : type.getInputs())
-      inputs.push_back(convertType(input));
-    for (Type result : type.getResults())
-      results.push_back(convertType(result));
+    auto inputs = llvm::map_to_vector(
+        type.getInputs(), [this](Type ty) { return convertType(ty); });
+    auto results = llvm::map_to_vector(
+        type.getResults(), [this](Type ty) { return convertType(ty); });
     return FunctionType::get(type.getContext(), inputs, results);
   });
 }
@@ -250,49 +250,54 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
 namespace {
 /// Converts any op that has operands/results/attributes with numeric MemRef
 /// memory spaces.
-struct MapMemRefStoragePattern final : public ConversionPattern {
+struct MapMemRefStoragePattern final : ConversionPattern {
   MapMemRefStoragePattern(MLIRContext *context, TypeConverter &converter)
       : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override;
-};
-} // namespace
-
-LogicalResult MapMemRefStoragePattern::matchAndRewrite(
-    Operation *op, ArrayRef<Value> operands,
-    ConversionPatternRewriter &rewriter) const {
-  llvm::SmallVector<NamedAttribute, 4> newAttrs;
-  newAttrs.reserve(op->getAttrs().size());
-  for (auto attr : op->getAttrs()) {
-    if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
-      auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
-      newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
-    } else {
-      newAttrs.push_back(attr);
+                  ConversionPatternRewriter &rewriter) const override {
+    llvm::SmallVector<NamedAttribute> newAttrs;
+    newAttrs.reserve(op->getAttrs().size());
+    for (NamedAttribute attr : op->getAttrs()) {
+      if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
+        Type newAttr = getTypeConverter()->convertType(typeAttr.getValue());
+        if (!newAttr) {
+          return rewriter.notifyMatchFailure(
+              op, "type attribute conversion failed");
+        }
+        newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
+      } else {
+        newAttrs.push_back(attr);
+      }
     }
-  }
 
-  llvm::SmallVector<Type, 4> newResults;
-  (void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
-
-  OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
-                       newResults, newAttrs, op->getSuccessors());
+    llvm::SmallVector<Type, 4> newResults;
+    if (failed(
+            getTypeConverter()->convertTypes(op->getResultTypes(), newResults)))
+      return rewriter.notifyMatchFailure(op, "result type conversion failed");
+
+    OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
+                         newResults, newAttrs, op->getSuccessors());
+
+    for (Region &region : op->getRegions()) {
+      Region *newRegion = state.addRegion();
+      rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
+      TypeConverter::SignatureConversion result(newRegion->getNumArguments());
+      if (failed(getTypeConverter()->convertSignatureArgs(
+              newRegion->getArgumentTypes(), result))) {
+        return rewriter.notifyMatchFailure(
+            op, "signature argument type conversion failed");
+      }
+      rewriter.applySignatureConversion(newRegion, result);
+    }
 
-  for (Region &region : op->getRegions()) {
-    Region *newRegion = state.addRegion();
-    rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
-    TypeConverter::SignatureConversion result(newRegion->getNumArguments());
-    (void)getTypeConverter()->convertSignatureArgs(
-        newRegion->getArgumentTypes(), result);
-    rewriter.applySignatureConversion(newRegion, result);
+    Operation *newOp = rewriter.create(state);
+    rewriter.replaceOp(op, newOp->getResults());
+    return success();
   }
-
-  Operation *newOp = rewriter.create(state);
-  rewriter.replaceOp(op, newOp->getResults());
-  return success();
-}
+};
+} // namespace
 
 void spirv::populateMemorySpaceToStorageClassPatterns(
     spirv::MemorySpaceToStorageClassConverter &typeConverter,
@@ -308,58 +313,53 @@ namespace {
 class MapMemRefStorageClassPass final
     : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
 public:
-  explicit MapMemRefStorageClassPass() {
-    memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
-  }
+  MapMemRefStorageClassPass() = default;
+
   explicit MapMemRefStorageClassPass(
       const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
       : memorySpaceMap(memorySpaceMap) {}
 
-  LogicalResult initializeOptions(StringRef options) override;
-
-  void runOnOperation() override;
-
-private:
-  spirv::MemorySpaceToStorageClassMap memorySpaceMap;
-};
-} // namespace
+  LogicalResult initializeOptions(StringRef options) override {
+    if (failed(Pass::initializeOptions(options)))
+      return failure();
 
-LogicalResult MapMemRefStorageClassPass::initializeOptions(StringRef options) {
-  if (failed(Pass::initializeOptions(options)))
-    return failure();
+    if (clientAPI == "opencl")
+      memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+    else if (clientAPI != "vulkan")
+      return failure();
 
-  if (clientAPI == "opencl") {
-    memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+    return success();
   }
 
-  if (clientAPI != "vulkan" && clientAPI != "opencl")
-    return failure();
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    Operation *op = getOperation();
+
+    if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
+      spirv::TargetEnv targetEnv(attr);
+      if (targetEnv.allows(spirv::Capability::Kernel)) {
+        memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
+      } else if (targetEnv.allows(spirv::Capability::Shader)) {
+        memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
+      }
+    }
 
-  return success();
-}
+    std::unique_ptr<ConversionTarget> target =
+        spirv::getMemorySpaceToStorageClassTarget(*context);
+    spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
 
-void MapMemRefStorageClassPass::runOnOperation() {
-  MLIRContext *context = &getContext();
-  Operation *op = getOperation();
+    RewritePatternSet patterns(context);
+    spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
 
-  if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) {
-    spirv::TargetEnv targetEnv(attr);
-    if (targetEnv.allows(spirv::Capability::Kernel)) {
-      memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
-    } else if (targetEnv.allows(spirv::Capability::Shader)) {
-      memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
-    }
+    if (failed(applyFullConversion(op, *target, std::move(patterns))))
+      return signalPassFailure();
   }
 
-  auto target = spirv::getMemorySpaceToStorageClassTarget(*context);
-  spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap);
-
-  RewritePatternSet patterns(context);
-  spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns);
-
-  if (failed(applyFullConversion(op, *target, std::move(patterns))))
-    return signalPassFailure();
-}
+private:
+  spirv::MemorySpaceToStorageClassMap memorySpaceMap =
+      spirv::mapMemorySpaceToVulkanStorageClass;
+};
+} // namespace
 
 std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass() {
   return std::make_unique<MapMemRefStorageClassPass>();


        


More information about the Mlir-commits mailing list