[Mlir-commits] [mlir] [MLIR][Linalg] Add aggregate ops decomposition pass and softmax decom… (PR #97582)

Petr Kurapov llvmlistbot at llvm.org
Tue Jul 9 06:42:50 PDT 2024


================
@@ -2695,43 +2621,88 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
 /// 4. Divide z and l. This gives the N-dimensional softmax.
 ///    softmax = z / l
 ///
-FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+FailureOr<DecompositionResult> SoftmaxOp::decomposeOperation(OpBuilder &b) {
+  if (!hasPureTensorSemantics()) {
+    // The decomposition assumes ranked tensors as input.
+    return failure();
+  }
+
   OpBuilder::InsertionGuard guard(b);
   b.setInsertionPoint(*this);
   Location loc = getLoc();
   Value input = getInput();
   ShapedType inputType = getInputOperandType();
   Type elementType = inputType.getElementType();
   int64_t reductionDim = getDimension();
-  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
   Value output = getOutput();
-  dims.erase(dims.begin() + reductionDim);
+
+  SmallVector<int64_t> reduceShape;
+  SmallVector<Value> dynReduceDims;
+  for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+    if (reductionDim == i)
+      continue;
+    reduceShape.push_back(inputType.getDimSize(i));
+    if (inputType.isDynamicDim(i))
+      dynReduceDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+  }
+
   // Step 1: Compute max along dim.
-  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
-  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
-                                                 elementType, b, loc,
-                                                 /*useOnlyFiniteValue=*/true);
-  Value neutralForMaxFInit =
-      b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
-          .result();
-  Value max =
-      reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
+  Value outputReduce =
+      b.create<tensor::EmptyOp>(loc, reduceShape, elementType, dynReduceDims);
+  TypedAttr maxFillValAttr = createInitValueForReduceMaxOp(elementType, b);
+  auto maxFillValue = b.create<arith::ConstantOp>(loc, maxFillValAttr);
+  auto neutralMaxInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{maxFillValue}, ValueRange{outputReduce});
+
+  auto reduceMaxOp = b.create<linalg::ReduceOp>(
+      loc, input, neutralMaxInitOp.result(), reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        Value result =
+            createLinalgReduceMaxBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
 
   // Step 2: Subtract max from input and exponentiate.
-  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
+  auto maxBroadcastOp = b.create<linalg::BroadcastOp>(
+      loc, reduceMaxOp.getResult(0), output, reduceMaxOp.getDimensionsAttr());
+
+  auto subOp = b.create<linalg::SubOp>(
+      loc, ValueRange{input, maxBroadcastOp.getResults().front()},
+      ValueRange{output});
+  auto expOp = b.create<linalg::ExpOp>(loc, ValueRange{subOp.getResult(0)},
+                                       ValueRange{output});
 
   // Step 3: Compute sum along dim.
-  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
-                                       b, loc, /*useOnlyFiniteValue=*/true);
-  Value zeroInit =
-      b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
-  Value denominator =
-      reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
+  TypedAttr sumFillValAttr = createInitValueForReduceSumOp(elementType, b);
+  auto sumFillValue = b.create<arith::ConstantOp>(loc, sumFillValAttr);
+  auto neutralSumInitOp = b.create<linalg::FillOp>(
+      loc, ValueRange{sumFillValue}, ValueRange{outputReduce});
+  auto reduceSumOp = b.create<linalg::ReduceOp>(
+      loc, expOp.getResults(), neutralSumInitOp.result(), reductionDim,
+      [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+        auto result =
+            createLinalgReduceSumBody(b, nestedLoc, args, elementType);
+        nestedBuilder.create<linalg::YieldOp>(nestedLoc, result);
+      });
 
   // Step 4: Compute softmax.
-  Value result =
-      buildDivOp(b, loc, numerator, denominator, output, reductionDim);
-  return SmallVector<Value>{result};
+  SmallVector<Value> dynDims;
+  for (unsigned i = 0, e = inputType.getRank(); i < e; i++) {
+    if (inputType.isDynamicDim(i))
+      dynDims.push_back(b.create<tensor::DimOp>(loc, input, i));
+  }
+  auto sumBcastOutput = b.create<tensor::EmptyOp>(
+      loc, getOutputOperandType().getShape(), elementType, dynDims);
+  auto sumBroadcastOp = b.create<linalg::BroadcastOp>(
----------------
kurapov-peter wrote:

What do you mean by that? What is confusing?

https://github.com/llvm/llvm-project/pull/97582


More information about the Mlir-commits mailing list