[Mlir-commits] [mlir] 1e9321e - [mlir][spirv] NFC: move folders and canonicalizers in a separate file
Lei Zhang
llvmlistbot at llvm.org
Wed Feb 26 09:41:47 PST 2020
Author: Lei Zhang
Date: 2020-02-26T12:41:14-05:00
New Revision: 1e9321e97aba43e41ccd7ab2f1bef41d5bcf65af
URL: https://github.com/llvm/llvm-project/commit/1e9321e97aba43e41ccd7ab2f1bef41d5bcf65af
DIFF: https://github.com/llvm/llvm-project/commit/1e9321e97aba43e41ccd7ab2f1bef41d5bcf65af.diff
LOG: [mlir][spirv] NFC: move folders and canonicalizers in a separate file
This gives us better file organization and faster compilation time
by avoid having a gigantic SPIRVOps.cpp file.
Added:
mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
Modified:
mlir/lib/Dialect/SPIRV/CMakeLists.txt
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
index d0ff25ef68f0..85bb7390b716 100644
--- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
@@ -4,6 +4,7 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
add_llvm_library(MLIRSPIRV
LayoutUtils.cpp
+ SPIRVCanonicalization.cpp
SPIRVDialect.cpp
SPIRVOps.cpp
SPIRVLowering.cpp
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
new file mode 100644
index 000000000000..32090f3d1ec0
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -0,0 +1,367 @@
+//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the folders and canonicalization patterns for SPIR-V ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+
+#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/Functional.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Common utility functions
+//===----------------------------------------------------------------------===//
+
+// Extracts an element from the given `composite` by following the given
+// `indices`. Returns a null Attribute if error happens.
+static Attribute extractCompositeElement(Attribute composite,
+ ArrayRef<unsigned> indices) {
+ // Check that given composite is a constant.
+ if (!composite)
+ return {};
+ // Return composite itself if we reach the end of the index chain.
+ if (indices.empty())
+ return composite;
+
+ if (auto vector = composite.dyn_cast<ElementsAttr>()) {
+ assert(indices.size() == 1 && "must have exactly one index for a vector");
+ return vector.getValue({indices[0]});
+ }
+
+ if (auto array = composite.dyn_cast<ArrayAttr>()) {
+ assert(!indices.empty() && "must have at least one index for an array");
+ return extractCompositeElement(array.getValue()[indices[0]],
+ indices.drop_front());
+ }
+
+ return {};
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'erated canonicalizers
+//===----------------------------------------------------------------------===//
+
+namespace {
+#include "SPIRVCanonicalization.inc"
+}
+
+//===----------------------------------------------------------------------===//
+// spv.AccessChainOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Combines chained `spirv::AccessChainOp` operations into one
+/// `spirv::AccessChainOp` operation.
+struct CombineChainedAccessChain
+ : public OpRewritePattern<spirv::AccessChainOp> {
+ using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
+ PatternRewriter &rewriter) const override {
+ auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
+ accessChainOp.base_ptr().getDefiningOp());
+
+ if (!parentAccessChainOp) {
+ return matchFailure();
+ }
+
+ // Combine indices.
+ SmallVector<Value, 4> indices(parentAccessChainOp.indices());
+ indices.append(accessChainOp.indices().begin(),
+ accessChainOp.indices().end());
+
+ rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
+ accessChainOp, parentAccessChainOp.base_ptr(), indices);
+
+ return matchSuccess();
+ }
+};
+} // end anonymous namespace
+
+void spirv::AccessChainOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<CombineChainedAccessChain>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BitcastOp
+//===----------------------------------------------------------------------===//
+
+void spirv::BitcastOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ConvertChainedBitcast>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.CompositeExtractOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
+ auto indexVector = functional::map(
+ [](Attribute attr) {
+ return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
+ },
+ indices());
+ return extractCompositeElement(operands[0], indexVector);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.constant
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.empty() && "spv.constant has no operands");
+ return value();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.IAdd
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "spv.IAdd expects two operands");
+ // x + 0 = x
+ if (matchPattern(operand2(), m_Zero()))
+ return operand1();
+
+ // According to the SPIR-V spec:
+ //
+ // The resulting value will equal the low-order N bits of the correct result
+ // R, where N is the component width and R is computed with enough precision
+ // to avoid overflow and underflow.
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a + b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spv.IMul
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
+ assert(operands.size() == 2 && "spv.IMul expects two operands");
+ // x * 0 == 0
+ if (matchPattern(operand2(), m_Zero()))
+ return operand2();
+ // x * 1 = x
+ if (matchPattern(operand2(), m_One()))
+ return operand1();
+
+ // According to the SPIR-V spec:
+ //
+ // The resulting value will equal the low-order N bits of the correct result
+ // R, where N is the component width and R is computed with enough precision
+ // to avoid overflow and underflow.
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a * b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ISub
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
+ // x - x = 0
+ if (operand1() == operand2())
+ return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+ // According to the SPIR-V spec:
+ //
+ // The resulting value will equal the low-order N bits of the correct result
+ // R, where N is the component width and R is computed with enough precision
+ // to avoid overflow and underflow.
+ return constFoldBinaryOp<IntegerAttr>(operands,
+ [](APInt a, APInt b) { return a - b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spv.LogicalNot
+//===----------------------------------------------------------------------===//
+
+void spirv::LogicalNotOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
+ ConvertLogicalNotOfLogicalEqual,
+ ConvertLogicalNotOfLogicalNotEqual>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.selection
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Blocks from the given `spv.selection` operation must satisfy the following
+// layout:
+//
+// +-----------------------------------------------+
+// | header block |
+// | spv.BranchConditionalOp %cond, ^case0, ^case1 |
+// +-----------------------------------------------+
+// / \
+// ...
+//
+//
+// +------------------------+ +------------------------+
+// | case #0 | | case #1 |
+// | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
+// | spv.Branch ^merge | | spv.Branch ^merge |
+// +------------------------+ +------------------------+
+//
+//
+// ...
+// \ /
+// v
+// +-------------+
+// | merge block |
+// +-------------+
+//
+struct ConvertSelectionOpToSelect
+ : public OpRewritePattern<spirv::SelectionOp> {
+ using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
+ PatternRewriter &rewriter) const override {
+ auto *op = selectionOp.getOperation();
+ auto &body = op->getRegion(0);
+ // Verifier allows an empty region for `spv.selection`.
+ if (body.empty()) {
+ return matchFailure();
+ }
+
+ // Check that region consists of 4 blocks:
+ // header block, `true` block, `false` block and merge block.
+ if (std::distance(body.begin(), body.end()) != 4) {
+ return matchFailure();
+ }
+
+ auto *headerBlock = selectionOp.getHeaderBlock();
+ if (!onlyContainsBranchConditionalOp(headerBlock)) {
+ return matchFailure();
+ }
+
+ auto brConditionalOp =
+ cast<spirv::BranchConditionalOp>(headerBlock->front());
+
+ auto *trueBlock = brConditionalOp.getSuccessor(0);
+ auto *falseBlock = brConditionalOp.getSuccessor(1);
+ auto *mergeBlock = selectionOp.getMergeBlock();
+
+ if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) {
+ return matchFailure();
+ }
+
+ auto trueValue = getSrcValue(trueBlock);
+ auto falseValue = getSrcValue(falseBlock);
+ auto ptrValue = getDstPtr(trueBlock);
+ auto storeOpAttributes =
+ cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();
+
+ auto selectOp = rewriter.create<spirv::SelectOp>(
+ selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
+ trueValue, falseValue);
+ rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
+ selectOp.getResult(), storeOpAttributes);
+
+ // `spv.selection` is not needed anymore.
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+
+private:
+ // Checks that given blocks follow the following rules:
+ // 1. Each conditional block consists of two operations, the first operation
+ // is a `spv.Store` and the last operation is a `spv.Branch`.
+ // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
+ // 3. A control flow goes into the given merge block from the given
+ // conditional blocks.
+ PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
+ Block *falseBlock,
+ Block *mergeBlock) const;
+
+ bool onlyContainsBranchConditionalOp(Block *block) const {
+ return std::next(block->begin()) == block->end() &&
+ isa<spirv::BranchConditionalOp>(block->front());
+ }
+
+ bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
+ return lhs.getOperation()->getAttrList().getDictionary() ==
+ rhs.getOperation()->getAttrList().getDictionary();
+ }
+
+ // Checks that given type is valid for `spv.SelectOp`.
+ // According to SPIR-V spec:
+ // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
+ // Starting with version 1.4, Result Type can additionally be a composite type
+ // other than a vector."
+ bool isValidType(Type type) const {
+ return spirv::SPIRVDialect::isValidScalarType(type) ||
+ type.isa<VectorType>();
+ }
+
+ // Returns a source value for the given block.
+ Value getSrcValue(Block *block) const {
+ auto storeOp = cast<spirv::StoreOp>(block->front());
+ return storeOp.value();
+ }
+
+ // Returns a destination value for the given block.
+ Value getDstPtr(Block *block) const {
+ auto storeOp = cast<spirv::StoreOp>(block->front());
+ return storeOp.ptr();
+ }
+};
+
+PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
+ Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
+ // Each block must consists of 2 operations.
+ if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
+ (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
+ return matchFailure();
+ }
+
+ auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
+ auto trueBrBranchOp =
+ dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
+ auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
+ auto falseBrBranchOp =
+ dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
+
+ if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
+ !falseBrBranchOp) {
+ return matchFailure();
+ }
+
+ // Check that each `spv.Store` uses the same pointer, memory access
+ // attributes and a valid type of the value.
+ if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
+ !isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
+ !isValidType(trueBrStoreOp.value().getType())) {
+ return matchFailure();
+ }
+
+ if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
+ (falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
+ return matchFailure();
+ }
+
+ return matchSuccess();
+}
+} // end anonymous namespace
+
+void spirv::SelectionOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ConvertSelectionOpToSelect>(context);
+}
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 01197498a704..1dc4dd9aee0a 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -13,17 +13,13 @@
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Analysis/CallInterfaces.h"
-#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/FunctionImplementation.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
-#include "mlir/Support/Functional.h"
#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/bit.h"
@@ -360,31 +356,6 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
-// Extracts an element from the given `composite` by following the given
-// `indices`. Returns a null Attribute if error happens.
-static Attribute extractCompositeElement(Attribute composite,
- ArrayRef<unsigned> indices) {
- // Check that given composite is a constant.
- if (!composite)
- return {};
- // Return composite itself if we reach the end of the index chain.
- if (indices.empty())
- return composite;
-
- if (auto vector = composite.dyn_cast<ElementsAttr>()) {
- assert(indices.size() == 1 && "must have exactly one index for a vector");
- return vector.getValue({indices[0]});
- }
-
- if (auto array = composite.dyn_cast<ArrayAttr>()) {
- assert(!indices.empty() && "must have at least one index for an array");
- return extractCompositeElement(array.getValue()[indices[0]],
- indices.drop_front());
- }
-
- return {};
-}
-
// Get bit width of types.
static unsigned getBitWidth(Type type) {
if (type.isa<spirv::PointerType>()) {
@@ -477,14 +448,6 @@ static inline bool isMergeBlock(Block &block) {
isa<spirv::MergeOp>(block.front());
}
-//===----------------------------------------------------------------------===//
-// TableGen'erated canonicalizers
-//===----------------------------------------------------------------------===//
-
-namespace {
-#include "SPIRVCanonicalization.inc"
-}
-
//===----------------------------------------------------------------------===//
// Common parsers and printers
//===----------------------------------------------------------------------===//
@@ -848,41 +811,6 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
return success();
}
-namespace {
-
-/// Combines chained `spirv::AccessChainOp` operations into one
-/// `spirv::AccessChainOp` operation.
-struct CombineChainedAccessChain
- : public OpRewritePattern<spirv::AccessChainOp> {
- using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
- PatternRewriter &rewriter) const override {
- auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
- accessChainOp.base_ptr().getDefiningOp());
-
- if (!parentAccessChainOp) {
- return matchFailure();
- }
-
- // Combine indices.
- SmallVector<Value, 4> indices(parentAccessChainOp.indices());
- indices.append(accessChainOp.indices().begin(),
- accessChainOp.indices().end());
-
- rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
- accessChainOp, parentAccessChainOp.base_ptr(), indices);
-
- return matchSuccess();
- }
-};
-} // end anonymous namespace
-
-void spirv::AccessChainOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<CombineChainedAccessChain>(context);
-}
-
//===----------------------------------------------------------------------===//
// spv._address_of
//===----------------------------------------------------------------------===//
@@ -1013,11 +941,6 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
return success();
}
-void spirv::BitcastOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ConvertChainedBitcast>(context);
-}
-
//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//
@@ -1230,16 +1153,6 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
return success();
}
-OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
- auto indexVector = functional::map(
- [](Attribute attr) {
- return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
- },
- indices());
- return extractCompositeElement(operands[0], indexVector);
-}
-
//===----------------------------------------------------------------------===//
// spv.CompositeInsert
//===----------------------------------------------------------------------===//
@@ -1390,11 +1303,6 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
return success();
}
-OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.empty() && "spv.constant has no operands");
- return value();
-}
-
bool spirv::ConstantOp::isBuildableWith(Type type) {
// Must be valid SPIR-V type first.
if (!SPIRVDialect::isValidType(type))
@@ -1890,65 +1798,6 @@ static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
return success();
}
-//===----------------------------------------------------------------------===//
-// spv.IAdd
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "spv.IAdd expects two operands");
- // x + 0 = x
- if (matchPattern(operand2(), m_Zero()))
- return operand1();
-
- // According to the SPIR-V spec:
- //
- // The resulting value will equal the low-order N bits of the correct result
- // R, where N is the component width and R is computed with enough precision
- // to avoid overflow and underflow.
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a + b; });
-}
-
-//===----------------------------------------------------------------------===//
-// spv.IMul
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
- assert(operands.size() == 2 && "spv.IMul expects two operands");
- // x * 0 == 0
- if (matchPattern(operand2(), m_Zero()))
- return operand2();
- // x * 1 = x
- if (matchPattern(operand2(), m_One()))
- return operand1();
-
- // According to the SPIR-V spec:
- //
- // The resulting value will equal the low-order N bits of the correct result
- // R, where N is the component width and R is computed with enough precision
- // to avoid overflow and underflow.
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a * b; });
-}
-
-//===----------------------------------------------------------------------===//
-// spv.ISub
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
- // x - x = 0
- if (operand1() == operand2())
- return Builder(getContext()).getIntegerAttr(getType(), 0);
-
- // According to the SPIR-V spec:
- //
- // The resulting value will equal the low-order N bits of the correct result
- // R, where N is the component width and R is computed with enough precision
- // to avoid overflow and underflow.
- return constFoldBinaryOp<IntegerAttr>(operands,
- [](APInt a, APInt b) { return a - b; });
-}
-
//===----------------------------------------------------------------------===//
// spv.LoadOp
//===----------------------------------------------------------------------===//
@@ -2008,17 +1857,6 @@ static LogicalResult verify(spirv::LoadOp loadOp) {
return verifyMemoryAccessAttribute(loadOp);
}
-//===----------------------------------------------------------------------===//
-// spv.LogicalNot
-//===----------------------------------------------------------------------===//
-
-void spirv::LogicalNotOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
- ConvertLogicalNotOfLogicalEqual,
- ConvertLogicalNotOfLogicalNotEqual>(context);
-}
-
//===----------------------------------------------------------------------===//
// spv.loop
//===----------------------------------------------------------------------===//
@@ -2547,170 +2385,6 @@ spirv::SelectionOp spirv::SelectionOp::createIfThen(
return selectionOp;
}
-namespace {
-// Blocks from the given `spv.selection` operation must satisfy the following
-// layout:
-//
-// +-----------------------------------------------+
-// | header block |
-// | spv.BranchConditionalOp %cond, ^case0, ^case1 |
-// +-----------------------------------------------+
-// / \
-// ...
-//
-//
-// +------------------------+ +------------------------+
-// | case #0 | | case #1 |
-// | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
-// | spv.Branch ^merge | | spv.Branch ^merge |
-// +------------------------+ +------------------------+
-//
-//
-// ...
-// \ /
-// v
-// +-------------+
-// | merge block |
-// +-------------+
-//
-struct ConvertSelectionOpToSelect
- : public OpRewritePattern<spirv::SelectionOp> {
- using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
-
- PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
- PatternRewriter &rewriter) const override {
- auto *op = selectionOp.getOperation();
- auto &body = op->getRegion(0);
- // Verifier allows an empty region for `spv.selection`.
- if (body.empty()) {
- return matchFailure();
- }
-
- // Check that region consists of 4 blocks:
- // header block, `true` block, `false` block and merge block.
- if (std::distance(body.begin(), body.end()) != 4) {
- return matchFailure();
- }
-
- auto *headerBlock = selectionOp.getHeaderBlock();
- if (!onlyContainsBranchConditionalOp(headerBlock)) {
- return matchFailure();
- }
-
- auto brConditionalOp =
- cast<spirv::BranchConditionalOp>(headerBlock->front());
-
- auto *trueBlock = brConditionalOp.getSuccessor(0);
- auto *falseBlock = brConditionalOp.getSuccessor(1);
- auto *mergeBlock = selectionOp.getMergeBlock();
-
- if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) {
- return matchFailure();
- }
-
- auto trueValue = getSrcValue(trueBlock);
- auto falseValue = getSrcValue(falseBlock);
- auto ptrValue = getDstPtr(trueBlock);
- auto storeOpAttributes =
- cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();
-
- auto selectOp = rewriter.create<spirv::SelectOp>(
- selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
- trueValue, falseValue);
- rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
- selectOp.getResult(), storeOpAttributes);
-
- // `spv.selection` is not needed anymore.
- rewriter.eraseOp(op);
- return matchSuccess();
- }
-
-private:
- // Checks that given blocks follow the following rules:
- // 1. Each conditional block consists of two operations, the first operation
- // is a `spv.Store` and the last operation is a `spv.Branch`.
- // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
- // 3. A control flow goes into the given merge block from the given
- // conditional blocks.
- PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
- Block *falseBlock,
- Block *mergeBlock) const;
-
- bool onlyContainsBranchConditionalOp(Block *block) const {
- return std::next(block->begin()) == block->end() &&
- isa<spirv::BranchConditionalOp>(block->front());
- }
-
- bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
- return lhs.getOperation()->getAttrList().getDictionary() ==
- rhs.getOperation()->getAttrList().getDictionary();
- }
-
- // Checks that given type is valid for `spv.SelectOp`.
- // According to SPIR-V spec:
- // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
- // Starting with version 1.4, Result Type can additionally be a composite type
- // other than a vector."
- bool isValidType(Type type) const {
- return spirv::SPIRVDialect::isValidScalarType(type) ||
- type.isa<VectorType>();
- }
-
- // Returns a source value for the given block.
- Value getSrcValue(Block *block) const {
- auto storeOp = cast<spirv::StoreOp>(block->front());
- return storeOp.value();
- }
-
- // Returns a destination value for the given block.
- Value getDstPtr(Block *block) const {
- auto storeOp = cast<spirv::StoreOp>(block->front());
- return storeOp.ptr();
- }
-};
-
-PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
- Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
- // Each block must consists of 2 operations.
- if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
- (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
- return matchFailure();
- }
-
- auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
- auto trueBrBranchOp =
- dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
- auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
- auto falseBrBranchOp =
- dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
-
- if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
- !falseBrBranchOp) {
- return matchFailure();
- }
-
- // Check that each `spv.Store` uses the same pointer, memory access
- // attributes and a valid type of the value.
- if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
- !isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
- !isValidType(trueBrStoreOp.value().getType())) {
- return matchFailure();
- }
-
- if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
- (falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
- return matchFailure();
- }
-
- return matchSuccess();
-}
-} // end anonymous namespace
-
-void spirv::SelectionOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
- results.insert<ConvertSelectionOpToSelect>(context);
-}
-
//===----------------------------------------------------------------------===//
// spv.specConstant
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list