[Mlir-commits] [mlir] [mlir][tosa] Add folding for TOSA ArgMax operator (PR #88871)
Jakub Kuderski
llvmlistbot at llvm.org
Tue Apr 16 09:40:14 PDT 2024
================
@@ -507,6 +507,20 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
resultTy);
}
+OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
+ auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+ auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
+ if (!inputTy || !outputTy || !inputTy.hasStaticShape() ||
+ !outputTy.hasStaticShape())
+ return {};
+
+ if (inputTy.getDimSize(getAxis()) == 1) {
+ return DenseElementsAttr::get(outputTy, 0);
+ }
----------------
kuhar wrote:
nit: for consistency with the if above, I'd either drop these braces or also add them above
https://github.com/llvm/llvm-project/pull/88871
More information about the Mlir-commits
mailing list