[Mlir-commits] [mlir] 81b4e7d - [mlir][spirv] Extract more ops from the main implementation file. NFC.

Jakub Kuderski llvmlistbot at llvm.org
Thu Jul 20 14:13:56 PDT 2023

Author: varconst
Date: 2023-07-20T17:11:32-04:00
New Revision: 81b4e7d2b0e1d222c76637f600cfcb74b631dfca

URL: https://github.com/llvm/llvm-project/commit/81b4e7d2b0e1d222c76637f600cfcb74b631dfca
DIFF: https://github.com/llvm/llvm-project/commit/81b4e7d2b0e1d222c76637f600cfcb74b631dfca.diff

LOG: [mlir][spirv] Extract more ops from the main implementation file. NFC.

Continue to work outlined in D155747 and split the main SPIR-V ops
implementation file into a few smaller and quicker to compile files.

Move control flow and memory ops to their own implementation files.
Create new `.cpp` files for tablegened code.

After this change, the `SPIRVOps.cpp` is 2k LoC-long and takes a
reasonable amount of time to compile.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D155883




diff  --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index d36e2ad8a73e85..f985bdd33e26ff 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -5,15 +5,19 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
+  ControlFlowOps.cpp
+  MemoryOps.cpp
+  SPIRVOpAvailability.cpp
+  SPIRVOpDefinition.cpp

diff  --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
new file mode 100644
index 00000000000000..e169cb5b65322e
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -0,0 +1,562 @@
+//===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops  -----------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Defines the control flow operations in the SPIR-V dialect.
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+using namespace mlir::spirv::AttrNames;
+namespace mlir::spirv {
+/// Parses Function, Selection and Loop control attributes. If no control is
+/// specified, "None" is used as a default.
+template <typename EnumAttrClass, typename EnumClass>
+static ParseResult
+parseControlAttribute(OpAsmParser &parser, OperationState &state,
+                      StringRef attrName = spirv::attributeName<EnumClass>()) {
+  if (succeeded(parser.parseOptionalKeyword(kControl))) {
+    EnumClass control;
+    if (parser.parseLParen() ||
+        spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
+        parser.parseRParen())
+      return failure();
+    return success();
+  }
+  // Set control to "None" otherwise.
+  Builder builder = parser.getBuilder();
+  state.addAttribute(attrName,
+                     builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
+  return success();
+// spirv.BranchOp
+SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
+  assert(index == 0 && "invalid successor index");
+  return SuccessorOperands(0, getTargetOperandsMutable());
+// spirv.BranchConditionalOp
+SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
+  assert(index < 2 && "invalid successor index");
+  return SuccessorOperands(index == kTrueIndex
+                               ? getTrueTargetOperandsMutable()
+                               : getFalseTargetOperandsMutable());
+ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  auto &builder = parser.getBuilder();
+  OpAsmParser::UnresolvedOperand condInfo;
+  Block *dest;
+  // Parse the condition.
+  Type boolTy = builder.getI1Type();
+  if (parser.parseOperand(condInfo) ||
+      parser.resolveOperand(condInfo, boolTy, result.operands))
+    return failure();
+  // Parse the optional branch weights.
+  if (succeeded(parser.parseOptionalLSquare())) {
+    IntegerAttr trueWeight, falseWeight;
+    NamedAttrList weights;
+    auto i32Type = builder.getIntegerType(32);
+    if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
+        parser.parseComma() ||
+        parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
+        parser.parseRSquare())
+      return failure();
+    result.addAttribute(kBranchWeightAttrName,
+                        builder.getArrayAttr({trueWeight, falseWeight}));
+  }
+  // Parse the true branch.
+  SmallVector<Value, 4> trueOperands;
+  if (parser.parseComma() ||
+      parser.parseSuccessorAndUseList(dest, trueOperands))
+    return failure();
+  result.addSuccessors(dest);
+  result.addOperands(trueOperands);
+  // Parse the false branch.
+  SmallVector<Value, 4> falseOperands;
+  if (parser.parseComma() ||
+      parser.parseSuccessorAndUseList(dest, falseOperands))
+    return failure();
+  result.addSuccessors(dest);
+  result.addOperands(falseOperands);
+  result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
+                      builder.getDenseI32ArrayAttr(
+                          {1, static_cast<int32_t>(trueOperands.size()),
+                           static_cast<int32_t>(falseOperands.size())}));
+  return success();
+void BranchConditionalOp::print(OpAsmPrinter &printer) {
+  printer << ' ' << getCondition();
+  if (auto weights = getBranchWeights()) {
+    printer << " [";
+    llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
+      printer << llvm::cast<IntegerAttr>(a).getInt();
+    });
+    printer << "]";
+  }
+  printer << ", ";
+  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
+  printer << ", ";
+  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
+LogicalResult BranchConditionalOp::verify() {
+  if (auto weights = getBranchWeights()) {
+    if (weights->getValue().size() != 2) {
+      return emitOpError("must have exactly two branch weights");
+    }
+    if (llvm::all_of(*weights, [](Attribute attr) {
+          return llvm::cast<IntegerAttr>(attr).getValue().isZero();
+        }))
+      return emitOpError("branch weights cannot both be zero");
+  }
+  return success();
+// spirv.FunctionCall
+LogicalResult FunctionCallOp::verify() {
+  auto fnName = getCalleeAttr();
+  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
+      SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
+  if (!funcOp) {
+    return emitOpError("callee function '")
+           << fnName.getValue() << "' not found in nearest symbol table";
+  }
+  auto functionType = funcOp.getFunctionType();
+  if (getNumResults() > 1) {
+    return emitOpError(
+               "expected callee function to have 0 or 1 result, but provided ")
+           << getNumResults();
+  }
+  if (functionType.getNumInputs() != getNumOperands()) {
+    return emitOpError("has incorrect number of operands for callee: expected ")
+           << functionType.getNumInputs() << ", but provided "
+           << getNumOperands();
+  }
+  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
+    if (getOperand(i).getType() != functionType.getInput(i)) {
+      return emitOpError("operand type mismatch: expected operand type ")
+             << functionType.getInput(i) << ", but provided "
+             << getOperand(i).getType() << " for operand number " << i;
+    }
+  }
+  if (functionType.getNumResults() != getNumResults()) {
+    return emitOpError(
+               "has incorrect number of results has for callee: expected ")
+           << functionType.getNumResults() << ", but provided "
+           << getNumResults();
+  }
+  if (getNumResults() &&
+      (getResult(0).getType() != functionType.getResult(0))) {
+    return emitOpError("result type mismatch: expected ")
+           << functionType.getResult(0) << ", but provided "
+           << getResult(0).getType();
+  }
+  return success();
+CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
+  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
+void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+  (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+Operation::operand_range FunctionCallOp::getArgOperands() {
+  return getArguments();
+// spirv.mlir.loop
+void LoopOp::build(OpBuilder &builder, OperationState &state) {
+  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
+                                         spirv::LoopControl::None));
+  state.addRegion();
+ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
+  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
+                                                                        result))
+    return failure();
+  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
+void LoopOp::print(OpAsmPrinter &printer) {
+  auto control = getLoopControl();
+  if (control != spirv::LoopControl::None)
+    printer << " control(" << spirv::stringifyLoopControl(control) << ")";
+  printer << ' ';
+  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                      /*printBlockTerminators=*/true);
+/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
+/// given `dstBlock`.
+static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
+  // Check that there is only one op in the `srcBlock`.
+  if (!llvm::hasSingleElement(srcBlock))
+    return false;
+  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
+  return branchOp && branchOp.getSuccessor() == &dstBlock;
+/// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
+static bool isMergeBlock(Block &block) {
+  return !block.empty() && std::next(block.begin()) == block.end() &&
+         isa<spirv::MergeOp>(block.front());
+LogicalResult LoopOp::verifyRegions() {
+  auto *op = getOperation();
+  // We need to verify that the blocks follow the following layout:
+  //
+  //                     +-------------+
+  //                     | entry block |
+  //                     +-------------+
+  //                            |
+  //                            v
+  //                     +-------------+
+  //                     | loop header | <-----+
+  //                     +-------------+       |
+  //                                           |
+  //                           ...             |
+  //                          \ | /            |
+  //                            v              |
+  //                    +---------------+      |
+  //                    | loop continue | -----+
+  //                    +---------------+
+  //
+  //                           ...
+  //                          \ | /
+  //                            v
+  //                     +-------------+
+  //                     | merge block |
+  //                     +-------------+
+  auto &region = op->getRegion(0);
+  // Allow empty region as a degenerated case, which can come from
+  // optimizations.
+  if (region.empty())
+    return success();
+  // The last block is the merge block.
+  Block &merge = region.back();
+  if (!isMergeBlock(merge))
+    return emitOpError("last block must be the merge block with only one "
+                       "'spirv.mlir.merge' op");
+  if (std::next(region.begin()) == region.end())
+    return emitOpError(
+        "must have an entry block branching to the loop header block");
+  // The first block is the entry block.
+  Block &entry = region.front();
+  if (std::next(region.begin(), 2) == region.end())
+    return emitOpError(
+        "must have a loop header block branched from the entry block");
+  // The second block is the loop header block.
+  Block &header = *std::next(region.begin(), 1);
+  if (!hasOneBranchOpTo(entry, header))
+    return emitOpError(
+        "entry block must only have one 'spirv.Branch' op to the second block");
+  if (std::next(region.begin(), 3) == region.end())
+    return emitOpError(
+        "requires a loop continue block branching to the loop header block");
+  // The second to last block is the loop continue block.
+  Block &cont = *std::prev(region.end(), 2);
+  // Make sure that we have a branch from the loop continue block to the loop
+  // header block.
+  if (llvm::none_of(
+          llvm::seq<unsigned>(0, cont.getNumSuccessors()),
+          [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
+    return emitOpError("second to last block must be the loop continue "
+                       "block that branches to the loop header block");
+  // Make sure that no other blocks (except the entry and loop continue block)
+  // branches to the loop header block.
+  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
+                                      std::prev(region.end(), 2))) {
+    for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
+      if (block.getSuccessor(i) == &header) {
+        return emitOpError("can only have the entry and loop continue "
+                           "block branching to the loop header block");
+      }
+    }
+  }
+  return success();
+Block *LoopOp::getEntryBlock() {
+  assert(!getBody().empty() && "op region should not be empty!");
+  return &getBody().front();
+Block *LoopOp::getHeaderBlock() {
+  assert(!getBody().empty() && "op region should not be empty!");
+  // The second block is the loop header block.
+  return &*std::next(getBody().begin());
+Block *LoopOp::getContinueBlock() {
+  assert(!getBody().empty() && "op region should not be empty!");
+  // The second to last block is the loop continue block.
+  return &*std::prev(getBody().end(), 2);
+Block *LoopOp::getMergeBlock() {
+  assert(!getBody().empty() && "op region should not be empty!");
+  // The last block is the loop merge block.
+  return &getBody().back();
+void LoopOp::addEntryAndMergeBlock() {
+  assert(getBody().empty() && "entry and merge block already exist");
+  getBody().push_back(new Block());
+  auto *mergeBlock = new Block();
+  getBody().push_back(mergeBlock);
+  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
+  // Add a spirv.mlir.merge op into the merge block.
+  builder.create<spirv::MergeOp>(getLoc());
+// spirv.mlir.merge
+LogicalResult MergeOp::verify() {
+  auto *parentOp = (*this)->getParentOp();
+  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
+    return emitOpError(
+        "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
+  // TODO: This check should be done in `verifyRegions` of parent op.
+  Block &parentLastBlock = (*this)->getParentRegion()->back();
+  if (getOperation() != parentLastBlock.getTerminator())
+    return emitOpError("can only be used in the last block of "
+                       "'spirv.mlir.selection' or 'spirv.mlir.loop'");
+  return success();
+// spirv.Return
+LogicalResult ReturnOp::verify() {
+  // Verification is performed in spirv.func op.
+  return success();
+// spirv.ReturnValue
+LogicalResult ReturnValueOp::verify() {
+  // Verification is performed in spirv.func op.
+  return success();
+// spirv.Select
+LogicalResult SelectOp::verify() {
+  if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
+    auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
+    if (!resultVectorTy) {
+      return emitOpError("result expected to be of vector type when "
+                         "condition is of vector type");
+    }
+    if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
+      return emitOpError("result should have the same number of elements as "
+                         "the condition when condition is of vector type");
+    }
+  }
+  return success();
+// spirv.mlir.selection
+ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
+  if (parseControlAttribute<spirv::SelectionControlAttr,
+                            spirv::SelectionControl>(parser, result))
+    return failure();
+  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
+void SelectionOp::print(OpAsmPrinter &printer) {
+  auto control = getSelectionControl();
+  if (control != spirv::SelectionControl::None)
+    printer << " control(" << spirv::stringifySelectionControl(control) << ")";
+  printer << ' ';
+  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                      /*printBlockTerminators=*/true);
+LogicalResult SelectionOp::verifyRegions() {
+  auto *op = getOperation();
+  // We need to verify that the blocks follow the following layout:
+  //
+  //                     +--------------+
+  //                     | header block |
+  //                     +--------------+
+  //                          / | \
+  //                           ...
+  //
+  //
+  //         +---------+   +---------+   +---------+
+  //         | case #0 |   | case #1 |   | case #2 |  ...
+  //         +---------+   +---------+   +---------+
+  //
+  //
+  //                           ...
+  //                          \ | /
+  //                            v
+  //                     +-------------+
+  //                     | merge block |
+  //                     +-------------+
+  auto &region = op->getRegion(0);
+  // Allow empty region as a degenerated case, which can come from
+  // optimizations.
+  if (region.empty())
+    return success();
+  // The last block is the merge block.
+  if (!isMergeBlock(region.back()))
+    return emitOpError("last block must be the merge block with only one "
+                       "'spirv.mlir.merge' op");
+  if (std::next(region.begin()) == region.end())
+    return emitOpError("must have a selection header block");
+  return success();
+Block *SelectionOp::getHeaderBlock() {
+  assert(!getBody().empty() && "op region should not be empty!");
+  // The first block is the loop header block.
+  return &getBody().front();
+Block *SelectionOp::getMergeBlock() {
+  assert(!getBody().empty() && "op region should not be empty!");
+  // The last block is the loop merge block.
+  return &getBody().back();
+void SelectionOp::addMergeBlock() {
+  assert(getBody().empty() && "entry and merge block already exist");
+  auto *mergeBlock = new Block();
+  getBody().push_back(mergeBlock);
+  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
+  // Add a spirv.mlir.merge op into the merge block.
+  builder.create<spirv::MergeOp>(getLoc());
+SelectionOp::createIfThen(Location loc, Value condition,
+                          function_ref<void(OpBuilder &builder)> thenBody,
+                          OpBuilder &builder) {
+  auto selectionOp =
+      builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+  selectionOp.addMergeBlock();
+  Block *mergeBlock = selectionOp.getMergeBlock();
+  Block *thenBlock = nullptr;
+  // Build the "then" block.
+  {
+    OpBuilder::InsertionGuard guard(builder);
+    thenBlock = builder.createBlock(mergeBlock);
+    thenBody(builder);
+    builder.create<spirv::BranchOp>(loc, mergeBlock);
+  }
+  // Build the header block.
+  {
+    OpBuilder::InsertionGuard guard(builder);
+    builder.createBlock(thenBlock);
+    builder.create<spirv::BranchConditionalOp>(
+        loc, condition, thenBlock,
+        /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
+        /*falseArguments=*/ArrayRef<Value>());
+  }
+  return selectionOp;
+// spirv.Unreachable
+LogicalResult spirv::UnreachableOp::verify() {
+  auto *block = (*this)->getBlock();
+  // Fast track: if this is in entry block, its invalid. Otherwise, if no
+  // predecessors, it's valid.
+  if (block->isEntryBlock())
+    return emitOpError("cannot be used in reachable block");
+  if (block->hasNoPredecessors())
+    return success();
+  // TODO: further verification needs to analyze reachability from
+  // the entry block.
+  return success();
+} // namespace mlir::spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
new file mode 100644
index 00000000000000..6ee162a761e21e
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -0,0 +1,751 @@
+//===- MemoryOps.cpp - MLIR SPIR-V Memory Ops  ----------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Defines the memory operations in the SPIR-V dialect.
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+#include "llvm/ADT/StringExtras.h"
+using namespace mlir::spirv::AttrNames;
+namespace mlir::spirv {
+// TODO Make sure to merge this and the previous function into one template
+// parameterized by memory access attribute name and alignment. Doing so now
+// results in VS2017 in producing an internal error (at the call site) that's
+// not detailed enough to understand what is happening.
+static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
+                                                     OperationState &state) {
+  // Parse an optional list of attributes staring with '['
+  if (parser.parseOptionalLSquare()) {
+    // Nothing to do
+    return success();
+  }
+  spirv::MemoryAccess memoryAccessAttr;
+  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
+          memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
+    return failure();
+  if (spirv::bitEnumContainsAll(memoryAccessAttr,
+                                spirv::MemoryAccess::Aligned)) {
+    // Parse integer attribute for alignment.
+    Attribute alignmentAttr;
+    Type i32Type = parser.getBuilder().getIntegerType(32);
+    if (parser.parseComma() ||
+        parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+                              state.attributes)) {
+      return failure();
+    }
+  }
+  return parser.parseRSquare();
+// TODO Make sure to merge this and the previous function into one template
+// parameterized by memory access attribute name and alignment. Doing so now
+// results in VS2017 in producing an internal error (at the call site) that's
+// not detailed enough to understand what is happening.
+template <typename MemoryOpTy>
+static void printSourceMemoryAccessAttribute(
+    MemoryOpTy memoryOp, OpAsmPrinter &printer,
+    SmallVectorImpl<StringRef> &elidedAttrs,
+    std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
+    std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
+  printer << ", ";
+  // Print optional memory access attribute.
+  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
+                                              : memoryOp.getMemoryAccess())) {
+    elidedAttrs.push_back(kSourceMemoryAccessAttrName);
+    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
+    if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
+      // Print integer alignment attribute.
+      if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
+                                               : memoryOp.getAlignment())) {
+        elidedAttrs.push_back(kSourceAlignmentAttrName);
+        printer << ", " << *alignment;
+      }
+    }
+    printer << "]";
+  }
+  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
+template <typename MemoryOpTy>
+static void printMemoryAccessAttribute(
+    MemoryOpTy memoryOp, OpAsmPrinter &printer,
+    SmallVectorImpl<StringRef> &elidedAttrs,
+    std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
+    std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
+  // Print optional memory access attribute.
+  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
+                                              : memoryOp.getMemoryAccess())) {
+    elidedAttrs.push_back(kMemoryAccessAttrName);
+    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
+    if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
+      // Print integer alignment attribute.
+      if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
+                                               : memoryOp.getAlignment())) {
+        elidedAttrs.push_back(kAlignmentAttrName);
+        printer << ", " << *alignment;
+      }
+    }
+    printer << "]";
+  }
+  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
+template <typename LoadStoreOpTy>
+static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
+                                                   Value val) {
+  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
+  // type of the pointer and the type of the value are the same
+  //
+  // TODO: Check that the value type satisfies restrictions of
+  // SPIR-V OpLoad/OpStore operations
+  if (val.getType() !=
+      llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
+    return op.emitOpError("mismatch in result type and pointer type");
+  }
+  return success();
+template <typename MemoryOpTy>
+static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
+  // ODS checks for attributes values. Just need to verify that if the
+  // memory-access attribute is Aligned, then the alignment attribute must be
+  // present.
+  auto *op = memoryOp.getOperation();
+  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
+  if (!memAccessAttr) {
+    // Alignment attribute shouldn't be present if memory access attribute is
+    // not present.
+    if (op->getAttr(kAlignmentAttrName)) {
+      return memoryOp.emitOpError(
+          "invalid alignment specification without aligned memory access "
+          "specification");
+    }
+    return success();
+  }
+  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
+  if (!memAccess) {
+    return memoryOp.emitOpError("invalid memory access specifier: ")
+           << memAccessAttr;
+  }
+  if (spirv::bitEnumContainsAll(memAccess.getValue(),
+                                spirv::MemoryAccess::Aligned)) {
+    if (!op->getAttr(kAlignmentAttrName)) {
+      return memoryOp.emitOpError("missing alignment value");
+    }
+  } else {
+    if (op->getAttr(kAlignmentAttrName)) {
+      return memoryOp.emitOpError(
+          "invalid alignment specification with non-aligned memory access "
+          "specification");
+    }
+  }
+  return success();
+// TODO Make sure to merge this and the previous function into one template
+// parameterized by memory access attribute name and alignment. Doing so now
+// results in VS2017 in producing an internal error (at the call site) that's
+// not detailed enough to understand what is happening.
+template <typename MemoryOpTy>
+static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
+  // ODS checks for attributes values. Just need to verify that if the
+  // memory-access attribute is Aligned, then the alignment attribute must be
+  // present.
+  auto *op = memoryOp.getOperation();
+  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
+  if (!memAccessAttr) {
+    // Alignment attribute shouldn't be present if memory access attribute is
+    // not present.
+    if (op->getAttr(kSourceAlignmentAttrName)) {
+      return memoryOp.emitOpError(
+          "invalid alignment specification without aligned memory access "
+          "specification");
+    }
+    return success();
+  }
+  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
+  if (!memAccess) {
+    return memoryOp.emitOpError("invalid memory access specifier: ")
+           << memAccess;
+  }
+  if (spirv::bitEnumContainsAll(memAccess.getValue(),
+                                spirv::MemoryAccess::Aligned)) {
+    if (!op->getAttr(kSourceAlignmentAttrName)) {
+      return memoryOp.emitOpError("missing alignment value");
+    }
+  } else {
+    if (op->getAttr(kSourceAlignmentAttrName)) {
+      return memoryOp.emitOpError(
+          "invalid alignment specification with non-aligned memory access "
+          "specification");
+    }
+  }
+  return success();
+// spirv.AccessChainOp
+static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
+  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+  if (!ptrType) {
+    emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
+                       "to composite type, but provided ")
+        << type;
+    return nullptr;
+  }
+  auto resultType = ptrType.getPointeeType();
+  auto resultStorageClass = ptrType.getStorageClass();
+  int32_t index = 0;
+  for (auto indexSSA : indices) {
+    auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
+    if (!cType) {
+      emitError(
+          baseLoc,
+          "'spirv.AccessChain' op cannot extract from non-composite type ")
+          << resultType << " with index " << index;
+      return nullptr;
+    }
+    index = 0;
+    if (llvm::isa<spirv::StructType>(resultType)) {
+      Operation *op = indexSSA.getDefiningOp();
+      if (!op) {
+        emitError(baseLoc, "'spirv.AccessChain' op index must be an "
+                           "integer spirv.Constant to access "
+                           "element of spirv.struct");
+        return nullptr;
+      }
+      // TODO: this should be relaxed to allow
+      // integer literals of other bitwidths.
+      if (failed(spirv::extractValueFromConstOp(op, index))) {
+        emitError(
+            baseLoc,
+            "'spirv.AccessChain' index must be an integer spirv.Constant to "
+            "access element of spirv.struct, but provided ")
+            << op->getName();
+        return nullptr;
+      }
+      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
+        emitError(baseLoc, "'spirv.AccessChain' op index ")
+            << index << " out of bounds for " << resultType;
+        return nullptr;
+      }
+    }
+    resultType = cType.getElementType(index);
+  }
+  return spirv::PointerType::get(resultType, resultStorageClass);
+void AccessChainOp::build(OpBuilder &builder, OperationState &state,
+                          Value basePtr, ValueRange indices) {
+  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+  assert(type && "Unable to deduce return type based on basePtr and indices");
+  build(builder, state, type, basePtr, indices);
+ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand ptrInfo;
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
+  Type type;
+  auto loc = parser.getCurrentLocation();
+  SmallVector<Type, 4> indicesTypes;
+  if (parser.parseOperand(ptrInfo) ||
+      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
+      parser.parseColonType(type) ||
+      parser.resolveOperand(ptrInfo, type, result.operands)) {
+    return failure();
+  }
+  // Check that the provided indices list is not empty before parsing their
+  // type list.
+  if (indicesInfo.empty()) {
+    return mlir::emitError(result.location,
+                           "'spirv.AccessChain' op expected at "
+                           "least one index ");
+  }
+  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
+    return failure();
+  // Check that the indices types list is not empty and that it has a one-to-one
+  // mapping to the provided indices.
+  if (indicesTypes.size() != indicesInfo.size()) {
+    return mlir::emitError(
+        result.location, "'spirv.AccessChain' op indices types' count must be "
+                         "equal to indices info count");
+  }
+  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
+    return failure();
+  auto resultType = getElementPtrType(
+      type, llvm::ArrayRef(result.operands).drop_front(), result.location);
+  if (!resultType) {
+    return failure();
+  }
+  result.addTypes(resultType);
+  return success();
+template <typename Op>
+static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
+  printer << ' ' << op.getBasePtr() << '[' << indices
+          << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
+void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
+  printAccessChain(*this, getIndices(), printer);
+template <typename Op>
+static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
+  auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
+                                      indices, accessChainOp.getLoc());
+  if (!resultType)
+    return failure();
+  auto providedResultType =
+      llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
+  if (!providedResultType)
+    return accessChainOp.emitOpError(
+               "result type must be a pointer, but provided")
+           << providedResultType;
+  if (resultType != providedResultType)
+    return accessChainOp.emitOpError("invalid result type: expected ")
+           << resultType << ", but provided " << providedResultType;
+  return success();
+LogicalResult AccessChainOp::verify() {
+  return verifyAccessChain(*this, getIndices());
+// spirv.LoadOp
+void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
+                   MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
+  auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
+  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
+        alignment);
+ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Parse the storage class specification
+  spirv::StorageClass storageClass;
+  OpAsmParser::UnresolvedOperand ptrInfo;
+  Type elementType;
+  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
+      parseMemoryAccessAttributes(parser, result) ||
+      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
+      parser.parseType(elementType)) {
+    return failure();
+  }
+  auto ptrType = spirv::PointerType::get(elementType, storageClass);
+  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
+    return failure();
+  }
+  result.addTypes(elementType);
+  return success();
+void LoadOp::print(OpAsmPrinter &printer) {
+  SmallVector<StringRef, 4> elidedAttrs;
+  StringRef sc = stringifyStorageClass(
+      llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
+  printer << " \"" << sc << "\" " << getPtr();
+  printMemoryAccessAttribute(*this, printer, elidedAttrs);
+  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+  printer << " : " << getType();
+LogicalResult LoadOp::verify() {
+  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
+  // type with fixed size; i.e., it cannot be, nor include, any
+  // OpTypeRuntimeArray types."
+  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
+    return failure();
+  }
+  return verifyMemoryAccessAttribute(*this);
+// spirv.StoreOp
+ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Parse the storage class specification
+  spirv::StorageClass storageClass;
+  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
+  auto loc = parser.getCurrentLocation();
+  Type elementType;
+  if (parseEnumStrAttr(storageClass, parser) ||
+      parser.parseOperandList(operandInfo, 2) ||
+      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
+      parser.parseType(elementType)) {
+    return failure();
+  }
+  auto ptrType = spirv::PointerType::get(elementType, storageClass);
+  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
+                             result.operands)) {
+    return failure();
+  }
+  return success();
+void StoreOp::print(OpAsmPrinter &printer) {
+  SmallVector<StringRef, 4> elidedAttrs;
+  StringRef sc = stringifyStorageClass(
+      llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
+  printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
+  printMemoryAccessAttribute(*this, printer, elidedAttrs);
+  printer << " : " << getValue().getType();
+  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+LogicalResult StoreOp::verify() {
+  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
+  // OpTypePointer whose Type operand is the same as the type of Object."
+  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
+    return failure();
+  return verifyMemoryAccessAttribute(*this);
+// spirv.CopyMemory
+void CopyMemoryOp::print(OpAsmPrinter &printer) {
+  printer << ' ';
+  StringRef targetStorageClass = stringifyStorageClass(
+      llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
+  printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
+  StringRef sourceStorageClass = stringifyStorageClass(
+      llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
+  printer << " \"" << sourceStorageClass << "\" " << getSource();
+  SmallVector<StringRef, 4> elidedAttrs;
+  printMemoryAccessAttribute(*this, printer, elidedAttrs);
+  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
+                                   getSourceMemoryAccess(),
+                                   getSourceAlignment());
+  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+  Type pointeeType =
+      llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
+  printer << " : " << pointeeType;
+ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
+  spirv::StorageClass targetStorageClass;
+  OpAsmParser::UnresolvedOperand targetPtrInfo;
+  spirv::StorageClass sourceStorageClass;
+  OpAsmParser::UnresolvedOperand sourcePtrInfo;
+  Type elementType;
+  if (parseEnumStrAttr(targetStorageClass, parser) ||
+      parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
+      parseEnumStrAttr(sourceStorageClass, parser) ||
+      parser.parseOperand(sourcePtrInfo) ||
+      parseMemoryAccessAttributes(parser, result)) {
+    return failure();
+  }
+  if (!parser.parseOptionalComma()) {
+    // Parse 2nd memory access attributes.
+    if (parseSourceMemoryAccessAttributes(parser, result)) {
+      return failure();
+    }
+  }
+  if (parser.parseColon() || parser.parseType(elementType))
+    return failure();
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
+  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
+  if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
+      parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
+    return failure();
+  }
+  return success();
+LogicalResult CopyMemoryOp::verify() {
+  Type targetType =
+      llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
+  Type sourceType =
+      llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
+  if (targetType != sourceType)
+    return emitOpError("both operands must be pointers to the same type");
+  if (failed(verifyMemoryAccessAttribute(*this)))
+    return failure();
+  // TODO - According to the spec:
+  //
+  // If two masks are present, the first applies to Target and cannot include
+  // MakePointerVisible, and the second applies to Source and cannot include
+  // MakePointerAvailable.
+  //
+  // Add such verification here.
+  return verifySourceMemoryAccessAttribute(*this);
+static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
+                                             OpAsmParser &parser,
+                                             OperationState &state) {
+  OpAsmParser::UnresolvedOperand ptrInfo;
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
+  Type type;
+  auto loc = parser.getCurrentLocation();
+  SmallVector<Type, 4> indicesTypes;
+  if (parser.parseOperand(ptrInfo) ||
+      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
+      parser.parseColonType(type) ||
+      parser.resolveOperand(ptrInfo, type, state.operands))
+    return failure();
+  // Check that the provided indices list is not empty before parsing their
+  // type list.
+  if (indicesInfo.empty())
+    return emitError(state.location) << opName << " expected element";
+  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
+    return failure();
+  // Check that the indices types list is not empty and that it has a one-to-one
+  // mapping to the provided indices.
+  if (indicesTypes.size() != indicesInfo.size())
+    return emitError(state.location)
+           << opName
+           << " indices types' count must be equal to indices info count";
+  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
+    return failure();
+  auto resultType = getElementPtrType(
+      type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
+  if (!resultType)
+    return failure();
+  state.addTypes(resultType);
+  return success();
+template <typename Op>
+static auto concatElemAndIndices(Op op) {
+  SmallVector<Value> ret(op.getIndices().size() + 1);
+  ret[0] = op.getElement();
+  llvm::copy(op.getIndices(), ret.begin() + 1);
+  return ret;
+// spirv.InBoundsPtrAccessChainOp
+void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
+                                     Value basePtr, Value element,
+                                     ValueRange indices) {
+  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+  assert(type && "Unable to deduce return type based on basePtr and indices");
+  build(builder, state, type, basePtr, element, indices);
+ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
+                                            OperationState &result) {
+  return parsePtrAccessChainOpImpl(
+      spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
+void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
+  printAccessChain(*this, concatElemAndIndices(*this), printer);
+LogicalResult InBoundsPtrAccessChainOp::verify() {
+  return verifyAccessChain(*this, getIndices());
+// spirv.PtrAccessChainOp
+void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
+                             Value basePtr, Value element, ValueRange indices) {
+  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+  assert(type && "Unable to deduce return type based on basePtr and indices");
+  build(builder, state, type, basePtr, element, indices);
+ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
+                                    OperationState &result) {
+  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
+                                   parser, result);
+void PtrAccessChainOp::print(OpAsmPrinter &printer) {
+  printAccessChain(*this, concatElemAndIndices(*this), printer);
+LogicalResult PtrAccessChainOp::verify() {
+  return verifyAccessChain(*this, getIndices());
+// spirv.Variable
+ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Parse optional initializer
+  std::optional<OpAsmParser::UnresolvedOperand> initInfo;
+  if (succeeded(parser.parseOptionalKeyword("init"))) {
+    initInfo = OpAsmParser::UnresolvedOperand();
+    if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
+        parser.parseRParen())
+      return failure();
+  }
+  if (parseVariableDecorations(parser, result)) {
+    return failure();
+  }
+  // Parse result pointer type
+  Type type;
+  if (parser.parseColon())
+    return failure();
+  auto loc = parser.getCurrentLocation();
+  if (parser.parseType(type))
+    return failure();
+  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+  if (!ptrType)
+    return parser.emitError(loc, "expected spirv.ptr type");
+  result.addTypes(ptrType);
+  // Resolve the initializer operand
+  if (initInfo) {
+    if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
+                              result.operands))
+      return failure();
+  }
+  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
+      ptrType.getStorageClass());
+  result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
+  return success();
+void VariableOp::print(OpAsmPrinter &printer) {
+  SmallVector<StringRef, 4> elidedAttrs{
+      spirv::attributeName<spirv::StorageClass>()};
+  // Print optional initializer
+  if (getNumOperands() != 0)
+    printer << " init(" << getInitializer() << ")";
+  printVariableDecorations(*this, printer, elidedAttrs);
+  printer << " : " << getType();
+LogicalResult VariableOp::verify() {
+  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
+  // object. It cannot be Generic. It must be the same as the Storage Class
+  // operand of the Result Type."
+  if (getStorageClass() != spirv::StorageClass::Function) {
+    return emitOpError(
+        "can only be used to model function-level variables. Use "
+        "spirv.GlobalVariable for module-level variables.");
+  }
+  auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
+  if (getStorageClass() != pointerType.getStorageClass())
+    return emitOpError(
+        "storage class must match result pointer's storage class");
+  if (getNumOperands() != 0) {
+    // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
+    // a global (module scope) OpVariable instruction".
+    auto *initOp = getOperand(0).getDefiningOp();
+    if (!initOp || !isa<spirv::ConstantOp,    // for normal constant
+                        spirv::ReferenceOfOp, // for spec constant
+                        spirv::AddressOfOp>(initOp))
+      return emitOpError("initializer must be the result of a "
+                         "constant or spirv.GlobalVariable op");
+  }
+  // TODO: generate these strings using ODS.
+  auto *op = getOperation();
+  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
+      stringifyDecoration(spirv::Decoration::DescriptorSet));
+  auto bindingName = llvm::convertToSnakeFromCamelCase(
+      stringifyDecoration(spirv::Decoration::Binding));
+  auto builtInName = llvm::convertToSnakeFromCamelCase(
+      stringifyDecoration(spirv::Decoration::BuiltIn));
+  for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
+    if (op->getAttr(attr))
+      return emitOpError("cannot have '")
+             << attr << "' attribute (only allowed in spirv.GlobalVariable)";
+  }
+  return success();
+} // namespace mlir::spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp
new file mode 100644
index 00000000000000..93c27c8701cabb
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp
@@ -0,0 +1,22 @@
+//===- SPIRVOpAvailability.cpp - MLIR SPIR-V Availability Implementation --===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Defines the SPIR-V operation availability in the SPIR-V dialect.
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+// TableGen'erated operation interfaces for querying versions, extensions, and
+// capabilities.
+#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
+namespace mlir::spirv {
+// TableGen'erated operation availability interface implementations.
+#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
+} // namespace mlir::spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
new file mode 100644
index 00000000000000..d8dfe164458e29
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -0,0 +1,76 @@
+//===- SPIRVOpDefinition.cpp - MLIR SPIR-V Op Definition Implementation ---===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Defines the TableGen'erated SPIR-V op implementation in the SPIR-V dialect.
+// These are placed in a separate file to reduce the total amount of code in
+// SPIRVOps.cpp and make that file faster to recompile.
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "SPIRVParsingUtils.h"
+#include "mlir/IR/TypeUtilities.h"
+namespace mlir::spirv {
+/// Returns true if the given op is a function-like op or nested in a
+/// function-like op without a module-like op in the middle.
+static bool isNestedInFunctionOpInterface(Operation *op) {
+  if (!op)
+    return false;
+  if (op->hasTrait<OpTrait::SymbolTable>())
+    return false;
+  if (isa<FunctionOpInterface>(op))
+    return true;
+  return isNestedInFunctionOpInterface(op->getParentOp());
+/// Returns true if the given op is an module-like op that maintains a symbol
+/// table.
+static bool isDirectInModuleLikeOp(Operation *op) {
+  return op && op->hasTrait<OpTrait::SymbolTable>();
+/// Result of a logical op must be a scalar or vector of boolean type.
+static Type getUnaryOpResultType(Type operandType) {
+  Builder builder(operandType.getContext());
+  Type resultType = builder.getIntegerType(1);
+  if (auto vecType = llvm::dyn_cast<VectorType>(operandType))
+    return VectorType::get(vecType.getNumElements(), resultType);
+  return resultType;
+static ParseResult parseImageOperands(OpAsmParser &parser,
+                                      spirv::ImageOperandsAttr &attr) {
+  // Expect image operands
+  if (parser.parseOptionalLSquare())
+    return success();
+  spirv::ImageOperands imageOperands;
+  if (parseEnumStrAttr(imageOperands, parser))
+    return failure();
+  attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
+  return parser.parseRSquare();
+static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
+                               spirv::ImageOperandsAttr attr) {
+  if (attr) {
+    auto strImageOperands = stringifyImageOperands(attr.getValue());
+    printer << "[\"" << strImageOperands << "\"]";
+  }
+} // namespace mlir::spirv
+// TablenGen'erated operation definitions.
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
index fff06bb5a7f207..e60fd53737d52d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
@@ -29,6 +29,9 @@ inline unsigned getBitWidth(Type type) {
   llvm_unreachable("unhandled bit width computation for type");
+void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
+                              SmallVectorImpl<StringRef> &elidedAttrs);
 LogicalResult extractValueFromConstOp(Operation *op, int32_t &value);
 LogicalResult verifyMemorySemantics(Operation *op,

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 47ffdc1cdad18f..6d7c8b9878f017 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -28,7 +28,6 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/CallInterfaces.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -45,6 +44,77 @@ using namespace mlir::spirv::AttrNames;
 // Common utility functions
+LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
+  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
+  if (!constOp) {
+    return failure();
+  }
+  auto valueAttr = constOp.getValue();
+  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
+  if (!integerValueAttr) {
+    return failure();
+  }
+  if (integerValueAttr.getType().isSignlessInteger())
+    value = integerValueAttr.getInt();
+  else
+    value = integerValueAttr.getSInt();
+  return success();
+spirv::verifyMemorySemantics(Operation *op,
+                             spirv::MemorySemantics memorySemantics) {
+  // According to the SPIR-V specification:
+  // "Despite being a mask and allowing multiple bits to be combined, it is
+  // invalid for more than one of these four bits to be set: Acquire, Release,
+  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
+  // Release semantics is done by setting the AcquireRelease bit, not by setting
+  // two bits."
+  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
+                        spirv::MemorySemantics::Release |
+                        spirv::MemorySemantics::AcquireRelease |
+                        spirv::MemorySemantics::SequentiallyConsistent;
+  auto bitCount =
+      llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
+  if (bitCount > 1) {
+    return op->emitError(
+        "expected at most one of these four memory constraints "
+        "to be set: `Acquire`, `Release`,"
+        "`AcquireRelease` or `SequentiallyConsistent`");
+  }
+  return success();
+void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer,
+                                     SmallVectorImpl<StringRef> &elidedAttrs) {
+  // Print optional descriptor binding
+  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
+      stringifyDecoration(spirv::Decoration::DescriptorSet));
+  auto bindingName = llvm::convertToSnakeFromCamelCase(
+      stringifyDecoration(spirv::Decoration::Binding));
+  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
+  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
+  if (descriptorSet && binding) {
+    elidedAttrs.push_back(descriptorSetName);
+    elidedAttrs.push_back(bindingName);
+    printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
+            << ")";
+  }
+  // Print BuiltIn attribute if present
+  auto builtInName = llvm::convertToSnakeFromCamelCase(
+      stringifyDecoration(spirv::Decoration::BuiltIn));
+  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
+    printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
+    elidedAttrs.push_back(builtInName);
+  }
+  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
                                                    OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
@@ -93,177 +163,6 @@ static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
   p << " : " << resultType;
-/// Returns true if the given op is a function-like op or nested in a
-/// function-like op without a module-like op in the middle.
-static bool isNestedInFunctionOpInterface(Operation *op) {
-  if (!op)
-    return false;
-  if (op->hasTrait<OpTrait::SymbolTable>())
-    return false;
-  if (isa<FunctionOpInterface>(op))
-    return true;
-  return isNestedInFunctionOpInterface(op->getParentOp());
-/// Returns true if the given op is an module-like op that maintains a symbol
-/// table.
-static bool isDirectInModuleLikeOp(Operation *op) {
-  return op && op->hasTrait<OpTrait::SymbolTable>();
-LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
-  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
-  if (!constOp) {
-    return failure();
-  }
-  auto valueAttr = constOp.getValue();
-  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
-  if (!integerValueAttr) {
-    return failure();
-  }
-  if (integerValueAttr.getType().isSignlessInteger())
-    value = integerValueAttr.getInt();
-  else
-    value = integerValueAttr.getSInt();
-  return success();
-/// Parses Function, Selection and Loop control attributes. If no control is
-/// specified, "None" is used as a default.
-template <typename EnumAttrClass, typename EnumClass>
-static ParseResult
-parseControlAttribute(OpAsmParser &parser, OperationState &state,
-                      StringRef attrName = spirv::attributeName<EnumClass>()) {
-  if (succeeded(parser.parseOptionalKeyword(kControl))) {
-    EnumClass control;
-    if (parser.parseLParen() ||
-        spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
-        parser.parseRParen())
-      return failure();
-    return success();
-  }
-  // Set control to "None" otherwise.
-  Builder builder = parser.getBuilder();
-  state.addAttribute(attrName,
-                     builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
-  return success();
-// TODO Make sure to merge this and the previous function into one template
-// parameterized by memory access attribute name and alignment. Doing so now
-// results in VS2017 in producing an internal error (at the call site) that's
-// not detailed enough to understand what is happening.
-static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
-                                                     OperationState &state) {
-  // Parse an optional list of attributes staring with '['
-  if (parser.parseOptionalLSquare()) {
-    // Nothing to do
-    return success();
-  }
-  spirv::MemoryAccess memoryAccessAttr;
-  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
-          memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
-    return failure();
-  if (spirv::bitEnumContainsAll(memoryAccessAttr,
-                                spirv::MemoryAccess::Aligned)) {
-    // Parse integer attribute for alignment.
-    Attribute alignmentAttr;
-    Type i32Type = parser.getBuilder().getIntegerType(32);
-    if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
-                              state.attributes)) {
-      return failure();
-    }
-  }
-  return parser.parseRSquare();
-template <typename MemoryOpTy>
-static void printMemoryAccessAttribute(
-    MemoryOpTy memoryOp, OpAsmPrinter &printer,
-    SmallVectorImpl<StringRef> &elidedAttrs,
-    std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
-    std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
-  // Print optional memory access attribute.
-  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
-                                              : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kMemoryAccessAttrName);
-    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
-    if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
-      // Print integer alignment attribute.
-      if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
-                                               : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kAlignmentAttrName);
-        printer << ", " << *alignment;
-      }
-    }
-    printer << "]";
-  }
-  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
-// TODO Make sure to merge this and the previous function into one template
-// parameterized by memory access attribute name and alignment. Doing so now
-// results in VS2017 in producing an internal error (at the call site) that's
-// not detailed enough to understand what is happening.
-template <typename MemoryOpTy>
-static void printSourceMemoryAccessAttribute(
-    MemoryOpTy memoryOp, OpAsmPrinter &printer,
-    SmallVectorImpl<StringRef> &elidedAttrs,
-    std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
-    std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
-  printer << ", ";
-  // Print optional memory access attribute.
-  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
-                                              : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kSourceMemoryAccessAttrName);
-    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
-    if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
-      // Print integer alignment attribute.
-      if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
-                                               : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kSourceAlignmentAttrName);
-        printer << ", " << *alignment;
-      }
-    }
-    printer << "]";
-  }
-  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
-static ParseResult parseImageOperands(OpAsmParser &parser,
-                                      spirv::ImageOperandsAttr &attr) {
-  // Expect image operands
-  if (parser.parseOptionalLSquare())
-    return success();
-  spirv::ImageOperands imageOperands;
-  if (parseEnumStrAttr(imageOperands, parser))
-    return failure();
-  attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
-  return parser.parseRSquare();
-static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
-                               spirv::ImageOperandsAttr attr) {
-  if (attr) {
-    auto strImageOperands = stringifyImageOperands(attr.getValue());
-    printer << "[\"" << strImageOperands << "\"]";
-  }
 template <typename Op>
 static LogicalResult verifyImageOperands(Op imageOp,
                                          spirv::ImageOperandsAttr attr,
@@ -292,130 +191,6 @@ static LogicalResult verifyImageOperands(Op imageOp,
   return success();
-template <typename MemoryOpTy>
-static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
-  // ODS checks for attributes values. Just need to verify that if the
-  // memory-access attribute is Aligned, then the alignment attribute must be
-  // present.
-  auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
-  if (!memAccessAttr) {
-    // Alignment attribute shouldn't be present if memory access attribute is
-    // not present.
-    if (op->getAttr(kAlignmentAttrName)) {
-      return memoryOp.emitOpError(
-          "invalid alignment specification without aligned memory access "
-          "specification");
-    }
-    return success();
-  }
-  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
-  if (!memAccess) {
-    return memoryOp.emitOpError("invalid memory access specifier: ")
-           << memAccessAttr;
-  }
-  if (spirv::bitEnumContainsAll(memAccess.getValue(),
-                                spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kAlignmentAttrName)) {
-      return memoryOp.emitOpError("missing alignment value");
-    }
-  } else {
-    if (op->getAttr(kAlignmentAttrName)) {
-      return memoryOp.emitOpError(
-          "invalid alignment specification with non-aligned memory access "
-          "specification");
-    }
-  }
-  return success();
-// TODO Make sure to merge this and the previous function into one template
-// parameterized by memory access attribute name and alignment. Doing so now
-// results in VS2017 in producing an internal error (at the call site) that's
-// not detailed enough to understand what is happening.
-template <typename MemoryOpTy>
-static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
-  // ODS checks for attributes values. Just need to verify that if the
-  // memory-access attribute is Aligned, then the alignment attribute must be
-  // present.
-  auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
-  if (!memAccessAttr) {
-    // Alignment attribute shouldn't be present if memory access attribute is
-    // not present.
-    if (op->getAttr(kSourceAlignmentAttrName)) {
-      return memoryOp.emitOpError(
-          "invalid alignment specification without aligned memory access "
-          "specification");
-    }
-    return success();
-  }
-  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
-  if (!memAccess) {
-    return memoryOp.emitOpError("invalid memory access specifier: ")
-           << memAccess;
-  }
-  if (spirv::bitEnumContainsAll(memAccess.getValue(),
-                                spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kSourceAlignmentAttrName)) {
-      return memoryOp.emitOpError("missing alignment value");
-    }
-  } else {
-    if (op->getAttr(kSourceAlignmentAttrName)) {
-      return memoryOp.emitOpError(
-          "invalid alignment specification with non-aligned memory access "
-          "specification");
-    }
-  }
-  return success();
-spirv::verifyMemorySemantics(Operation *op,
-                             spirv::MemorySemantics memorySemantics) {
-  // According to the SPIR-V specification:
-  // "Despite being a mask and allowing multiple bits to be combined, it is
-  // invalid for more than one of these four bits to be set: Acquire, Release,
-  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
-  // Release semantics is done by setting the AcquireRelease bit, not by setting
-  // two bits."
-  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
-                        spirv::MemorySemantics::Release |
-                        spirv::MemorySemantics::AcquireRelease |
-                        spirv::MemorySemantics::SequentiallyConsistent;
-  auto bitCount =
-      llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
-  if (bitCount > 1) {
-    return op->emitError(
-        "expected at most one of these four memory constraints "
-        "to be set: `Acquire`, `Release`,"
-        "`AcquireRelease` or `SequentiallyConsistent`");
-  }
-  return success();
-template <typename LoadStoreOpTy>
-static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
-                                                   Value val) {
-  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
-  // type of the pointer and the type of the value are the same
-  //
-  // TODO: Check that the value type satisfies restrictions of
-  // SPIR-V OpLoad/OpStore operations
-  if (val.getType() !=
-      llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
-    return op.emitOpError("mismatch in result type and pointer type");
-  }
-  return success();
 template <typename BlockReadWriteOpTy>
 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
                                                         Value ptr, Value val) {
@@ -430,70 +205,6 @@ static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
   return success();
-static ParseResult parseVariableDecorations(OpAsmParser &parser,
-                                            OperationState &state) {
-  auto builtInName = llvm::convertToSnakeFromCamelCase(
-      stringifyDecoration(spirv::Decoration::BuiltIn));
-  if (succeeded(parser.parseOptionalKeyword("bind"))) {
-    Attribute set, binding;
-    // Parse optional descriptor binding
-    auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
-        stringifyDecoration(spirv::Decoration::DescriptorSet));
-    auto bindingName = llvm::convertToSnakeFromCamelCase(
-        stringifyDecoration(spirv::Decoration::Binding));
-    Type i32Type = parser.getBuilder().getIntegerType(32);
-    if (parser.parseLParen() ||
-        parser.parseAttribute(set, i32Type, descriptorSetName,
-                              state.attributes) ||
-        parser.parseComma() ||
-        parser.parseAttribute(binding, i32Type, bindingName,
-                              state.attributes) ||
-        parser.parseRParen()) {
-      return failure();
-    }
-  } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
-    StringAttr builtIn;
-    if (parser.parseLParen() ||
-        parser.parseAttribute(builtIn, builtInName, state.attributes) ||
-        parser.parseRParen()) {
-      return failure();
-    }
-  }
-  // Parse other attributes
-  if (parser.parseOptionalAttrDict(state.attributes))
-    return failure();
-  return success();
-static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
-                                     SmallVectorImpl<StringRef> &elidedAttrs) {
-  // Print optional descriptor binding
-  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
-      stringifyDecoration(spirv::Decoration::DescriptorSet));
-  auto bindingName = llvm::convertToSnakeFromCamelCase(
-      stringifyDecoration(spirv::Decoration::Binding));
-  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
-  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
-  if (descriptorSet && binding) {
-    elidedAttrs.push_back(descriptorSetName);
-    elidedAttrs.push_back(bindingName);
-    printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
-            << ")";
-  }
-  // Print BuiltIn attribute if present
-  auto builtInName = llvm::convertToSnakeFromCamelCase(
-      stringifyDecoration(spirv::Decoration::BuiltIn));
-  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
-    printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
-    elidedAttrs.push_back(builtInName);
-  }
-  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 /// Walks the given type hierarchy with the given indices, potentially down
 /// to component granularity, to select an element type. Returns null type and
 /// emits errors with the given loc on failure.
@@ -564,12 +275,6 @@ static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
   return getElementType(type, indices, errorFn);
-/// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
-static inline bool isMergeBlock(Block &block) {
-  return !block.empty() && std::next(block.begin()) == block.end() &&
-         isa<spirv::MergeOp>(block.front());
 template <typename ExtendedBinaryOp>
 static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
   auto resultType = llvm::cast<spirv::StructType>(op.getType());
@@ -617,15 +322,6 @@ static void printArithmeticExtendedBinaryOp(Operation *op,
   printer << " : " << op->getResultTypes().front();
-/// Result of a logical op must be a scalar or vector of boolean type.
-static Type getUnaryOpResultType(Type operandType) {
-  Builder builder(operandType.getContext());
-  Type resultType = builder.getIntegerType(1);
-  if (auto vecType = llvm::dyn_cast<VectorType>(operandType))
-    return VectorType::get(vecType.getNumElements(), resultType);
-  return resultType;
 static LogicalResult verifyShiftOp(Operation *op) {
   if (op->getOperand(0).getType() != op->getResult(0).getType()) {
     return op->emitError("expected the same type for the first operand and "
@@ -636,152 +332,6 @@ static LogicalResult verifyShiftOp(Operation *op) {
   return success();
-// spirv.AccessChainOp
-static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
-  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
-  if (!ptrType) {
-    emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
-                       "to composite type, but provided ")
-        << type;
-    return nullptr;
-  }
-  auto resultType = ptrType.getPointeeType();
-  auto resultStorageClass = ptrType.getStorageClass();
-  int32_t index = 0;
-  for (auto indexSSA : indices) {
-    auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
-    if (!cType) {
-      emitError(
-          baseLoc,
-          "'spirv.AccessChain' op cannot extract from non-composite type ")
-          << resultType << " with index " << index;
-      return nullptr;
-    }
-    index = 0;
-    if (llvm::isa<spirv::StructType>(resultType)) {
-      Operation *op = indexSSA.getDefiningOp();
-      if (!op) {
-        emitError(baseLoc, "'spirv.AccessChain' op index must be an "
-                           "integer spirv.Constant to access "
-                           "element of spirv.struct");
-        return nullptr;
-      }
-      // TODO: this should be relaxed to allow
-      // integer literals of other bitwidths.
-      if (failed(spirv::extractValueFromConstOp(op, index))) {
-        emitError(
-            baseLoc,
-            "'spirv.AccessChain' index must be an integer spirv.Constant to "
-            "access element of spirv.struct, but provided ")
-            << op->getName();
-        return nullptr;
-      }
-      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
-        emitError(baseLoc, "'spirv.AccessChain' op index ")
-            << index << " out of bounds for " << resultType;
-        return nullptr;
-      }
-    }
-    resultType = cType.getElementType(index);
-  }
-  return spirv::PointerType::get(resultType, resultStorageClass);
-void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
-                                 Value basePtr, ValueRange indices) {
-  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
-  assert(type && "Unable to deduce return type based on basePtr and indices");
-  build(builder, state, type, basePtr, indices);
-ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
-                                        OperationState &result) {
-  OpAsmParser::UnresolvedOperand ptrInfo;
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
-  Type type;
-  auto loc = parser.getCurrentLocation();
-  SmallVector<Type, 4> indicesTypes;
-  if (parser.parseOperand(ptrInfo) ||
-      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(ptrInfo, type, result.operands)) {
-    return failure();
-  }
-  // Check that the provided indices list is not empty before parsing their
-  // type list.
-  if (indicesInfo.empty()) {
-    return mlir::emitError(result.location,
-                           "'spirv.AccessChain' op expected at "
-                           "least one index ");
-  }
-  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
-    return failure();
-  // Check that the indices types list is not empty and that it has a one-to-one
-  // mapping to the provided indices.
-  if (indicesTypes.size() != indicesInfo.size()) {
-    return mlir::emitError(
-        result.location, "'spirv.AccessChain' op indices types' count must be "
-                         "equal to indices info count");
-  }
-  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
-    return failure();
-  auto resultType = getElementPtrType(
-      type, llvm::ArrayRef(result.operands).drop_front(), result.location);
-  if (!resultType) {
-    return failure();
-  }
-  result.addTypes(resultType);
-  return success();
-template <typename Op>
-static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
-  printer << ' ' << op.getBasePtr() << '[' << indices
-          << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
-void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, getIndices(), printer);
-template <typename Op>
-static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
-  auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
-                                      indices, accessChainOp.getLoc());
-  if (!resultType)
-    return failure();
-  auto providedResultType =
-      llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
-  if (!providedResultType)
-    return accessChainOp.emitOpError(
-               "result type must be a pointer, but provided")
-           << providedResultType;
-  if (resultType != providedResultType)
-    return accessChainOp.emitOpError("invalid result type: expected ")
-           << resultType << ", but provided " << providedResultType;
-  return success();
-LogicalResult spirv::AccessChainOp::verify() {
-  return verifyAccessChain(*this, getIndices());
 // spirv.mlir.addressof
@@ -805,109 +355,6 @@ LogicalResult spirv::AddressOfOp::verify() {
   return success();
-// spirv.BranchOp
-SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
-  assert(index == 0 && "invalid successor index");
-  return SuccessorOperands(0, getTargetOperandsMutable());
-// spirv.BranchConditionalOp
-spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
-  assert(index < 2 && "invalid successor index");
-  return SuccessorOperands(index == kTrueIndex
-                               ? getTrueTargetOperandsMutable()
-                               : getFalseTargetOperandsMutable());
-ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
-                                              OperationState &result) {
-  auto &builder = parser.getBuilder();
-  OpAsmParser::UnresolvedOperand condInfo;
-  Block *dest;
-  // Parse the condition.
-  Type boolTy = builder.getI1Type();
-  if (parser.parseOperand(condInfo) ||
-      parser.resolveOperand(condInfo, boolTy, result.operands))
-    return failure();
-  // Parse the optional branch weights.
-  if (succeeded(parser.parseOptionalLSquare())) {
-    IntegerAttr trueWeight, falseWeight;
-    NamedAttrList weights;
-    auto i32Type = builder.getIntegerType(32);
-    if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
-        parser.parseComma() ||
-        parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
-        parser.parseRSquare())
-      return failure();
-    result.addAttribute(kBranchWeightAttrName,
-                        builder.getArrayAttr({trueWeight, falseWeight}));
-  }
-  // Parse the true branch.
-  SmallVector<Value, 4> trueOperands;
-  if (parser.parseComma() ||
-      parser.parseSuccessorAndUseList(dest, trueOperands))
-    return failure();
-  result.addSuccessors(dest);
-  result.addOperands(trueOperands);
-  // Parse the false branch.
-  SmallVector<Value, 4> falseOperands;
-  if (parser.parseComma() ||
-      parser.parseSuccessorAndUseList(dest, falseOperands))
-    return failure();
-  result.addSuccessors(dest);
-  result.addOperands(falseOperands);
-  result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
-                      builder.getDenseI32ArrayAttr(
-                          {1, static_cast<int32_t>(trueOperands.size()),
-                           static_cast<int32_t>(falseOperands.size())}));
-  return success();
-void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
-  printer << ' ' << getCondition();
-  if (auto weights = getBranchWeights()) {
-    printer << " [";
-    llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
-      printer << llvm::cast<IntegerAttr>(a).getInt();
-    });
-    printer << "]";
-  }
-  printer << ", ";
-  printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
-  printer << ", ";
-  printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
-LogicalResult spirv::BranchConditionalOp::verify() {
-  if (auto weights = getBranchWeights()) {
-    if (weights->getValue().size() != 2) {
-      return emitOpError("must have exactly two branch weights");
-    }
-    if (llvm::all_of(*weights, [](Attribute attr) {
-          return llvm::cast<IntegerAttr>(attr).getValue().isZero();
-        }))
-      return emitOpError("branch weights cannot both be zero");
-  }
-  return success();
 // spirv.CompositeConstruct
@@ -1584,72 +1031,6 @@ ::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() {
   return getResAttrs().value_or(nullptr);
-// spirv.FunctionCall
-LogicalResult spirv::FunctionCallOp::verify() {
-  auto fnName = getCalleeAttr();
-  auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
-      SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
-  if (!funcOp) {
-    return emitOpError("callee function '")
-           << fnName.getValue() << "' not found in nearest symbol table";
-  }
-  auto functionType = funcOp.getFunctionType();
-  if (getNumResults() > 1) {
-    return emitOpError(
-               "expected callee function to have 0 or 1 result, but provided ")
-           << getNumResults();
-  }
-  if (functionType.getNumInputs() != getNumOperands()) {
-    return emitOpError("has incorrect number of operands for callee: expected ")
-           << functionType.getNumInputs() << ", but provided "
-           << getNumOperands();
-  }
-  for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
-    if (getOperand(i).getType() != functionType.getInput(i)) {
-      return emitOpError("operand type mismatch: expected operand type ")
-             << functionType.getInput(i) << ", but provided "
-             << getOperand(i).getType() << " for operand number " << i;
-    }
-  }
-  if (functionType.getNumResults() != getNumResults()) {
-    return emitOpError(
-               "has incorrect number of results has for callee: expected ")
-           << functionType.getNumResults() << ", but provided "
-           << getNumResults();
-  }
-  if (getNumResults() &&
-      (getResult(0).getType() != functionType.getResult(0))) {
-    return emitOpError("result type mismatch: expected ")
-           << functionType.getResult(0) << ", but provided "
-           << getResult(0).getType();
-  }
-  return success();
-CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
-  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
-void spirv::FunctionCallOp::setCalleeFromCallable(
-    CallInterfaceCallable callee) {
-  (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
-Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
-  return getArguments();
 // spirv.GLFClampOp
@@ -1768,7 +1149,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
-  printVariableDecorations(*this, printer, elidedAttrs);
+  spirv::printVariableDecorations(*this, printer, elidedAttrs);
   printer << " : " << getType();
@@ -1933,232 +1314,21 @@ void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
   ::printArithmeticExtendedBinaryOp(*this, printer);
-// spirv.UMulExtended
-LogicalResult spirv::UMulExtendedOp::verify() {
-  return ::verifyArithmeticExtendedBinaryOp(*this);
-ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
-                                         OperationState &result) {
-  return ::parseArithmeticExtendedBinaryOp(parser, result);
-void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
-  ::printArithmeticExtendedBinaryOp(*this, printer);
-// spirv.LoadOp
-void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
-                          Value basePtr, MemoryAccessAttr memoryAccess,
-                          IntegerAttr alignment) {
-  auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
-  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
-        alignment);
-ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
-  // Parse the storage class specification
-  spirv::StorageClass storageClass;
-  OpAsmParser::UnresolvedOperand ptrInfo;
-  Type elementType;
-  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
-      parseMemoryAccessAttributes(parser, result) ||
-      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
-      parser.parseType(elementType)) {
-    return failure();
-  }
-  auto ptrType = spirv::PointerType::get(elementType, storageClass);
-  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
-    return failure();
-  }
-  result.addTypes(elementType);
-  return success();
-void spirv::LoadOp::print(OpAsmPrinter &printer) {
-  SmallVector<StringRef, 4> elidedAttrs;
-  StringRef sc = stringifyStorageClass(
-      llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
-  printer << " \"" << sc << "\" " << getPtr();
-  printMemoryAccessAttribute(*this, printer, elidedAttrs);
-  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
-  printer << " : " << getType();
-LogicalResult spirv::LoadOp::verify() {
-  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
-  // type with fixed size; i.e., it cannot be, nor include, any
-  // OpTypeRuntimeArray types."
-  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
-    return failure();
-  }
-  return verifyMemoryAccessAttribute(*this);
-// spirv.mlir.loop
-void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
-  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
-                                         spirv::LoopControl::None));
-  state.addRegion();
-ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
-  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
-                                                                        result))
-    return failure();
-  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
-void spirv::LoopOp::print(OpAsmPrinter &printer) {
-  auto control = getLoopControl();
-  if (control != spirv::LoopControl::None)
-    printer << " control(" << spirv::stringifyLoopControl(control) << ")";
-  printer << ' ';
-  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
-                      /*printBlockTerminators=*/true);
-/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
-/// given `dstBlock`.
-static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
-  // Check that there is only one op in the `srcBlock`.
-  if (!llvm::hasSingleElement(srcBlock))
-    return false;
-  auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
-  return branchOp && branchOp.getSuccessor() == &dstBlock;
-LogicalResult spirv::LoopOp::verifyRegions() {
-  auto *op = getOperation();
-  // We need to verify that the blocks follow the following layout:
-  //
-  //                     +-------------+
-  //                     | entry block |
-  //                     +-------------+
-  //                            |
-  //                            v
-  //                     +-------------+
-  //                     | loop header | <-----+
-  //                     +-------------+       |
-  //                                           |
-  //                           ...             |
-  //                          \ | /            |
-  //                            v              |
-  //                    +---------------+      |
-  //                    | loop continue | -----+
-  //                    +---------------+
-  //
-  //                           ...
-  //                          \ | /
-  //                            v
-  //                     +-------------+
-  //                     | merge block |
-  //                     +-------------+
-  auto &region = op->getRegion(0);
-  // Allow empty region as a degenerated case, which can come from
-  // optimizations.
-  if (region.empty())
-    return success();
-  // The last block is the merge block.
-  Block &merge = region.back();
-  if (!isMergeBlock(merge))
-    return emitOpError("last block must be the merge block with only one "
-                       "'spirv.mlir.merge' op");
-  if (std::next(region.begin()) == region.end())
-    return emitOpError(
-        "must have an entry block branching to the loop header block");
-  // The first block is the entry block.
-  Block &entry = region.front();
-  if (std::next(region.begin(), 2) == region.end())
-    return emitOpError(
-        "must have a loop header block branched from the entry block");
-  // The second block is the loop header block.
-  Block &header = *std::next(region.begin(), 1);
-  if (!hasOneBranchOpTo(entry, header))
-    return emitOpError(
-        "entry block must only have one 'spirv.Branch' op to the second block");
-  if (std::next(region.begin(), 3) == region.end())
-    return emitOpError(
-        "requires a loop continue block branching to the loop header block");
-  // The second to last block is the loop continue block.
-  Block &cont = *std::prev(region.end(), 2);
-  // Make sure that we have a branch from the loop continue block to the loop
-  // header block.
-  if (llvm::none_of(
-          llvm::seq<unsigned>(0, cont.getNumSuccessors()),
-          [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
-    return emitOpError("second to last block must be the loop continue "
-                       "block that branches to the loop header block");
-  // Make sure that no other blocks (except the entry and loop continue block)
-  // branches to the loop header block.
-  for (auto &block : llvm::make_range(std::next(region.begin(), 2),
-                                      std::prev(region.end(), 2))) {
-    for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
-      if (block.getSuccessor(i) == &header) {
-        return emitOpError("can only have the entry and loop continue "
-                           "block branching to the loop header block");
-      }
-    }
-  }
-  return success();
-Block *spirv::LoopOp::getEntryBlock() {
-  assert(!getBody().empty() && "op region should not be empty!");
-  return &getBody().front();
-Block *spirv::LoopOp::getHeaderBlock() {
-  assert(!getBody().empty() && "op region should not be empty!");
-  // The second block is the loop header block.
-  return &*std::next(getBody().begin());
+// spirv.UMulExtended
-Block *spirv::LoopOp::getContinueBlock() {
-  assert(!getBody().empty() && "op region should not be empty!");
-  // The second to last block is the loop continue block.
-  return &*std::prev(getBody().end(), 2);
+LogicalResult spirv::UMulExtendedOp::verify() {
+  return ::verifyArithmeticExtendedBinaryOp(*this);
-Block *spirv::LoopOp::getMergeBlock() {
-  assert(!getBody().empty() && "op region should not be empty!");
-  // The last block is the loop merge block.
-  return &getBody().back();
+ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
+                                         OperationState &result) {
+  return ::parseArithmeticExtendedBinaryOp(parser, result);
-void spirv::LoopOp::addEntryAndMergeBlock() {
-  assert(getBody().empty() && "entry and merge block already exist");
-  getBody().push_back(new Block());
-  auto *mergeBlock = new Block();
-  getBody().push_back(mergeBlock);
-  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
-  // Add a spirv.mlir.merge op into the merge block.
-  builder.create<spirv::MergeOp>(getLoc());
+void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
+  ::printArithmeticExtendedBinaryOp(*this, printer);
@@ -2169,24 +1339,6 @@ LogicalResult spirv::MemoryBarrierOp::verify() {
   return verifyMemorySemantics(getOperation(), getMemorySemantics());
-// spirv.mlir.merge
-LogicalResult spirv::MergeOp::verify() {
-  auto *parentOp = (*this)->getParentOp();
-  if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
-    return emitOpError(
-        "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
-  // TODO: This check should be done in `verifyRegions` of parent op.
-  Block &parentLastBlock = (*this)->getParentRegion()->back();
-  if (getOperation() != parentLastBlock.getTerminator())
-    return emitOpError("can only be used in the last block of "
-                       "'spirv.mlir.selection' or 'spirv.mlir.loop'");
-  return success();
 // spirv.module
@@ -2382,158 +1534,6 @@ LogicalResult spirv::ReferenceOfOp::verify() {
   return success();
-// spirv.Return
-LogicalResult spirv::ReturnOp::verify() {
-  // Verification is performed in spirv.func op.
-  return success();
-// spirv.ReturnValue
-LogicalResult spirv::ReturnValueOp::verify() {
-  // Verification is performed in spirv.func op.
-  return success();
-// spirv.Select
-LogicalResult spirv::SelectOp::verify() {
-  if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
-    auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
-    if (!resultVectorTy) {
-      return emitOpError("result expected to be of vector type when "
-                         "condition is of vector type");
-    }
-    if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
-      return emitOpError("result should have the same number of elements as "
-                         "the condition when condition is of vector type");
-    }
-  }
-  return success();
-// spirv.mlir.selection
-ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
-                                      OperationState &result) {
-  if (parseControlAttribute<spirv::SelectionControlAttr,
-                            spirv::SelectionControl>(parser, result))
-    return failure();
-  return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
-void spirv::SelectionOp::print(OpAsmPrinter &printer) {
-  auto control = getSelectionControl();
-  if (control != spirv::SelectionControl::None)
-    printer << " control(" << spirv::stringifySelectionControl(control) << ")";
-  printer << ' ';
-  printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
-                      /*printBlockTerminators=*/true);
-LogicalResult spirv::SelectionOp::verifyRegions() {
-  auto *op = getOperation();
-  // We need to verify that the blocks follow the following layout:
-  //
-  //                     +--------------+
-  //                     | header block |
-  //                     +--------------+
-  //                          / | \
-  //                           ...
-  //
-  //
-  //         +---------+   +---------+   +---------+
-  //         | case #0 |   | case #1 |   | case #2 |  ...
-  //         +---------+   +---------+   +---------+
-  //
-  //
-  //                           ...
-  //                          \ | /
-  //                            v
-  //                     +-------------+
-  //                     | merge block |
-  //                     +-------------+
-  auto &region = op->getRegion(0);
-  // Allow empty region as a degenerated case, which can come from
-  // optimizations.
-  if (region.empty())
-    return success();
-  // The last block is the merge block.
-  if (!isMergeBlock(region.back()))
-    return emitOpError("last block must be the merge block with only one "
-                       "'spirv.mlir.merge' op");
-  if (std::next(region.begin()) == region.end())
-    return emitOpError("must have a selection header block");
-  return success();
-Block *spirv::SelectionOp::getHeaderBlock() {
-  assert(!getBody().empty() && "op region should not be empty!");
-  // The first block is the loop header block.
-  return &getBody().front();
-Block *spirv::SelectionOp::getMergeBlock() {
-  assert(!getBody().empty() && "op region should not be empty!");
-  // The last block is the loop merge block.
-  return &getBody().back();
-void spirv::SelectionOp::addMergeBlock() {
-  assert(getBody().empty() && "entry and merge block already exist");
-  auto *mergeBlock = new Block();
-  getBody().push_back(mergeBlock);
-  OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
-  // Add a spirv.mlir.merge op into the merge block.
-  builder.create<spirv::MergeOp>(getLoc());
-spirv::SelectionOp spirv::SelectionOp::createIfThen(
-    Location loc, Value condition,
-    function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
-  auto selectionOp =
-      builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
-  selectionOp.addMergeBlock();
-  Block *mergeBlock = selectionOp.getMergeBlock();
-  Block *thenBlock = nullptr;
-  // Build the "then" block.
-  {
-    OpBuilder::InsertionGuard guard(builder);
-    thenBlock = builder.createBlock(mergeBlock);
-    thenBody(builder);
-    builder.create<spirv::BranchOp>(loc, mergeBlock);
-  }
-  // Build the header block.
-  {
-    OpBuilder::InsertionGuard guard(builder);
-    builder.createBlock(thenBlock);
-    builder.create<spirv::BranchConditionalOp>(
-        loc, condition, thenBlock,
-        /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
-        /*falseArguments=*/ArrayRef<Value>());
-  }
-  return selectionOp;
 // spirv.SpecConstant
@@ -2588,171 +1588,6 @@ LogicalResult spirv::SpecConstantOp::verify() {
       "default value can only be a bool, integer, or float scalar");
-// spirv.StoreOp
-ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
-  // Parse the storage class specification
-  spirv::StorageClass storageClass;
-  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
-  auto loc = parser.getCurrentLocation();
-  Type elementType;
-  if (parseEnumStrAttr(storageClass, parser) ||
-      parser.parseOperandList(operandInfo, 2) ||
-      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
-      parser.parseType(elementType)) {
-    return failure();
-  }
-  auto ptrType = spirv::PointerType::get(elementType, storageClass);
-  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
-                             result.operands)) {
-    return failure();
-  }
-  return success();
-void spirv::StoreOp::print(OpAsmPrinter &printer) {
-  SmallVector<StringRef, 4> elidedAttrs;
-  StringRef sc = stringifyStorageClass(
-      llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
-  printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
-  printMemoryAccessAttribute(*this, printer, elidedAttrs);
-  printer << " : " << getValue().getType();
-  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
-LogicalResult spirv::StoreOp::verify() {
-  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
-  // OpTypePointer whose Type operand is the same as the type of Object."
-  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
-    return failure();
-  return verifyMemoryAccessAttribute(*this);
-// spirv.Unreachable
-LogicalResult spirv::UnreachableOp::verify() {
-  auto *block = (*this)->getBlock();
-  // Fast track: if this is in entry block, its invalid. Otherwise, if no
-  // predecessors, it's valid.
-  if (block->isEntryBlock())
-    return emitOpError("cannot be used in reachable block");
-  if (block->hasNoPredecessors())
-    return success();
-  // TODO: further verification needs to analyze reachability from
-  // the entry block.
-  return success();
-// spirv.Variable
-ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
-                                     OperationState &result) {
-  // Parse optional initializer
-  std::optional<OpAsmParser::UnresolvedOperand> initInfo;
-  if (succeeded(parser.parseOptionalKeyword("init"))) {
-    initInfo = OpAsmParser::UnresolvedOperand();
-    if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
-        parser.parseRParen())
-      return failure();
-  }
-  if (parseVariableDecorations(parser, result)) {
-    return failure();
-  }
-  // Parse result pointer type
-  Type type;
-  if (parser.parseColon())
-    return failure();
-  auto loc = parser.getCurrentLocation();
-  if (parser.parseType(type))
-    return failure();
-  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
-  if (!ptrType)
-    return parser.emitError(loc, "expected spirv.ptr type");
-  result.addTypes(ptrType);
-  // Resolve the initializer operand
-  if (initInfo) {
-    if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
-                              result.operands))
-      return failure();
-  }
-  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
-      ptrType.getStorageClass());
-  result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
-  return success();
-void spirv::VariableOp::print(OpAsmPrinter &printer) {
-  SmallVector<StringRef, 4> elidedAttrs{
-      spirv::attributeName<spirv::StorageClass>()};
-  // Print optional initializer
-  if (getNumOperands() != 0)
-    printer << " init(" << getInitializer() << ")";
-  printVariableDecorations(*this, printer, elidedAttrs);
-  printer << " : " << getType();
-LogicalResult spirv::VariableOp::verify() {
-  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
-  // object. It cannot be Generic. It must be the same as the Storage Class
-  // operand of the Result Type."
-  if (getStorageClass() != spirv::StorageClass::Function) {
-    return emitOpError(
-        "can only be used to model function-level variables. Use "
-        "spirv.GlobalVariable for module-level variables.");
-  }
-  auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
-  if (getStorageClass() != pointerType.getStorageClass())
-    return emitOpError(
-        "storage class must match result pointer's storage class");
-  if (getNumOperands() != 0) {
-    // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
-    // a global (module scope) OpVariable instruction".
-    auto *initOp = getOperand(0).getDefiningOp();
-    if (!initOp || !isa<spirv::ConstantOp,    // for normal constant
-                        spirv::ReferenceOfOp, // for spec constant
-                        spirv::AddressOfOp>(initOp))
-      return emitOpError("initializer must be the result of a "
-                         "constant or spirv.GlobalVariable op");
-  }
-  // TODO: generate these strings using ODS.
-  auto *op = getOperation();
-  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
-      stringifyDecoration(spirv::Decoration::DescriptorSet));
-  auto bindingName = llvm::convertToSnakeFromCamelCase(
-      stringifyDecoration(spirv::Decoration::Binding));
-  auto builtInName = llvm::convertToSnakeFromCamelCase(
-      stringifyDecoration(spirv::Decoration::BuiltIn));
-  for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
-    if (op->getAttr(attr))
-      return emitOpError("cannot have '")
-             << attr << "' attribute (only allowed in spirv.GlobalVariable)";
-  }
-  return success();
 // spirv.VectorShuffle
@@ -2804,100 +1639,6 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
   return success();
-// spirv.CopyMemory
-void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
-  printer << ' ';
-  StringRef targetStorageClass = stringifyStorageClass(
-      llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
-  printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
-  StringRef sourceStorageClass = stringifyStorageClass(
-      llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
-  printer << " \"" << sourceStorageClass << "\" " << getSource();
-  SmallVector<StringRef, 4> elidedAttrs;
-  printMemoryAccessAttribute(*this, printer, elidedAttrs);
-  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
-                                   getSourceMemoryAccess(),
-                                   getSourceAlignment());
-  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
-  Type pointeeType =
-      llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
-  printer << " : " << pointeeType;
-ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
-                                       OperationState &result) {
-  spirv::StorageClass targetStorageClass;
-  OpAsmParser::UnresolvedOperand targetPtrInfo;
-  spirv::StorageClass sourceStorageClass;
-  OpAsmParser::UnresolvedOperand sourcePtrInfo;
-  Type elementType;
-  if (parseEnumStrAttr(targetStorageClass, parser) ||
-      parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
-      parseEnumStrAttr(sourceStorageClass, parser) ||
-      parser.parseOperand(sourcePtrInfo) ||
-      parseMemoryAccessAttributes(parser, result)) {
-    return failure();
-  }
-  if (!parser.parseOptionalComma()) {
-    // Parse 2nd memory access attributes.
-    if (parseSourceMemoryAccessAttributes(parser, result)) {
-      return failure();
-    }
-  }
-  if (parser.parseColon() || parser.parseType(elementType))
-    return failure();
-  if (parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
-  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
-  if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
-      parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
-    return failure();
-  }
-  return success();
-LogicalResult spirv::CopyMemoryOp::verify() {
-  Type targetType =
-      llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
-  Type sourceType =
-      llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
-  if (targetType != sourceType)
-    return emitOpError("both operands must be pointers to the same type");
-  if (failed(verifyMemoryAccessAttribute(*this)))
-    return failure();
-  // TODO - According to the spec:
-  //
-  // If two masks are present, the first applies to Target and cannot include
-  // MakePointerVisible, and the second applies to Source and cannot include
-  // MakePointerAvailable.
-  //
-  // Add such verification here.
-  return verifySourceMemoryAccessAttribute(*this);
 // spirv.Transpose
@@ -3305,109 +2046,6 @@ LogicalResult spirv::ImageQuerySizeOp::verify() {
   return success();
-static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
-                                             OpAsmParser &parser,
-                                             OperationState &state) {
-  OpAsmParser::UnresolvedOperand ptrInfo;
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
-  Type type;
-  auto loc = parser.getCurrentLocation();
-  SmallVector<Type, 4> indicesTypes;
-  if (parser.parseOperand(ptrInfo) ||
-      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
-      parser.parseColonType(type) ||
-      parser.resolveOperand(ptrInfo, type, state.operands))
-    return failure();
-  // Check that the provided indices list is not empty before parsing their
-  // type list.
-  if (indicesInfo.empty())
-    return emitError(state.location) << opName << " expected element";
-  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
-    return failure();
-  // Check that the indices types list is not empty and that it has a one-to-one
-  // mapping to the provided indices.
-  if (indicesTypes.size() != indicesInfo.size())
-    return emitError(state.location)
-           << opName
-           << " indices types' count must be equal to indices info count";
-  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
-    return failure();
-  auto resultType = getElementPtrType(
-      type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
-  if (!resultType)
-    return failure();
-  state.addTypes(resultType);
-  return success();
-template <typename Op>
-static auto concatElemAndIndices(Op op) {
-  SmallVector<Value> ret(op.getIndices().size() + 1);
-  ret[0] = op.getElement();
-  llvm::copy(op.getIndices(), ret.begin() + 1);
-  return ret;
-// spirv.InBoundsPtrAccessChainOp
-void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
-                                            OperationState &state,
-                                            Value basePtr, Value element,
-                                            ValueRange indices) {
-  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
-  assert(type && "Unable to deduce return type based on basePtr and indices");
-  build(builder, state, type, basePtr, element, indices);
-ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
-                                                   OperationState &result) {
-  return parsePtrAccessChainOpImpl(
-      spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
-void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
-  return verifyAccessChain(*this, getIndices());
-// spirv.PtrAccessChainOp
-void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
-                                    Value basePtr, Value element,
-                                    ValueRange indices) {
-  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
-  assert(type && "Unable to deduce return type based on basePtr and indices");
-  build(builder, state, type, basePtr, element, indices);
-ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
-                                           OperationState &result) {
-  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
-                                   parser, result);
-void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
-  printAccessChain(*this, concatElemAndIndices(*this), printer);
-LogicalResult spirv::PtrAccessChainOp::verify() {
-  return verifyAccessChain(*this, getIndices());
 // spirv.VectorTimesScalarOp
@@ -3420,18 +2058,3 @@ LogicalResult spirv::VectorTimesScalarOp::verify() {
     return emitOpError("scalar operand and result element type match");
   return success();
-// TableGen'erated operation interfaces for querying versions, extensions, and
-// capabilities.
-#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
-// TablenGen'erated operation definitions.
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
-namespace mlir {
-namespace spirv {
-// TableGen'erated operation availability interface implementations.
-#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
-} // namespace spirv
-} // namespace mlir

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index 43c0beaccc0fd3..27c373300aee8e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -12,6 +12,8 @@
 #include "SPIRVParsingUtils.h"
+#include "llvm/ADT/StringExtras.h"
 using namespace mlir::spirv::AttrNames;
 namespace mlir::spirv {
@@ -45,4 +47,41 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
   return parser.parseRSquare();
+ParseResult parseVariableDecorations(OpAsmParser &parser,
+                                     OperationState &state) {
+  auto builtInName = llvm::convertToSnakeFromCamelCase(
+      stringifyDecoration(spirv::Decoration::BuiltIn));
+  if (succeeded(parser.parseOptionalKeyword("bind"))) {
+    Attribute set, binding;
+    // Parse optional descriptor binding
+    auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
+        stringifyDecoration(spirv::Decoration::DescriptorSet));
+    auto bindingName = llvm::convertToSnakeFromCamelCase(
+        stringifyDecoration(spirv::Decoration::Binding));
+    Type i32Type = parser.getBuilder().getIntegerType(32);
+    if (parser.parseLParen() ||
+        parser.parseAttribute(set, i32Type, descriptorSetName,
+                              state.attributes) ||
+        parser.parseComma() ||
+        parser.parseAttribute(binding, i32Type, bindingName,
+                              state.attributes) ||
+        parser.parseRParen()) {
+      return failure();
+    }
+  } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
+    StringAttr builtIn;
+    if (parser.parseLParen() ||
+        parser.parseAttribute(builtIn, builtInName, state.attributes) ||
+        parser.parseRParen()) {
+      return failure();
+    }
+  }
+  // Parse other attributes
+  if (parser.parseOptionalAttrDict(state.attributes))
+    return failure();
+  return success();
 } // namespace mlir::spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index fd2faf4b7b333f..625c82f6e8e899 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -153,4 +153,7 @@ ParseResult parseMemoryAccessAttributes(
     OpAsmParser &parser, OperationState &state,
     StringRef attrName = AttrNames::kMemoryAccessAttrName);
+ParseResult parseVariableDecorations(OpAsmParser &parser,
+                                     OperationState &state);
 } // namespace mlir::spirv


More information about the Mlir-commits mailing list