[Mlir-commits] [mlir] [mlir][linalg] convert arith ops to destination-passing-style. (PR #157854)
Jakub Kuderski
llvmlistbot at llvm.org
Sun Sep 14 17:55:29 PDT 2025
================
@@ -603,6 +610,94 @@ Value linalg::bufferizeToAllocation(
}
namespace {
+/// Rewrites an arith op operating on tensors, e.g.
+/// `%z = arith.addf %x, %y : tensor<5xf32>`
+/// into an equivalent linalg.generic in destination-passing-style.
+/// ```mlir
+/// %0 = tensor.empty() : tensor<5xf32>
+/// %1 = linalg.generic ...
+/// ins(%x, %y : tensor<5xf32>, tensor<5xf32>)
+/// outs(%0 : tensor<5xf32>) {
+/// ^bb0(%in: f32, %in_0: f32, %out: f32):
+/// %2 = arith.addf %in, %in_0 : f32
+/// linalg.yield %2 : f32
+/// } -> tensor<5xf32>
+template <typename OpTy>
+FailureOr<Operation *>
+rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
+ // Reject ops such as `arith.constant` and `arith.select`.
+ // constants don't need dps conversion and select is a a `todo`.
+ auto numOperands = op->getNumOperands();
+ if (numOperands == 0 || numOperands > 2)
+ return failure();
+
+ // destination passing style rewrite is only for ops on tensor types.
+ Type resultType = op->getResult(0).getType();
+ auto tensorType = dyn_cast<RankedTensorType>(resultType);
+ if (!tensorType)
+ return failure();
+
+ auto loc = op.getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
+
+ // Create tensor.empty for `outs` of destination-passing-style.
+ Value outs = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+
+ // Create linalg.generic
+ auto rank = tensorType.getRank();
+ SmallVector<AffineMap> indexingMaps(numOperands + 1,
+ rewriter.getMultiDimIdentityMap(rank));
+ SmallVector<utils::IteratorType> iteratorTypes(rank,
+ utils::IteratorType::parallel);
+
+ // Check 'fast-math'. If present, propagate it.
+ auto fmfOpInterface =
+ llvm::dyn_cast<arith::ArithFastMathInterface>(op.getOperation());
+
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, tensorType,
+ op->getOperands(), // inputs
+ ValueRange{outs}, // outputs
----------------
kuhar wrote:
Use `/*argName=*/ param` -- some tools can then warn if the names don't match
https://github.com/llvm/llvm-project/pull/157854
More information about the Mlir-commits
mailing list