[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

Renato Golin llvmlistbot at llvm.org
Fri Jul 12 05:21:13 PDT 2024


================
@@ -2695,43 +2621,88 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
 /// 4. Divide z and l. This gives the N-dimensional softmax.
 ///    softmax = z / l
 ///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+  if (!hasPureTensorSemantics()) {
+    // The decomposition assumes ranked tensors as input.
+    return failure();
+  }
+
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(*this);
   Location loc = getLoc();
   Value input = getInput();
   ShapedType inputType = getInputOperandType();
   Type elementType = inputType.getElementType();
   int64_t reductionDim = getDimension();
-  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
   Value output = getOutput();
-  dims.erase(dims.begin() + reductionDim);
+
+  SmallVector<int64_t> reduceShape;
+  SmallVector<Value> dynReduceDims;
+  for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+    if (reductionDim == i)
+      continue;
+    reduceShape.push_back(inputType.getDimSize(i));
+    if (inputType.isDynamicDim(i))
+      dynReduceDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+  }
+
   // Step 1: Compute max along dim.
-  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
-  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
-                                                 elementType, b, loc,
-                                                 /*useOnlyFiniteValue=*/true);
-  Value neutralForMaxFInit =
-      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
-          .result();
-  Value max =
-      reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+  Value outputReduce =
+      b.create<tensor::EmptyOp>(loc, reduceShape, elementType, dynReduceDims);
+  TypedAttr maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+  auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+  auto neutralMaxInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+
+  auto reduceMaxOp = b.create<linalg::ReduceOp>(
+      loc, input, neutralMaxInitOp.result(), reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value result =
+            createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
 
   // Step 2: Subtract max from input and exponentiate.
-  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+  auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+      loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+
+  auto subOp = b.create<linalg::SubOp>(
+      loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+      ValueRange{output});
+  auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+                                       ValueRange{output});
 
   // Step 3: Compute sum along dim.
-  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
-                                       b, loc, /*useOnlyFiniteValue=*/true);
-  Value zeroInit =
-      b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
-  Value denominator =
-      reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+  TypedAttr sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+  auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+  auto neutralSumInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+  auto reduceSumOp = b.create<linalg::ReduceOp>(
+      loc, expOp.getResults(), neutralSumInitOp.result(), reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        auto result =
+            createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
 
   // Step 4: Compute softmax.
-  Value result =
-      buildDivOp(b, loc, numerator, denominator, output, reductionDim);
-  return SmallVector<Value>{result};
+  SmallVector<Value> dynDims;
+  for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+    if (inputType.isDynamicDim(i))
+      dynDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+  }
+  auto sumBcastOutput = b.create<tensor::EmptyOp>(
+      loc, getOutputOperandType().getShape(), elementType, dynDims);
+  auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
----------------
rengolin wrote:

First, regardless of the merits of lowering `softmax` to generics or not, I disagree that a more succinct representation is always preferred. We discuss this at length on the canonicalization threads in the forum.

Second, what this PR is adding is just the semantics that:
* Decomposition is breaking named ops into further named ops.
* Generalization is lowering named ops into generics.

Above you mention you prefer `fill` than its generic form, even though the step that the test is doing is _generalization_, while here you're advocating for a generic form, even though the step is _decomposition_. This seems totally arbitrary to me.

The linalg dialect and its transforms should be consistent and predictable. If you want to generalize only some ops and not others, this seems to me like a local pass.

https://github.com/llvm/llvm-project/pull/97582


More information about the Mlir-commits mailing list