[Mlir-commits] [mlir] [mlir][spirv] Initial support for TOSA Extended Instruction Set (0010… (PR #174402)

Jakub Kuderski llvmlistbot at llvm.org
Fri Jan 9 07:56:39 PST 2026


================
@@ -0,0 +1,55 @@
+//===- SPIRVTosaOps.cpp - MLIR SPIR-V operations --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Tosa operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/TypeUtilities.h"
+
+namespace mlir::spirv {
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaArgmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TosaArgMaxOp::verify() {
+  auto inputTy = cast<ShapedType>(getInput().getType());
+  auto resultTy = cast<ShapedType>(getType());
+
+  if (inputTy.hasRank() && resultTy.hasRank() &&
+      resultTy.getRank() !=
+          (inputTy.getRank() > 1 ? inputTy.getRank() - 1 : 1)) {
+    return emitOpError("result rank must be max of 1 and (input rank - 1)");
+  }
+
+  Type resultETy = resultTy.getElementType();
+  if (!resultETy.isIntOrIndex()) {
+    return emitOpError("result is not of integer type");
+  }
+
+  IntegerAttr axisAttr;
+  if (!matchPattern(getAxis(), m_Constant(&axisAttr))) {
+    return emitOpError("axis type must be a constant integer");
+  }
+
+  const int axis = axisAttr.getInt();
+  if (inputTy.hasRank() && ((axis < 0) || axis >= inputTy.getRank())) {
+    return emitOpError("specified axis is outside the rank of input");
+  }
----------------
kuhar wrote:

if the spirv spec requires something to be a constant it must be an attribute at the mlir level -- otherwise we can't enforce the constness requirement

https://github.com/llvm/llvm-project/pull/174402


More information about the Mlir-commits mailing list