[Mlir-commits] [mlir] 17fcf8a - [mlir][spirv][NFC] Clean up SPIR-V canonicalization
Jakub Kuderski
llvmlistbot at llvm.org
Fri May 26 16:55:10 PDT 2023
Author: Jakub Kuderski
Date: 2023-05-26T19:54:44-04:00
New Revision: 17fcf8a6bd472e116b39ed993cf73e5c7e28fba7
URL: https://github.com/llvm/llvm-project/commit/17fcf8a6bd472e116b39ed993cf73e5c7e28fba7
DIFF: https://github.com/llvm/llvm-project/commit/17fcf8a6bd472e116b39ed993cf73e5c7e28fba7.diff
LOG: [mlir][spirv][NFC] Clean up SPIR-V canonicalization
Follow best practices. Use llvm helper functions for readability.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D151600
Added:
Modified:
mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9219e31f1169..3ada160444dd 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -10,8 +10,8 @@
//
//===----------------------------------------------------------------------===//
-#include <utility>
#include <optional>
+#include <utility>
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
@@ -20,6 +20,8 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVectorExtras.h"
using namespace mlir;
@@ -82,14 +84,14 @@ namespace {
/// Combines chained `spirv::AccessChainOp` operations into one
/// `spirv::AccessChainOp` operation.
-struct CombineChainedAccessChain
- : public OpRewritePattern<spirv::AccessChainOp> {
- using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
+struct CombineChainedAccessChain final
+ : OpRewritePattern<spirv::AccessChainOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
PatternRewriter &rewriter) const override {
- auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
- accessChainOp.getBasePtr().getDefiningOp());
+ auto parentAccessChainOp =
+ accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
if (!parentAccessChainOp) {
return failure();
@@ -97,8 +99,7 @@ struct CombineChainedAccessChain
// Combine indices.
SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
- indices.append(accessChainOp.getIndices().begin(),
- accessChainOp.getIndices().end());
+ llvm::append_range(indices, accessChainOp.getIndices());
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
accessChainOp, parentAccessChainOp.getBasePtr(), indices);
@@ -155,17 +156,16 @@ OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
if (getIndices().size() == 1 &&
constructOp.getConstituents().size() == type.getNumElements()) {
- auto i = getIndices().begin()->cast<IntegerAttr>();
- if (static_cast<size_t>(i.getValue().getSExtValue()) <
- constructOp.getConstituents().size())
+ auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
+ if (i.getValue().getSExtValue() <
+ static_cast<int64_t>(constructOp.getConstituents().size()))
return constructOp.getConstituents()[i.getValue().getSExtValue()];
}
}
- auto indexVector =
- llvm::to_vector<8>(llvm::map_range(getIndices(), [](Attribute attr) {
- return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
- }));
+ auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
+ return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
+ });
return extractCompositeElement(adaptor.getComposite(), indexVector);
}
@@ -289,13 +289,15 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
- if (*rhs)
+ if (*rhs) {
// x || true = true
return adaptor.getOperand2();
+ }
- // x || false = x
- if (!*rhs)
+ if (!*rhs) {
+ // x || false = x
return getOperand1();
+ }
}
return Attribute();
@@ -331,14 +333,13 @@ namespace {
// | merge block |
// +-------------+
//
-struct ConvertSelectionOpToSelect
- : public OpRewritePattern<spirv::SelectionOp> {
- using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
+struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
+ using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
PatternRewriter &rewriter) const override {
- auto *op = selectionOp.getOperation();
- auto &body = op->getRegion(0);
+ Operation *op = selectionOp.getOperation();
+ Region &body = op->getRegion(0);
// Verifier allows an empty region for `spirv.mlir.selection`.
if (body.empty()) {
return failure();
@@ -346,11 +347,11 @@ struct ConvertSelectionOpToSelect
// Check that region consists of 4 blocks:
// header block, `true` block, `false` block and merge block.
- if (std::distance(body.begin(), body.end()) != 4) {
+ if (llvm::range_size(body) != 4) {
return failure();
}
- auto *headerBlock = selectionOp.getHeaderBlock();
+ Block *headerBlock = selectionOp.getHeaderBlock();
if (!onlyContainsBranchConditionalOp(headerBlock)) {
return failure();
}
@@ -358,16 +359,16 @@ struct ConvertSelectionOpToSelect
auto brConditionalOp =
cast<spirv::BranchConditionalOp>(headerBlock->front());
- auto *trueBlock = brConditionalOp.getSuccessor(0);
- auto *falseBlock = brConditionalOp.getSuccessor(1);
- auto *mergeBlock = selectionOp.getMergeBlock();
+ Block *trueBlock = brConditionalOp.getSuccessor(0);
+ Block *falseBlock = brConditionalOp.getSuccessor(1);
+ Block *mergeBlock = selectionOp.getMergeBlock();
if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
return failure();
- auto trueValue = getSrcValue(trueBlock);
- auto falseValue = getSrcValue(falseBlock);
- auto ptrValue = getDstPtr(trueBlock);
+ Value trueValue = getSrcValue(trueBlock);
+ Value falseValue = getSrcValue(falseBlock);
+ Value ptrValue = getDstPtr(trueBlock);
auto storeOpAttributes =
cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
@@ -393,7 +394,7 @@ struct ConvertSelectionOpToSelect
Block *mergeBlock) const;
bool onlyContainsBranchConditionalOp(Block *block) const {
- return std::next(block->begin()) == block->end() &&
+ return llvm::hasSingleElement(*block) &&
isa<spirv::BranchConditionalOp>(block->front());
}
@@ -419,8 +420,7 @@ struct ConvertSelectionOpToSelect
LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
// Each block must consists of 2 operations.
- if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
- (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
+ if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
return failure();
}
More information about the Mlir-commits
mailing list