[Mlir-commits] [mlir] [mlir][tosa][spirv] Add TOSA to SPIR-V TOSA pass plumbing (PR #196539)
Davide Grohmann
llvmlistbot at llvm.org
Tue May 12 05:57:50 PDT 2026
================
@@ -0,0 +1,182 @@
+//===- TosaToSPIRVTosa.cpp - TOSA to SPIR-V Graph/TOSA patterns -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert TOSA IR to SPIR-V Graph/TOSA.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+
+#define DEBUG_TYPE "tosa-to-spirv-tosa-pattern"
+
+namespace mlir::tosa {
+namespace {
+
+constexpr StringLiteral graphARMInterfaceVarABIAttrName =
+ "spv.grapharm.interface_var_abi";
+
+void copyFuncAttrsToGraph(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
+ spirv::GraphARMOp graphOp) {
+ for (NamedAttribute attr : adaptor.getAttributes()) {
+ StringRef attrName = attr.getName().getValue();
+ if (llvm::is_contained({SymbolTable::getSymbolAttrName(),
+ funcOp.getFunctionTypeAttrName().getValue(),
+ funcOp.getArgAttrsAttrName().getValue(),
+ funcOp.getResAttrsAttrName().getValue(),
+ graphOp.getEntryPointAttrName().getValue()},
+ attrName))
+ continue;
+
+ graphOp->setAttr(attr.getName(), attr.getValue());
+ }
+}
+
+struct FuncGraphConvert final : OpConversionPattern<func::FuncOp> {
+ using Base::Base;
+
+private:
+ void normalizeInterfaceVarABIAttr(spirv::GraphARMOp graphOp,
+ MLIRContext *context, unsigned index,
+ bool isResult,
+ uint32_t defaultDescriptorSet,
+ uint32_t defaultBinding) const {
+ auto abiInfo =
+ isResult ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, graphARMInterfaceVarABIAttrName)
+ : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, graphARMInterfaceVarABIAttrName);
+
+ if (!abiInfo) {
+ abiInfo = isResult
+ ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, spirv::getInterfaceVarABIAttrName())
+ : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, spirv::getInterfaceVarABIAttrName());
+ }
+
+ if (!abiInfo) {
+ abiInfo = spirv::InterfaceVarABIAttr::get(
+ defaultDescriptorSet, defaultBinding, std::nullopt, context);
+ }
+
+ if (isResult) {
+ graphOp.setResultAttr(index, spirv::getInterfaceVarABIAttrName(),
+ abiInfo);
+ graphOp.removeResultAttr(index, graphARMInterfaceVarABIAttrName);
+ } else {
+ graphOp.setArgAttr(index, spirv::getInterfaceVarABIAttrName(), abiInfo);
+ graphOp.removeArgAttr(index, graphARMInterfaceVarABIAttrName);
+ }
+ }
+
+ void normalizeInterfaceVarABIAttrs(spirv::GraphARMOp graphOp,
+ MLIRContext *context, unsigned inputs,
+ unsigned outputs,
+ uint32_t descriptorSet) const {
+ for (auto argIndex : llvm::seq<unsigned>(0, inputs)) {
+ normalizeInterfaceVarABIAttr(graphOp, context, argIndex, false,
+ descriptorSet, argIndex);
+ }
+ for (auto resIndex : llvm::seq<unsigned>(0, outputs)) {
+ normalizeInterfaceVarABIAttr(graphOp, context, resIndex, true,
+ descriptorSet, resIndex + inputs);
+ }
+ }
+
+public:
+ LogicalResult
+ matchAndRewrite(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MLIRContext *context = rewriter.getContext();
+
+ StringRef name = adaptor.getSymName();
+
+ bool entryPoint = !isa<func::FuncOp>(funcOp->getParentOp());
+ if (entryPoint) {
+ auto spvModule = spirv::ModuleOp::create(
+ rewriter, funcOp.getLoc(), spirv::AddressingModel::Logical,
+ spirv::MemoryModel::Vulkan, std::nullopt,
+ ("_spirv_tosa_" + name).str());
+
+ rewriter.setInsertionPoint(spvModule.getBody(), spvModule.begin());
+ }
+
+ FunctionType ftype = adaptor.getFunctionType();
+ ArrayAttr argAttrs = adaptor.getArgAttrsAttr();
+ ArrayAttr resAttrs = adaptor.getResAttrsAttr();
+
+ TypeConverter::SignatureConversion signatureConverter(ftype.getNumInputs());
+ if (failed(typeConverter->convertSignatureArgs(ftype.getInputs(),
+ signatureConverter))) {
+ return funcOp.emitError("failed to convert function argument types");
+ }
+
+ // Update the signature of the function.
+ SmallVector<Type, 2> newResultTypes;
+ if (failed(getTypeConverter()->convertTypes(ftype.getResults(),
+ newResultTypes))) {
+ return funcOp.emitError("failed to convert function result types");
+ }
+
+ auto graphTy = GraphType::get(
+ context, signatureConverter.getConvertedTypes(), newResultTypes);
+ auto entryPointAttr = BoolAttr::get(context, entryPoint);
+ auto graphOp =
+ spirv::GraphARMOp::create(rewriter, funcOp.getLoc(), graphTy, argAttrs,
+ resAttrs, entryPointAttr, name);
+ copyFuncAttrsToGraph(funcOp, adaptor, graphOp);
+ rewriter.inlineRegionBefore(funcOp.getBody(), graphOp.getBody(),
+ graphOp.end());
+ if (failed(rewriter.convertRegionTypes(
+ &graphOp.getBody(), *getTypeConverter(), &signatureConverter))) {
+ return funcOp.emitError("failed to convert function regions");
+ }
+
+ if (entryPoint) {
+ uint32_t descriptorSet = 0;
+ if (auto descriptorSetAttr =
+ funcOp->getAttrOfType<IntegerAttr>("descriptor_set")) {
+ descriptorSet = static_cast<uint32_t>(descriptorSetAttr.getUInt());
+ }
+
+ normalizeInterfaceVarABIAttrs(graphOp, context, ftype.getNumInputs(),
+ ftype.getNumResults(), descriptorSet);
+ }
+
+ rewriter.eraseOp(funcOp);
----------------
davidegrohmann wrote:
Agreed. For this conversion, the source `func.func` is being used as the container for a single TOSA graph body, not as a general callable function abstraction. There is no SPIR-V Graph invoke/call operation we can lower `func.call` to, so preserving callable helpers or translating calls would need a separate design.
I think the right behavior for this patch is to reject `func.call`/`func.call_indirect` explicitly and require callers to inline/lower calls before running `tosa-to-spirv-tosa`. I will mark those ops illegal in the conversion target and add failing conversion patterns so the pass reports a clear diagnostic instead of leaving dangling calls inside the generated `spirv.ARM.Graph`.
https://github.com/llvm/llvm-project/pull/196539
More information about the Mlir-commits
mailing list