[Mlir-commits] [mlir] [mlir][tosa][spirv] Add TOSA to SPIR-V TOSA pass plumbing (PR #196539)

Davide Grohmann llvmlistbot at llvm.org
Tue May 12 05:57:50 PDT 2026


================
@@ -0,0 +1,182 @@
+//===- TosaToSPIRVTosa.cpp - TOSA to SPIR-V Graph/TOSA patterns -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert TOSA IR to SPIR-V Graph/TOSA.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+
+#define DEBUG_TYPE "tosa-to-spirv-tosa-pattern"
+
+namespace mlir::tosa {
+namespace {
+
+constexpr StringLiteral graphARMInterfaceVarABIAttrName =
+    "spv.grapharm.interface_var_abi";
+
+void copyFuncAttrsToGraph(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
+                          spirv::GraphARMOp graphOp) {
+  for (NamedAttribute attr : adaptor.getAttributes()) {
+    StringRef attrName = attr.getName().getValue();
+    if (llvm::is_contained({SymbolTable::getSymbolAttrName(),
+                            funcOp.getFunctionTypeAttrName().getValue(),
+                            funcOp.getArgAttrsAttrName().getValue(),
+                            funcOp.getResAttrsAttrName().getValue(),
+                            graphOp.getEntryPointAttrName().getValue()},
+                           attrName))
+      continue;
+
+    graphOp->setAttr(attr.getName(), attr.getValue());
+  }
+}
+
+struct FuncGraphConvert final : OpConversionPattern<func::FuncOp> {
+  using Base::Base;
+
+private:
+  void normalizeInterfaceVarABIAttr(spirv::GraphARMOp graphOp,
+                                    MLIRContext *context, unsigned index,
+                                    bool isResult,
+                                    uint32_t defaultDescriptorSet,
+                                    uint32_t defaultBinding) const {
+    auto abiInfo =
+        isResult ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+                       index, graphARMInterfaceVarABIAttrName)
+                 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+                       index, graphARMInterfaceVarABIAttrName);
+
+    if (!abiInfo) {
+      abiInfo = isResult
+                    ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+                          index, spirv::getInterfaceVarABIAttrName())
+                    : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+                          index, spirv::getInterfaceVarABIAttrName());
+    }
+
+    if (!abiInfo) {
+      abiInfo = spirv::InterfaceVarABIAttr::get(
+          defaultDescriptorSet, defaultBinding, std::nullopt, context);
+    }
+
+    if (isResult) {
+      graphOp.setResultAttr(index, spirv::getInterfaceVarABIAttrName(),
+                            abiInfo);
+      graphOp.removeResultAttr(index, graphARMInterfaceVarABIAttrName);
+    } else {
+      graphOp.setArgAttr(index, spirv::getInterfaceVarABIAttrName(), abiInfo);
+      graphOp.removeArgAttr(index, graphARMInterfaceVarABIAttrName);
+    }
+  }
+
+  void normalizeInterfaceVarABIAttrs(spirv::GraphARMOp graphOp,
+                                     MLIRContext *context, unsigned inputs,
+                                     unsigned outputs,
+                                     uint32_t descriptorSet) const {
+    for (auto argIndex : llvm::seq<unsigned>(0, inputs)) {
+      normalizeInterfaceVarABIAttr(graphOp, context, argIndex, false,
+                                   descriptorSet, argIndex);
+    }
+    for (auto resIndex : llvm::seq<unsigned>(0, outputs)) {
+      normalizeInterfaceVarABIAttr(graphOp, context, resIndex, true,
+                                   descriptorSet, resIndex + inputs);
+    }
+  }
+
+public:
+  LogicalResult
+  matchAndRewrite(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MLIRContext *context = rewriter.getContext();
+
+    StringRef name = adaptor.getSymName();
+
+    bool entryPoint = !isa<func::FuncOp>(funcOp->getParentOp());
+    if (entryPoint) {
+      auto spvModule = spirv::ModuleOp::create(
+          rewriter, funcOp.getLoc(), spirv::AddressingModel::Logical,
+          spirv::MemoryModel::Vulkan, std::nullopt,
+          ("_spirv_tosa_" + name).str());
+
+      rewriter.setInsertionPoint(spvModule.getBody(), spvModule.begin());
+    }
+
+    FunctionType ftype = adaptor.getFunctionType();
+    ArrayAttr argAttrs = adaptor.getArgAttrsAttr();
+    ArrayAttr resAttrs = adaptor.getResAttrsAttr();
+
+    TypeConverter::SignatureConversion signatureConverter(ftype.getNumInputs());
+    if (failed(typeConverter->convertSignatureArgs(ftype.getInputs(),
+                                                   signatureConverter))) {
+      return funcOp.emitError("failed to convert function argument types");
+    }
+
+    // Update the signature of the function.
+    SmallVector<Type, 2> newResultTypes;
+    if (failed(getTypeConverter()->convertTypes(ftype.getResults(),
+                                                newResultTypes))) {
+      return funcOp.emitError("failed to convert function result types");
+    }
+
+    auto graphTy = GraphType::get(
+        context, signatureConverter.getConvertedTypes(), newResultTypes);
+    auto entryPointAttr = BoolAttr::get(context, entryPoint);
+    auto graphOp =
+        spirv::GraphARMOp::create(rewriter, funcOp.getLoc(), graphTy, argAttrs,
+                                  resAttrs, entryPointAttr, name);
+    copyFuncAttrsToGraph(funcOp, adaptor, graphOp);
+    rewriter.inlineRegionBefore(funcOp.getBody(), graphOp.getBody(),
+                                graphOp.end());
+    if (failed(rewriter.convertRegionTypes(
+            &graphOp.getBody(), *getTypeConverter(), &signatureConverter))) {
+      return funcOp.emitError("failed to convert function regions");
+    }
+
+    if (entryPoint) {
+      uint32_t descriptorSet = 0;
+      if (auto descriptorSetAttr =
+              funcOp->getAttrOfType<IntegerAttr>("descriptor_set")) {
+        descriptorSet = static_cast<uint32_t>(descriptorSetAttr.getUInt());
+      }
+
+      normalizeInterfaceVarABIAttrs(graphOp, context, ftype.getNumInputs(),
+                                    ftype.getNumResults(), descriptorSet);
+    }
+
+    rewriter.eraseOp(funcOp);
----------------
davidegrohmann wrote:

 Agreed. For this conversion, the source `func.func` is being used as the container for a single TOSA graph body, not as a general callable function abstraction. There is no SPIR-V Graph invoke/call operation we can lower `func.call` to, so preserving callable helpers or translating calls would need a separate design.

I think the right behavior for this patch is to reject `func.call`/`func.call_indirect` explicitly and require callers to inline/lower calls before running `tosa-to-spirv-tosa`. I will mark those ops illegal in the conversion target and add failing conversion patterns so the pass reports a clear diagnostic instead of leaving dangling calls inside the generated `spirv.ARM.Graph`.


https://github.com/llvm/llvm-project/pull/196539


More information about the Mlir-commits mailing list