[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Jul 5 07:54:07 PDT 2024


================
@@ -2695,43 +2621,89 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
 /// 4. Divide z and l. This gives the N-dimensional softmax.
 ///    softmax = z / l
 ///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+  if (!hasPureTensorSemantics()) {
+    // The decomposition assumes ranked tensors as input
+    return failure();
+  }
+
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(*this);
   Location loc = getLoc();
   Value input = getInput();
   ShapedType inputType = getInputOperandType();
   Type elementType = inputType.getElementType();
   int64_t reductionDim = getDimension();
-  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
   Value output = getOutput();
-  dims.erase(dims.begin() + reductionDim);
+
+  SmallVector<int64_t> reduceShape;
+  SmallVector<Value> dynReduceDims;
+  for (unsigned i = 0; i < inputType.getRank(); i++) {
+    if (reductionDim != i) {
+      reduceShape.push_back(inputType.getDimSize(i));
+      if (inputType.isDynamicDim(i))
+        dynReduceDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+    }
+  }
+
   // Step 1: Compute max along dim.
-  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
-  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
-                                                 elementType, b, loc,
-                                                 /*useOnlyFiniteValue=*/true);
-  Value neutralForMaxFInit =
-      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
-          .result();
-  Value max =
-      reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+  Value outputReduce =
+      b.create<tensor::EmptyOp>(loc, reduceShape, elementType, dynReduceDims);
+  auto maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
----------------
ftynse wrote:

Please expand auto unless the type is obvious from RHS (e.g., there's a cast or create). Here and below.

https://github.com/llvm/llvm-project/pull/97582


More information about the Mlir-commits mailing list