[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension - part 1 (PR #151934)

Jakub Kuderski llvmlistbot at llvm.org
Sat Aug 30 10:42:41 PDT 2025


================
@@ -0,0 +1,253 @@
+//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVParsingUtils.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/Support/InterleavedRange.h"
+
+using namespace mlir;
+using namespace mlir::spirv::AttrNames;
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphARM
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  Builder &builder = parser.getBuilder();
+
+  // Parse the name as a symbol.
+  StringAttr nameAttr;
+  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                             result.attributes))
+    return failure();
+
+  // Parse the function signature.
+  bool isVariadic = false;
+  SmallVector<OpAsmParser::Argument> entryArgs;
+  SmallVector<Type> resultTypes;
+  SmallVector<DictionaryAttr> resultAttrs;
+  if (function_interface_impl::parseFunctionSignatureWithArguments(
+          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+          resultAttrs))
+    return failure();
+
+  SmallVector<Type> argTypes = llvm::map_to_vector(
+      entryArgs, [](const OpAsmParser::Argument &arg) { return arg.type; });
+  GraphType grType = builder.getGraphType(argTypes, resultTypes);
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(grType));
+
+  // If additional attributes are present, parse them.
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  // Add the attributes to the function arguments.
+  assert(resultAttrs.size() == resultTypes.size());
+  call_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
+
+  // Parse the optional function body.
+  Region *body = result.addRegion();
+  OptionalParseResult parseResult =
+      parser.parseOptionalRegion(*body, entryArgs);
+  return failure(parseResult.has_value() && failed(*parseResult));
+}
+
+void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
+  // Print graph name, signature, and control.
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  GraphType grType = getFunctionType();
+  function_interface_impl::printFunctionSignature(
+      printer, *this, grType.getInputs(),
+      /*isVariadic=*/false, grType.getResults());
+  function_interface_impl::printFunctionAttributes(printer, *this,
+                                                   {getFunctionTypeAttrName(),
+                                                    getArgAttrsAttrName(),
+                                                    getResAttrsAttrName()});
+
+  // Print the body.
+  Region &body = this->getBody();
+  if (!body.empty()) {
+    printer << ' ';
+    printer.printRegion(body, /*printEntryBlockArgs=*/false,
+                        /*printBlockTerminators=*/true);
+  }
+}
+
+LogicalResult spirv::GraphARMOp::verifyType() {
+  if (getFunctionType().getNumResults() < 1)
+    return emitOpError("there should be at least one result");
+  return success();
+}
+
+LogicalResult spirv::GraphARMOp::verifyBody() {
+  for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
+    if (!isa<spirv::TensorArmType>(graphArgType)) {
+      return emitOpError("type of argument #")
+             << index << " must be a TensorArmType, but got " << graphArgType;
+    }
+  }
+  for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
+    if (!isa<spirv::TensorArmType>(graphResType)) {
+      return emitOpError("type of result #")
+             << index << " must be a TensorArmType, but got " << graphResType;
+    }
+  }
+
+  if (!isExternal()) {
+    Block &entryBlock = front();
+
+    unsigned numArguments = this->getNumArguments();
+    if (entryBlock.getNumArguments() != numArguments)
+      return emitOpError("entry block must have ")
+             << numArguments << " arguments to match graph signature";
+
+    for (auto [index, grArgType, blockArgType] :
+         llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
+      if (blockArgType != grArgType) {
+        return emitOpError("type of entry block argument #")
+               << index << '(' << blockArgType
+               << ") must match the type of the corresponding argument in "
+               << "graph signature(" << grArgType << ')';
+      }
+    }
+  }
+
+  GraphType grType = getFunctionType();
+  auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
+    if (grType.getNumResults() != op.getNumOperands())
+      return op.emitOpError("is returning ")
+             << op.getNumOperands()
+             << " value(s) but enclosing spirv.ARM.Graph requires "
+             << grType.getNumResults() << " result(s)";
+
+    ValueTypeRange<OperandRange> graphOutputOperandTypes =
+        op.getValue().getType();
+    for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size; ++i) {
+      Type graphOutputOperandType = graphOutputOperandTypes[i];
----------------
kuhar wrote:

You can use `llvm::enumerate` here

https://github.com/llvm/llvm-project/pull/151934


More information about the Mlir-commits mailing list