[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 12 13:44:16 PDT 2024
================
@@ -2695,43 +2621,88 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
/// 4. Divide z and l. This gives the N-dimensional softmax.
/// softmax = z / l
///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+ if (!hasPureTensorSemantics()) {
+ // The decomposition assumes ranked tensors as input.
+ return failure();
+ }
+
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(*this);
Location loc = getLoc();
Value input = getInput();
ShapedType inputType = getInputOperandType();
Type elementType = inputType.getElementType();
int64_t reductionDim = getDimension();
- SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
Value output = getOutput();
- dims.erase(dims.begin() + reductionDim);
+
+ SmallVector<int64_t> reduceShape;
+ SmallVector<Value> dynReduceDims;
+ for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+ if (reductionDim == i)
+ continue;
+ reduceShape.push_back(inputType.getDimSize(i));
+ if (inputType.isDynamicDim(i))
+ dynReduceDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+ }
+
// Step 1: Compute max along dim.
- Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
- Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
- elementType, b, loc,
- /*useOnlyFiniteValue=*/true);
- Value neutralForMaxFInit =
- b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
- .result();
- Value max =
- reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+ Value outputReduce =
+ b.create<tensor::EmptyOp>(loc, reduceShape, elementType, dynReduceDims);
+ TypedAttr maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+ auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+ auto neutralMaxInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+
+ auto reduceMaxOp = b.create<linalg::ReduceOp>(
+ loc, input, neutralMaxInitOp.result(), reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value result =
+ createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 2: Subtract max from input and exponentiate.
- Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+ auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+ loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+
+ auto subOp = b.create<linalg::SubOp>(
+ loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+ ValueRange{output});
+ auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+ ValueRange{output});
// Step 3: Compute sum along dim.
- Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
- b, loc, /*useOnlyFiniteValue=*/true);
- Value zeroInit =
- b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
- Value denominator =
- reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+ TypedAttr sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+ auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+ auto neutralSumInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+ auto reduceSumOp = b.create<linalg::ReduceOp>(
+ loc, expOp.getResults(), neutralSumInitOp.result(), reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ auto result =
+ createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 4: Compute softmax.
- Value result =
- buildDivOp(b, loc, numerator, denominator, output, reductionDim);
- return SmallVector<Value>{result};
+ SmallVector<Value> dynDims;
+ for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+ if (inputType.isDynamicDim(i))
+ dynDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+ }
+ auto sumBcastOutput = b.create<tensor::EmptyOp>(
+ loc, getOutputOperandType().getShape(), elementType, dynDims);
+ auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
----------------
MaheshRavishankar wrote:
All I am saying is that this is today it is lowered to what I consider a more succinct representation and with this change it does not seem that straightforward to get back to the previous state. So to not break downstream usage, it would be better to keep the default current behavior and introduce some optionality that allows you to lower to named ops. Going from named ops to the previous state does not seem that straight-forward to me (you essentially need to know this is a softmax and handle it accordingly).
W.R.T `fill`, ack your concern, but `fill` has always felt like a "special" op to me. To preserve the perfectly nested loop nature of linalg ops, what is
```
C = A * B
```
is converted to
```
C = zeros(...);
D = C + A*B;
```
but any sane backend needs to "fuse" the fill with the matmul operation to generate efficient code.
Anyway, thats is a bit of a minor digression. My bigger concern is the change of the default that goes to named ops which in-turn enforce (what I consider outdated) explicit-broadcasting semantics (when Linalg has a perfectly succicnt and unambiguous way to represent broadcasts). If named ops allowed for representing broadcasting on-par with what Linalg inhernetly allows, then there would be no issue here.
https://github.com/llvm/llvm-project/pull/97582
More information about the Mlir-commits
mailing list