[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 11 22:02:13 PDT 2024

@@ -2695,43 +2621,88 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
 /// 4. Divide z and l. This gives the N-dimensional softmax.
 ///    softmax = z / l
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+  if (!hasPureTensorSemantics()) {
+    // The decomposition assumes ranked tensors as input.
+    return failure();
+  }
   OpBuilder::InsertionGuard guard(b);
   Location loc = getLoc();
   Value input = getInput();
   ShapedType inputType = getInputOperandType();
   Type elementType = inputType.getElementType();
   int64_t reductionDim = getDimension();
-  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
   Value output = getOutput();
-  dims.erase(dims.begin() + reductionDim);
+  SmallVector<int64_t> reduceShape;
+  SmallVector<Value> dynReduceDims;
+  for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+    if (reductionDim == i)
+      continue;
+    reduceShape.push_back(inputType.getDimSize(i));
+    if (inputType.isDynamicDim(i))
+      dynReduceDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+  }
   // Step 1: Compute max along dim.
-  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
-  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
-                                                 elementType, b, loc,
-                                                 /*useOnlyFiniteValue=*/true);
-  Value neutralForMaxFInit =
-      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
-          .result();
-  Value max =
-      reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+  Value outputReduce =
+      b.create<tensor::EmptyOp>(loc, reduceShape, elementType, dynReduceDims);
+  TypedAttr maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+  auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+  auto neutralMaxInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+  auto reduceMaxOp = b.create<linalg::ReduceOp>(
+      loc, input, neutralMaxInitOp.result(), reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value result =
+            createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
   // Step 2: Subtract max from input and exponentiate.
-  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+  auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+      loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+  auto subOp = b.create<linalg::SubOp>(
+      loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+      ValueRange{output});
+  auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+                                       ValueRange{output});
   // Step 3: Compute sum along dim.
-  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
-                                       b, loc, /*useOnlyFiniteValue=*/true);
-  Value zeroInit =
-      b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
-  Value denominator =
-      reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+  TypedAttr sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+  auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+  auto neutralSumInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+  auto reduceSumOp = b.create<linalg::ReduceOp>(
+      loc, expOp.getResults(), neutralSumInitOp.result(), reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        auto result =
+            createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
   // Step 4: Compute softmax.
-  Value result =
-      buildDivOp(b, loc, numerator, denominator, output, reductionDim);
-  return SmallVector<Value>{result};
+  SmallVector<Value> dynDims;
+  for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+    if (inputType.isDynamicDim(i))
+      dynDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+  }
+  auto sumBcastOutput = b.create<tensor::EmptyOp>(
+      loc, getOutputOperandType().getShape(), elementType, dynDims);
+  auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
MaheshRavishankar wrote:

I think lowering to named ops this way is not great for softmax. If anything it is showing issues with named ops sa being currently defined and being developed. For example, linalg allows a more succicnt representation of broadcasting behavior than having to add an explicit broadcast operation. I dont think this is a great way forward. This is requiring everything that relies on this decomposition "to do additional work" to get back to the state that the decomposition was doing. That doesnt seem like a good idea to me. Maybe you can create an option that will allow you to lower softmax to named ops like you want here (and we can decide which to make the default), but changing the decomposition this way is a red flag for me. 

Also the decomposition change needs to be a separate PR and not rolled into the PR that is adding a pass for decomposition.


More information about the Mlir-commits mailing list