[Mlir-commits] [mlir] convert scfforall to scf for with shared outputs (PR #133032)
Saiem Irfan
llvmlistbot at llvm.org
Tue Mar 25 21:12:51 PDT 2025
https://github.com/CursedKeyboard created https://github.com/llvm/llvm-project/pull/133032
@matthias-springer unable to ping the actual guy so I'll ping the name I see the most in mlir.
>From 2b0f5367c90c01fcbd7bee2befbfbfeb0dff3976 Mon Sep 17 00:00:00 2001
From: Saiem Irfan <sirfan at tesla.com>
Date: Tue, 25 Mar 2025 23:36:10 -0400
Subject: [PATCH] convert scfforall to scf for with shared outputs
---
mlir/lib/Dialect/SCF/IR/SCF.cpp | 2 +
.../Dialect/SCF/Transforms/ForallToFor.cpp | 105 +++++++++++++++++-
mlir/test/Dialect/SCF/forall-to-for.mlir | 39 ++++++-
3 files changed, 141 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 1cfb866db0b51..e41be8cbc1aa1 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -19,12 +19,14 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <optional>
using namespace mlir;
using namespace mlir::scf;
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index a2f03f1e1056e..c8960039a6ce1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -14,7 +14,12 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
@@ -35,16 +40,108 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
SmallVector<Value> steps = forallOp.getStep(rewriter);
- LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
+ SmallVector<Value> iterArgs;
+ for (auto result : forallOp->getResults()) {
+ iterArgs.push_back(forallOp.getTiedOpOperand(result)->get());
+ }
+
+ InParallelOp threadReduction =
+ cast<InParallelOp>(forallOp.getBody()->getTerminator());
+ SmallVector<tensor::ParallelInsertSliceOp> regionArgToSlice;
+ for (auto &op : threadReduction.getBody()->getOperations()) {
+ auto parallelInsert = dyn_cast<tensor::ParallelInsertSliceOp>(op);
+ if (!parallelInsert) {
+ return op.emitOpError() << "expected parallel insert slice op";
+ }
+ regionArgToSlice.push_back(parallelInsert);
+ }
+
+ function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
+ build = [&](OpBuilder &rewriter, Location loc, ValueRange ivs,
+ ValueRange regionArgs) -> ValueVector {
+ SmallVector<Value> res;
+ for (auto [i, val] : llvm::enumerate(regionArgs)) {
+ tensor::ParallelInsertSliceOp sliceOp = regionArgToSlice[i];
+
+ // Map new induction variables where applicable.
+
+ SmallVector<OpFoldResult> sliceOpOffsets = sliceOp.getMixedOffsets();
+ for (OpFoldResult offset : sliceOpOffsets) {
+ if (offset.is<Value>()) {
+ Value dynamicOffset = offset.get<Value>();
+ SmallVector<Value> originalInductionVars =
+ forallOp.getInductionVars();
+ auto *it = llvm::find(originalInductionVars, dynamicOffset);
+ if (it != originalInductionVars.end()) {
+ size_t index = std::distance(originalInductionVars.begin(), it);
+ offset = ivs[index];
+ }
+ }
+ }
+
+ SmallVector<OpFoldResult> sliceOpSizes = sliceOp.getMixedSizes();
+ for (OpFoldResult size : sliceOpSizes) {
+ if (size.is<Value>()) {
+ Value dynamicSize = size.get<Value>();
+ SmallVector<Value> originalInductionVars =
+ forallOp.getInductionVars();
+ auto *it = llvm::find(originalInductionVars, dynamicSize);
+ if (it != originalInductionVars.end()) {
+ size_t index = std::distance(originalInductionVars.begin(), it);
+ size = ivs[index];
+ }
+ }
+ }
+
+ SmallVector<OpFoldResult> sliceOpStrides = sliceOp.getMixedStrides();
+ for (OpFoldResult stride : sliceOpStrides) {
+ if (stride.is<Value>()) {
+ Value dynamicStride = stride.get<Value>();
+ SmallVector<Value> originalInductionVars =
+ forallOp.getInductionVars();
+ auto *it = llvm::find(originalInductionVars, dynamicStride);
+ if (it != originalInductionVars.end()) {
+ size_t index = std::distance(originalInductionVars.begin(), it);
+ stride = ivs[index];
+ }
+ }
+ }
+
+ res.push_back(rewriter.create<tensor::InsertSliceOp>(
+ sliceOp->getLoc(), sliceOp.getSource(), val, sliceOpOffsets,
+ sliceOpSizes, sliceOpStrides));
+ }
+ return res;
+ };
+ // Now we want to create our new loops with the innermost getting the tensor
+ // insert slices appropriately.
+ LoopNest loopNest =
+ scf::buildLoopNest(rewriter, loc, lbs, ubs, steps, iterArgs, build);
SmallVector<Value> ivs = llvm::map_to_vector(
loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
+ rewriter.replaceAllOpUsesWith(forallOp,
+ {loopNest.loops.front()->getResults()});
+ // Erase the parallel inserts and associated shared outputs.
+ for (tensor::ParallelInsertSliceOp insertSlice :
+ llvm::make_early_inc_range(regionArgToSlice)) {
+ auto loopBlockArg = dyn_cast<BlockArgument>(insertSlice.getDest());
+ if (!loopBlockArg || loopBlockArg.getOwner()->getParentOp() != forallOp) {
+ insertSlice->emitOpError()
+ << "expected destination to be block argument in loop";
+ }
+ rewriter.eraseOp(insertSlice);
+ rewriter.modifyOpInPlace(forallOp, [&]() {
+ forallOp.getBody()->eraseArgument(loopBlockArg.getArgNumber());
+ });
+ }
+ rewriter.eraseOp(forallOp.getTerminator());
+
Block *innermostBlock = loopNest.loops.back().getBody();
- rewriter.eraseOp(forallOp.getBody()->getTerminator());
+
rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
- innermostBlock->getTerminator()->getIterator(),
- ivs);
+ innermostBlock->front().getIterator(), ivs);
rewriter.eraseOp(forallOp);
if (results) {
diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
index e7d183fb9d2b5..4d8390f0b62c4 100644
--- a/mlir/test/Dialect/SCF/forall-to-for.mlir
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for,canonicalize))' -split-input-file | FileCheck %s
func.func private @callee(%i: index, %j: index)
@@ -55,3 +55,40 @@ func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
}
return
}
+
+// -----
+
+func.func @nested_with_result() -> tensor<4x2xf32> {
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<4x2xf32>
+ %res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
+ %1 = tensor.empty() : tensor<1x1xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] :
+ tensor<1x1xf32> into tensor<4x2xf32>
+ }
+ }
+ return %res: tensor<4x2xf32>
+}
+
+// CHECK-LABEL: func.func @nested_with_result() -> tensor<4x2xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[FILL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[REDUCED_RES:.*]] = tensor.empty() : tensor<4x2xf32>
+// CHECK: %[[OUTER:.*]] = scf.for %[[IV_OUTER:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[OUTER_RES:.*]] = %[[REDUCED_RES]]) -> (tensor<4x2xf32>) {
+// CHECK: %[[INNER:.*]] = scf.for %[[IV_INNER:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[INNER_RES:.*]] = %[[OUTER_RES]]) -> (tensor<4x2xf32>) {
+// CHECK: %[[ITERATION_TENS:.*]] = tensor.empty() : tensor<1x1xf32>
+// CHECK: %[[ITERATION_RES:.*]] = linalg.fill ins(%[[FILL]] : f32) outs(%[[ITERATION_TENS]] : tensor<1x1xf32>) -> tensor<1x1xf32>
+// CHECK: %[[UPDATED_RES:.*]] = tensor.insert_slice %[[ITERATION_RES]] into %[[INNER_RES]]{{\[}}%[[IV_OUTER]], %[[IV_INNER]]] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<4x2xf32>
+// CHECK: scf.yield %[[UPDATED_RES]] : tensor<4x2xf32>
+// CHECK: }
+// CHECK: scf.yield %[[INNER]] : tensor<4x2xf32>
+// CHECK: }
+// CHECK: return %[[OUTER]] : tensor<4x2xf32>
+// CHECK: }
\ No newline at end of file
More information about the Mlir-commits
mailing list