[Mlir-commits] [mlir] 81b4e7d - [mlir][spirv] Extract more ops from the main implementation file. NFC.
Jakub Kuderski
llvmlistbot at llvm.org
Thu Jul 20 14:13:56 PDT 2023
Author: varconst
Date: 2023-07-20T17:11:32-04:00
New Revision: 81b4e7d2b0e1d222c76637f600cfcb74b631dfca
URL: https://github.com/llvm/llvm-project/commit/81b4e7d2b0e1d222c76637f600cfcb74b631dfca
DIFF: https://github.com/llvm/llvm-project/commit/81b4e7d2b0e1d222c76637f600cfcb74b631dfca.diff
LOG: [mlir][spirv] Extract more ops from the main implementation file. NFC.
Continue to work outlined in D155747 and split the main SPIR-V ops
implementation file into a few smaller and quicker to compile files.
Move control flow and memory ops to their own implementation files.
Create new `.cpp` files for tablegened code.
After this change, the `SPIRVOps.cpp` is 2k LoC-long and takes a
reasonable amount of time to compile.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D155883
Added:
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
Modified:
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index d36e2ad8a73e85..f985bdd33e26ff 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -5,15 +5,19 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
add_mlir_dialect_library(MLIRSPIRVDialect
AtomicOps.cpp
CastOps.cpp
+ ControlFlowOps.cpp
CooperativeMatrixOps.cpp
GroupOps.cpp
IntegerDotProductOps.cpp
JointMatrixOps.cpp
+ MemoryOps.cpp
SPIRVAttributes.cpp
SPIRVCanonicalization.cpp
SPIRVGLCanonicalization.cpp
SPIRVDialect.cpp
SPIRVEnums.cpp
+ SPIRVOpAvailability.cpp
+ SPIRVOpDefinition.cpp
SPIRVOps.cpp
SPIRVParsingUtils.cpp
SPIRVTypes.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
new file mode 100644
index 00000000000000..e169cb5b65322e
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -0,0 +1,562 @@
+//===- ControlFlowOps.cpp - MLIR SPIR-V Control Flow Ops -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the control flow operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+
+/// Parses Function, Selection and Loop control attributes. If no control is
+/// specified, "None" is used as a default.
+template <typename EnumAttrClass, typename EnumClass>
+static ParseResult
+parseControlAttribute(OpAsmParser &parser, OperationState &state,
+ StringRef attrName = spirv::attributeName<EnumClass>()) {
+ if (succeeded(parser.parseOptionalKeyword(kControl))) {
+ EnumClass control;
+ if (parser.parseLParen() ||
+ spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
+ parser.parseRParen())
+ return failure();
+ return success();
+ }
+ // Set control to "None" otherwise.
+ Builder builder = parser.getBuilder();
+ state.addAttribute(attrName,
+ builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BranchOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
+ assert(index == 0 && "invalid successor index");
+ return SuccessorOperands(0, getTargetOperandsMutable());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BranchConditionalOp
+//===----------------------------------------------------------------------===//
+
+SuccessorOperands BranchConditionalOp::getSuccessorOperands(unsigned index) {
+ assert(index < 2 && "invalid successor index");
+ return SuccessorOperands(index == kTrueIndex
+ ? getTrueTargetOperandsMutable()
+ : getFalseTargetOperandsMutable());
+}
+
+ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+ OpAsmParser::UnresolvedOperand condInfo;
+ Block *dest;
+
+ // Parse the condition.
+ Type boolTy = builder.getI1Type();
+ if (parser.parseOperand(condInfo) ||
+ parser.resolveOperand(condInfo, boolTy, result.operands))
+ return failure();
+
+ // Parse the optional branch weights.
+ if (succeeded(parser.parseOptionalLSquare())) {
+ IntegerAttr trueWeight, falseWeight;
+ NamedAttrList weights;
+
+ auto i32Type = builder.getIntegerType(32);
+ if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
+ parser.parseComma() ||
+ parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
+ parser.parseRSquare())
+ return failure();
+
+ result.addAttribute(kBranchWeightAttrName,
+ builder.getArrayAttr({trueWeight, falseWeight}));
+ }
+
+ // Parse the true branch.
+ SmallVector<Value, 4> trueOperands;
+ if (parser.parseComma() ||
+ parser.parseSuccessorAndUseList(dest, trueOperands))
+ return failure();
+ result.addSuccessors(dest);
+ result.addOperands(trueOperands);
+
+ // Parse the false branch.
+ SmallVector<Value, 4> falseOperands;
+ if (parser.parseComma() ||
+ parser.parseSuccessorAndUseList(dest, falseOperands))
+ return failure();
+ result.addSuccessors(dest);
+ result.addOperands(falseOperands);
+ result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr(
+ {1, static_cast<int32_t>(trueOperands.size()),
+ static_cast<int32_t>(falseOperands.size())}));
+
+ return success();
+}
+
+void BranchConditionalOp::print(OpAsmPrinter &printer) {
+ printer << ' ' << getCondition();
+
+ if (auto weights = getBranchWeights()) {
+ printer << " [";
+ llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
+ printer << llvm::cast<IntegerAttr>(a).getInt();
+ });
+ printer << "]";
+ }
+
+ printer << ", ";
+ printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
+ printer << ", ";
+ printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
+}
+
+LogicalResult BranchConditionalOp::verify() {
+ if (auto weights = getBranchWeights()) {
+ if (weights->getValue().size() != 2) {
+ return emitOpError("must have exactly two branch weights");
+ }
+ if (llvm::all_of(*weights, [](Attribute attr) {
+ return llvm::cast<IntegerAttr>(attr).getValue().isZero();
+ }))
+ return emitOpError("branch weights cannot both be zero");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.FunctionCall
+//===----------------------------------------------------------------------===//
+
+LogicalResult FunctionCallOp::verify() {
+ auto fnName = getCalleeAttr();
+
+ auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
+ SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
+ if (!funcOp) {
+ return emitOpError("callee function '")
+ << fnName.getValue() << "' not found in nearest symbol table";
+ }
+
+ auto functionType = funcOp.getFunctionType();
+
+ if (getNumResults() > 1) {
+ return emitOpError(
+ "expected callee function to have 0 or 1 result, but provided ")
+ << getNumResults();
+ }
+
+ if (functionType.getNumInputs() != getNumOperands()) {
+ return emitOpError("has incorrect number of operands for callee: expected ")
+ << functionType.getNumInputs() << ", but provided "
+ << getNumOperands();
+ }
+
+ for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
+ if (getOperand(i).getType() != functionType.getInput(i)) {
+ return emitOpError("operand type mismatch: expected operand type ")
+ << functionType.getInput(i) << ", but provided "
+ << getOperand(i).getType() << " for operand number " << i;
+ }
+ }
+
+ if (functionType.getNumResults() != getNumResults()) {
+ return emitOpError(
+ "has incorrect number of results has for callee: expected ")
+ << functionType.getNumResults() << ", but provided "
+ << getNumResults();
+ }
+
+ if (getNumResults() &&
+ (getResult(0).getType() != functionType.getResult(0))) {
+ return emitOpError("result type mismatch: expected ")
+ << functionType.getResult(0) << ", but provided "
+ << getResult(0).getType();
+ }
+
+ return success();
+}
+
+CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
+ return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
+}
+
+void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+ (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+}
+
+Operation::operand_range FunctionCallOp::getArgOperands() {
+ return getArguments();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.mlir.loop
+//===----------------------------------------------------------------------===//
+
+void LoopOp::build(OpBuilder &builder, OperationState &state) {
+ state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
+ spirv::LoopControl::None));
+ state.addRegion();
+}
+
+ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
+ if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
+ result))
+ return failure();
+ return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
+}
+
+void LoopOp::print(OpAsmPrinter &printer) {
+ auto control = getLoopControl();
+ if (control != spirv::LoopControl::None)
+ printer << " control(" << spirv::stringifyLoopControl(control) << ")";
+ printer << ' ';
+ printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+}
+
+/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
+/// given `dstBlock`.
+static bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
+ // Check that there is only one op in the `srcBlock`.
+ if (!llvm::hasSingleElement(srcBlock))
+ return false;
+
+ auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
+ return branchOp && branchOp.getSuccessor() == &dstBlock;
+}
+
+/// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
+static bool isMergeBlock(Block &block) {
+ return !block.empty() && std::next(block.begin()) == block.end() &&
+ isa<spirv::MergeOp>(block.front());
+}
+
+LogicalResult LoopOp::verifyRegions() {
+ auto *op = getOperation();
+
+ // We need to verify that the blocks follow the following layout:
+ //
+ // +-------------+
+ // | entry block |
+ // +-------------+
+ // |
+ // v
+ // +-------------+
+ // | loop header | <-----+
+ // +-------------+ |
+ // |
+ // ... |
+ // \ | / |
+ // v |
+ // +---------------+ |
+ // | loop continue | -----+
+ // +---------------+
+ //
+ // ...
+ // \ | /
+ // v
+ // +-------------+
+ // | merge block |
+ // +-------------+
+
+ auto ®ion = op->getRegion(0);
+ // Allow empty region as a degenerated case, which can come from
+ // optimizations.
+ if (region.empty())
+ return success();
+
+ // The last block is the merge block.
+ Block &merge = region.back();
+ if (!isMergeBlock(merge))
+ return emitOpError("last block must be the merge block with only one "
+ "'spirv.mlir.merge' op");
+
+ if (std::next(region.begin()) == region.end())
+ return emitOpError(
+ "must have an entry block branching to the loop header block");
+ // The first block is the entry block.
+ Block &entry = region.front();
+
+ if (std::next(region.begin(), 2) == region.end())
+ return emitOpError(
+ "must have a loop header block branched from the entry block");
+ // The second block is the loop header block.
+ Block &header = *std::next(region.begin(), 1);
+
+ if (!hasOneBranchOpTo(entry, header))
+ return emitOpError(
+ "entry block must only have one 'spirv.Branch' op to the second block");
+
+ if (std::next(region.begin(), 3) == region.end())
+ return emitOpError(
+ "requires a loop continue block branching to the loop header block");
+ // The second to last block is the loop continue block.
+ Block &cont = *std::prev(region.end(), 2);
+
+ // Make sure that we have a branch from the loop continue block to the loop
+ // header block.
+ if (llvm::none_of(
+ llvm::seq<unsigned>(0, cont.getNumSuccessors()),
+ [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
+ return emitOpError("second to last block must be the loop continue "
+ "block that branches to the loop header block");
+
+ // Make sure that no other blocks (except the entry and loop continue block)
+ // branches to the loop header block.
+ for (auto &block : llvm::make_range(std::next(region.begin(), 2),
+ std::prev(region.end(), 2))) {
+ for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
+ if (block.getSuccessor(i) == &header) {
+ return emitOpError("can only have the entry and loop continue "
+ "block branching to the loop header block");
+ }
+ }
+ }
+
+ return success();
+}
+
+Block *LoopOp::getEntryBlock() {
+ assert(!getBody().empty() && "op region should not be empty!");
+ return &getBody().front();
+}
+
+Block *LoopOp::getHeaderBlock() {
+ assert(!getBody().empty() && "op region should not be empty!");
+ // The second block is the loop header block.
+ return &*std::next(getBody().begin());
+}
+
+Block *LoopOp::getContinueBlock() {
+ assert(!getBody().empty() && "op region should not be empty!");
+ // The second to last block is the loop continue block.
+ return &*std::prev(getBody().end(), 2);
+}
+
+Block *LoopOp::getMergeBlock() {
+ assert(!getBody().empty() && "op region should not be empty!");
+ // The last block is the loop merge block.
+ return &getBody().back();
+}
+
+void LoopOp::addEntryAndMergeBlock() {
+ assert(getBody().empty() && "entry and merge block already exist");
+ getBody().push_back(new Block());
+ auto *mergeBlock = new Block();
+ getBody().push_back(mergeBlock);
+ OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
+
+ // Add a spirv.mlir.merge op into the merge block.
+ builder.create<spirv::MergeOp>(getLoc());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.mlir.merge
+//===----------------------------------------------------------------------===//
+
+LogicalResult MergeOp::verify() {
+ auto *parentOp = (*this)->getParentOp();
+ if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
+ return emitOpError(
+ "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
+
+ // TODO: This check should be done in `verifyRegions` of parent op.
+ Block &parentLastBlock = (*this)->getParentRegion()->back();
+ if (getOperation() != parentLastBlock.getTerminator())
+ return emitOpError("can only be used in the last block of "
+ "'spirv.mlir.selection' or 'spirv.mlir.loop'");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.Return
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReturnOp::verify() {
+ // Verification is performed in spirv.func op.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ReturnValue
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReturnValueOp::verify() {
+ // Verification is performed in spirv.func op.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.Select
+//===----------------------------------------------------------------------===//
+
+LogicalResult SelectOp::verify() {
+ if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
+ auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
+ if (!resultVectorTy) {
+ return emitOpError("result expected to be of vector type when "
+ "condition is of vector type");
+ }
+ if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
+ return emitOpError("result should have the same number of elements as "
+ "the condition when condition is of vector type");
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.mlir.selection
+//===----------------------------------------------------------------------===//
+
+ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
+ if (parseControlAttribute<spirv::SelectionControlAttr,
+ spirv::SelectionControl>(parser, result))
+ return failure();
+ return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
+}
+
+void SelectionOp::print(OpAsmPrinter &printer) {
+ auto control = getSelectionControl();
+ if (control != spirv::SelectionControl::None)
+ printer << " control(" << spirv::stringifySelectionControl(control) << ")";
+ printer << ' ';
+ printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+}
+
+LogicalResult SelectionOp::verifyRegions() {
+ auto *op = getOperation();
+
+ // We need to verify that the blocks follow the following layout:
+ //
+ // +--------------+
+ // | header block |
+ // +--------------+
+ // / | \
+ // ...
+ //
+ //
+ // +---------+ +---------+ +---------+
+ // | case #0 | | case #1 | | case #2 | ...
+ // +---------+ +---------+ +---------+
+ //
+ //
+ // ...
+ // \ | /
+ // v
+ // +-------------+
+ // | merge block |
+ // +-------------+
+
+ auto ®ion = op->getRegion(0);
+ // Allow empty region as a degenerated case, which can come from
+ // optimizations.
+ if (region.empty())
+ return success();
+
+ // The last block is the merge block.
+ if (!isMergeBlock(region.back()))
+ return emitOpError("last block must be the merge block with only one "
+ "'spirv.mlir.merge' op");
+
+ if (std::next(region.begin()) == region.end())
+ return emitOpError("must have a selection header block");
+
+ return success();
+}
+
+Block *SelectionOp::getHeaderBlock() {
+ assert(!getBody().empty() && "op region should not be empty!");
+ // The first block is the loop header block.
+ return &getBody().front();
+}
+
+Block *SelectionOp::getMergeBlock() {
+ assert(!getBody().empty() && "op region should not be empty!");
+ // The last block is the loop merge block.
+ return &getBody().back();
+}
+
+void SelectionOp::addMergeBlock() {
+ assert(getBody().empty() && "entry and merge block already exist");
+ auto *mergeBlock = new Block();
+ getBody().push_back(mergeBlock);
+ OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
+
+ // Add a spirv.mlir.merge op into the merge block.
+ builder.create<spirv::MergeOp>(getLoc());
+}
+
+SelectionOp
+SelectionOp::createIfThen(Location loc, Value condition,
+ function_ref<void(OpBuilder &builder)> thenBody,
+ OpBuilder &builder) {
+ auto selectionOp =
+ builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+
+ selectionOp.addMergeBlock();
+ Block *mergeBlock = selectionOp.getMergeBlock();
+ Block *thenBlock = nullptr;
+
+ // Build the "then" block.
+ {
+ OpBuilder::InsertionGuard guard(builder);
+ thenBlock = builder.createBlock(mergeBlock);
+ thenBody(builder);
+ builder.create<spirv::BranchOp>(loc, mergeBlock);
+ }
+
+ // Build the header block.
+ {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.createBlock(thenBlock);
+ builder.create<spirv::BranchConditionalOp>(
+ loc, condition, thenBlock,
+ /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
+ /*falseArguments=*/ArrayRef<Value>());
+ }
+
+ return selectionOp;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.Unreachable
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::UnreachableOp::verify() {
+ auto *block = (*this)->getBlock();
+ // Fast track: if this is in entry block, its invalid. Otherwise, if no
+ // predecessors, it's valid.
+ if (block->isEntryBlock())
+ return emitOpError("cannot be used in reachable block");
+ if (block->hasNoPredecessors())
+ return success();
+
+ // TODO: further verification needs to analyze reachability from
+ // the entry block.
+
+ return success();
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
new file mode 100644
index 00000000000000..6ee162a761e21e
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -0,0 +1,751 @@
+//===- MemoryOps.cpp - MLIR SPIR-V Memory Ops ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the memory operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+
+#include "llvm/ADT/StringExtras.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+
+// TODO Make sure to merge this and the previous function into one template
+// parameterized by memory access attribute name and alignment. Doing so now
+// results in VS2017 in producing an internal error (at the call site) that's
+// not detailed enough to understand what is happening.
+static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
+ OperationState &state) {
+ // Parse an optional list of attributes staring with '['
+ if (parser.parseOptionalLSquare()) {
+ // Nothing to do
+ return success();
+ }
+
+ spirv::MemoryAccess memoryAccessAttr;
+ if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
+ memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
+ return failure();
+
+ if (spirv::bitEnumContainsAll(memoryAccessAttr,
+ spirv::MemoryAccess::Aligned)) {
+ // Parse integer attribute for alignment.
+ Attribute alignmentAttr;
+ Type i32Type = parser.getBuilder().getIntegerType(32);
+ if (parser.parseComma() ||
+ parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+ state.attributes)) {
+ return failure();
+ }
+ }
+ return parser.parseRSquare();
+}
+
+// TODO Make sure to merge this and the previous function into one template
+// parameterized by memory access attribute name and alignment. Doing so now
+// results in VS2017 in producing an internal error (at the call site) that's
+// not detailed enough to understand what is happening.
+template <typename MemoryOpTy>
+static void printSourceMemoryAccessAttribute(
+ MemoryOpTy memoryOp, OpAsmPrinter &printer,
+ SmallVectorImpl<StringRef> &elidedAttrs,
+ std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
+ std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
+
+ printer << ", ";
+
+ // Print optional memory access attribute.
+ if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
+ : memoryOp.getMemoryAccess())) {
+ elidedAttrs.push_back(kSourceMemoryAccessAttrName);
+
+ printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
+
+ if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
+ // Print integer alignment attribute.
+ if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
+ : memoryOp.getAlignment())) {
+ elidedAttrs.push_back(kSourceAlignmentAttrName);
+ printer << ", " << *alignment;
+ }
+ }
+ printer << "]";
+ }
+ elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
+}
+
+template <typename MemoryOpTy>
+static void printMemoryAccessAttribute(
+ MemoryOpTy memoryOp, OpAsmPrinter &printer,
+ SmallVectorImpl<StringRef> &elidedAttrs,
+ std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
+ std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
+ // Print optional memory access attribute.
+ if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
+ : memoryOp.getMemoryAccess())) {
+ elidedAttrs.push_back(kMemoryAccessAttrName);
+
+ printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
+
+ if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
+ // Print integer alignment attribute.
+ if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
+ : memoryOp.getAlignment())) {
+ elidedAttrs.push_back(kAlignmentAttrName);
+ printer << ", " << *alignment;
+ }
+ }
+ printer << "]";
+ }
+ elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
+}
+
+template <typename LoadStoreOpTy>
+static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
+ Value val) {
+ // ODS already checks ptr is spirv::PointerType. Just check that the pointee
+ // type of the pointer and the type of the value are the same
+ //
+ // TODO: Check that the value type satisfies restrictions of
+ // SPIR-V OpLoad/OpStore operations
+ if (val.getType() !=
+ llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
+ return op.emitOpError("mismatch in result type and pointer type");
+ }
+ return success();
+}
+
+template <typename MemoryOpTy>
+static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
+ // ODS checks for attributes values. Just need to verify that if the
+ // memory-access attribute is Aligned, then the alignment attribute must be
+ // present.
+ auto *op = memoryOp.getOperation();
+ auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
+ if (!memAccessAttr) {
+ // Alignment attribute shouldn't be present if memory access attribute is
+ // not present.
+ if (op->getAttr(kAlignmentAttrName)) {
+ return memoryOp.emitOpError(
+ "invalid alignment specification without aligned memory access "
+ "specification");
+ }
+ return success();
+ }
+
+ auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
+
+ if (!memAccess) {
+ return memoryOp.emitOpError("invalid memory access specifier: ")
+ << memAccessAttr;
+ }
+
+ if (spirv::bitEnumContainsAll(memAccess.getValue(),
+ spirv::MemoryAccess::Aligned)) {
+ if (!op->getAttr(kAlignmentAttrName)) {
+ return memoryOp.emitOpError("missing alignment value");
+ }
+ } else {
+ if (op->getAttr(kAlignmentAttrName)) {
+ return memoryOp.emitOpError(
+ "invalid alignment specification with non-aligned memory access "
+ "specification");
+ }
+ }
+ return success();
+}
+
+// TODO Make sure to merge this and the previous function into one template
+// parameterized by memory access attribute name and alignment. Doing so now
+// results in VS2017 in producing an internal error (at the call site) that's
+// not detailed enough to understand what is happening.
+template <typename MemoryOpTy>
+static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
+ // ODS checks for attributes values. Just need to verify that if the
+ // memory-access attribute is Aligned, then the alignment attribute must be
+ // present.
+ auto *op = memoryOp.getOperation();
+ auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
+ if (!memAccessAttr) {
+ // Alignment attribute shouldn't be present if memory access attribute is
+ // not present.
+ if (op->getAttr(kSourceAlignmentAttrName)) {
+ return memoryOp.emitOpError(
+ "invalid alignment specification without aligned memory access "
+ "specification");
+ }
+ return success();
+ }
+
+ auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
+
+ if (!memAccess) {
+ return memoryOp.emitOpError("invalid memory access specifier: ")
+ << memAccess;
+ }
+
+ if (spirv::bitEnumContainsAll(memAccess.getValue(),
+ spirv::MemoryAccess::Aligned)) {
+ if (!op->getAttr(kSourceAlignmentAttrName)) {
+ return memoryOp.emitOpError("missing alignment value");
+ }
+ } else {
+ if (op->getAttr(kSourceAlignmentAttrName)) {
+ return memoryOp.emitOpError(
+ "invalid alignment specification with non-aligned memory access "
+ "specification");
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.AccessChainOp
+//===----------------------------------------------------------------------===//
+
+static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+ if (!ptrType) {
+ emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
+ "to composite type, but provided ")
+ << type;
+ return nullptr;
+ }
+
+ auto resultType = ptrType.getPointeeType();
+ auto resultStorageClass = ptrType.getStorageClass();
+ int32_t index = 0;
+
+ for (auto indexSSA : indices) {
+ auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
+ if (!cType) {
+ emitError(
+ baseLoc,
+ "'spirv.AccessChain' op cannot extract from non-composite type ")
+ << resultType << " with index " << index;
+ return nullptr;
+ }
+ index = 0;
+ if (llvm::isa<spirv::StructType>(resultType)) {
+ Operation *op = indexSSA.getDefiningOp();
+ if (!op) {
+ emitError(baseLoc, "'spirv.AccessChain' op index must be an "
+ "integer spirv.Constant to access "
+ "element of spirv.struct");
+ return nullptr;
+ }
+
+ // TODO: this should be relaxed to allow
+ // integer literals of other bitwidths.
+ if (failed(spirv::extractValueFromConstOp(op, index))) {
+ emitError(
+ baseLoc,
+ "'spirv.AccessChain' index must be an integer spirv.Constant to "
+ "access element of spirv.struct, but provided ")
+ << op->getName();
+ return nullptr;
+ }
+ if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
+ emitError(baseLoc, "'spirv.AccessChain' op index ")
+ << index << " out of bounds for " << resultType;
+ return nullptr;
+ }
+ }
+ resultType = cType.getElementType(index);
+ }
+ return spirv::PointerType::get(resultType, resultStorageClass);
+}
+
+void AccessChainOp::build(OpBuilder &builder, OperationState &state,
+ Value basePtr, ValueRange indices) {
+ auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+ assert(type && "Unable to deduce return type based on basePtr and indices");
+ build(builder, state, type, basePtr, indices);
+}
+
+ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::UnresolvedOperand ptrInfo;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
+ Type type;
+ auto loc = parser.getCurrentLocation();
+ SmallVector<Type, 4> indicesTypes;
+
+ if (parser.parseOperand(ptrInfo) ||
+ parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(ptrInfo, type, result.operands)) {
+ return failure();
+ }
+
+ // Check that the provided indices list is not empty before parsing their
+ // type list.
+ if (indicesInfo.empty()) {
+ return mlir::emitError(result.location,
+ "'spirv.AccessChain' op expected at "
+ "least one index ");
+ }
+
+ if (parser.parseComma() || parser.parseTypeList(indicesTypes))
+ return failure();
+
+ // Check that the indices types list is not empty and that it has a one-to-one
+ // mapping to the provided indices.
+ if (indicesTypes.size() != indicesInfo.size()) {
+ return mlir::emitError(
+ result.location, "'spirv.AccessChain' op indices types' count must be "
+ "equal to indices info count");
+ }
+
+ if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
+ return failure();
+
+ auto resultType = getElementPtrType(
+ type, llvm::ArrayRef(result.operands).drop_front(), result.location);
+ if (!resultType) {
+ return failure();
+ }
+
+ result.addTypes(resultType);
+ return success();
+}
+
+template <typename Op>
+static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
+ printer << ' ' << op.getBasePtr() << '[' << indices
+ << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
+}
+
+void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
+ printAccessChain(*this, getIndices(), printer);
+}
+
+template <typename Op>
+static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
+ auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
+ indices, accessChainOp.getLoc());
+ if (!resultType)
+ return failure();
+
+ auto providedResultType =
+ llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
+ if (!providedResultType)
+ return accessChainOp.emitOpError(
+ "result type must be a pointer, but provided")
+ << providedResultType;
+
+ if (resultType != providedResultType)
+ return accessChainOp.emitOpError("invalid result type: expected ")
+ << resultType << ", but provided " << providedResultType;
+
+ return success();
+}
+
+LogicalResult AccessChainOp::verify() {
+ return verifyAccessChain(*this, getIndices());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.LoadOp
+//===----------------------------------------------------------------------===//
+
+void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
+ MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
+ auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
+ build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
+ alignment);
+}
+
+ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Parse the storage class specification
+ spirv::StorageClass storageClass;
+ OpAsmParser::UnresolvedOperand ptrInfo;
+ Type elementType;
+ if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
+ parseMemoryAccessAttributes(parser, result) ||
+ parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
+ parser.parseType(elementType)) {
+ return failure();
+ }
+
+ auto ptrType = spirv::PointerType::get(elementType, storageClass);
+ if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
+ return failure();
+ }
+
+ result.addTypes(elementType);
+ return success();
+}
+
+void LoadOp::print(OpAsmPrinter &printer) {
+ SmallVector<StringRef, 4> elidedAttrs;
+ StringRef sc = stringifyStorageClass(
+ llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
+ printer << " \"" << sc << "\" " << getPtr();
+
+ printMemoryAccessAttribute(*this, printer, elidedAttrs);
+
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+ printer << " : " << getType();
+}
+
+LogicalResult LoadOp::verify() {
+ // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
+ // type with fixed size; i.e., it cannot be, nor include, any
+ // OpTypeRuntimeArray types."
+ if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
+ return failure();
+ }
+ return verifyMemoryAccessAttribute(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.StoreOp
+//===----------------------------------------------------------------------===//
+
+ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Parse the storage class specification
+ spirv::StorageClass storageClass;
+ SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
+ auto loc = parser.getCurrentLocation();
+ Type elementType;
+ if (parseEnumStrAttr(storageClass, parser) ||
+ parser.parseOperandList(operandInfo, 2) ||
+ parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
+ parser.parseType(elementType)) {
+ return failure();
+ }
+
+ auto ptrType = spirv::PointerType::get(elementType, storageClass);
+ if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
+ result.operands)) {
+ return failure();
+ }
+ return success();
+}
+
+void StoreOp::print(OpAsmPrinter &printer) {
+ SmallVector<StringRef, 4> elidedAttrs;
+ StringRef sc = stringifyStorageClass(
+ llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
+ printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
+
+ printMemoryAccessAttribute(*this, printer, elidedAttrs);
+
+ printer << " : " << getValue().getType();
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+}
+
+LogicalResult StoreOp::verify() {
+ // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
+ // OpTypePointer whose Type operand is the same as the type of Object."
+ if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
+ return failure();
+ return verifyMemoryAccessAttribute(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.CopyMemory
+//===----------------------------------------------------------------------===//
+
+void CopyMemoryOp::print(OpAsmPrinter &printer) {
+ printer << ' ';
+
+ StringRef targetStorageClass = stringifyStorageClass(
+ llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
+ printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
+
+ StringRef sourceStorageClass = stringifyStorageClass(
+ llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
+ printer << " \"" << sourceStorageClass << "\" " << getSource();
+
+ SmallVector<StringRef, 4> elidedAttrs;
+ printMemoryAccessAttribute(*this, printer, elidedAttrs);
+ printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
+ getSourceMemoryAccess(),
+ getSourceAlignment());
+
+ printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+
+ Type pointeeType =
+ llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
+ printer << " : " << pointeeType;
+}
+
+ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
+ spirv::StorageClass targetStorageClass;
+ OpAsmParser::UnresolvedOperand targetPtrInfo;
+
+ spirv::StorageClass sourceStorageClass;
+ OpAsmParser::UnresolvedOperand sourcePtrInfo;
+
+ Type elementType;
+
+ if (parseEnumStrAttr(targetStorageClass, parser) ||
+ parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
+ parseEnumStrAttr(sourceStorageClass, parser) ||
+ parser.parseOperand(sourcePtrInfo) ||
+ parseMemoryAccessAttributes(parser, result)) {
+ return failure();
+ }
+
+ if (!parser.parseOptionalComma()) {
+ // Parse 2nd memory access attributes.
+ if (parseSourceMemoryAccessAttributes(parser, result)) {
+ return failure();
+ }
+ }
+
+ if (parser.parseColon() || parser.parseType(elementType))
+ return failure();
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
+ auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
+
+ if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
+ parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
+ return failure();
+ }
+
+ return success();
+}
+
+LogicalResult CopyMemoryOp::verify() {
+ Type targetType =
+ llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
+
+ Type sourceType =
+ llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
+
+ if (targetType != sourceType)
+ return emitOpError("both operands must be pointers to the same type");
+
+ if (failed(verifyMemoryAccessAttribute(*this)))
+ return failure();
+
+ // TODO - According to the spec:
+ //
+ // If two masks are present, the first applies to Target and cannot include
+ // MakePointerVisible, and the second applies to Source and cannot include
+ // MakePointerAvailable.
+ //
+ // Add such verification here.
+
+ return verifySourceMemoryAccessAttribute(*this);
+}
+
+static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
+ OpAsmParser &parser,
+ OperationState &state) {
+ OpAsmParser::UnresolvedOperand ptrInfo;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
+ Type type;
+ auto loc = parser.getCurrentLocation();
+ SmallVector<Type, 4> indicesTypes;
+
+ if (parser.parseOperand(ptrInfo) ||
+ parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
+ parser.parseColonType(type) ||
+ parser.resolveOperand(ptrInfo, type, state.operands))
+ return failure();
+
+ // Check that the provided indices list is not empty before parsing their
+ // type list.
+ if (indicesInfo.empty())
+ return emitError(state.location) << opName << " expected element";
+
+ if (parser.parseComma() || parser.parseTypeList(indicesTypes))
+ return failure();
+
+ // Check that the indices types list is not empty and that it has a one-to-one
+ // mapping to the provided indices.
+ if (indicesTypes.size() != indicesInfo.size())
+ return emitError(state.location)
+ << opName
+ << " indices types' count must be equal to indices info count";
+
+ if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
+ return failure();
+
+ auto resultType = getElementPtrType(
+ type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
+ if (!resultType)
+ return failure();
+
+ state.addTypes(resultType);
+ return success();
+}
+
+template <typename Op>
+static auto concatElemAndIndices(Op op) {
+ SmallVector<Value> ret(op.getIndices().size() + 1);
+ ret[0] = op.getElement();
+ llvm::copy(op.getIndices(), ret.begin() + 1);
+ return ret;
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.InBoundsPtrAccessChainOp
+//===----------------------------------------------------------------------===//
+
+void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
+ Value basePtr, Value element,
+ ValueRange indices) {
+ auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+ assert(type && "Unable to deduce return type based on basePtr and indices");
+ build(builder, state, type, basePtr, element, indices);
+}
+
+ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parsePtrAccessChainOpImpl(
+ spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
+}
+
+void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
+ printAccessChain(*this, concatElemAndIndices(*this), printer);
+}
+
+LogicalResult InBoundsPtrAccessChainOp::verify() {
+ return verifyAccessChain(*this, getIndices());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.PtrAccessChainOp
+//===----------------------------------------------------------------------===//
+
+void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
+ Value basePtr, Value element, ValueRange indices) {
+ auto type = getElementPtrType(basePtr.getType(), indices, state.location);
+ assert(type && "Unable to deduce return type based on basePtr and indices");
+ build(builder, state, type, basePtr, element, indices);
+}
+
+ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
+ parser, result);
+}
+
+void PtrAccessChainOp::print(OpAsmPrinter &printer) {
+ printAccessChain(*this, concatElemAndIndices(*this), printer);
+}
+
+LogicalResult PtrAccessChainOp::verify() {
+ return verifyAccessChain(*this, getIndices());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.Variable
+//===----------------------------------------------------------------------===//
+
+ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Parse optional initializer
+ std::optional<OpAsmParser::UnresolvedOperand> initInfo;
+ if (succeeded(parser.parseOptionalKeyword("init"))) {
+ initInfo = OpAsmParser::UnresolvedOperand();
+ if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
+ parser.parseRParen())
+ return failure();
+ }
+
+ if (parseVariableDecorations(parser, result)) {
+ return failure();
+ }
+
+ // Parse result pointer type
+ Type type;
+ if (parser.parseColon())
+ return failure();
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseType(type))
+ return failure();
+
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+ if (!ptrType)
+ return parser.emitError(loc, "expected spirv.ptr type");
+ result.addTypes(ptrType);
+
+ // Resolve the initializer operand
+ if (initInfo) {
+ if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
+ result.operands))
+ return failure();
+ }
+
+ auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
+ ptrType.getStorageClass());
+ result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
+
+ return success();
+}
+
+void VariableOp::print(OpAsmPrinter &printer) {
+ SmallVector<StringRef, 4> elidedAttrs{
+ spirv::attributeName<spirv::StorageClass>()};
+ // Print optional initializer
+ if (getNumOperands() != 0)
+ printer << " init(" << getInitializer() << ")";
+
+ printVariableDecorations(*this, printer, elidedAttrs);
+ printer << " : " << getType();
+}
+
+LogicalResult VariableOp::verify() {
+ // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
+ // object. It cannot be Generic. It must be the same as the Storage Class
+ // operand of the Result Type."
+ if (getStorageClass() != spirv::StorageClass::Function) {
+ return emitOpError(
+ "can only be used to model function-level variables. Use "
+ "spirv.GlobalVariable for module-level variables.");
+ }
+
+ auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ if (getStorageClass() != pointerType.getStorageClass())
+ return emitOpError(
+ "storage class must match result pointer's storage class");
+
+ if (getNumOperands() != 0) {
+ // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
+ // a global (module scope) OpVariable instruction".
+ auto *initOp = getOperand(0).getDefiningOp();
+ if (!initOp || !isa<spirv::ConstantOp, // for normal constant
+ spirv::ReferenceOfOp, // for spec constant
+ spirv::AddressOfOp>(initOp))
+ return emitOpError("initializer must be the result of a "
+ "constant or spirv.GlobalVariable op");
+ }
+
+ // TODO: generate these strings using ODS.
+ auto *op = getOperation();
+ auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
+ stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingName = llvm::convertToSnakeFromCamelCase(
+ stringifyDecoration(spirv::Decoration::Binding));
+ auto builtInName = llvm::convertToSnakeFromCamelCase(
+ stringifyDecoration(spirv::Decoration::BuiltIn));
+
+ for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
+ if (op->getAttr(attr))
+ return emitOpError("cannot have '")
+ << attr << "' attribute (only allowed in spirv.GlobalVariable)";
+ }
+
+ return success();
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp
new file mode 100644
index 00000000000000..93c27c8701cabb
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpAvailability.cpp
@@ -0,0 +1,22 @@
+//===- SPIRVOpAvailability.cpp - MLIR SPIR-V Availability Implementation --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the SPIR-V operation availability in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+// TableGen'erated operation interfaces for querying versions, extensions, and
+// capabilities.
+#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
+
+namespace mlir::spirv {
+// TableGen'erated operation availability interface implementations.
+#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
new file mode 100644
index 00000000000000..d8dfe164458e29
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -0,0 +1,76 @@
+//===- SPIRVOpDefinition.cpp - MLIR SPIR-V Op Definition Implementation ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the TableGen'erated SPIR-V op implementation in the SPIR-V dialect.
+// These are placed in a separate file to reduce the total amount of code in
+// SPIRVOps.cpp and make that file faster to recompile.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVParsingUtils.h"
+
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir::spirv {
+/// Returns true if the given op is a function-like op or nested in a
+/// function-like op without a module-like op in the middle.
+static bool isNestedInFunctionOpInterface(Operation *op) {
+ if (!op)
+ return false;
+ if (op->hasTrait<OpTrait::SymbolTable>())
+ return false;
+ if (isa<FunctionOpInterface>(op))
+ return true;
+ return isNestedInFunctionOpInterface(op->getParentOp());
+}
+
+/// Returns true if the given op is an module-like op that maintains a symbol
+/// table.
+static bool isDirectInModuleLikeOp(Operation *op) {
+ return op && op->hasTrait<OpTrait::SymbolTable>();
+}
+
+/// Result of a logical op must be a scalar or vector of boolean type.
+static Type getUnaryOpResultType(Type operandType) {
+ Builder builder(operandType.getContext());
+ Type resultType = builder.getIntegerType(1);
+ if (auto vecType = llvm::dyn_cast<VectorType>(operandType))
+ return VectorType::get(vecType.getNumElements(), resultType);
+ return resultType;
+}
+
+static ParseResult parseImageOperands(OpAsmParser &parser,
+ spirv::ImageOperandsAttr &attr) {
+ // Expect image operands
+ if (parser.parseOptionalLSquare())
+ return success();
+
+ spirv::ImageOperands imageOperands;
+ if (parseEnumStrAttr(imageOperands, parser))
+ return failure();
+
+ attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
+
+ return parser.parseRSquare();
+}
+
+static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
+ spirv::ImageOperandsAttr attr) {
+ if (attr) {
+ auto strImageOperands = stringifyImageOperands(attr.getValue());
+ printer << "[\"" << strImageOperands << "\"]";
+ }
+}
+
+} // namespace mlir::spirv
+
+// TablenGen'erated operation definitions.
+#define GET_OP_CLASSES
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
index fff06bb5a7f207..e60fd53737d52d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
@@ -29,6 +29,9 @@ inline unsigned getBitWidth(Type type) {
llvm_unreachable("unhandled bit width computation for type");
}
+void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
+ SmallVectorImpl<StringRef> &elidedAttrs);
+
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value);
LogicalResult verifyMemorySemantics(Operation *op,
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 47ffdc1cdad18f..6d7c8b9878f017 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -28,7 +28,6 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
@@ -45,6 +44,77 @@ using namespace mlir::spirv::AttrNames;
// Common utility functions
//===----------------------------------------------------------------------===//
+LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
+ auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
+ if (!constOp) {
+ return failure();
+ }
+ auto valueAttr = constOp.getValue();
+ auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
+ if (!integerValueAttr) {
+ return failure();
+ }
+
+ if (integerValueAttr.getType().isSignlessInteger())
+ value = integerValueAttr.getInt();
+ else
+ value = integerValueAttr.getSInt();
+
+ return success();
+}
+
+LogicalResult
+spirv::verifyMemorySemantics(Operation *op,
+ spirv::MemorySemantics memorySemantics) {
+ // According to the SPIR-V specification:
+ // "Despite being a mask and allowing multiple bits to be combined, it is
+ // invalid for more than one of these four bits to be set: Acquire, Release,
+ // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
+ // Release semantics is done by setting the AcquireRelease bit, not by setting
+ // two bits."
+ auto atMostOneInSet = spirv::MemorySemantics::Acquire |
+ spirv::MemorySemantics::Release |
+ spirv::MemorySemantics::AcquireRelease |
+ spirv::MemorySemantics::SequentiallyConsistent;
+
+ auto bitCount =
+ llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
+ if (bitCount > 1) {
+ return op->emitError(
+ "expected at most one of these four memory constraints "
+ "to be set: `Acquire`, `Release`,"
+ "`AcquireRelease` or `SequentiallyConsistent`");
+ }
+ return success();
+}
+
+void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer,
+ SmallVectorImpl<StringRef> &elidedAttrs) {
+ // Print optional descriptor binding
+ auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
+ stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingName = llvm::convertToSnakeFromCamelCase(
+ stringifyDecoration(spirv::Decoration::Binding));
+ auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
+ auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
+ if (descriptorSet && binding) {
+ elidedAttrs.push_back(descriptorSetName);
+ elidedAttrs.push_back(bindingName);
+ printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
+ << ")";
+ }
+
+ // Print BuiltIn attribute if present
+ auto builtInName = llvm::convertToSnakeFromCamelCase(
+ stringifyDecoration(spirv::Decoration::BuiltIn));
+ if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
+ printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
+ elidedAttrs.push_back(builtInName);
+ }
+
+ printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
+}
+
static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
@@ -93,177 +163,6 @@ static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
p << " : " << resultType;
}
-/// Returns true if the given op is a function-like op or nested in a
-/// function-like op without a module-like op in the middle.
-static bool isNestedInFunctionOpInterface(Operation *op) {
- if (!op)
- return false;
- if (op->hasTrait<OpTrait::SymbolTable>())
- return false;
- if (isa<FunctionOpInterface>(op))
- return true;
- return isNestedInFunctionOpInterface(op->getParentOp());
-}
-
-/// Returns true if the given op is an module-like op that maintains a symbol
-/// table.
-static bool isDirectInModuleLikeOp(Operation *op) {
- return op && op->hasTrait<OpTrait::SymbolTable>();
-}
-
-LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
- auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
- if (!constOp) {
- return failure();
- }
- auto valueAttr = constOp.getValue();
- auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
- if (!integerValueAttr) {
- return failure();
- }
-
- if (integerValueAttr.getType().isSignlessInteger())
- value = integerValueAttr.getInt();
- else
- value = integerValueAttr.getSInt();
-
- return success();
-}
-
-/// Parses Function, Selection and Loop control attributes. If no control is
-/// specified, "None" is used as a default.
-template <typename EnumAttrClass, typename EnumClass>
-static ParseResult
-parseControlAttribute(OpAsmParser &parser, OperationState &state,
- StringRef attrName = spirv::attributeName<EnumClass>()) {
- if (succeeded(parser.parseOptionalKeyword(kControl))) {
- EnumClass control;
- if (parser.parseLParen() ||
- spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
- parser.parseRParen())
- return failure();
- return success();
- }
- // Set control to "None" otherwise.
- Builder builder = parser.getBuilder();
- state.addAttribute(attrName,
- builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
- return success();
-}
-
-// TODO Make sure to merge this and the previous function into one template
-// parameterized by memory access attribute name and alignment. Doing so now
-// results in VS2017 in producing an internal error (at the call site) that's
-// not detailed enough to understand what is happening.
-static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
- OperationState &state) {
- // Parse an optional list of attributes staring with '['
- if (parser.parseOptionalLSquare()) {
- // Nothing to do
- return success();
- }
-
- spirv::MemoryAccess memoryAccessAttr;
- if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
- memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
- return failure();
-
- if (spirv::bitEnumContainsAll(memoryAccessAttr,
- spirv::MemoryAccess::Aligned)) {
- // Parse integer attribute for alignment.
- Attribute alignmentAttr;
- Type i32Type = parser.getBuilder().getIntegerType(32);
- if (parser.parseComma() ||
- parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
- state.attributes)) {
- return failure();
- }
- }
- return parser.parseRSquare();
-}
-
-template <typename MemoryOpTy>
-static void printMemoryAccessAttribute(
- MemoryOpTy memoryOp, OpAsmPrinter &printer,
- SmallVectorImpl<StringRef> &elidedAttrs,
- std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
- std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
- // Print optional memory access attribute.
- if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
- : memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(kMemoryAccessAttrName);
-
- printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
-
- if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
- // Print integer alignment attribute.
- if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
- : memoryOp.getAlignment())) {
- elidedAttrs.push_back(kAlignmentAttrName);
- printer << ", " << *alignment;
- }
- }
- printer << "]";
- }
- elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
-}
-
-// TODO Make sure to merge this and the previous function into one template
-// parameterized by memory access attribute name and alignment. Doing so now
-// results in VS2017 in producing an internal error (at the call site) that's
-// not detailed enough to understand what is happening.
-template <typename MemoryOpTy>
-static void printSourceMemoryAccessAttribute(
- MemoryOpTy memoryOp, OpAsmPrinter &printer,
- SmallVectorImpl<StringRef> &elidedAttrs,
- std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
- std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
-
- printer << ", ";
-
- // Print optional memory access attribute.
- if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
- : memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(kSourceMemoryAccessAttrName);
-
- printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
-
- if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
- // Print integer alignment attribute.
- if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
- : memoryOp.getAlignment())) {
- elidedAttrs.push_back(kSourceAlignmentAttrName);
- printer << ", " << *alignment;
- }
- }
- printer << "]";
- }
- elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
-}
-
-static ParseResult parseImageOperands(OpAsmParser &parser,
- spirv::ImageOperandsAttr &attr) {
- // Expect image operands
- if (parser.parseOptionalLSquare())
- return success();
-
- spirv::ImageOperands imageOperands;
- if (parseEnumStrAttr(imageOperands, parser))
- return failure();
-
- attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
-
- return parser.parseRSquare();
-}
-
-static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
- spirv::ImageOperandsAttr attr) {
- if (attr) {
- auto strImageOperands = stringifyImageOperands(attr.getValue());
- printer << "[\"" << strImageOperands << "\"]";
- }
-}
-
template <typename Op>
static LogicalResult verifyImageOperands(Op imageOp,
spirv::ImageOperandsAttr attr,
@@ -292,130 +191,6 @@ static LogicalResult verifyImageOperands(Op imageOp,
return success();
}
-template <typename MemoryOpTy>
-static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
- // ODS checks for attributes values. Just need to verify that if the
- // memory-access attribute is Aligned, then the alignment attribute must be
- // present.
- auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
- if (!memAccessAttr) {
- // Alignment attribute shouldn't be present if memory access attribute is
- // not present.
- if (op->getAttr(kAlignmentAttrName)) {
- return memoryOp.emitOpError(
- "invalid alignment specification without aligned memory access "
- "specification");
- }
- return success();
- }
-
- auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
-
- if (!memAccess) {
- return memoryOp.emitOpError("invalid memory access specifier: ")
- << memAccessAttr;
- }
-
- if (spirv::bitEnumContainsAll(memAccess.getValue(),
- spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(kAlignmentAttrName)) {
- return memoryOp.emitOpError("missing alignment value");
- }
- } else {
- if (op->getAttr(kAlignmentAttrName)) {
- return memoryOp.emitOpError(
- "invalid alignment specification with non-aligned memory access "
- "specification");
- }
- }
- return success();
-}
-
-// TODO Make sure to merge this and the previous function into one template
-// parameterized by memory access attribute name and alignment. Doing so now
-// results in VS2017 in producing an internal error (at the call site) that's
-// not detailed enough to understand what is happening.
-template <typename MemoryOpTy>
-static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
- // ODS checks for attributes values. Just need to verify that if the
- // memory-access attribute is Aligned, then the alignment attribute must be
- // present.
- auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
- if (!memAccessAttr) {
- // Alignment attribute shouldn't be present if memory access attribute is
- // not present.
- if (op->getAttr(kSourceAlignmentAttrName)) {
- return memoryOp.emitOpError(
- "invalid alignment specification without aligned memory access "
- "specification");
- }
- return success();
- }
-
- auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);
-
- if (!memAccess) {
- return memoryOp.emitOpError("invalid memory access specifier: ")
- << memAccess;
- }
-
- if (spirv::bitEnumContainsAll(memAccess.getValue(),
- spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(kSourceAlignmentAttrName)) {
- return memoryOp.emitOpError("missing alignment value");
- }
- } else {
- if (op->getAttr(kSourceAlignmentAttrName)) {
- return memoryOp.emitOpError(
- "invalid alignment specification with non-aligned memory access "
- "specification");
- }
- }
- return success();
-}
-
-LogicalResult
-spirv::verifyMemorySemantics(Operation *op,
- spirv::MemorySemantics memorySemantics) {
- // According to the SPIR-V specification:
- // "Despite being a mask and allowing multiple bits to be combined, it is
- // invalid for more than one of these four bits to be set: Acquire, Release,
- // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
- // Release semantics is done by setting the AcquireRelease bit, not by setting
- // two bits."
- auto atMostOneInSet = spirv::MemorySemantics::Acquire |
- spirv::MemorySemantics::Release |
- spirv::MemorySemantics::AcquireRelease |
- spirv::MemorySemantics::SequentiallyConsistent;
-
- auto bitCount =
- llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
- if (bitCount > 1) {
- return op->emitError(
- "expected at most one of these four memory constraints "
- "to be set: `Acquire`, `Release`,"
- "`AcquireRelease` or `SequentiallyConsistent`");
- }
- return success();
-}
-
-template <typename LoadStoreOpTy>
-static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
- Value val) {
- // ODS already checks ptr is spirv::PointerType. Just check that the pointee
- // type of the pointer and the type of the value are the same
- //
- // TODO: Check that the value type satisfies restrictions of
- // SPIR-V OpLoad/OpStore operations
- if (val.getType() !=
- llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
- return op.emitOpError("mismatch in result type and pointer type");
- }
- return success();
-}
-
template <typename BlockReadWriteOpTy>
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
Value ptr, Value val) {
@@ -430,70 +205,6 @@ static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
return success();
}
-static ParseResult parseVariableDecorations(OpAsmParser &parser,
- OperationState &state) {
- auto builtInName = llvm::convertToSnakeFromCamelCase(
- stringifyDecoration(spirv::Decoration::BuiltIn));
- if (succeeded(parser.parseOptionalKeyword("bind"))) {
- Attribute set, binding;
- // Parse optional descriptor binding
- auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
- stringifyDecoration(spirv::Decoration::DescriptorSet));
- auto bindingName = llvm::convertToSnakeFromCamelCase(
- stringifyDecoration(spirv::Decoration::Binding));
- Type i32Type = parser.getBuilder().getIntegerType(32);
- if (parser.parseLParen() ||
- parser.parseAttribute(set, i32Type, descriptorSetName,
- state.attributes) ||
- parser.parseComma() ||
- parser.parseAttribute(binding, i32Type, bindingName,
- state.attributes) ||
- parser.parseRParen()) {
- return failure();
- }
- } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
- StringAttr builtIn;
- if (parser.parseLParen() ||
- parser.parseAttribute(builtIn, builtInName, state.attributes) ||
- parser.parseRParen()) {
- return failure();
- }
- }
-
- // Parse other attributes
- if (parser.parseOptionalAttrDict(state.attributes))
- return failure();
-
- return success();
-}
-
-static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
- SmallVectorImpl<StringRef> &elidedAttrs) {
- // Print optional descriptor binding
- auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
- stringifyDecoration(spirv::Decoration::DescriptorSet));
- auto bindingName = llvm::convertToSnakeFromCamelCase(
- stringifyDecoration(spirv::Decoration::Binding));
- auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
- auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
- if (descriptorSet && binding) {
- elidedAttrs.push_back(descriptorSetName);
- elidedAttrs.push_back(bindingName);
- printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
- << ")";
- }
-
- // Print BuiltIn attribute if present
- auto builtInName = llvm::convertToSnakeFromCamelCase(
- stringifyDecoration(spirv::Decoration::BuiltIn));
- if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
- printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
- elidedAttrs.push_back(builtInName);
- }
-
- printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
-}
-
/// Walks the given type hierarchy with the given indices, potentially down
/// to component granularity, to select an element type. Returns null type and
/// emits errors with the given loc on failure.
@@ -564,12 +275,6 @@ static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
return getElementType(type, indices, errorFn);
}
-/// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
-static inline bool isMergeBlock(Block &block) {
- return !block.empty() && std::next(block.begin()) == block.end() &&
- isa<spirv::MergeOp>(block.front());
-}
-
template <typename ExtendedBinaryOp>
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
auto resultType = llvm::cast<spirv::StructType>(op.getType());
@@ -617,15 +322,6 @@ static void printArithmeticExtendedBinaryOp(Operation *op,
printer << " : " << op->getResultTypes().front();
}
-/// Result of a logical op must be a scalar or vector of boolean type.
-static Type getUnaryOpResultType(Type operandType) {
- Builder builder(operandType.getContext());
- Type resultType = builder.getIntegerType(1);
- if (auto vecType = llvm::dyn_cast<VectorType>(operandType))
- return VectorType::get(vecType.getNumElements(), resultType);
- return resultType;
-}
-
static LogicalResult verifyShiftOp(Operation *op) {
if (op->getOperand(0).getType() != op->getResult(0).getType()) {
return op->emitError("expected the same type for the first operand and "
@@ -636,152 +332,6 @@ static LogicalResult verifyShiftOp(Operation *op) {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.AccessChainOp
-//===----------------------------------------------------------------------===//
-
-static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
- auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
- if (!ptrType) {
- emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
- "to composite type, but provided ")
- << type;
- return nullptr;
- }
-
- auto resultType = ptrType.getPointeeType();
- auto resultStorageClass = ptrType.getStorageClass();
- int32_t index = 0;
-
- for (auto indexSSA : indices) {
- auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
- if (!cType) {
- emitError(
- baseLoc,
- "'spirv.AccessChain' op cannot extract from non-composite type ")
- << resultType << " with index " << index;
- return nullptr;
- }
- index = 0;
- if (llvm::isa<spirv::StructType>(resultType)) {
- Operation *op = indexSSA.getDefiningOp();
- if (!op) {
- emitError(baseLoc, "'spirv.AccessChain' op index must be an "
- "integer spirv.Constant to access "
- "element of spirv.struct");
- return nullptr;
- }
-
- // TODO: this should be relaxed to allow
- // integer literals of other bitwidths.
- if (failed(spirv::extractValueFromConstOp(op, index))) {
- emitError(
- baseLoc,
- "'spirv.AccessChain' index must be an integer spirv.Constant to "
- "access element of spirv.struct, but provided ")
- << op->getName();
- return nullptr;
- }
- if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
- emitError(baseLoc, "'spirv.AccessChain' op index ")
- << index << " out of bounds for " << resultType;
- return nullptr;
- }
- }
- resultType = cType.getElementType(index);
- }
- return spirv::PointerType::get(resultType, resultStorageClass);
-}
-
-void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
- Value basePtr, ValueRange indices) {
- auto type = getElementPtrType(basePtr.getType(), indices, state.location);
- assert(type && "Unable to deduce return type based on basePtr and indices");
- build(builder, state, type, basePtr, indices);
-}
-
-ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::UnresolvedOperand ptrInfo;
- SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
- Type type;
- auto loc = parser.getCurrentLocation();
- SmallVector<Type, 4> indicesTypes;
-
- if (parser.parseOperand(ptrInfo) ||
- parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(ptrInfo, type, result.operands)) {
- return failure();
- }
-
- // Check that the provided indices list is not empty before parsing their
- // type list.
- if (indicesInfo.empty()) {
- return mlir::emitError(result.location,
- "'spirv.AccessChain' op expected at "
- "least one index ");
- }
-
- if (parser.parseComma() || parser.parseTypeList(indicesTypes))
- return failure();
-
- // Check that the indices types list is not empty and that it has a one-to-one
- // mapping to the provided indices.
- if (indicesTypes.size() != indicesInfo.size()) {
- return mlir::emitError(
- result.location, "'spirv.AccessChain' op indices types' count must be "
- "equal to indices info count");
- }
-
- if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
- return failure();
-
- auto resultType = getElementPtrType(
- type, llvm::ArrayRef(result.operands).drop_front(), result.location);
- if (!resultType) {
- return failure();
- }
-
- result.addTypes(resultType);
- return success();
-}
-
-template <typename Op>
-static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
- printer << ' ' << op.getBasePtr() << '[' << indices
- << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
-}
-
-void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, getIndices(), printer);
-}
-
-template <typename Op>
-static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
- auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
- indices, accessChainOp.getLoc());
- if (!resultType)
- return failure();
-
- auto providedResultType =
- llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
- if (!providedResultType)
- return accessChainOp.emitOpError(
- "result type must be a pointer, but provided")
- << providedResultType;
-
- if (resultType != providedResultType)
- return accessChainOp.emitOpError("invalid result type: expected ")
- << resultType << ", but provided " << providedResultType;
-
- return success();
-}
-
-LogicalResult spirv::AccessChainOp::verify() {
- return verifyAccessChain(*this, getIndices());
-}
-
//===----------------------------------------------------------------------===//
// spirv.mlir.addressof
//===----------------------------------------------------------------------===//
@@ -805,109 +355,6 @@ LogicalResult spirv::AddressOfOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.BranchOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
- assert(index == 0 && "invalid successor index");
- return SuccessorOperands(0, getTargetOperandsMutable());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.BranchConditionalOp
-//===----------------------------------------------------------------------===//
-
-SuccessorOperands
-spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
- assert(index < 2 && "invalid successor index");
- return SuccessorOperands(index == kTrueIndex
- ? getTrueTargetOperandsMutable()
- : getFalseTargetOperandsMutable());
-}
-
-ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
- OperationState &result) {
- auto &builder = parser.getBuilder();
- OpAsmParser::UnresolvedOperand condInfo;
- Block *dest;
-
- // Parse the condition.
- Type boolTy = builder.getI1Type();
- if (parser.parseOperand(condInfo) ||
- parser.resolveOperand(condInfo, boolTy, result.operands))
- return failure();
-
- // Parse the optional branch weights.
- if (succeeded(parser.parseOptionalLSquare())) {
- IntegerAttr trueWeight, falseWeight;
- NamedAttrList weights;
-
- auto i32Type = builder.getIntegerType(32);
- if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
- parser.parseComma() ||
- parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
- parser.parseRSquare())
- return failure();
-
- result.addAttribute(kBranchWeightAttrName,
- builder.getArrayAttr({trueWeight, falseWeight}));
- }
-
- // Parse the true branch.
- SmallVector<Value, 4> trueOperands;
- if (parser.parseComma() ||
- parser.parseSuccessorAndUseList(dest, trueOperands))
- return failure();
- result.addSuccessors(dest);
- result.addOperands(trueOperands);
-
- // Parse the false branch.
- SmallVector<Value, 4> falseOperands;
- if (parser.parseComma() ||
- parser.parseSuccessorAndUseList(dest, falseOperands))
- return failure();
- result.addSuccessors(dest);
- result.addOperands(falseOperands);
- result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr(
- {1, static_cast<int32_t>(trueOperands.size()),
- static_cast<int32_t>(falseOperands.size())}));
-
- return success();
-}
-
-void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
- printer << ' ' << getCondition();
-
- if (auto weights = getBranchWeights()) {
- printer << " [";
- llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
- printer << llvm::cast<IntegerAttr>(a).getInt();
- });
- printer << "]";
- }
-
- printer << ", ";
- printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
- printer << ", ";
- printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
-}
-
-LogicalResult spirv::BranchConditionalOp::verify() {
- if (auto weights = getBranchWeights()) {
- if (weights->getValue().size() != 2) {
- return emitOpError("must have exactly two branch weights");
- }
- if (llvm::all_of(*weights, [](Attribute attr) {
- return llvm::cast<IntegerAttr>(attr).getValue().isZero();
- }))
- return emitOpError("branch weights cannot both be zero");
- }
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.CompositeConstruct
//===----------------------------------------------------------------------===//
@@ -1584,72 +1031,6 @@ ::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() {
return getResAttrs().value_or(nullptr);
}
-//===----------------------------------------------------------------------===//
-// spirv.FunctionCall
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::FunctionCallOp::verify() {
- auto fnName = getCalleeAttr();
-
- auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
- SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
- if (!funcOp) {
- return emitOpError("callee function '")
- << fnName.getValue() << "' not found in nearest symbol table";
- }
-
- auto functionType = funcOp.getFunctionType();
-
- if (getNumResults() > 1) {
- return emitOpError(
- "expected callee function to have 0 or 1 result, but provided ")
- << getNumResults();
- }
-
- if (functionType.getNumInputs() != getNumOperands()) {
- return emitOpError("has incorrect number of operands for callee: expected ")
- << functionType.getNumInputs() << ", but provided "
- << getNumOperands();
- }
-
- for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
- if (getOperand(i).getType() != functionType.getInput(i)) {
- return emitOpError("operand type mismatch: expected operand type ")
- << functionType.getInput(i) << ", but provided "
- << getOperand(i).getType() << " for operand number " << i;
- }
- }
-
- if (functionType.getNumResults() != getNumResults()) {
- return emitOpError(
- "has incorrect number of results has for callee: expected ")
- << functionType.getNumResults() << ", but provided "
- << getNumResults();
- }
-
- if (getNumResults() &&
- (getResult(0).getType() != functionType.getResult(0))) {
- return emitOpError("result type mismatch: expected ")
- << functionType.getResult(0) << ", but provided "
- << getResult(0).getType();
- }
-
- return success();
-}
-
-CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
- return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
-}
-
-void spirv::FunctionCallOp::setCalleeFromCallable(
- CallInterfaceCallable callee) {
- (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
-}
-
-Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
- return getArguments();
-}
-
//===----------------------------------------------------------------------===//
// spirv.GLFClampOp
//===----------------------------------------------------------------------===//
@@ -1768,7 +1149,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
}
elidedAttrs.push_back(kTypeAttrName);
- printVariableDecorations(*this, printer, elidedAttrs);
+ spirv::printVariableDecorations(*this, printer, elidedAttrs);
printer << " : " << getType();
}
@@ -1933,232 +1314,21 @@ void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
::printArithmeticExtendedBinaryOp(*this, printer);
}
-//===----------------------------------------------------------------------===//
-// spirv.UMulExtended
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::UMulExtendedOp::verify() {
- return ::verifyArithmeticExtendedBinaryOp(*this);
-}
-
-ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseArithmeticExtendedBinaryOp(parser, result);
-}
-
-void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
- ::printArithmeticExtendedBinaryOp(*this, printer);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.LoadOp
-//===----------------------------------------------------------------------===//
-
-void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
- Value basePtr, MemoryAccessAttr memoryAccess,
- IntegerAttr alignment) {
- auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
- build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
- alignment);
-}
-
-ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
- // Parse the storage class specification
- spirv::StorageClass storageClass;
- OpAsmParser::UnresolvedOperand ptrInfo;
- Type elementType;
- if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
- parseMemoryAccessAttributes(parser, result) ||
- parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
- parser.parseType(elementType)) {
- return failure();
- }
-
- auto ptrType = spirv::PointerType::get(elementType, storageClass);
- if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
- return failure();
- }
-
- result.addTypes(elementType);
- return success();
-}
-
-void spirv::LoadOp::print(OpAsmPrinter &printer) {
- SmallVector<StringRef, 4> elidedAttrs;
- StringRef sc = stringifyStorageClass(
- llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
- printer << " \"" << sc << "\" " << getPtr();
-
- printMemoryAccessAttribute(*this, printer, elidedAttrs);
-
- printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
- printer << " : " << getType();
-}
-
-LogicalResult spirv::LoadOp::verify() {
- // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
- // type with fixed size; i.e., it cannot be, nor include, any
- // OpTypeRuntimeArray types."
- if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
- return failure();
- }
- return verifyMemoryAccessAttribute(*this);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.mlir.loop
-//===----------------------------------------------------------------------===//
-
-void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
- state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
- spirv::LoopControl::None));
- state.addRegion();
-}
-
-ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
- if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
- result))
- return failure();
- return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
-}
-
-void spirv::LoopOp::print(OpAsmPrinter &printer) {
- auto control = getLoopControl();
- if (control != spirv::LoopControl::None)
- printer << " control(" << spirv::stringifyLoopControl(control) << ")";
- printer << ' ';
- printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/true);
-}
-
-/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
-/// given `dstBlock`.
-static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
- // Check that there is only one op in the `srcBlock`.
- if (!llvm::hasSingleElement(srcBlock))
- return false;
-
- auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
- return branchOp && branchOp.getSuccessor() == &dstBlock;
-}
-
-LogicalResult spirv::LoopOp::verifyRegions() {
- auto *op = getOperation();
-
- // We need to verify that the blocks follow the following layout:
- //
- // +-------------+
- // | entry block |
- // +-------------+
- // |
- // v
- // +-------------+
- // | loop header | <-----+
- // +-------------+ |
- // |
- // ... |
- // \ | / |
- // v |
- // +---------------+ |
- // | loop continue | -----+
- // +---------------+
- //
- // ...
- // \ | /
- // v
- // +-------------+
- // | merge block |
- // +-------------+
-
- auto ®ion = op->getRegion(0);
- // Allow empty region as a degenerated case, which can come from
- // optimizations.
- if (region.empty())
- return success();
-
- // The last block is the merge block.
- Block &merge = region.back();
- if (!isMergeBlock(merge))
- return emitOpError("last block must be the merge block with only one "
- "'spirv.mlir.merge' op");
-
- if (std::next(region.begin()) == region.end())
- return emitOpError(
- "must have an entry block branching to the loop header block");
- // The first block is the entry block.
- Block &entry = region.front();
-
- if (std::next(region.begin(), 2) == region.end())
- return emitOpError(
- "must have a loop header block branched from the entry block");
- // The second block is the loop header block.
- Block &header = *std::next(region.begin(), 1);
-
- if (!hasOneBranchOpTo(entry, header))
- return emitOpError(
- "entry block must only have one 'spirv.Branch' op to the second block");
-
- if (std::next(region.begin(), 3) == region.end())
- return emitOpError(
- "requires a loop continue block branching to the loop header block");
- // The second to last block is the loop continue block.
- Block &cont = *std::prev(region.end(), 2);
-
- // Make sure that we have a branch from the loop continue block to the loop
- // header block.
- if (llvm::none_of(
- llvm::seq<unsigned>(0, cont.getNumSuccessors()),
- [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
- return emitOpError("second to last block must be the loop continue "
- "block that branches to the loop header block");
-
- // Make sure that no other blocks (except the entry and loop continue block)
- // branches to the loop header block.
- for (auto &block : llvm::make_range(std::next(region.begin(), 2),
- std::prev(region.end(), 2))) {
- for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
- if (block.getSuccessor(i) == &header) {
- return emitOpError("can only have the entry and loop continue "
- "block branching to the loop header block");
- }
- }
- }
-
- return success();
-}
-
-Block *spirv::LoopOp::getEntryBlock() {
- assert(!getBody().empty() && "op region should not be empty!");
- return &getBody().front();
-}
-
-Block *spirv::LoopOp::getHeaderBlock() {
- assert(!getBody().empty() && "op region should not be empty!");
- // The second block is the loop header block.
- return &*std::next(getBody().begin());
-}
+//===----------------------------------------------------------------------===//
+// spirv.UMulExtended
+//===----------------------------------------------------------------------===//
-Block *spirv::LoopOp::getContinueBlock() {
- assert(!getBody().empty() && "op region should not be empty!");
- // The second to last block is the loop continue block.
- return &*std::prev(getBody().end(), 2);
+LogicalResult spirv::UMulExtendedOp::verify() {
+ return ::verifyArithmeticExtendedBinaryOp(*this);
}
-Block *spirv::LoopOp::getMergeBlock() {
- assert(!getBody().empty() && "op region should not be empty!");
- // The last block is the loop merge block.
- return &getBody().back();
+ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return ::parseArithmeticExtendedBinaryOp(parser, result);
}
-void spirv::LoopOp::addEntryAndMergeBlock() {
- assert(getBody().empty() && "entry and merge block already exist");
- getBody().push_back(new Block());
- auto *mergeBlock = new Block();
- getBody().push_back(mergeBlock);
- OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
-
- // Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
+void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
+ ::printArithmeticExtendedBinaryOp(*this, printer);
}
//===----------------------------------------------------------------------===//
@@ -2169,24 +1339,6 @@ LogicalResult spirv::MemoryBarrierOp::verify() {
return verifyMemorySemantics(getOperation(), getMemorySemantics());
}
-//===----------------------------------------------------------------------===//
-// spirv.mlir.merge
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::MergeOp::verify() {
- auto *parentOp = (*this)->getParentOp();
- if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
- return emitOpError(
- "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
-
- // TODO: This check should be done in `verifyRegions` of parent op.
- Block &parentLastBlock = (*this)->getParentRegion()->back();
- if (getOperation() != parentLastBlock.getTerminator())
- return emitOpError("can only be used in the last block of "
- "'spirv.mlir.selection' or 'spirv.mlir.loop'");
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.module
//===----------------------------------------------------------------------===//
@@ -2382,158 +1534,6 @@ LogicalResult spirv::ReferenceOfOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.Return
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ReturnOp::verify() {
- // Verification is performed in spirv.func op.
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.ReturnValue
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ReturnValueOp::verify() {
- // Verification is performed in spirv.func op.
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.Select
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::SelectOp::verify() {
- if (auto conditionTy = llvm::dyn_cast<VectorType>(getCondition().getType())) {
- auto resultVectorTy = llvm::dyn_cast<VectorType>(getResult().getType());
- if (!resultVectorTy) {
- return emitOpError("result expected to be of vector type when "
- "condition is of vector type");
- }
- if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
- return emitOpError("result should have the same number of elements as "
- "the condition when condition is of vector type");
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.mlir.selection
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
- OperationState &result) {
- if (parseControlAttribute<spirv::SelectionControlAttr,
- spirv::SelectionControl>(parser, result))
- return failure();
- return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
-}
-
-void spirv::SelectionOp::print(OpAsmPrinter &printer) {
- auto control = getSelectionControl();
- if (control != spirv::SelectionControl::None)
- printer << " control(" << spirv::stringifySelectionControl(control) << ")";
- printer << ' ';
- printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/true);
-}
-
-LogicalResult spirv::SelectionOp::verifyRegions() {
- auto *op = getOperation();
-
- // We need to verify that the blocks follow the following layout:
- //
- // +--------------+
- // | header block |
- // +--------------+
- // / | \
- // ...
- //
- //
- // +---------+ +---------+ +---------+
- // | case #0 | | case #1 | | case #2 | ...
- // +---------+ +---------+ +---------+
- //
- //
- // ...
- // \ | /
- // v
- // +-------------+
- // | merge block |
- // +-------------+
-
- auto ®ion = op->getRegion(0);
- // Allow empty region as a degenerated case, which can come from
- // optimizations.
- if (region.empty())
- return success();
-
- // The last block is the merge block.
- if (!isMergeBlock(region.back()))
- return emitOpError("last block must be the merge block with only one "
- "'spirv.mlir.merge' op");
-
- if (std::next(region.begin()) == region.end())
- return emitOpError("must have a selection header block");
-
- return success();
-}
-
-Block *spirv::SelectionOp::getHeaderBlock() {
- assert(!getBody().empty() && "op region should not be empty!");
- // The first block is the loop header block.
- return &getBody().front();
-}
-
-Block *spirv::SelectionOp::getMergeBlock() {
- assert(!getBody().empty() && "op region should not be empty!");
- // The last block is the loop merge block.
- return &getBody().back();
-}
-
-void spirv::SelectionOp::addMergeBlock() {
- assert(getBody().empty() && "entry and merge block already exist");
- auto *mergeBlock = new Block();
- getBody().push_back(mergeBlock);
- OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
-
- // Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
-}
-
-spirv::SelectionOp spirv::SelectionOp::createIfThen(
- Location loc, Value condition,
- function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
- auto selectionOp =
- builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
-
- selectionOp.addMergeBlock();
- Block *mergeBlock = selectionOp.getMergeBlock();
- Block *thenBlock = nullptr;
-
- // Build the "then" block.
- {
- OpBuilder::InsertionGuard guard(builder);
- thenBlock = builder.createBlock(mergeBlock);
- thenBody(builder);
- builder.create<spirv::BranchOp>(loc, mergeBlock);
- }
-
- // Build the header block.
- {
- OpBuilder::InsertionGuard guard(builder);
- builder.createBlock(thenBlock);
- builder.create<spirv::BranchConditionalOp>(
- loc, condition, thenBlock,
- /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
- /*falseArguments=*/ArrayRef<Value>());
- }
-
- return selectionOp;
-}
-
//===----------------------------------------------------------------------===//
// spirv.SpecConstant
//===----------------------------------------------------------------------===//
@@ -2588,171 +1588,6 @@ LogicalResult spirv::SpecConstantOp::verify() {
"default value can only be a bool, integer, or float scalar");
}
-//===----------------------------------------------------------------------===//
-// spirv.StoreOp
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
- // Parse the storage class specification
- spirv::StorageClass storageClass;
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
- auto loc = parser.getCurrentLocation();
- Type elementType;
- if (parseEnumStrAttr(storageClass, parser) ||
- parser.parseOperandList(operandInfo, 2) ||
- parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
- parser.parseType(elementType)) {
- return failure();
- }
-
- auto ptrType = spirv::PointerType::get(elementType, storageClass);
- if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
- result.operands)) {
- return failure();
- }
- return success();
-}
-
-void spirv::StoreOp::print(OpAsmPrinter &printer) {
- SmallVector<StringRef, 4> elidedAttrs;
- StringRef sc = stringifyStorageClass(
- llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
- printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
-
- printMemoryAccessAttribute(*this, printer, elidedAttrs);
-
- printer << " : " << getValue().getType();
- printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
-}
-
-LogicalResult spirv::StoreOp::verify() {
- // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
- // OpTypePointer whose Type operand is the same as the type of Object."
- if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
- return failure();
- return verifyMemoryAccessAttribute(*this);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.Unreachable
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::UnreachableOp::verify() {
- auto *block = (*this)->getBlock();
- // Fast track: if this is in entry block, its invalid. Otherwise, if no
- // predecessors, it's valid.
- if (block->isEntryBlock())
- return emitOpError("cannot be used in reachable block");
- if (block->hasNoPredecessors())
- return success();
-
- // TODO: further verification needs to analyze reachability from
- // the entry block.
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.Variable
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
- OperationState &result) {
- // Parse optional initializer
- std::optional<OpAsmParser::UnresolvedOperand> initInfo;
- if (succeeded(parser.parseOptionalKeyword("init"))) {
- initInfo = OpAsmParser::UnresolvedOperand();
- if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
- parser.parseRParen())
- return failure();
- }
-
- if (parseVariableDecorations(parser, result)) {
- return failure();
- }
-
- // Parse result pointer type
- Type type;
- if (parser.parseColon())
- return failure();
- auto loc = parser.getCurrentLocation();
- if (parser.parseType(type))
- return failure();
-
- auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
- if (!ptrType)
- return parser.emitError(loc, "expected spirv.ptr type");
- result.addTypes(ptrType);
-
- // Resolve the initializer operand
- if (initInfo) {
- if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
- result.operands))
- return failure();
- }
-
- auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
- ptrType.getStorageClass());
- result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
-
- return success();
-}
-
-void spirv::VariableOp::print(OpAsmPrinter &printer) {
- SmallVector<StringRef, 4> elidedAttrs{
- spirv::attributeName<spirv::StorageClass>()};
- // Print optional initializer
- if (getNumOperands() != 0)
- printer << " init(" << getInitializer() << ")";
-
- printVariableDecorations(*this, printer, elidedAttrs);
- printer << " : " << getType();
-}
-
-LogicalResult spirv::VariableOp::verify() {
- // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
- // object. It cannot be Generic. It must be the same as the Storage Class
- // operand of the Result Type."
- if (getStorageClass() != spirv::StorageClass::Function) {
- return emitOpError(
- "can only be used to model function-level variables. Use "
- "spirv.GlobalVariable for module-level variables.");
- }
-
- auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
- if (getStorageClass() != pointerType.getStorageClass())
- return emitOpError(
- "storage class must match result pointer's storage class");
-
- if (getNumOperands() != 0) {
- // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
- // a global (module scope) OpVariable instruction".
- auto *initOp = getOperand(0).getDefiningOp();
- if (!initOp || !isa<spirv::ConstantOp, // for normal constant
- spirv::ReferenceOfOp, // for spec constant
- spirv::AddressOfOp>(initOp))
- return emitOpError("initializer must be the result of a "
- "constant or spirv.GlobalVariable op");
- }
-
- // TODO: generate these strings using ODS.
- auto *op = getOperation();
- auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
- stringifyDecoration(spirv::Decoration::DescriptorSet));
- auto bindingName = llvm::convertToSnakeFromCamelCase(
- stringifyDecoration(spirv::Decoration::Binding));
- auto builtInName = llvm::convertToSnakeFromCamelCase(
- stringifyDecoration(spirv::Decoration::BuiltIn));
-
- for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
- if (op->getAttr(attr))
- return emitOpError("cannot have '")
- << attr << "' attribute (only allowed in spirv.GlobalVariable)";
- }
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.VectorShuffle
//===----------------------------------------------------------------------===//
@@ -2804,100 +1639,6 @@ LogicalResult spirv::MatrixTimesScalarOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.CopyMemory
-//===----------------------------------------------------------------------===//
-
-void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
- printer << ' ';
-
- StringRef targetStorageClass = stringifyStorageClass(
- llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
- printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
-
- StringRef sourceStorageClass = stringifyStorageClass(
- llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
- printer << " \"" << sourceStorageClass << "\" " << getSource();
-
- SmallVector<StringRef, 4> elidedAttrs;
- printMemoryAccessAttribute(*this, printer, elidedAttrs);
- printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
- getSourceMemoryAccess(),
- getSourceAlignment());
-
- printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
-
- Type pointeeType =
- llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
- printer << " : " << pointeeType;
-}
-
-ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
- OperationState &result) {
- spirv::StorageClass targetStorageClass;
- OpAsmParser::UnresolvedOperand targetPtrInfo;
-
- spirv::StorageClass sourceStorageClass;
- OpAsmParser::UnresolvedOperand sourcePtrInfo;
-
- Type elementType;
-
- if (parseEnumStrAttr(targetStorageClass, parser) ||
- parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
- parseEnumStrAttr(sourceStorageClass, parser) ||
- parser.parseOperand(sourcePtrInfo) ||
- parseMemoryAccessAttributes(parser, result)) {
- return failure();
- }
-
- if (!parser.parseOptionalComma()) {
- // Parse 2nd memory access attributes.
- if (parseSourceMemoryAccessAttributes(parser, result)) {
- return failure();
- }
- }
-
- if (parser.parseColon() || parser.parseType(elementType))
- return failure();
-
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
- auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
-
- if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
- parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
- return failure();
- }
-
- return success();
-}
-
-LogicalResult spirv::CopyMemoryOp::verify() {
- Type targetType =
- llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
-
- Type sourceType =
- llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();
-
- if (targetType != sourceType)
- return emitOpError("both operands must be pointers to the same type");
-
- if (failed(verifyMemoryAccessAttribute(*this)))
- return failure();
-
- // TODO - According to the spec:
- //
- // If two masks are present, the first applies to Target and cannot include
- // MakePointerVisible, and the second applies to Source and cannot include
- // MakePointerAvailable.
- //
- // Add such verification here.
-
- return verifySourceMemoryAccessAttribute(*this);
-}
-
//===----------------------------------------------------------------------===//
// spirv.Transpose
//===----------------------------------------------------------------------===//
@@ -3305,109 +2046,6 @@ LogicalResult spirv::ImageQuerySizeOp::verify() {
return success();
}
-static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
- OpAsmParser &parser,
- OperationState &state) {
- OpAsmParser::UnresolvedOperand ptrInfo;
- SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
- Type type;
- auto loc = parser.getCurrentLocation();
- SmallVector<Type, 4> indicesTypes;
-
- if (parser.parseOperand(ptrInfo) ||
- parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseColonType(type) ||
- parser.resolveOperand(ptrInfo, type, state.operands))
- return failure();
-
- // Check that the provided indices list is not empty before parsing their
- // type list.
- if (indicesInfo.empty())
- return emitError(state.location) << opName << " expected element";
-
- if (parser.parseComma() || parser.parseTypeList(indicesTypes))
- return failure();
-
- // Check that the indices types list is not empty and that it has a one-to-one
- // mapping to the provided indices.
- if (indicesTypes.size() != indicesInfo.size())
- return emitError(state.location)
- << opName
- << " indices types' count must be equal to indices info count";
-
- if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
- return failure();
-
- auto resultType = getElementPtrType(
- type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
- if (!resultType)
- return failure();
-
- state.addTypes(resultType);
- return success();
-}
-
-template <typename Op>
-static auto concatElemAndIndices(Op op) {
- SmallVector<Value> ret(op.getIndices().size() + 1);
- ret[0] = op.getElement();
- llvm::copy(op.getIndices(), ret.begin() + 1);
- return ret;
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.InBoundsPtrAccessChainOp
-//===----------------------------------------------------------------------===//
-
-void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
- OperationState &state,
- Value basePtr, Value element,
- ValueRange indices) {
- auto type = getElementPtrType(basePtr.getType(), indices, state.location);
- assert(type && "Unable to deduce return type based on basePtr and indices");
- build(builder, state, type, basePtr, element, indices);
-}
-
-ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parsePtrAccessChainOpImpl(
- spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
-}
-
-void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
-LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
- return verifyAccessChain(*this, getIndices());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.PtrAccessChainOp
-//===----------------------------------------------------------------------===//
-
-void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
- Value basePtr, Value element,
- ValueRange indices) {
- auto type = getElementPtrType(basePtr.getType(), indices, state.location);
- assert(type && "Unable to deduce return type based on basePtr and indices");
- build(builder, state, type, basePtr, element, indices);
-}
-
-ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
- parser, result);
-}
-
-void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
- printAccessChain(*this, concatElemAndIndices(*this), printer);
-}
-
-LogicalResult spirv::PtrAccessChainOp::verify() {
- return verifyAccessChain(*this, getIndices());
-}
-
//===----------------------------------------------------------------------===//
// spirv.VectorTimesScalarOp
//===----------------------------------------------------------------------===//
@@ -3420,18 +2058,3 @@ LogicalResult spirv::VectorTimesScalarOp::verify() {
return emitOpError("scalar operand and result element type match");
return success();
}
-
-// TableGen'erated operation interfaces for querying versions, extensions, and
-// capabilities.
-#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
-
-// TablenGen'erated operation definitions.
-#define GET_OP_CLASSES
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
-
-namespace mlir {
-namespace spirv {
-// TableGen'erated operation availability interface implementations.
-#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
-} // namespace spirv
-} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index 43c0beaccc0fd3..27c373300aee8e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -12,6 +12,8 @@
#include "SPIRVParsingUtils.h"
+#include "llvm/ADT/StringExtras.h"
+
using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
@@ -45,4 +47,41 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
return parser.parseRSquare();
}
+ParseResult parseVariableDecorations(OpAsmParser &parser,
+ OperationState &state) {
+ auto builtInName = llvm::convertToSnakeFromCamelCase(
+ stringifyDecoration(spirv::Decoration::BuiltIn));
+ if (succeeded(parser.parseOptionalKeyword("bind"))) {
+ Attribute set, binding;
+ // Parse optional descriptor binding
+ auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
+ stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingName = llvm::convertToSnakeFromCamelCase(
+ stringifyDecoration(spirv::Decoration::Binding));
+ Type i32Type = parser.getBuilder().getIntegerType(32);
+ if (parser.parseLParen() ||
+ parser.parseAttribute(set, i32Type, descriptorSetName,
+ state.attributes) ||
+ parser.parseComma() ||
+ parser.parseAttribute(binding, i32Type, bindingName,
+ state.attributes) ||
+ parser.parseRParen()) {
+ return failure();
+ }
+ } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
+ StringAttr builtIn;
+ if (parser.parseLParen() ||
+ parser.parseAttribute(builtIn, builtInName, state.attributes) ||
+ parser.parseRParen()) {
+ return failure();
+ }
+ }
+
+ // Parse other attributes
+ if (parser.parseOptionalAttrDict(state.attributes))
+ return failure();
+
+ return success();
+}
+
} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index fd2faf4b7b333f..625c82f6e8e899 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -153,4 +153,7 @@ ParseResult parseMemoryAccessAttributes(
OpAsmParser &parser, OperationState &state,
StringRef attrName = AttrNames::kMemoryAccessAttrName);
+ParseResult parseVariableDecorations(OpAsmParser &parser,
+ OperationState &state);
+
} // namespace mlir::spirv
More information about the Mlir-commits
mailing list