[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

Adam Siemieniuk llvmlistbot at llvm.org
Wed Jul 3 10:55:16 PDT 2024


================
@@ -2706,32 +2631,60 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
   Value output = getOutput();
   dims.erase(dims.begin() + reductionDim);
+
   // Step 1: Compute max along dim.
   Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
-  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
-                                                 elementType, b, loc,
-                                                 /*useOnlyFiniteValue=*/true);
-  Value neutralForMaxFInit =
-      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
-          .result();
-  Value max =
-      reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+  auto maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+  auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+  auto neutralMaxInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+  Value neutralForMaxFInit = neutralMaxInitOp.result();
----------------
adam-smnk wrote:

nit: it's not only MaxF anymore, you can probably skip the temporary

https://github.com/llvm/llvm-project/pull/97582


More information about the Mlir-commits mailing list