[Mlir-commits] [mlir] [mlir][tosa] Add folding for TOSA ArgMax operator (PR #88871)

Jakub Kuderski llvmlistbot at llvm.org
Tue Apr 16 09:40:14 PDT 2024


================
@@ -507,6 +507,20 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
                                                             resultTy);
 }
 
+OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
+  auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+  auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
+  if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
+      !outputTy.hasStaticShape())
+    return {};
+
+  if (inputTy.getDimSize(getAxis()) == 1) {
+    return DenseElementsAttr::get(outputTy, 0);
+  }
----------------
kuhar wrote:

nit: for consistency with the if above, I'd either drop these braces or also add them above

https://github.com/llvm/llvm-project/pull/88871


More information about the Mlir-commits mailing list