[Mlir-commits] [mlir] 1e9321e - [mlir][spirv] NFC: move folders and canonicalizers in a separate file

Lei Zhang llvmlistbot at llvm.org
Wed Feb 26 09:41:47 PST 2020


Author: Lei Zhang
Date: 2020-02-26T12:41:14-05:00
New Revision: 1e9321e97aba43e41ccd7ab2f1bef41d5bcf65af

URL: https://github.com/llvm/llvm-project/commit/1e9321e97aba43e41ccd7ab2f1bef41d5bcf65af
DIFF: https://github.com/llvm/llvm-project/commit/1e9321e97aba43e41ccd7ab2f1bef41d5bcf65af.diff

LOG: [mlir][spirv] NFC: move folders and canonicalizers in a separate file

This gives us better file organization and faster compilation time
by avoid having a gigantic SPIRVOps.cpp file.

Added: 
    mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp

Modified: 
    mlir/lib/Dialect/SPIRV/CMakeLists.txt
    mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
index d0ff25ef68f0..85bb7390b716 100644
--- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
@@ -4,6 +4,7 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
 
 add_llvm_library(MLIRSPIRV
   LayoutUtils.cpp
+  SPIRVCanonicalization.cpp
   SPIRVDialect.cpp
   SPIRVOps.cpp
   SPIRVLowering.cpp

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
new file mode 100644
index 000000000000..32090f3d1ec0
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp
@@ -0,0 +1,367 @@
+//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the folders and canonicalization patterns for SPIR-V ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+
+#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/Functional.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Common utility functions
+//===----------------------------------------------------------------------===//
+
+// Extracts an element from the given `composite` by following the given
+// `indices`. Returns a null Attribute if error happens.
+static Attribute extractCompositeElement(Attribute composite,
+                                         ArrayRef<unsigned> indices) {
+  // Check that given composite is a constant.
+  if (!composite)
+    return {};
+  // Return composite itself if we reach the end of the index chain.
+  if (indices.empty())
+    return composite;
+
+  if (auto vector = composite.dyn_cast<ElementsAttr>()) {
+    assert(indices.size() == 1 && "must have exactly one index for a vector");
+    return vector.getValue({indices[0]});
+  }
+
+  if (auto array = composite.dyn_cast<ArrayAttr>()) {
+    assert(!indices.empty() && "must have at least one index for an array");
+    return extractCompositeElement(array.getValue()[indices[0]],
+                                   indices.drop_front());
+  }
+
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'erated canonicalizers
+//===----------------------------------------------------------------------===//
+
+namespace {
+#include "SPIRVCanonicalization.inc"
+}
+
+//===----------------------------------------------------------------------===//
+// spv.AccessChainOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Combines chained `spirv::AccessChainOp` operations into one
+/// `spirv::AccessChainOp` operation.
+struct CombineChainedAccessChain
+    : public OpRewritePattern<spirv::AccessChainOp> {
+  using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
+                                     PatternRewriter &rewriter) const override {
+    auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
+        accessChainOp.base_ptr().getDefiningOp());
+
+    if (!parentAccessChainOp) {
+      return matchFailure();
+    }
+
+    // Combine indices.
+    SmallVector<Value, 4> indices(parentAccessChainOp.indices());
+    indices.append(accessChainOp.indices().begin(),
+                   accessChainOp.indices().end());
+
+    rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
+        accessChainOp, parentAccessChainOp.base_ptr(), indices);
+
+    return matchSuccess();
+  }
+};
+} // end anonymous namespace
+
+void spirv::AccessChainOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<CombineChainedAccessChain>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BitcastOp
+//===----------------------------------------------------------------------===//
+
+void spirv::BitcastOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ConvertChainedBitcast>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.CompositeExtractOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
+  auto indexVector = functional::map(
+      [](Attribute attr) {
+        return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
+      },
+      indices());
+  return extractCompositeElement(operands[0], indexVector);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.constant
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.empty() && "spv.constant has no operands");
+  return value();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.IAdd
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "spv.IAdd expects two operands");
+  // x + 0 = x
+  if (matchPattern(operand2(), m_Zero()))
+    return operand1();
+
+  // According to the SPIR-V spec:
+  //
+  // The resulting value will equal the low-order N bits of the correct result
+  // R, where N is the component width and R is computed with enough precision
+  // to avoid overflow and underflow.
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a + b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spv.IMul
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 2 && "spv.IMul expects two operands");
+  // x * 0 == 0
+  if (matchPattern(operand2(), m_Zero()))
+    return operand2();
+  // x * 1 = x
+  if (matchPattern(operand2(), m_One()))
+    return operand1();
+
+  // According to the SPIR-V spec:
+  //
+  // The resulting value will equal the low-order N bits of the correct result
+  // R, where N is the component width and R is computed with enough precision
+  // to avoid overflow and underflow.
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a * b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spv.ISub
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
+  // x - x = 0
+  if (operand1() == operand2())
+    return Builder(getContext()).getIntegerAttr(getType(), 0);
+
+  // According to the SPIR-V spec:
+  //
+  // The resulting value will equal the low-order N bits of the correct result
+  // R, where N is the component width and R is computed with enough precision
+  // to avoid overflow and underflow.
+  return constFoldBinaryOp<IntegerAttr>(operands,
+                                        [](APInt a, APInt b) { return a - b; });
+}
+
+//===----------------------------------------------------------------------===//
+// spv.LogicalNot
+//===----------------------------------------------------------------------===//
+
+void spirv::LogicalNotOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
+                 ConvertLogicalNotOfLogicalEqual,
+                 ConvertLogicalNotOfLogicalNotEqual>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// spv.selection
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Blocks from the given `spv.selection` operation must satisfy the following
+// layout:
+//
+//       +-----------------------------------------------+
+//       | header block                                  |
+//       | spv.BranchConditionalOp %cond, ^case0, ^case1 |
+//       +-----------------------------------------------+
+//                            /   \
+//                             ...
+//
+//
+//   +------------------------+    +------------------------+
+//   | case #0                |    | case #1                |
+//   | spv.Store %ptr %value0 |    | spv.Store %ptr %value1 |
+//   | spv.Branch ^merge      |    | spv.Branch ^merge      |
+//   +------------------------+    +------------------------+
+//
+//
+//                             ...
+//                            \   /
+//                              v
+//                       +-------------+
+//                       | merge block |
+//                       +-------------+
+//
+struct ConvertSelectionOpToSelect
+    : public OpRewritePattern<spirv::SelectionOp> {
+  using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
+                                     PatternRewriter &rewriter) const override {
+    auto *op = selectionOp.getOperation();
+    auto &body = op->getRegion(0);
+    // Verifier allows an empty region for `spv.selection`.
+    if (body.empty()) {
+      return matchFailure();
+    }
+
+    // Check that region consists of 4 blocks:
+    // header block, `true` block, `false` block and merge block.
+    if (std::distance(body.begin(), body.end()) != 4) {
+      return matchFailure();
+    }
+
+    auto *headerBlock = selectionOp.getHeaderBlock();
+    if (!onlyContainsBranchConditionalOp(headerBlock)) {
+      return matchFailure();
+    }
+
+    auto brConditionalOp =
+        cast<spirv::BranchConditionalOp>(headerBlock->front());
+
+    auto *trueBlock = brConditionalOp.getSuccessor(0);
+    auto *falseBlock = brConditionalOp.getSuccessor(1);
+    auto *mergeBlock = selectionOp.getMergeBlock();
+
+    if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) {
+      return matchFailure();
+    }
+
+    auto trueValue = getSrcValue(trueBlock);
+    auto falseValue = getSrcValue(falseBlock);
+    auto ptrValue = getDstPtr(trueBlock);
+    auto storeOpAttributes =
+        cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();
+
+    auto selectOp = rewriter.create<spirv::SelectOp>(
+        selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
+        trueValue, falseValue);
+    rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
+                                    selectOp.getResult(), storeOpAttributes);
+
+    // `spv.selection` is not needed anymore.
+    rewriter.eraseOp(op);
+    return matchSuccess();
+  }
+
+private:
+  // Checks that given blocks follow the following rules:
+  // 1. Each conditional block consists of two operations, the first operation
+  //    is a `spv.Store` and the last operation is a `spv.Branch`.
+  // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
+  // 3. A control flow goes into the given merge block from the given
+  //    conditional blocks.
+  PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
+                                              Block *falseBlock,
+                                              Block *mergeBlock) const;
+
+  bool onlyContainsBranchConditionalOp(Block *block) const {
+    return std::next(block->begin()) == block->end() &&
+           isa<spirv::BranchConditionalOp>(block->front());
+  }
+
+  bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
+    return lhs.getOperation()->getAttrList().getDictionary() ==
+           rhs.getOperation()->getAttrList().getDictionary();
+  }
+
+  // Checks that given type is valid for `spv.SelectOp`.
+  // According to SPIR-V spec:
+  // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
+  // Starting with version 1.4, Result Type can additionally be a composite type
+  // other than a vector."
+  bool isValidType(Type type) const {
+    return spirv::SPIRVDialect::isValidScalarType(type) ||
+           type.isa<VectorType>();
+  }
+
+  // Returns a source value for the given block.
+  Value getSrcValue(Block *block) const {
+    auto storeOp = cast<spirv::StoreOp>(block->front());
+    return storeOp.value();
+  }
+
+  // Returns a destination value for the given block.
+  Value getDstPtr(Block *block) const {
+    auto storeOp = cast<spirv::StoreOp>(block->front());
+    return storeOp.ptr();
+  }
+};
+
+PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
+    Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
+  // Each block must consists of 2 operations.
+  if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
+      (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
+    return matchFailure();
+  }
+
+  auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
+  auto trueBrBranchOp =
+      dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
+  auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
+  auto falseBrBranchOp =
+      dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
+
+  if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
+      !falseBrBranchOp) {
+    return matchFailure();
+  }
+
+  // Check that each `spv.Store` uses the same pointer, memory access
+  // attributes and a valid type of the value.
+  if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
+      !isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
+      !isValidType(trueBrStoreOp.value().getType())) {
+    return matchFailure();
+  }
+
+  if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
+      (falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
+    return matchFailure();
+  }
+
+  return matchSuccess();
+}
+} // end anonymous namespace
+
+void spirv::SelectionOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<ConvertSelectionOpToSelect>(context);
+}

diff  --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index 01197498a704..1dc4dd9aee0a 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -13,17 +13,13 @@
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 
 #include "mlir/Analysis/CallInterfaces.h"
-#include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/FunctionImplementation.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
-#include "mlir/Support/Functional.h"
 #include "mlir/Support/StringExtras.h"
 #include "llvm/ADT/bit.h"
 
@@ -360,31 +356,6 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 }
 
-// Extracts an element from the given `composite` by following the given
-// `indices`. Returns a null Attribute if error happens.
-static Attribute extractCompositeElement(Attribute composite,
-                                         ArrayRef<unsigned> indices) {
-  // Check that given composite is a constant.
-  if (!composite)
-    return {};
-  // Return composite itself if we reach the end of the index chain.
-  if (indices.empty())
-    return composite;
-
-  if (auto vector = composite.dyn_cast<ElementsAttr>()) {
-    assert(indices.size() == 1 && "must have exactly one index for a vector");
-    return vector.getValue({indices[0]});
-  }
-
-  if (auto array = composite.dyn_cast<ArrayAttr>()) {
-    assert(!indices.empty() && "must have at least one index for an array");
-    return extractCompositeElement(array.getValue()[indices[0]],
-                                   indices.drop_front());
-  }
-
-  return {};
-}
-
 // Get bit width of types.
 static unsigned getBitWidth(Type type) {
   if (type.isa<spirv::PointerType>()) {
@@ -477,14 +448,6 @@ static inline bool isMergeBlock(Block &block) {
          isa<spirv::MergeOp>(block.front());
 }
 
-//===----------------------------------------------------------------------===//
-// TableGen'erated canonicalizers
-//===----------------------------------------------------------------------===//
-
-namespace {
-#include "SPIRVCanonicalization.inc"
-}
-
 //===----------------------------------------------------------------------===//
 // Common parsers and printers
 //===----------------------------------------------------------------------===//
@@ -848,41 +811,6 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
   return success();
 }
 
-namespace {
-
-/// Combines chained `spirv::AccessChainOp` operations into one
-/// `spirv::AccessChainOp` operation.
-struct CombineChainedAccessChain
-    : public OpRewritePattern<spirv::AccessChainOp> {
-  using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
-
-  PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
-                                     PatternRewriter &rewriter) const override {
-    auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
-        accessChainOp.base_ptr().getDefiningOp());
-
-    if (!parentAccessChainOp) {
-      return matchFailure();
-    }
-
-    // Combine indices.
-    SmallVector<Value, 4> indices(parentAccessChainOp.indices());
-    indices.append(accessChainOp.indices().begin(),
-                   accessChainOp.indices().end());
-
-    rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
-        accessChainOp, parentAccessChainOp.base_ptr(), indices);
-
-    return matchSuccess();
-  }
-};
-} // end anonymous namespace
-
-void spirv::AccessChainOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<CombineChainedAccessChain>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // spv._address_of
 //===----------------------------------------------------------------------===//
@@ -1013,11 +941,6 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
   return success();
 }
 
-void spirv::BitcastOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<ConvertChainedBitcast>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // spv.BranchConditionalOp
 //===----------------------------------------------------------------------===//
@@ -1230,16 +1153,6 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
   return success();
 }
 
-OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
-  auto indexVector = functional::map(
-      [](Attribute attr) {
-        return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
-      },
-      indices());
-  return extractCompositeElement(operands[0], indexVector);
-}
-
 //===----------------------------------------------------------------------===//
 // spv.CompositeInsert
 //===----------------------------------------------------------------------===//
@@ -1390,11 +1303,6 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
   return success();
 }
 
-OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.empty() && "spv.constant has no operands");
-  return value();
-}
-
 bool spirv::ConstantOp::isBuildableWith(Type type) {
   // Must be valid SPIR-V type first.
   if (!SPIRVDialect::isValidType(type))
@@ -1890,65 +1798,6 @@ static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spv.IAdd
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "spv.IAdd expects two operands");
-  // x + 0 = x
-  if (matchPattern(operand2(), m_Zero()))
-    return operand1();
-
-  // According to the SPIR-V spec:
-  //
-  // The resulting value will equal the low-order N bits of the correct result
-  // R, where N is the component width and R is computed with enough precision
-  // to avoid overflow and underflow.
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a + b; });
-}
-
-//===----------------------------------------------------------------------===//
-// spv.IMul
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
-  assert(operands.size() == 2 && "spv.IMul expects two operands");
-  // x * 0 == 0
-  if (matchPattern(operand2(), m_Zero()))
-    return operand2();
-  // x * 1 = x
-  if (matchPattern(operand2(), m_One()))
-    return operand1();
-
-  // According to the SPIR-V spec:
-  //
-  // The resulting value will equal the low-order N bits of the correct result
-  // R, where N is the component width and R is computed with enough precision
-  // to avoid overflow and underflow.
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a * b; });
-}
-
-//===----------------------------------------------------------------------===//
-// spv.ISub
-//===----------------------------------------------------------------------===//
-
-OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
-  // x - x = 0
-  if (operand1() == operand2())
-    return Builder(getContext()).getIntegerAttr(getType(), 0);
-
-  // According to the SPIR-V spec:
-  //
-  // The resulting value will equal the low-order N bits of the correct result
-  // R, where N is the component width and R is computed with enough precision
-  // to avoid overflow and underflow.
-  return constFoldBinaryOp<IntegerAttr>(operands,
-                                        [](APInt a, APInt b) { return a - b; });
-}
-
 //===----------------------------------------------------------------------===//
 // spv.LoadOp
 //===----------------------------------------------------------------------===//
@@ -2008,17 +1857,6 @@ static LogicalResult verify(spirv::LoadOp loadOp) {
   return verifyMemoryAccessAttribute(loadOp);
 }
 
-//===----------------------------------------------------------------------===//
-// spv.LogicalNot
-//===----------------------------------------------------------------------===//
-
-void spirv::LogicalNotOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
-                 ConvertLogicalNotOfLogicalEqual,
-                 ConvertLogicalNotOfLogicalNotEqual>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // spv.loop
 //===----------------------------------------------------------------------===//
@@ -2547,170 +2385,6 @@ spirv::SelectionOp spirv::SelectionOp::createIfThen(
   return selectionOp;
 }
 
-namespace {
-// Blocks from the given `spv.selection` operation must satisfy the following
-// layout:
-//
-//       +-----------------------------------------------+
-//       | header block                                  |
-//       | spv.BranchConditionalOp %cond, ^case0, ^case1 |
-//       +-----------------------------------------------+
-//                            /   \
-//                             ...
-//
-//
-//   +------------------------+    +------------------------+
-//   | case #0                |    | case #1                |
-//   | spv.Store %ptr %value0 |    | spv.Store %ptr %value1 |
-//   | spv.Branch ^merge      |    | spv.Branch ^merge      |
-//   +------------------------+    +------------------------+
-//
-//
-//                             ...
-//                            \   /
-//                              v
-//                       +-------------+
-//                       | merge block |
-//                       +-------------+
-//
-struct ConvertSelectionOpToSelect
-    : public OpRewritePattern<spirv::SelectionOp> {
-  using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
-
-  PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
-                                     PatternRewriter &rewriter) const override {
-    auto *op = selectionOp.getOperation();
-    auto &body = op->getRegion(0);
-    // Verifier allows an empty region for `spv.selection`.
-    if (body.empty()) {
-      return matchFailure();
-    }
-
-    // Check that region consists of 4 blocks:
-    // header block, `true` block, `false` block and merge block.
-    if (std::distance(body.begin(), body.end()) != 4) {
-      return matchFailure();
-    }
-
-    auto *headerBlock = selectionOp.getHeaderBlock();
-    if (!onlyContainsBranchConditionalOp(headerBlock)) {
-      return matchFailure();
-    }
-
-    auto brConditionalOp =
-        cast<spirv::BranchConditionalOp>(headerBlock->front());
-
-    auto *trueBlock = brConditionalOp.getSuccessor(0);
-    auto *falseBlock = brConditionalOp.getSuccessor(1);
-    auto *mergeBlock = selectionOp.getMergeBlock();
-
-    if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) {
-      return matchFailure();
-    }
-
-    auto trueValue = getSrcValue(trueBlock);
-    auto falseValue = getSrcValue(falseBlock);
-    auto ptrValue = getDstPtr(trueBlock);
-    auto storeOpAttributes =
-        cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();
-
-    auto selectOp = rewriter.create<spirv::SelectOp>(
-        selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
-        trueValue, falseValue);
-    rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
-                                    selectOp.getResult(), storeOpAttributes);
-
-    // `spv.selection` is not needed anymore.
-    rewriter.eraseOp(op);
-    return matchSuccess();
-  }
-
-private:
-  // Checks that given blocks follow the following rules:
-  // 1. Each conditional block consists of two operations, the first operation
-  //    is a `spv.Store` and the last operation is a `spv.Branch`.
-  // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
-  // 3. A control flow goes into the given merge block from the given
-  //    conditional blocks.
-  PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
-                                              Block *falseBlock,
-                                              Block *mergeBlock) const;
-
-  bool onlyContainsBranchConditionalOp(Block *block) const {
-    return std::next(block->begin()) == block->end() &&
-           isa<spirv::BranchConditionalOp>(block->front());
-  }
-
-  bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
-    return lhs.getOperation()->getAttrList().getDictionary() ==
-           rhs.getOperation()->getAttrList().getDictionary();
-  }
-
-  // Checks that given type is valid for `spv.SelectOp`.
-  // According to SPIR-V spec:
-  // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
-  // Starting with version 1.4, Result Type can additionally be a composite type
-  // other than a vector."
-  bool isValidType(Type type) const {
-    return spirv::SPIRVDialect::isValidScalarType(type) ||
-           type.isa<VectorType>();
-  }
-
-  // Returns a source value for the given block.
-  Value getSrcValue(Block *block) const {
-    auto storeOp = cast<spirv::StoreOp>(block->front());
-    return storeOp.value();
-  }
-
-  // Returns a destination value for the given block.
-  Value getDstPtr(Block *block) const {
-    auto storeOp = cast<spirv::StoreOp>(block->front());
-    return storeOp.ptr();
-  }
-};
-
-PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
-    Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
-  // Each block must consists of 2 operations.
-  if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
-      (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
-    return matchFailure();
-  }
-
-  auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
-  auto trueBrBranchOp =
-      dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
-  auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
-  auto falseBrBranchOp =
-      dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
-
-  if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
-      !falseBrBranchOp) {
-    return matchFailure();
-  }
-
-  // Check that each `spv.Store` uses the same pointer, memory access
-  // attributes and a valid type of the value.
-  if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
-      !isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
-      !isValidType(trueBrStoreOp.value().getType())) {
-    return matchFailure();
-  }
-
-  if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
-      (falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
-    return matchFailure();
-  }
-
-  return matchSuccess();
-}
-} // end anonymous namespace
-
-void spirv::SelectionOp::getCanonicalizationPatterns(
-    OwningRewritePatternList &results, MLIRContext *context) {
-  results.insert<ConvertSelectionOpToSelect>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // spv.specConstant
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list