[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri Jul 5 07:54:07 PDT 2024


================
@@ -2695,43 +2621,89 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
 /// 4. Divide z and l. This gives the N-dimensional softmax.
 ///    softmax = z / l
 ///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+  if (!hasPureTensorSemantics()) {
+    // The decomposition assumes ranked tensors as input
+    return failure();
+  }
+
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(*this);
   Location loc = getLoc();
   Value input = getInput();
   ShapedType inputType = getInputOperandType();
   Type elementType = inputType.getElementType();
   int64_t reductionDim = getDimension();
-  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
   Value output = getOutput();
-  dims.erase(dims.begin() + reductionDim);
+
+  SmallVector<int64_t> reduceShape;
+  SmallVector<Value> dynReduceDims;
+  for (unsigned i = 0; i < inputType.getRank(); i++) {
+    if (reductionDim != i) {
----------------
ftynse wrote:

https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code

https://github.com/llvm/llvm-project/pull/97582


More information about the Mlir-commits mailing list