[Mlir-commits] [mlir] [mlir][linalg] convert arith ops to destination-passing-style. (PR #157854)

Jakub Kuderski llvmlistbot at llvm.org
Sun Sep 14 17:55:29 PDT 2025


================
@@ -603,6 +610,94 @@ Value linalg::bufferizeToAllocation(
 }
 
 namespace {
+/// Rewrites an arith op operating on tensors, e.g.
+///  `%z = arith.addf %x, %y : tensor<5xf32>`
+/// into an equivalent linalg.generic in destination-passing-style.
+/// ```mlir
+/// %0 = tensor.empty() : tensor<5xf32>
+/// %1 = linalg.generic ...
+///        ins(%x, %y : tensor<5xf32>, tensor<5xf32>)
+///        outs(%0 : tensor<5xf32>) {
+///      ^bb0(%in: f32, %in_0: f32, %out: f32):
+///         %2 = arith.addf %in, %in_0 : f32
+///         linalg.yield %2 : f32
+///     } -> tensor<5xf32>
+template <typename OpTy>
+FailureOr<Operation *>
+rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
+  // Reject ops such as `arith.constant` and `arith.select`.
+  // constants don't need dps conversion and select is a a `todo`.
+  auto numOperands = op->getNumOperands();
+  if (numOperands == 0 || numOperands > 2)
+    return failure();
+
+  // destination passing style rewrite is only for ops on tensor types.
+  Type resultType = op->getResult(0).getType();
+  auto tensorType = dyn_cast<RankedTensorType>(resultType);
+  if (!tensorType)
+    return failure();
+
+  auto loc = op.getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
+
+  // Create tensor.empty for `outs` of destination-passing-style.
+  Value outs = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+
+  // Create linalg.generic
+  auto rank = tensorType.getRank();
+  SmallVector<AffineMap> indexingMaps(numOperands + 1,
+                                      rewriter.getMultiDimIdentityMap(rank));
+  SmallVector<utils::IteratorType> iteratorTypes(rank,
+                                                 utils::IteratorType::parallel);
+
+  // Check 'fast-math'. If present, propagate it.
+  auto fmfOpInterface =
+      llvm::dyn_cast<arith::ArithFastMathInterface>(op.getOperation());
+
+  auto genericOp = linalg::GenericOp::create(
+      rewriter, loc, tensorType,
+      op->getOperands(), // inputs
+      ValueRange{outs},  // outputs
+      indexingMaps, iteratorTypes,
+      [&](OpBuilder &builder, Location loc, ValueRange args) {
+        Value res;
+        if (args.size() == 2) {
+          if (fmfOpInterface) {
+            auto attr = fmfOpInterface.getFastMathFlagsAttr();
+            auto fmf = rewriter.getNamedAttr("fastmath", attr);
+            res = builder
+                      .create<OpTy>(loc, args[1].getType(), ValueRange{args[0]},
+                                    fmf)
+                      .getResult();
+          } else {
+            res = builder
+                      .create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
+                      .getResult();
+          }
+        } else if (args.size() == 3) {
+          if (fmfOpInterface) {
+            auto attr = fmfOpInterface.getFastMathFlagsAttr();
+            auto fmf = rewriter.getNamedAttr("fastmath", attr);
+            res = builder
+                      .create<OpTy>(loc, args[2].getType(),
+                                    ValueRange{args[0], args[1]}, fmf)
+                      .getResult();
+          } else {
+            res = builder
+                      .create<OpTy>(loc, args[2].getType(),
+                                    ValueRange{args[0], args[1]})
+                      .getResult();
+          }
+        } else
+          llvm_unreachable("did not expect ops other than nary and binary");
+        linalg::YieldOp::create(builder, loc, res);
----------------
kuhar wrote:

nit: I think the convention is that if one branch if/else branch uses braces, all the other ones should too

```suggestion
        } else {
          llvm_unreachable("did not expect ops other than nary and binary");
        }
        linalg::YieldOp::create(builder, loc, res);
```


https://github.com/llvm/llvm-project/pull/157854


More information about the Mlir-commits mailing list