[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)
Petr Kurapov
llvmlistbot at llvm.org
Tue Jul 9 06:42:50 PDT 2024
================
@@ -2695,43 +2621,88 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
/// 4. Divide z and l. This gives the N-dimensional softmax.
/// softmax = z / l
///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+ if (!hasPureTensorSemantics()) {
+ // The decomposition assumes ranked tensors as input.
+ return failure();
+ }
+
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(*this);
Location loc = getLoc();
Value input = getInput();
ShapedType inputType = getInputOperandType();
Type elementType = inputType.getElementType();
int64_t reductionDim = getDimension();
- SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
Value output = getOutput();
- dims.erase(dims.begin() + reductionDim);
+
+ SmallVector<int64_t> reduceShape;
+ SmallVector<Value> dynReduceDims;
+ for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+ if (reductionDim == i)
+ continue;
+ reduceShape.push_back(inputType.getDimSize(i));
+ if (inputType.isDynamicDim(i))
+ dynReduceDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+ }
+
// Step 1: Compute max along dim.
- Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
- Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
- elementType, b, loc,
- /*useOnlyFiniteValue=*/true);
- Value neutralForMaxFInit =
- b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
- .result();
- Value max =
- reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+ Value outputReduce =
+ b.create<tensor::EmptyOp>(loc, reduceShape, elementType, dynReduceDims);
+ TypedAttr maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+ auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+ auto neutralMaxInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+
+ auto reduceMaxOp = b.create<linalg::ReduceOp>(
+ loc, input, neutralMaxInitOp.result(), reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ Value result =
+ createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 2: Subtract max from input and exponentiate.
- Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+ auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+ loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+
+ auto subOp = b.create<linalg::SubOp>(
+ loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+ ValueRange{output});
+ auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+ ValueRange{output});
// Step 3: Compute sum along dim.
- Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
- b, loc, /*useOnlyFiniteValue=*/true);
- Value zeroInit =
- b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
- Value denominator =
- reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+ TypedAttr sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+ auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+ auto neutralSumInitOp = b.create<linalg::FillOp>(
+ loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+ auto reduceSumOp = b.create<linalg::ReduceOp>(
+ loc, expOp.getResults(), neutralSumInitOp.result(), reductionDim,
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ auto result =
+ createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+ nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+ });
// Step 4: Compute softmax.
- Value result =
- buildDivOp(b, loc, numerator, denominator, output, reductionDim);
- return SmallVector<Value>{result};
+ SmallVector<Value> dynDims;
+ for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+ if (inputType.isDynamicDim(i))
+ dynDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+ }
+ auto sumBcastOutput = b.create<tensor::EmptyOp>(
+ loc, getOutputOperandType().getShape(), elementType, dynDims);
+ auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
----------------
kurapov-peter wrote:
What do you mean by that? What is confusing?
https://github.com/llvm/llvm-project/pull/97582
More information about the Mlir-commits
mailing list