[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension - part 1 (PR #151934)

Jakub Kuderski llvmlistbot at llvm.org
Tue Aug 26 07:53:26 PDT 2025


================
@@ -0,0 +1,255 @@
+//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVParsingUtils.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/Support/InterleavedRange.h"
+
+using namespace mlir;
+using namespace mlir::spirv::AttrNames;
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphARM
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  Builder &builder = parser.getBuilder();
+
+  // Parse the name as a symbol.
+  StringAttr nameAttr;
+  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                             result.attributes))
+    return failure();
+
+  // Parse the function signature.
+  bool isVariadic = false;
+  SmallVector<OpAsmParser::Argument> entryArgs;
+  SmallVector<Type> resultTypes;
+  SmallVector<DictionaryAttr> resultAttrs;
+  if (function_interface_impl::parseFunctionSignatureWithArguments(
+          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+          resultAttrs))
+    return failure();
+
+  SmallVector<Type> argTypes;
+  for (OpAsmParser::Argument &arg : entryArgs)
+    argTypes.push_back(arg.type);
+  GraphType grType = builder.getGraphType(argTypes, resultTypes);
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(grType));
+
+  // If additional attributes are present, parse them.
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  // Add the attributes to the function arguments.
+  assert(resultAttrs.size() == resultTypes.size());
+  call_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
+
+  // Parse the optional function body.
+  Region *body = result.addRegion();
+  OptionalParseResult parseResult =
+      parser.parseOptionalRegion(*body, entryArgs);
+  return failure(parseResult.has_value() && failed(*parseResult));
+}
+
+void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
+  // Print graph name, signature, and control.
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  GraphType grType = getFunctionType();
+  function_interface_impl::printFunctionSignature(
+      printer, *this, grType.getInputs(),
+      /*isVariadic=*/false, grType.getResults());
+  function_interface_impl::printFunctionAttributes(printer, *this,
+                                                   {getFunctionTypeAttrName(),
+                                                    getArgAttrsAttrName(),
+                                                    getResAttrsAttrName()});
+
+  // Print the body.
+  Region &body = this->getBody();
+  if (!body.empty()) {
+    printer << ' ';
+    printer.printRegion(body, /*printEntryBlockArgs=*/false,
+                        /*printBlockTerminators=*/true);
+  }
+}
+
+LogicalResult spirv::GraphARMOp::verifyType() {
+  if (getFunctionType().getNumResults() < 1)
+    return emitOpError("there should be at least one result");
+  return success();
+}
+
+LogicalResult spirv::GraphARMOp::verifyBody() {
+  for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
+    if (!isa<spirv::TensorArmType>(graphArgType)) {
+      return emitOpError("type of argument #")
+             << index << " must be a TensorArmType, but got " << graphArgType;
+    }
+  }
+  for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
+    if (!isa<spirv::TensorArmType>(graphResType)) {
+      return emitOpError("type of result #")
+             << index << " must be a TensorArmType, but got " << graphResType;
+    }
+  }
+
+  if (!isExternal()) {
+    Block &entryBlock = front();
+
+    unsigned numArguments = this->getNumArguments();
+    if (entryBlock.getNumArguments() != numArguments)
+      return emitOpError("entry block must have ")
+             << numArguments << " arguments to match graph signature";
+
+    for (auto [index, grArgType, blockArgType] :
+         llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
+      if (blockArgType != grArgType) {
+        return emitOpError("type of entry block argument #")
+               << index << '(' << blockArgType
+               << ") must match the type of the corresponding argument in "
+               << "graph signature(" << grArgType << ')';
+      }
+    }
+  }
+
+  GraphType grType = getFunctionType();
+  auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
+    if (grType.getNumResults() != op.getNumOperands())
+      return op.emitOpError("is returning ")
+             << op.getNumOperands()
+             << " value(s) but enclosing spirv.ARM.Graph requires "
+             << grType.getNumResults() << " result(s)";
+
+    ValueTypeRange<OperandRange> graphOutputOperandTypes =
+        op.getValue().getType();
+    for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size; ++i) {
+      Type graphOutputOperandType = graphOutputOperandTypes[i];
+      Type grResultType = grType.getResult(i);
+      if (graphOutputOperandType != grResultType)
+        return op.emitError("type of return operand ")
+               << i << " (" << graphOutputOperandType
+               << ") doesn't match graph result type (" << grResultType << ")";
+    }
+    return WalkResult::advance();
+  });
+
+  return failure(walkResult.wasInterrupted());
+}
+
+void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
+                              StringRef name, GraphType type,
+                              ArrayRef<NamedAttribute> attrs, bool entryPoint) {
+  state.addAttribute(SymbolTable::getSymbolAttrName(),
+                     builder.getStringAttr(name));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.attributes.append(attrs);
+  state.addAttribute(getEntryPointAttrName(state.name),
+                     builder.getBoolAttr(entryPoint));
+  state.addRegion();
+}
+
+ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
+  return getFunctionType().getInputs();
+}
+
+ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
+  return getFunctionType().getResults();
+}
+
+Region *spirv::GraphARMOp::getCallableRegion() {
+  return isExternal() ? nullptr : &getBody();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphOutputsARM
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GraphOutputsARMOp::verify() {
+  auto graph = cast<GraphARMOp>((*this)->getParentOp());
+
+  // The operand number and types must match the graph signature.
+  const ArrayRef<Type> &results = graph.getFunctionType().getResults();
+  if (getNumOperands() != results.size())
+    return emitOpError("has ")
+           << getNumOperands() << " operands, but enclosing  spirv.ARM.Graph (@"
+           << graph.getName() << ") returns " << results.size();
+
+  for (unsigned i = 0, size = results.size(); i < size; ++i)
----------------
kuhar wrote:

Use `llvm::enumerate`

https://github.com/llvm/llvm-project/pull/151934


More information about the Mlir-commits mailing list