[Mlir-commits] [mlir] [mlir][tosa][spirv] Add TOSA to SPIR-V TOSA pass plumbing (PR #196539)
Davide Grohmann
llvmlistbot at llvm.org
Wed May 20 06:47:50 PDT 2026
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/196539
>From d835e11bfdee7a00da11fb24ea1b50c184096775 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 8 May 2026 11:50:35 +0200
Subject: [PATCH] [mlir][tosa][spirv] Add TOSA to SPIR-V TOSA pass plumbing
Introduce the initial TosaToSPIRVTosa conversion pass and library
wiring. This slice converts func.func regions to spirv.ARM.Graph
inside spirv.module, rewrites graph input/result types to SPIR-V ARM
tensor types, maps func.return to spirv.ARM.GraphOutputs, and adds
focused tests for type conversion, descriptor bindings, and nested
containers.
Change-Id: I1b9b80e7575fb21be2dfb0d88913a74d672edeb2
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 19 ++
.../TosaToSPIRVTosa/TosaToSPIRVTosa.h | 45 ++++
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../Conversion/TosaToSPIRVTosa/CMakeLists.txt | 21 ++
.../TosaToSPIRVTosa/TosaToSPIRVTosa.cpp | 188 +++++++++++++++
.../TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp | 215 ++++++++++++++++++
.../descriptor-set-and-bindings.mlir | 19 ++
.../TosaToSPIRVTosa/op-nesting.mlir | 28 +++
.../TosaToSPIRVTosa/type-conversions.mlir | 67 ++++++
.../unsupported-func-calls.mlir | 28 +++
11 files changed, 632 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h
create mode 100644 mlir/lib/Conversion/TosaToSPIRVTosa/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.cpp
create mode 100644 mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp
create mode 100644 mlir/test/Conversion/TosaToSPIRVTosa/descriptor-set-and-bindings.mlir
create mode 100644 mlir/test/Conversion/TosaToSPIRVTosa/op-nesting.mlir
create mode 100644 mlir/test/Conversion/TosaToSPIRVTosa/type-conversions.mlir
create mode 100644 mlir/test/Conversion/TosaToSPIRVTosa/unsupported-func-calls.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a54b98004c3b6..82c7670296e52 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -76,6 +76,7 @@
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h"
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d401b56c7602d..dda756ddab152 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1410,6 +1410,25 @@ def TosaToSCFPass : Pass<"tosa-to-scf"> {
}];
}
+//===----------------------------------------------------------------------===//
+// TOSA to SPIR-V Graph/TOSA
+//===----------------------------------------------------------------------===//
+
+def TosaToSPIRVTosa : Pass<"tosa-to-spirv-tosa"> {
+ let summary = "Lower TOSA IR to SPIR-V Graph/TOSA operations";
+ let dependentDialects = [
+ "spirv::SPIRVDialect",
+ ];
+ let description = [{
+ Converts TOSA programs to the SPIR-V Graph/TOSA representation by
+ wrapping converted functions in `spirv.module` and `spirv.ARM.Graph`,
+ and rewriting TOSA tensor and shape types to the corresponding SPIR-V ARM
+ tensor types.
+ }];
+
+ let constructor = "tosa::createTosaToSPIRVTosa()";
+}
+
//===----------------------------------------------------------------------===//
// TosaToTensor
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h b/mlir/include/mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h
new file mode 100644
index 0000000000000..fc36d82ff20c1
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h
@@ -0,0 +1,45 @@
+//===-- TosaToSPIRVTosa.h - TOSA to SPIR-V Graph/TOSA patterns --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Provides pass and patterns to lower TOSA IR to SPIR-V Graph/TOSA
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOSPIRVTOSA_TOSATOSPIRVTOSA_H
+#define MLIR_CONVERSION_TOSATOSPIRVTOSA_TOSATOSPIRVTOSA_H
+
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+#define GEN_PASS_DECL_TOSATOSPIRVTOSA
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToSPIRVTosa();
+
+spirv::VerCapExtAttr getDefaultVerCapExtAttr(MLIRContext *context);
+
+spirv::TargetEnvAttr constructTargetEnvAttrWithCapExtDefaults(
+ MLIRContext *context, spirv::ResourceLimitsAttr limits = {},
+ spirv::ClientAPI clientAPI = spirv::ClientAPI::Unknown,
+ spirv::Vendor vendorID = spirv::Vendor::Unknown,
+ spirv::DeviceType deviceType = spirv::DeviceType::Unknown,
+ uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID);
+
+void populateTosaToSPIRVTosaConversionPatterns(
+ SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns,
+ spirv::TargetEnvAttr targetAttr);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOSPIRVTOSA_TOSATOSPIRVTOSA_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e17988b12cade..f5e0bcf613e59 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -69,6 +69,7 @@ add_subdirectory(TosaToArith)
add_subdirectory(TosaToLinalg)
add_subdirectory(TosaToMLProgram)
add_subdirectory(TosaToSCF)
+add_subdirectory(TosaToSPIRVTosa)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
add_subdirectory(UBToSPIRV)
diff --git a/mlir/lib/Conversion/TosaToSPIRVTosa/CMakeLists.txt b/mlir/lib/Conversion/TosaToSPIRVTosa/CMakeLists.txt
new file mode 100644
index 0000000000000..630278447fa42
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSPIRVTosa/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRTosaToSPIRVTosa
+ TosaToSPIRVTosa.cpp
+ TosaToSPIRVTosaPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRIR
+ MLIRPass
+ MLIRSPIRVDialect
+ MLIRSPIRVConversion
+ MLIRSupport
+ MLIRTransformUtils
+ MLIRTosaDialect
+)
diff --git a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.cpp b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.cpp
new file mode 100644
index 0000000000000..92d2479fdf543
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.cpp
@@ -0,0 +1,188 @@
+//===- TosaToSPIRVTosa.cpp - TOSA to SPIR-V Graph/TOSA patterns -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert TOSA IR to SPIR-V Graph/TOSA.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+
+#define DEBUG_TYPE "tosa-to-spirv-tosa-pattern"
+
+namespace mlir::tosa {
+namespace {
+
+// Allows users to specify descriptor sets and binding ids on the source
+// function inputs and outputs. Use a source-side GraphARM attribute because
+// `spirv.interface_var_abi` is verified by the SPIR-V dialect before this
+// conversion runs, and result attrs are only accepted on `spirv.ARM.Graph`.
+constexpr StringLiteral graphARMInterfaceVarABIAttrName =
+ "grapharm.interface_var_abi";
+
+void copyFuncAttrsToGraph(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
+ spirv::GraphARMOp graphOp) {
+ for (NamedAttribute attr : adaptor.getAttributes()) {
+ StringRef attrName = attr.getName().getValue();
+ if (llvm::is_contained({SymbolTable::getSymbolAttrName(),
+ funcOp.getFunctionTypeAttrName().getValue(),
+ funcOp.getArgAttrsAttrName().getValue(),
+ funcOp.getResAttrsAttrName().getValue(),
+ graphOp.getEntryPointAttrName().getValue()},
+ attrName))
+ continue;
+
+ graphOp->setAttr(attr.getName(), attr.getValue());
+ }
+}
+
+struct FuncGraphConvert final : OpConversionPattern<func::FuncOp> {
+ FuncGraphConvert(SPIRVTypeConverter &typeConverter, MLIRContext *context,
+ spirv::TargetEnvAttr targetAttr)
+ : OpConversionPattern<func::FuncOp>(typeConverter, context),
+ targetAttr(targetAttr) {}
+
+private:
+ spirv::TargetEnvAttr targetAttr;
+
+ // Prefer an explicit source-side GraphARM ABI annotation, then preserve an
+ // already-canonical SPIR-V ABI annotation, and otherwise synthesize the
+ // default descriptor set and binding id.
+ void normalizeInterfaceVarABIAttr(spirv::GraphARMOp graphOp,
+ MLIRContext *context, unsigned index,
+ bool isResult,
+ uint32_t defaultDescriptorSet,
+ uint32_t defaultBinding) const {
+ auto abiInfo =
+ isResult ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, graphARMInterfaceVarABIAttrName)
+ : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, graphARMInterfaceVarABIAttrName);
+
+ if (!abiInfo) {
+ abiInfo = isResult
+ ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, spirv::getInterfaceVarABIAttrName())
+ : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+ index, spirv::getInterfaceVarABIAttrName());
+ }
+
+ if (!abiInfo) {
+ abiInfo = spirv::InterfaceVarABIAttr::get(
+ defaultDescriptorSet, defaultBinding, std::nullopt, context);
+ }
+
+ if (isResult) {
+ graphOp.setResultAttr(index, spirv::getInterfaceVarABIAttrName(),
+ abiInfo);
+ graphOp.removeResultAttr(index, graphARMInterfaceVarABIAttrName);
+ } else {
+ graphOp.setArgAttr(index, spirv::getInterfaceVarABIAttrName(), abiInfo);
+ graphOp.removeArgAttr(index, graphARMInterfaceVarABIAttrName);
+ }
+ }
+
+ void normalizeInterfaceVarABIAttrs(spirv::GraphARMOp graphOp,
+ MLIRContext *context, unsigned inputs,
+ unsigned outputs) const {
+ constexpr uint32_t defaultDescriptorSet = 0;
+ for (auto argIndex : llvm::seq<unsigned>(0, inputs)) {
+ normalizeInterfaceVarABIAttr(graphOp, context, argIndex, false,
+ defaultDescriptorSet, argIndex);
+ }
+ for (auto resIndex : llvm::seq<unsigned>(0, outputs)) {
+ normalizeInterfaceVarABIAttr(graphOp, context, resIndex, true,
+ defaultDescriptorSet, resIndex + inputs);
+ }
+ }
+
+public:
+ LogicalResult
+ matchAndRewrite(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MLIRContext *context = rewriter.getContext();
+
+ StringRef name = adaptor.getSymName();
+ auto spvModule = spirv::ModuleOp::create(
+ rewriter, funcOp.getLoc(), spirv::AddressingModel::Logical,
+ spirv::MemoryModel::Vulkan, std::nullopt,
+ ("_spirv_tosa_" + name).str());
+ spvModule->setAttr(spirv::getTargetEnvAttrName(), targetAttr);
+
+ rewriter.setInsertionPoint(spvModule.getBody(), spvModule.begin());
+
+ FunctionType ftype = adaptor.getFunctionType();
+ ArrayAttr argAttrs = adaptor.getArgAttrsAttr();
+ ArrayAttr resAttrs = adaptor.getResAttrsAttr();
+
+ TypeConverter::SignatureConversion signatureConverter(ftype.getNumInputs());
+ if (failed(typeConverter->convertSignatureArgs(ftype.getInputs(),
+ signatureConverter))) {
+ return funcOp.emitError("failed to convert function argument types");
+ }
+
+ // Update the signature of the function.
+ SmallVector<Type, 2> newResultTypes;
+ if (failed(getTypeConverter()->convertTypes(ftype.getResults(),
+ newResultTypes))) {
+ return funcOp.emitError("failed to convert function result types");
+ }
+
+ // TOSA graphs cannot contain nested funcs, so the converted GraphARM op is
+ // an entry point.
+ auto entryPointAttr = BoolAttr::get(context, true);
+ auto graphTy = GraphType::get(
+ context, signatureConverter.getConvertedTypes(), newResultTypes);
+ auto graphOp =
+ spirv::GraphARMOp::create(rewriter, funcOp.getLoc(), graphTy, argAttrs,
+ resAttrs, entryPointAttr, name);
+ copyFuncAttrsToGraph(funcOp, adaptor, graphOp);
+
+ rewriter.inlineRegionBefore(funcOp.getBody(), graphOp.getBody(),
+ graphOp.end());
+ if (failed(rewriter.convertRegionTypes(
+ &graphOp.getBody(), *getTypeConverter(), &signatureConverter))) {
+ return funcOp.emitError("failed to convert function regions");
+ }
+
+ normalizeInterfaceVarABIAttrs(graphOp, context, ftype.getNumInputs(),
+ ftype.getNumResults());
+
+ rewriter.eraseOp(funcOp);
+ return success();
+ }
+};
+
+/// Converts func.return to spirv.ARM.GraphOutputs.
+struct ReturnGraphOutputConvert final : OpConversionPattern<func::ReturnOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<spirv::GraphOutputsARMOp>(
+ returnOp, adaptor.getOperands());
+ return success();
+ }
+};
+
+} // namespace
+
+void populateTosaToSPIRVTosaConversionPatterns(
+ SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns,
+ spirv::TargetEnvAttr targetAttr) {
+ patterns.add<FuncGraphConvert>(typeConverter, patterns.getContext(),
+ targetAttr);
+ patterns.add<ReturnGraphOutputConvert>(typeConverter, patterns.getContext());
+}
+
+} // namespace mlir::tosa
diff --git a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp
new file mode 100644
index 0000000000000..f558d14c2dabd
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp
@@ -0,0 +1,215 @@
+//===- TosaToSPIRVTosaPass.cpp - Lower TOSA to SPIR-V Graph/TOSA ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers TOSA IR to the SPIR-V Graph/TOSA representation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include <algorithm>
+
+namespace mlir {
+#define GEN_PASS_DEF_TOSATOSPIRVTOSA
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace tosa {
+
+spirv::VerCapExtAttr getDefaultVerCapExtAttr(MLIRContext *context) {
+ return spirv::VerCapExtAttr::get(
+ spirv::Version::V_1_5,
+ {
+ spirv::Capability::VulkanMemoryModel,
+ spirv::Capability::Shader,
+ spirv::Capability::Int8,
+ spirv::Capability::Int16,
+ spirv::Capability::Int64,
+ spirv::Capability::Float16,
+ spirv::Capability::BFloat16TypeKHR,
+ spirv::Capability::Float8EXT,
+ spirv::Capability::TensorsARM,
+ spirv::Capability::GraphARM,
+ spirv::Capability::ReplicatedCompositesEXT,
+ },
+ {
+ spirv::Extension::SPV_ARM_tensors,
+ spirv::Extension::SPV_ARM_graph,
+ spirv::Extension::SPV_KHR_vulkan_memory_model,
+ spirv::Extension::SPV_EXT_replicated_composites,
+ spirv::Extension::SPV_KHR_bfloat16,
+ spirv::Extension::SPV_EXT_float8,
+ },
+ context);
+}
+
+spirv::TargetEnvAttr constructTargetEnvAttrWithCapExtDefaults(
+ MLIRContext *context, spirv::ResourceLimitsAttr limits,
+ spirv::ClientAPI clientAPI, spirv::Vendor vendorID,
+ spirv::DeviceType deviceType, uint32_t deviceID) {
+ if (!limits)
+ limits = spirv::getDefaultResourceLimits(context);
+
+ return spirv::TargetEnvAttr::get(getDefaultVerCapExtAttr(context), limits,
+ clientAPI, vendorID, deviceType, deviceID);
+}
+
+namespace {
+
+LogicalResult verifyGraphTargetEnv(Operation *op,
+ spirv::TargetEnvAttr targetAttr) {
+ spirv::TargetEnv targetEnv(targetAttr);
+ if (targetEnv.allows(spirv::Capability::GraphARM) &&
+ targetEnv.allows(spirv::Extension::SPV_ARM_graph) &&
+ targetEnv.allows(spirv::Extension::SPV_ARM_tensors)) {
+ return success();
+ }
+
+ return op->emitOpError()
+ << "requires GraphARM capability and SPV_ARM_graph/SPV_ARM_tensors "
+ "extensions in spirv.target_env";
+}
+
+LogicalResult verifyNoUnsupportedFuncOps(Operation *op) {
+ WalkResult result =
+ op->walk([](Operation *op) -> WalkResult {
+ if (isa<func::CallOp, func::CallIndirectOp>(op)) {
+ op->emitOpError()
+ << "is not supported in TOSA to SPIR-V Graph conversion; inline "
+ "calls before running this pass";
+ return WalkResult::interrupt();
+ }
+ if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
+ if (funcOp->getParentOfType<func::FuncOp>()) {
+ funcOp.emitOpError()
+ << "nesting is not supported in TOSA to SPIR-V Graph conversion";
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
+ });
+
+ return failure(result.wasInterrupted());
+}
+
+struct TosaToSPIRVTosa final : impl::TosaToSPIRVTosaBase<TosaToSPIRVTosa> {
+ void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ RewritePatternSet patterns(context);
+ Operation *op = getOperation();
+
+ spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(op);
+ if (!targetAttr) {
+ targetAttr = constructTargetEnvAttrWithCapExtDefaults(context);
+ }
+
+ if (failed(verifyGraphTargetEnv(op, targetAttr)) ||
+ failed(verifyNoUnsupportedFuncOps(op))) {
+ signalPassFailure();
+ return;
+ }
+
+ std::unique_ptr<ConversionTarget> target =
+ SPIRVConversionTarget::get(targetAttr);
+
+ target->addIllegalDialect<tosa::TosaDialect>();
+ target->addIllegalOp<func::CallOp, func::CallIndirectOp>();
+
+ SPIRVTypeConverter typeConverter(targetAttr);
+ typeConverter.addConversion([this](IntegerType integerType) {
+ return this->convertIntegerType(integerType);
+ });
+ typeConverter.addConversion([this](TensorType tensorType) {
+ return this->convertTensorType(tensorType);
+ });
+ typeConverter.addConversion([this](tosa::shapeType shapeType) {
+ return this->convertShapeType(shapeType);
+ });
+
+ populateTosaToSPIRVTosaConversionPatterns(typeConverter, patterns,
+ targetAttr);
+
+ FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+ if (failed(applyPartialConversion(op, *target, frozenPatterns))) {
+ signalPassFailure();
+ }
+ }
+
+private:
+ IntegerType convertIntegerType(IntegerType integerType) {
+ if (integerType.getWidth() == 48) {
+ return IntegerType::get(&getContext(), 64, integerType.getSignedness());
+ }
+
+ if (integerType.getWidth() == 4) {
+ return IntegerType::get(&getContext(), 8, integerType.getSignedness());
+ }
+
+ return integerType;
+ }
+
+ std::optional<SmallVector<int64_t>> convertShape(ArrayRef<int64_t> shape) {
+ // Scalar ARM tensors are not supported, so convert them to
+ // tensors with shape [1].
+ if (shape.empty())
+ return SmallVector<int64_t>({1});
+
+ if (llvm::is_contained(shape, 0))
+ return std::nullopt;
+
+ bool isPartiallyDynamic =
+ llvm::is_contained(shape, ShapedType::kDynamic) &&
+ llvm::any_of(shape, [](int64_t dim) { return dim > 0; });
+ // Partially shaped ARM tensors are not supported, so convert them to
+ // unshaped tensors.
+ if (isPartiallyDynamic)
+ return SmallVector<int64_t>(shape.size(), ShapedType::kDynamic);
+ return SmallVector<int64_t>(shape);
+ }
+
+ std::optional<spirv::TensorArmType> convertTensorType(TensorType tensorType) {
+ Type elementType = getElementTypeOrSelf(tensorType);
+ if (elementType.isIndex())
+ elementType = IntegerType::get(&getContext(), 32);
+ if (auto integerType = dyn_cast<IntegerType>(elementType))
+ elementType = convertIntegerType(integerType);
+
+ SmallVector<int64_t> shape;
+ if (tensorType.hasRank()) {
+ std::optional<SmallVector<int64_t>> convertedShape =
+ convertShape(tensorType.getShape());
+ if (!convertedShape)
+ return std::nullopt;
+ shape = std::move(*convertedShape);
+ }
+
+ return spirv::TensorArmType::get(shape, elementType);
+ }
+
+ spirv::TensorArmType convertShapeType(tosa::shapeType shapeType) {
+ const int64_t rank = std::max(shapeType.getRank(), 1);
+ return spirv::TensorArmType::get({rank},
+ IntegerType::get(&getContext(), 32));
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> createTosaToSPIRVTosa() {
+ return std::make_unique<TosaToSPIRVTosa>();
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/descriptor-set-and-bindings.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/descriptor-set-and-bindings.mlir
new file mode 100644
index 0000000000000..bbbe72c00a15b
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/descriptor-set-and-bindings.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa --verify-diagnostics %s | FileCheck %s
+
+// CHECK-NOT: grapharm.interface_var_abi
+
+// CHECK: spirv.module @_spirv_tosa_default_interface_var_abi Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @default_interface_var_abi(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @default_interface_var_abi(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+ // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+ return %arg0 : tensor<1xi8>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_custom_grapharm_abi Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @custom_grapharm_abi(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(3, 9)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(7, 11)>}) attributes {entry_point = true} {
+func.func @custom_grapharm_abi(%arg0: tensor<1xi8> {grapharm.interface_var_abi = #spirv.interface_var_abi<(3, 9)>}) -> (tensor<1xi8> {grapharm.interface_var_abi = #spirv.interface_var_abi<(7, 11)>}) {
+ // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+ return %arg0 : tensor<1xi8>
+}
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/op-nesting.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/op-nesting.mlir
new file mode 100644
index 0000000000000..11e76d7df1637
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/op-nesting.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa --verify-diagnostics %s | FileCheck %s
+
+// CHECK: gpu.module @random_container
+gpu.module @random_container {
+ // CHECK: spirv.module @_spirv_tosa_nested Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+ // CHECK: spirv.ARM.Graph @nested(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+ func.func @nested(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+ // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+ return %arg0 : tensor<1xi8>
+ }
+}
+
+// -----
+
+// CHECK: module @random_container {
+module @random_container {
+ module @yet_anther_random_container {
+ // CHECK: gpu.module @another_random_container
+ gpu.module @another_random_container {
+ // CHECK: spirv.module @_spirv_tosa_nested Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+ // CHECK: spirv.ARM.Graph @nested(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+ func.func @nested(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+ // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+ return %arg0 : tensor<1xi8>
+ }
+ }
+ }
+}
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/type-conversions.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/type-conversions.mlir
new file mode 100644
index 0000000000000..76ca038f5da2d
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/type-conversions.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa --verify-diagnostics %s | FileCheck %s
+
+// CHECK: spirv.module @_spirv_tosa_i48_to_i64 Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @i48_to_i64(%[[ARG0:.*]]: !spirv.arm.tensor<1x2x3x4xi64> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1x2x3x4xi64> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @i48_to_i64(%arg0: tensor<1x2x3x4xi48>) -> tensor<1x2x3x4xi48> {
+ // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1x2x3x4xi64>
+ return %arg0 : tensor<1x2x3x4xi48>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_i4_to_i8 Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @i4_to_i8(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @i4_to_i8(%arg0: tensor<1xi4>) -> tensor<1xi4> {
+ // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+ return %arg0 : tensor<1xi4>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_scalar_tensor_to_rank1_tensor Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @scalar_tensor_to_rank1_tensor(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @scalar_tensor_to_rank1_tensor(%arg0: tensor<i8>) -> tensor<i8> {
+ // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+ return %arg0 : tensor<i8>
+}
+
+// -----
+
+// expected-error at below {{failed to convert function argument types}}
+func.func @zero_sized_tensor(%arg0: tensor<0xi8>) -> tensor<0xi8> {
+ return %arg0 : tensor<0xi8>
+}
+
+// -----
+
+// expected-error at below {{failed to convert function argument types}}
+func.func @mixed_zero_sized_tensor(%arg0: tensor<1x0x2xi8>) -> tensor<1x0x2xi8> {
+ return %arg0 : tensor<1x0x2xi8>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_partially_shaped_tensor_to_unshaped_tensor Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @partially_shaped_tensor_to_unshaped_tensor(%[[ARG0:.*]]: !spirv.arm.tensor<?x?xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<?x?xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @partially_shaped_tensor_to_unshaped_tensor(%arg0: tensor<1x?xi8>) -> tensor<1x?xi8> {
+ // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<?x?xi8>
+ return %arg0 : tensor<1x?xi8>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_unranked_tensor Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @unranked_tensor(%[[ARG0:.*]]: !spirv.arm.tensor<*xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<*xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @unranked_tensor(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+ // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<*xi8>
+ return %arg0 : tensor<*xi8>
+}
+
+// -----
+
+// expected-error at below {{'builtin.module' op requires GraphARM capability and SPV_ARM_graph/SPV_ARM_tensors extensions in spirv.target_env}}
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>} {
+ func.func @unsupported_target_env(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+ return %arg0 : tensor<1xi8>
+ }
+}
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/unsupported-func-calls.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/unsupported-func-calls.mlir
new file mode 100644
index 0000000000000..52c1ad526727d
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/unsupported-func-calls.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa --verify-diagnostics %s
+
+func.func @direct_call(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+ // expected-error at below {{'func.call' op is not supported in TOSA to SPIR-V Graph conversion; inline calls before running this pass}}
+ %0 = func.call @direct_call(%arg0) : (tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+
+func.func @indirect_call(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+ %callee = func.constant @indirect_call : (tensor<1xi8>) -> tensor<1xi8>
+ // expected-error at below {{'func.call_indirect' op is not supported in TOSA to SPIR-V Graph conversion; inline calls before running this pass}}
+ %0 = func.call_indirect %callee(%arg0) : (tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+
+func.func @nested_func(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+ builtin.module {
+ // expected-error at below {{'func.func' op nesting is not supported in TOSA to SPIR-V Graph conversion}}
+ func.func @nested(%arg1: tensor<1xi8>) -> tensor<1xi8> {
+ return %arg1 : tensor<1xi8>
+ }
+ }
+ return %arg0 : tensor<1xi8>
+}
More information about the Mlir-commits
mailing list