[Mlir-commits] [mlir] [mlir][spirv] Initial support for TOSA Extended Instruction Set (0010… (PR #174402)
Jakub Kuderski
llvmlistbot at llvm.org
Mon Jan 5 07:02:04 PST 2026
================
@@ -0,0 +1,72 @@
+//===- SPIRVTosaOps.cpp - MLIR SPIR-V operations --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Tosa operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// TOSA Operator Verifiers.
+//===----------------------------------------------------------------------===//
+
+// Get value attr from spirv::ConstantOp or
+// spirv::EXTConstantCompositeReplicateOp
+template <typename TAttr>
+static LogicalResult getConstAttr(Value value, TAttr &valAttr) {
+ if (auto constOp = value.template getDefiningOp<spirv::ConstantOp>()) {
+ valAttr = dyn_cast<TAttr>(constOp.getValue());
+ } else if (auto constCompositeReplicateOp =
+ value.template getDefiningOp<
+ spirv::EXTConstantCompositeReplicateOp>()) {
+ auto splatAttr = constCompositeReplicateOp.getValue();
+ auto denseValAttr = SplatElementsAttr::get(
+ cast<ShapedType>(constCompositeReplicateOp.getType()), splatAttr);
+ valAttr = dyn_cast<TAttr>(denseValAttr);
+ }
+
+ return valAttr ? success() : failure();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.TosaArgmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::TosaArgMaxOp::verify() {
+ auto inputTy = cast<ShapedType>(getInput().getType());
+ auto resultTy = cast<ShapedType>(getType());
+
+ if (inputTy.hasRank() && resultTy.hasRank() &&
+ resultTy.getRank() !=
+ (inputTy.getRank() > 1 ? inputTy.getRank() - 1 : 1)) {
+ return emitOpError("result rank must be max of 1 and (input rank - 1)");
+ }
+
+ auto resultETy = resultTy.getElementType();
+ if (!resultETy.isIntOrIndex()) {
+ return emitOpError("result is not of integer type");
+ }
+
+ IntegerAttr axisAttr;
+ if (getConstAttr(getAxis(), axisAttr).failed()) {
----------------
kuhar wrote:
`failed(...)`
https://github.com/llvm/llvm-project/pull/174402
More information about the Mlir-commits
mailing list