[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension - part 3 (PR #156845)

Davide Grohmann llvmlistbot at llvm.org
Thu Sep 11 01:28:42 PDT 2025


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/156845

>From a1b6cf3d1a29c6642c8ae1338605244c7e18b513 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Mon, 11 Aug 2025 13:43:37 +0200
Subject: [PATCH 1/2] [mlir][spirv] Add support for SPV_ARM_graph extension -
 part 3
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This is the third patch to add support for the `SPV_ARM_graph` SPIR-V
extension to MLIR’s SPIR-V dialect. The extension introduces a new
`Graph` abstraction for expressing dataflow computations over full
resources.

The part 3 implementation includes:

- ABI lowering support for graph entry points via `LowerABIAttributesPass`.
- Tests covering ABI handling.

Graphs currently support only `SPV_ARM_tensors`, but are designed to
generalize to other resource types, such as images.

Spec: https://github.com/KhronosGroup/SPIRV-Registry/pull/346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I31896806a3e3a856530149ffd919b8568d5b6208
---
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |   9 +-
 .../Transforms/LowerABIAttributesPass.cpp     | 114 +++++++++++++++++-
 .../test/Dialect/SPIRV/IR/target-and-abi.mlir |   8 ++
 .../SPIRV/Transforms/abi-interface.mlir       |  22 ++++
 4 files changed, 150 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index fcf1526491971..44c86bc8777e4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1066,7 +1066,12 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
 }
 
 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
-    Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
+    Operation *op, unsigned /*regionIndex*/, unsigned resultIndex,
     NamedAttribute attribute) {
-  return op->emitError("cannot attach SPIR-V attributes to region result");
+  if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op))
+    return verifyRegionAttribute(
+        op->getLoc(), graphOp.getResultTypes()[resultIndex], attribute);
+  return op->emitError(
+      "cannot attach SPIR-V attributes to region result which is "
+      "not part of a spirv::GraphARMOp type");
 }
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 3911ec08fcc27..91aa0e3823a31 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -85,10 +85,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
                                          abiInfo.getBinding());
 }
 
+/// Creates a global variable for an argument or result based on the ABI info.
+static spirv::GlobalVariableOp
+createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
+                                  unsigned index, bool isArg,
+                                  spirv::InterfaceVarABIAttr abiInfo) {
+  auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
+  if (!spirvModule)
+    return nullptr;
+
+  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+  builder.setInsertionPoint(graphOp.getOperation());
+  std::string varName = graphOp.getName().str() + (isArg ? "_arg_" : "_res_") +
+                        std::to_string(index);
+
+  Type varType = isArg ? graphOp.getFunctionType().getInput(index)
+                       : graphOp.getFunctionType().getResult(index);
+
+  auto pointerType = spirv::PointerType::get(
+      varType,
+      abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));
+
+  return builder.create<spirv::GlobalVariableOp>(
+      graphOp.getLoc(), pointerType, varName, abiInfo.getDescriptorSet(),
+      abiInfo.getBinding());
+}
+
 /// Gets the global variables that need to be specified as interface variable
 /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
 static LogicalResult
-getInterfaceVariables(spirv::FuncOp funcOp,
+getInterfaceVariables(mlir::FunctionOpInterface funcOp,
                       SmallVectorImpl<Attribute> &interfaceVars) {
   auto module = funcOp->getParentOfType<spirv::ModuleOp>();
   if (!module) {
@@ -224,6 +250,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// A pattern to convert graph signature according to interface variable ABI
+/// attributes.
+///
+/// Specifically, this pattern creates global variables according to interface
+/// variable ABI attributes attached to graph arguments and results.
+class ProcessGraphInterfaceVarABI final
+    : public OpConversionPattern<spirv::GraphARMOp> {
+public:
+  using OpConversionPattern<spirv::GraphARMOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Pass to implement the ABI information specified as attributes.
 class LowerABIAttributesPass final
     : public spirv::impl::SPIRVLowerABIAttributesPassBase<
@@ -297,6 +338,65 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
   return success();
 }
 
+LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
+    spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  // Non-entry point graphs are not handled.
+  if (!graphOp.getEntryPoint().value_or(false))
+    return failure();
+
+  TypeConverter::SignatureConversion signatureConverter(
+      graphOp.getFunctionType().getNumInputs());
+
+  StringRef attrName = spirv::getInterfaceVarABIAttrName();
+  SmallVector<Attribute, 4> interfaceVars;
+
+  // Convert arguments.
+  unsigned numInputs = graphOp.getFunctionType().getNumInputs();
+  unsigned numResults = graphOp.getFunctionType().getNumResults();
+  for (unsigned index = 0; index < numInputs; ++index) {
+    auto abiInfo =
+        graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(index, attrName);
+    if (!abiInfo)
+      return failure();
+    spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+        rewriter, graphOp, index, true, abiInfo);
+    if (!var)
+      return failure();
+    interfaceVars.push_back(
+        SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+  }
+
+  for (unsigned index = 0; index < numResults; ++index) {
+    auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+        index, attrName);
+    if (!abiInfo)
+      return failure();
+    spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+        rewriter, graphOp, index, false, abiInfo);
+    if (!var)
+      return failure();
+    interfaceVars.push_back(
+        SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+  }
+
+  // Update signature.
+  rewriter.modifyOpInPlace(graphOp, [&] {
+    for (unsigned index = 0; index < numInputs; ++index) {
+      graphOp.removeArgAttr(index, attrName);
+    }
+    for (unsigned index = 0; index < numResults; ++index) {
+      graphOp.removeResultAttr(index, rewriter.getStringAttr(attrName));
+    }
+  });
+
+  OpBuilder::InsertionGuard insertionGuard(rewriter);
+  rewriter.setInsertionPoint(graphOp);
+  rewriter.create<spirv::GraphEntryPointARMOp>(graphOp.getLoc(), graphOp,
+                                               interfaceVars);
+  return success();
+}
+
 void LowerABIAttributesPass::runOnOperation() {
   // Uses the signature conversion methodology of the dialect conversion
   // framework to implement the conversion.
@@ -323,6 +423,7 @@ void LowerABIAttributesPass::runOnOperation() {
 
   RewritePatternSet patterns(context);
   patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
+  patterns.add<ProcessGraphInterfaceVarABI>(typeConverter, context);
 
   ConversionTarget target(*context);
   // "Legal" function ops should have no interface variable ABI attributes.
@@ -333,6 +434,17 @@ void LowerABIAttributesPass::runOnOperation() {
         return false;
     return true;
   });
+  target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
+    StringRef attrName = spirv::getInterfaceVarABIAttrName();
+    for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
+      if (op.getArgAttr(i, attrName))
+        return false;
+    for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
+      if (op.getResultAttr(i, attrName))
+        return false;
+    return true;
+  });
+
   // All other SPIR-V ops are legal.
   target.markUnknownOpDynamicallyLegal([](Operation *op) {
     return op->getDialect()->getNamespace() ==
diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 10fbcf06eb052..63dea6af83556 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -101,6 +101,14 @@ func.func @interface_var(
 
 // -----
 
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+spirv.ARM.Graph @interface_var(%arg: !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (
+    !spirv.arm.tensor<1xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+) { spirv.ARM.GraphOutputs %arg : !spirv.arm.tensor<1xf32> }
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.resource_limits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index f3a3218e5aec0..04667c828bbd1 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -35,6 +35,28 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+module attributes {
+  spirv.target_env = #spirv.target_env<
+     #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: spirv.module
+spirv.module Logical Vulkan {
+  //  CHECK-DAG:    spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+  //  CHECK-DAG:    spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+
+  //      CHECK:    spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+  //      CHECK:    spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true}
+  spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
+                  -> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+  }
+} // end spirv.module
+
+} // end module
+
+// -----
+
 module {
 // expected-error at +1 {{'spirv.module' op missing SPIR-V target env attribute}}
 spirv.module Logical GLSL450 {}

>From b0ef295ec2564c8e0a9a7ad0b323628d83968559 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Thu, 11 Sep 2025 10:26:29 +0200
Subject: [PATCH 2/2] resolve code review comments

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I83992517e77c9dc53fd5da2e839ba4fc22f9f6a7
---
 .../Transforms/LowerABIAttributesPass.cpp     | 23 +++++++++----------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 91aa0e3823a31..8d9295aa25fcd 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
 
 namespace mlir {
 namespace spirv {
@@ -96,8 +97,8 @@ createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
 
   OpBuilder::InsertionGuard moduleInsertionGuard(builder);
   builder.setInsertionPoint(graphOp.getOperation());
-  std::string varName = graphOp.getName().str() + (isArg ? "_arg_" : "_res_") +
-                        std::to_string(index);
+  std::string varName = llvm::formatv("{0}_{1}_{2}", graphOp.getName(),
+                                      isArg ? "arg" : "res", index);
 
   Type varType = isArg ? graphOp.getFunctionType().getInput(index)
                        : graphOp.getFunctionType().getResult(index);
@@ -106,9 +107,9 @@ createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
       varType,
       abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));
 
-  return builder.create<spirv::GlobalVariableOp>(
-      graphOp.getLoc(), pointerType, varName, abiInfo.getDescriptorSet(),
-      abiInfo.getBinding());
+  return spirv::GlobalVariableOp::create(builder, graphOp.getLoc(), pointerType,
+                                         varName, abiInfo.getDescriptorSet(),
+                                         abiInfo.getBinding());
 }
 
 /// Gets the global variables that need to be specified as interface variable
@@ -380,7 +381,7 @@ LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
         SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
   }
 
-  // Update signature.
+  // Update graph signature.
   rewriter.modifyOpInPlace(graphOp, [&] {
     for (unsigned index = 0; index < numInputs; ++index) {
       graphOp.removeArgAttr(index, attrName);
@@ -390,10 +391,8 @@ LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
     }
   });
 
-  OpBuilder::InsertionGuard insertionGuard(rewriter);
-  rewriter.setInsertionPoint(graphOp);
-  rewriter.create<spirv::GraphEntryPointARMOp>(graphOp.getLoc(), graphOp,
-                                               interfaceVars);
+  spirv::GraphEntryPointARMOp::create(rewriter, graphOp.getLoc(), graphOp,
+                                      interfaceVars);
   return success();
 }
 
@@ -422,8 +421,8 @@ void LowerABIAttributesPass::runOnOperation() {
   });
 
   RewritePatternSet patterns(context);
-  patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
-  patterns.add<ProcessGraphInterfaceVarABI>(typeConverter, context);
+  patterns.add<ProcessInterfaceVarABI, ProcessGraphInterfaceVarABI>(
+      typeConverter, context);
 
   ConversionTarget target(*context);
   // "Legal" function ops should have no interface variable ABI attributes.



More information about the Mlir-commits mailing list