[Mlir-commits] [mlir] [mlir][tosa][spirv] Add TOSA to SPIR-V TOSA pass plumbing (PR #196539)

Davide Grohmann llvmlistbot at llvm.org
Wed May 20 06:47:50 PDT 2026


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/196539

>From d835e11bfdee7a00da11fb24ea1b50c184096775 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 8 May 2026 11:50:35 +0200
Subject: [PATCH] [mlir][tosa][spirv] Add TOSA to SPIR-V TOSA pass plumbing

Introduce the initial TosaToSPIRVTosa conversion pass and library
wiring. This slice converts func.func regions to spirv.ARM.Graph
inside spirv.module, rewrites graph input/result types to SPIR-V ARM
tensor types, maps func.return to spirv.ARM.GraphOutputs, and adds
focused tests for type conversion, descriptor bindings, and nested
containers.

Change-Id: I1b9b80e7575fb21be2dfb0d88913a74d672edeb2
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
---
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  19 ++
 .../TosaToSPIRVTosa/TosaToSPIRVTosa.h         |  45 ++++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../Conversion/TosaToSPIRVTosa/CMakeLists.txt |  21 ++
 .../TosaToSPIRVTosa/TosaToSPIRVTosa.cpp       | 188 +++++++++++++++
 .../TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp   | 215 ++++++++++++++++++
 .../descriptor-set-and-bindings.mlir          |  19 ++
 .../TosaToSPIRVTosa/op-nesting.mlir           |  28 +++
 .../TosaToSPIRVTosa/type-conversions.mlir     |  67 ++++++
 .../unsupported-func-calls.mlir               |  28 +++
 11 files changed, 632 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h
 create mode 100644 mlir/lib/Conversion/TosaToSPIRVTosa/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.cpp
 create mode 100644 mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp
 create mode 100644 mlir/test/Conversion/TosaToSPIRVTosa/descriptor-set-and-bindings.mlir
 create mode 100644 mlir/test/Conversion/TosaToSPIRVTosa/op-nesting.mlir
 create mode 100644 mlir/test/Conversion/TosaToSPIRVTosa/type-conversions.mlir
 create mode 100644 mlir/test/Conversion/TosaToSPIRVTosa/unsupported-func-calls.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index a54b98004c3b6..82c7670296e52 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -76,6 +76,7 @@
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
 #include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h"
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d401b56c7602d..dda756ddab152 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1410,6 +1410,25 @@ def TosaToSCFPass : Pass<"tosa-to-scf"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TOSA to SPIR-V Graph/TOSA
+//===----------------------------------------------------------------------===//
+
+def TosaToSPIRVTosa : Pass<"tosa-to-spirv-tosa"> {
+  let summary = "Lower TOSA IR to SPIR-V Graph/TOSA operations";
+  let dependentDialects = [
+    "spirv::SPIRVDialect",
+  ];
+  let description = [{
+    Converts TOSA programs to the SPIR-V Graph/TOSA representation by
+    wrapping converted functions in `spirv.module` and `spirv.ARM.Graph`,
+    and rewriting TOSA tensor and shape types to the corresponding SPIR-V ARM
+    tensor types.
+  }];
+
+  let constructor = "tosa::createTosaToSPIRVTosa()";
+}
+
 //===----------------------------------------------------------------------===//
 // TosaToTensor
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h b/mlir/include/mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h
new file mode 100644
index 0000000000000..fc36d82ff20c1
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h
@@ -0,0 +1,45 @@
+//===-- TosaToSPIRVTosa.h - TOSA to SPIR-V Graph/TOSA patterns --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Provides pass and patterns to lower TOSA IR to SPIR-V Graph/TOSA
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOSPIRVTOSA_TOSATOSPIRVTOSA_H
+#define MLIR_CONVERSION_TOSATOSPIRVTOSA_TOSATOSPIRVTOSA_H
+
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+
+#define GEN_PASS_DECL_TOSATOSPIRVTOSA
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToSPIRVTosa();
+
+spirv::VerCapExtAttr getDefaultVerCapExtAttr(MLIRContext *context);
+
+spirv::TargetEnvAttr constructTargetEnvAttrWithCapExtDefaults(
+    MLIRContext *context, spirv::ResourceLimitsAttr limits = {},
+    spirv::ClientAPI clientAPI = spirv::ClientAPI::Unknown,
+    spirv::Vendor vendorID = spirv::Vendor::Unknown,
+    spirv::DeviceType deviceType = spirv::DeviceType::Unknown,
+    uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID);
+
+void populateTosaToSPIRVTosaConversionPatterns(
+    SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns,
+    spirv::TargetEnvAttr targetAttr);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOSPIRVTOSA_TOSATOSPIRVTOSA_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e17988b12cade..f5e0bcf613e59 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -69,6 +69,7 @@ add_subdirectory(TosaToArith)
 add_subdirectory(TosaToLinalg)
 add_subdirectory(TosaToMLProgram)
 add_subdirectory(TosaToSCF)
+add_subdirectory(TosaToSPIRVTosa)
 add_subdirectory(TosaToTensor)
 add_subdirectory(UBToLLVM)
 add_subdirectory(UBToSPIRV)
diff --git a/mlir/lib/Conversion/TosaToSPIRVTosa/CMakeLists.txt b/mlir/lib/Conversion/TosaToSPIRVTosa/CMakeLists.txt
new file mode 100644
index 0000000000000..630278447fa42
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSPIRVTosa/CMakeLists.txt
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRTosaToSPIRVTosa
+  TosaToSPIRVTosa.cpp
+  TosaToSPIRVTosaPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRIR
+  MLIRPass
+  MLIRSPIRVDialect
+  MLIRSPIRVConversion
+  MLIRSupport
+  MLIRTransformUtils
+  MLIRTosaDialect
+)
diff --git a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.cpp b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.cpp
new file mode 100644
index 0000000000000..92d2479fdf543
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.cpp
@@ -0,0 +1,188 @@
+//===- TosaToSPIRVTosa.cpp - TOSA to SPIR-V Graph/TOSA patterns -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert TOSA IR to SPIR-V Graph/TOSA.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
+
+#define DEBUG_TYPE "tosa-to-spirv-tosa-pattern"
+
+namespace mlir::tosa {
+namespace {
+
+// Allows users to specify descriptor sets and binding ids on the source
+// function inputs and outputs. Use a source-side GraphARM attribute because
+// `spirv.interface_var_abi` is verified by the SPIR-V dialect before this
+// conversion runs, and result attrs are only accepted on `spirv.ARM.Graph`.
+constexpr StringLiteral graphARMInterfaceVarABIAttrName =
+    "grapharm.interface_var_abi";
+
+void copyFuncAttrsToGraph(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
+                          spirv::GraphARMOp graphOp) {
+  for (NamedAttribute attr : adaptor.getAttributes()) {
+    StringRef attrName = attr.getName().getValue();
+    if (llvm::is_contained({SymbolTable::getSymbolAttrName(),
+                            funcOp.getFunctionTypeAttrName().getValue(),
+                            funcOp.getArgAttrsAttrName().getValue(),
+                            funcOp.getResAttrsAttrName().getValue(),
+                            graphOp.getEntryPointAttrName().getValue()},
+                           attrName))
+      continue;
+
+    graphOp->setAttr(attr.getName(), attr.getValue());
+  }
+}
+
+struct FuncGraphConvert final : OpConversionPattern<func::FuncOp> {
+  FuncGraphConvert(SPIRVTypeConverter &typeConverter, MLIRContext *context,
+                   spirv::TargetEnvAttr targetAttr)
+      : OpConversionPattern<func::FuncOp>(typeConverter, context),
+        targetAttr(targetAttr) {}
+
+private:
+  spirv::TargetEnvAttr targetAttr;
+
+  // Prefer an explicit source-side GraphARM ABI annotation, then preserve an
+  // already-canonical SPIR-V ABI annotation, and otherwise synthesize the
+  // default descriptor set and binding id.
+  void normalizeInterfaceVarABIAttr(spirv::GraphARMOp graphOp,
+                                    MLIRContext *context, unsigned index,
+                                    bool isResult,
+                                    uint32_t defaultDescriptorSet,
+                                    uint32_t defaultBinding) const {
+    auto abiInfo =
+        isResult ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+                       index, graphARMInterfaceVarABIAttrName)
+                 : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+                       index, graphARMInterfaceVarABIAttrName);
+
+    if (!abiInfo) {
+      abiInfo = isResult
+                    ? graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+                          index, spirv::getInterfaceVarABIAttrName())
+                    : graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+                          index, spirv::getInterfaceVarABIAttrName());
+    }
+
+    if (!abiInfo) {
+      abiInfo = spirv::InterfaceVarABIAttr::get(
+          defaultDescriptorSet, defaultBinding, std::nullopt, context);
+    }
+
+    if (isResult) {
+      graphOp.setResultAttr(index, spirv::getInterfaceVarABIAttrName(),
+                            abiInfo);
+      graphOp.removeResultAttr(index, graphARMInterfaceVarABIAttrName);
+    } else {
+      graphOp.setArgAttr(index, spirv::getInterfaceVarABIAttrName(), abiInfo);
+      graphOp.removeArgAttr(index, graphARMInterfaceVarABIAttrName);
+    }
+  }
+
+  void normalizeInterfaceVarABIAttrs(spirv::GraphARMOp graphOp,
+                                     MLIRContext *context, unsigned inputs,
+                                     unsigned outputs) const {
+    constexpr uint32_t defaultDescriptorSet = 0;
+    for (auto argIndex : llvm::seq<unsigned>(0, inputs)) {
+      normalizeInterfaceVarABIAttr(graphOp, context, argIndex, false,
+                                   defaultDescriptorSet, argIndex);
+    }
+    for (auto resIndex : llvm::seq<unsigned>(0, outputs)) {
+      normalizeInterfaceVarABIAttr(graphOp, context, resIndex, true,
+                                   defaultDescriptorSet, resIndex + inputs);
+    }
+  }
+
+public:
+  LogicalResult
+  matchAndRewrite(func::FuncOp funcOp, func::FuncOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MLIRContext *context = rewriter.getContext();
+
+    StringRef name = adaptor.getSymName();
+    auto spvModule = spirv::ModuleOp::create(
+        rewriter, funcOp.getLoc(), spirv::AddressingModel::Logical,
+        spirv::MemoryModel::Vulkan, std::nullopt,
+        ("_spirv_tosa_" + name).str());
+    spvModule->setAttr(spirv::getTargetEnvAttrName(), targetAttr);
+
+    rewriter.setInsertionPoint(spvModule.getBody(), spvModule.begin());
+
+    FunctionType ftype = adaptor.getFunctionType();
+    ArrayAttr argAttrs = adaptor.getArgAttrsAttr();
+    ArrayAttr resAttrs = adaptor.getResAttrsAttr();
+
+    TypeConverter::SignatureConversion signatureConverter(ftype.getNumInputs());
+    if (failed(typeConverter->convertSignatureArgs(ftype.getInputs(),
+                                                   signatureConverter))) {
+      return funcOp.emitError("failed to convert function argument types");
+    }
+
+    // Update the signature of the function.
+    SmallVector<Type, 2> newResultTypes;
+    if (failed(getTypeConverter()->convertTypes(ftype.getResults(),
+                                                newResultTypes))) {
+      return funcOp.emitError("failed to convert function result types");
+    }
+
+    // TOSA graphs cannot contain nested funcs, so the converted GraphARM op is
+    // an entry point.
+    auto entryPointAttr = BoolAttr::get(context, true);
+    auto graphTy = GraphType::get(
+        context, signatureConverter.getConvertedTypes(), newResultTypes);
+    auto graphOp =
+        spirv::GraphARMOp::create(rewriter, funcOp.getLoc(), graphTy, argAttrs,
+                                  resAttrs, entryPointAttr, name);
+    copyFuncAttrsToGraph(funcOp, adaptor, graphOp);
+
+    rewriter.inlineRegionBefore(funcOp.getBody(), graphOp.getBody(),
+                                graphOp.end());
+    if (failed(rewriter.convertRegionTypes(
+            &graphOp.getBody(), *getTypeConverter(), &signatureConverter))) {
+      return funcOp.emitError("failed to convert function regions");
+    }
+
+    normalizeInterfaceVarABIAttrs(graphOp, context, ftype.getNumInputs(),
+                                  ftype.getNumResults());
+
+    rewriter.eraseOp(funcOp);
+    return success();
+  }
+};
+
+/// Converts func.return to spirv.ARM.GraphOutputs.
+struct ReturnGraphOutputConvert final : OpConversionPattern<func::ReturnOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::GraphOutputsARMOp>(
+        returnOp, adaptor.getOperands());
+    return success();
+  }
+};
+
+} // namespace
+
+void populateTosaToSPIRVTosaConversionPatterns(
+    SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns,
+    spirv::TargetEnvAttr targetAttr) {
+  patterns.add<FuncGraphConvert>(typeConverter, patterns.getContext(),
+                                 targetAttr);
+  patterns.add<ReturnGraphOutputConvert>(typeConverter, patterns.getContext());
+}
+
+} // namespace mlir::tosa
diff --git a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp
new file mode 100644
index 0000000000000..f558d14c2dabd
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp
@@ -0,0 +1,215 @@
+//===- TosaToSPIRVTosaPass.cpp - Lower TOSA to SPIR-V Graph/TOSA ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers TOSA IR to the SPIR-V Graph/TOSA representation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosa.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include <algorithm>
+
+namespace mlir {
+#define GEN_PASS_DEF_TOSATOSPIRVTOSA
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace tosa {
+
+spirv::VerCapExtAttr getDefaultVerCapExtAttr(MLIRContext *context) {
+  return spirv::VerCapExtAttr::get(
+      spirv::Version::V_1_5,
+      {
+          spirv::Capability::VulkanMemoryModel,
+          spirv::Capability::Shader,
+          spirv::Capability::Int8,
+          spirv::Capability::Int16,
+          spirv::Capability::Int64,
+          spirv::Capability::Float16,
+          spirv::Capability::BFloat16TypeKHR,
+          spirv::Capability::Float8EXT,
+          spirv::Capability::TensorsARM,
+          spirv::Capability::GraphARM,
+          spirv::Capability::ReplicatedCompositesEXT,
+      },
+      {
+          spirv::Extension::SPV_ARM_tensors,
+          spirv::Extension::SPV_ARM_graph,
+          spirv::Extension::SPV_KHR_vulkan_memory_model,
+          spirv::Extension::SPV_EXT_replicated_composites,
+          spirv::Extension::SPV_KHR_bfloat16,
+          spirv::Extension::SPV_EXT_float8,
+      },
+      context);
+}
+
+spirv::TargetEnvAttr constructTargetEnvAttrWithCapExtDefaults(
+    MLIRContext *context, spirv::ResourceLimitsAttr limits,
+    spirv::ClientAPI clientAPI, spirv::Vendor vendorID,
+    spirv::DeviceType deviceType, uint32_t deviceID) {
+  if (!limits)
+    limits = spirv::getDefaultResourceLimits(context);
+
+  return spirv::TargetEnvAttr::get(getDefaultVerCapExtAttr(context), limits,
+                                   clientAPI, vendorID, deviceType, deviceID);
+}
+
+namespace {
+
+LogicalResult verifyGraphTargetEnv(Operation *op,
+                                   spirv::TargetEnvAttr targetAttr) {
+  spirv::TargetEnv targetEnv(targetAttr);
+  if (targetEnv.allows(spirv::Capability::GraphARM) &&
+      targetEnv.allows(spirv::Extension::SPV_ARM_graph) &&
+      targetEnv.allows(spirv::Extension::SPV_ARM_tensors)) {
+    return success();
+  }
+
+  return op->emitOpError()
+         << "requires GraphARM capability and SPV_ARM_graph/SPV_ARM_tensors "
+            "extensions in spirv.target_env";
+}
+
+LogicalResult verifyNoUnsupportedFuncOps(Operation *op) {
+  WalkResult result =
+      op->walk([](Operation *op) -> WalkResult {
+        if (isa<func::CallOp, func::CallIndirectOp>(op)) {
+          op->emitOpError()
+              << "is not supported in TOSA to SPIR-V Graph conversion; inline "
+                 "calls before running this pass";
+          return WalkResult::interrupt();
+        }
+        if (auto funcOp = dyn_cast<func::FuncOp>(op)) {
+          if (funcOp->getParentOfType<func::FuncOp>()) {
+            funcOp.emitOpError()
+                << "nesting is not supported in TOSA to SPIR-V Graph conversion";
+            return WalkResult::interrupt();
+          }
+        }
+        return WalkResult::advance();
+      });
+
+  return failure(result.wasInterrupted());
+}
+
+struct TosaToSPIRVTosa final : impl::TosaToSPIRVTosaBase<TosaToSPIRVTosa> {
+  void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    RewritePatternSet patterns(context);
+    Operation *op = getOperation();
+
+    spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(op);
+    if (!targetAttr) {
+      targetAttr = constructTargetEnvAttrWithCapExtDefaults(context);
+    }
+
+    if (failed(verifyGraphTargetEnv(op, targetAttr)) ||
+        failed(verifyNoUnsupportedFuncOps(op))) {
+      signalPassFailure();
+      return;
+    }
+
+    std::unique_ptr<ConversionTarget> target =
+        SPIRVConversionTarget::get(targetAttr);
+
+    target->addIllegalDialect<tosa::TosaDialect>();
+    target->addIllegalOp<func::CallOp, func::CallIndirectOp>();
+
+    SPIRVTypeConverter typeConverter(targetAttr);
+    typeConverter.addConversion([this](IntegerType integerType) {
+      return this->convertIntegerType(integerType);
+    });
+    typeConverter.addConversion([this](TensorType tensorType) {
+      return this->convertTensorType(tensorType);
+    });
+    typeConverter.addConversion([this](tosa::shapeType shapeType) {
+      return this->convertShapeType(shapeType);
+    });
+
+    populateTosaToSPIRVTosaConversionPatterns(typeConverter, patterns,
+                                              targetAttr);
+
+    FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+    if (failed(applyPartialConversion(op, *target, frozenPatterns))) {
+      signalPassFailure();
+    }
+  }
+
+private:
+  IntegerType convertIntegerType(IntegerType integerType) {
+    if (integerType.getWidth() == 48) {
+      return IntegerType::get(&getContext(), 64, integerType.getSignedness());
+    }
+
+    if (integerType.getWidth() == 4) {
+      return IntegerType::get(&getContext(), 8, integerType.getSignedness());
+    }
+
+    return integerType;
+  }
+
+  std::optional<SmallVector<int64_t>> convertShape(ArrayRef<int64_t> shape) {
+    // Scalar ARM tensors are not supported, so convert them to
+    // tensors with shape [1].
+    if (shape.empty())
+      return SmallVector<int64_t>({1});
+
+    if (llvm::is_contained(shape, 0))
+      return std::nullopt;
+
+    bool isPartiallyDynamic =
+        llvm::is_contained(shape, ShapedType::kDynamic) &&
+        llvm::any_of(shape, [](int64_t dim) { return dim > 0; });
+    // Partially shaped ARM tensors are not supported, so convert them to
+    // unshaped tensors.
+    if (isPartiallyDynamic)
+      return SmallVector<int64_t>(shape.size(), ShapedType::kDynamic);
+    return SmallVector<int64_t>(shape);
+  }
+
+  std::optional<spirv::TensorArmType> convertTensorType(TensorType tensorType) {
+    Type elementType = getElementTypeOrSelf(tensorType);
+    if (elementType.isIndex())
+      elementType = IntegerType::get(&getContext(), 32);
+    if (auto integerType = dyn_cast<IntegerType>(elementType))
+      elementType = convertIntegerType(integerType);
+
+    SmallVector<int64_t> shape;
+    if (tensorType.hasRank()) {
+      std::optional<SmallVector<int64_t>> convertedShape =
+          convertShape(tensorType.getShape());
+      if (!convertedShape)
+        return std::nullopt;
+      shape = std::move(*convertedShape);
+    }
+
+    return spirv::TensorArmType::get(shape, elementType);
+  }
+
+  spirv::TensorArmType convertShapeType(tosa::shapeType shapeType) {
+    const int64_t rank = std::max(shapeType.getRank(), 1);
+    return spirv::TensorArmType::get({rank},
+                                     IntegerType::get(&getContext(), 32));
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> createTosaToSPIRVTosa() {
+  return std::make_unique<TosaToSPIRVTosa>();
+}
+
+} // namespace tosa
+} // namespace mlir
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/descriptor-set-and-bindings.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/descriptor-set-and-bindings.mlir
new file mode 100644
index 0000000000000..bbbe72c00a15b
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/descriptor-set-and-bindings.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa --verify-diagnostics %s | FileCheck %s
+
+// CHECK-NOT: grapharm.interface_var_abi
+
+// CHECK: spirv.module @_spirv_tosa_default_interface_var_abi Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @default_interface_var_abi(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @default_interface_var_abi(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+  // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+  return %arg0 : tensor<1xi8>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_custom_grapharm_abi Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @custom_grapharm_abi(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(3, 9)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(7, 11)>}) attributes {entry_point = true} {
+func.func @custom_grapharm_abi(%arg0: tensor<1xi8> {grapharm.interface_var_abi = #spirv.interface_var_abi<(3, 9)>}) -> (tensor<1xi8> {grapharm.interface_var_abi = #spirv.interface_var_abi<(7, 11)>}) {
+  // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+  return %arg0 : tensor<1xi8>
+}
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/op-nesting.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/op-nesting.mlir
new file mode 100644
index 0000000000000..11e76d7df1637
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/op-nesting.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa --verify-diagnostics %s | FileCheck %s
+
+// CHECK: gpu.module @random_container
+gpu.module @random_container {
+  // CHECK: spirv.module @_spirv_tosa_nested Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+  // CHECK: spirv.ARM.Graph @nested(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+  func.func @nested(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+    // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+    return %arg0 : tensor<1xi8>
+  }
+}
+
+// -----
+
+// CHECK: module @random_container {
+module @random_container {
+  module @yet_anther_random_container {
+    // CHECK: gpu.module @another_random_container
+    gpu.module @another_random_container {
+      // CHECK: spirv.module @_spirv_tosa_nested Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+      // CHECK: spirv.ARM.Graph @nested(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+      func.func @nested(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+        // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+        return %arg0 : tensor<1xi8>
+      }
+    }
+  }
+}
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/type-conversions.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/type-conversions.mlir
new file mode 100644
index 0000000000000..76ca038f5da2d
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/type-conversions.mlir
@@ -0,0 +1,67 @@
+// RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa --verify-diagnostics %s | FileCheck %s
+
+// CHECK: spirv.module @_spirv_tosa_i48_to_i64 Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @i48_to_i64(%[[ARG0:.*]]: !spirv.arm.tensor<1x2x3x4xi64> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1x2x3x4xi64> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @i48_to_i64(%arg0: tensor<1x2x3x4xi48>) -> tensor<1x2x3x4xi48> {
+  // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1x2x3x4xi64>
+  return %arg0 : tensor<1x2x3x4xi48>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_i4_to_i8 Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @i4_to_i8(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @i4_to_i8(%arg0: tensor<1xi4>) -> tensor<1xi4> {
+  // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+  return %arg0 : tensor<1xi4>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_scalar_tensor_to_rank1_tensor Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @scalar_tensor_to_rank1_tensor(%[[ARG0:.*]]: !spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<1xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @scalar_tensor_to_rank1_tensor(%arg0: tensor<i8>) -> tensor<i8> {
+  // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<1xi8>
+  return %arg0 : tensor<i8>
+}
+
+// -----
+
+// expected-error at below {{failed to convert function argument types}}
+func.func @zero_sized_tensor(%arg0: tensor<0xi8>) -> tensor<0xi8> {
+  return %arg0 : tensor<0xi8>
+}
+
+// -----
+
+// expected-error at below {{failed to convert function argument types}}
+func.func @mixed_zero_sized_tensor(%arg0: tensor<1x0x2xi8>) -> tensor<1x0x2xi8> {
+  return %arg0 : tensor<1x0x2xi8>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_partially_shaped_tensor_to_unshaped_tensor Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @partially_shaped_tensor_to_unshaped_tensor(%[[ARG0:.*]]: !spirv.arm.tensor<?x?xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<?x?xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @partially_shaped_tensor_to_unshaped_tensor(%arg0: tensor<1x?xi8>) -> tensor<1x?xi8> {
+  // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<?x?xi8>
+  return %arg0 : tensor<1x?xi8>
+}
+
+// -----
+
+// CHECK: spirv.module @_spirv_tosa_unranked_tensor Logical Vulkan attributes {spirv.target_env = #spirv.target_env<
+// CHECK: spirv.ARM.Graph @unranked_tensor(%[[ARG0:.*]]: !spirv.arm.tensor<*xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}) -> (!spirv.arm.tensor<*xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+func.func @unranked_tensor(%arg0: tensor<*xi8>) -> tensor<*xi8> {
+  // CHECK: spirv.ARM.GraphOutputs %[[ARG0]] : !spirv.arm.tensor<*xi8>
+  return %arg0 : tensor<*xi8>
+}
+
+// -----
+
+// expected-error at below {{'builtin.module' op requires GraphARM capability and SPV_ARM_graph/SPV_ARM_tensors extensions in spirv.target_env}}
+module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>} {
+  func.func @unsupported_target_env(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+    return %arg0 : tensor<1xi8>
+  }
+}
diff --git a/mlir/test/Conversion/TosaToSPIRVTosa/unsupported-func-calls.mlir b/mlir/test/Conversion/TosaToSPIRVTosa/unsupported-func-calls.mlir
new file mode 100644
index 0000000000000..52c1ad526727d
--- /dev/null
+++ b/mlir/test/Conversion/TosaToSPIRVTosa/unsupported-func-calls.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt --split-input-file --tosa-to-spirv-tosa --verify-diagnostics %s
+
+func.func @direct_call(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+  // expected-error at below {{'func.call' op is not supported in TOSA to SPIR-V Graph conversion; inline calls before running this pass}}
+  %0 = func.call @direct_call(%arg0) : (tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+
+func.func @indirect_call(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+  %callee = func.constant @indirect_call : (tensor<1xi8>) -> tensor<1xi8>
+  // expected-error at below {{'func.call_indirect' op is not supported in TOSA to SPIR-V Graph conversion; inline calls before running this pass}}
+  %0 = func.call_indirect %callee(%arg0) : (tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+
+func.func @nested_func(%arg0: tensor<1xi8>) -> tensor<1xi8> {
+  builtin.module {
+    // expected-error at below {{'func.func' op nesting is not supported in TOSA to SPIR-V Graph conversion}}
+    func.func @nested(%arg1: tensor<1xi8>) -> tensor<1xi8> {
+      return %arg1 : tensor<1xi8>
+    }
+  }
+  return %arg0 : tensor<1xi8>
+}



More information about the Mlir-commits mailing list