[Mlir-commits] [mlir] 17fcf8a - [mlir][spirv][NFC] Clean up SPIR-V canonicalization

Jakub Kuderski llvmlistbot at llvm.org
Fri May 26 16:55:10 PDT 2023


Author: Jakub Kuderski
Date: 2023-05-26T19:54:44-04:00
New Revision: 17fcf8a6bd472e116b39ed993cf73e5c7e28fba7

URL: https://github.com/llvm/llvm-project/commit/17fcf8a6bd472e116b39ed993cf73e5c7e28fba7
DIFF: https://github.com/llvm/llvm-project/commit/17fcf8a6bd472e116b39ed993cf73e5c7e28fba7.diff

LOG: [mlir][spirv][NFC] Clean up SPIR-V canonicalization

Follow best practices. Use llvm helper functions for readability.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D151600

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9219e31f1169..3ada160444dd 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -10,8 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <utility>
 #include <optional>
+#include <utility>
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 
@@ -20,6 +20,8 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
 
 using namespace mlir;
 
@@ -82,14 +84,14 @@ namespace {
 
 /// Combines chained `spirv::AccessChainOp` operations into one
 /// `spirv::AccessChainOp` operation.
-struct CombineChainedAccessChain
-    : public OpRewritePattern<spirv::AccessChainOp> {
-  using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
+struct CombineChainedAccessChain final
+    : OpRewritePattern<spirv::AccessChainOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
                                 PatternRewriter &rewriter) const override {
-    auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
-        accessChainOp.getBasePtr().getDefiningOp());
+    auto parentAccessChainOp =
+        accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
 
     if (!parentAccessChainOp) {
       return failure();
@@ -97,8 +99,7 @@ struct CombineChainedAccessChain
 
     // Combine indices.
     SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
-    indices.append(accessChainOp.getIndices().begin(),
-                   accessChainOp.getIndices().end());
+    llvm::append_range(indices, accessChainOp.getIndices());
 
     rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
         accessChainOp, parentAccessChainOp.getBasePtr(), indices);
@@ -155,17 +156,16 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
     auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
     if (getIndices().size() == 1 &&
         constructOp.getConstituents().size() == type.getNumElements()) {
-      auto i = getIndices().begin()->cast<IntegerAttr>();
-      if (static_cast<size_t>(i.getValue().getSExtValue()) <
-          constructOp.getConstituents().size())
+      auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
+      if (i.getValue().getSExtValue() <
+          static_cast<int64_t>(constructOp.getConstituents().size()))
         return constructOp.getConstituents()[i.getValue().getSExtValue()];
     }
   }
 
-  auto indexVector =
-      llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) {
-        return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
-      }));
+  auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
+    return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
+  });
   return extractCompositeElement(adaptor.getComposite(), indexVector);
 }
 
@@ -289,13 +289,15 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
 
 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
   if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
-    if (*rhs)
+    if (*rhs) {
       // x || true = true
       return adaptor.getOperand2();
+    }
 
-    // x || false = x
-    if (!*rhs)
+    if (!*rhs) {
+      // x || false = x
       return getOperand1();
+    }
   }
 
   return Attribute();
@@ -331,14 +333,13 @@ namespace {
 //                       | merge block |
 //                       +-------------+
 //
-struct ConvertSelectionOpToSelect
-    : public OpRewritePattern<spirv::SelectionOp> {
-  using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
+struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
                                 PatternRewriter &rewriter) const override {
-    auto *op = selectionOp.getOperation();
-    auto &body = op->getRegion(0);
+    Operation *op = selectionOp.getOperation();
+    Region &body = op->getRegion(0);
     // Verifier allows an empty region for `spirv.mlir.selection`.
     if (body.empty()) {
       return failure();
@@ -346,11 +347,11 @@ struct ConvertSelectionOpToSelect
 
     // Check that region consists of 4 blocks:
     // header block, `true` block, `false` block and merge block.
-    if (std::distance(body.begin(), body.end()) != 4) {
+    if (llvm::range_size(body) != 4) {
       return failure();
     }
 
-    auto *headerBlock = selectionOp.getHeaderBlock();
+    Block *headerBlock = selectionOp.getHeaderBlock();
     if (!onlyContainsBranchConditionalOp(headerBlock)) {
       return failure();
     }
@@ -358,16 +359,16 @@ struct ConvertSelectionOpToSelect
     auto brConditionalOp =
         cast<spirv::BranchConditionalOp>(headerBlock->front());
 
-    auto *trueBlock = brConditionalOp.getSuccessor(0);
-    auto *falseBlock = brConditionalOp.getSuccessor(1);
-    auto *mergeBlock = selectionOp.getMergeBlock();
+    Block *trueBlock = brConditionalOp.getSuccessor(0);
+    Block *falseBlock = brConditionalOp.getSuccessor(1);
+    Block *mergeBlock = selectionOp.getMergeBlock();
 
     if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
       return failure();
 
-    auto trueValue = getSrcValue(trueBlock);
-    auto falseValue = getSrcValue(falseBlock);
-    auto ptrValue = getDstPtr(trueBlock);
+    Value trueValue = getSrcValue(trueBlock);
+    Value falseValue = getSrcValue(falseBlock);
+    Value ptrValue = getDstPtr(trueBlock);
     auto storeOpAttributes =
         cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
 
@@ -393,7 +394,7 @@ struct ConvertSelectionOpToSelect
                                          Block *mergeBlock) const;
 
   bool onlyContainsBranchConditionalOp(Block *block) const {
-    return std::next(block->begin()) == block->end() &&
+    return llvm::hasSingleElement(*block) &&
            isa<spirv::BranchConditionalOp>(block->front());
   }
 
@@ -419,8 +420,7 @@ struct ConvertSelectionOpToSelect
 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
     Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
   // Each block must consists of 2 operations.
-  if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
-      (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
+  if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
     return failure();
   }
 


        


More information about the Mlir-commits mailing list