[Mlir-commits] [mlir] convert scfforall to scf for with shared outputs (PR #133032)

Saiem Irfan llvmlistbot at llvm.org
Tue Mar 25 21:12:51 PDT 2025


https://github.com/CursedKeyboard created https://github.com/llvm/llvm-project/pull/133032

@matthias-springer unable to ping the actual guy so I'll ping the name I see the most in mlir.

>From 2b0f5367c90c01fcbd7bee2befbfbfeb0dff3976 Mon Sep 17 00:00:00 2001
From: Saiem Irfan <sirfan at tesla.com>
Date: Tue, 25 Mar 2025 23:36:10 -0400
Subject: [PATCH] convert scfforall to scf for with shared outputs

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |   2 +
 .../Dialect/SCF/Transforms/ForallToFor.cpp    | 105 +++++++++++++++++-
 mlir/test/Dialect/SCF/forall-to-for.mlir      |  39 ++++++-
 3 files changed, 141 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 1cfb866db0b51..e41be8cbc1aa1 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -19,12 +19,14 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::scf;
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index a2f03f1e1056e..c8960039a6ce1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -14,7 +14,12 @@
 
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_SCFFORALLTOFORLOOP
@@ -35,16 +40,108 @@ mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
   SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
   SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
   SmallVector<Value> steps = forallOp.getStep(rewriter);
-  LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
+  SmallVector<Value> iterArgs;
+  for (auto result : forallOp->getResults()) {
+    iterArgs.push_back(forallOp.getTiedOpOperand(result)->get());
+  }
+
+  InParallelOp threadReduction =
+      cast<InParallelOp>(forallOp.getBody()->getTerminator());
+  SmallVector<tensor::ParallelInsertSliceOp> regionArgToSlice;
+  for (auto &op : threadReduction.getBody()->getOperations()) {
+    auto parallelInsert = dyn_cast<tensor::ParallelInsertSliceOp>(op);
+    if (!parallelInsert) {
+      return op.emitOpError() << "expected parallel insert slice op";
+    }
+    regionArgToSlice.push_back(parallelInsert);
+  }
+
+  function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
+      build = [&](OpBuilder &rewriter, Location loc, ValueRange ivs,
+                  ValueRange regionArgs) -> ValueVector {
+    SmallVector<Value> res;
+    for (auto [i, val] : llvm::enumerate(regionArgs)) {
+      tensor::ParallelInsertSliceOp sliceOp = regionArgToSlice[i];
+
+      // Map new induction variables where applicable.
+
+      SmallVector<OpFoldResult> sliceOpOffsets = sliceOp.getMixedOffsets();
+      for (OpFoldResult offset : sliceOpOffsets) {
+        if (offset.is<Value>()) {
+          Value dynamicOffset = offset.get<Value>();
+          SmallVector<Value> originalInductionVars =
+              forallOp.getInductionVars();
+          auto *it = llvm::find(originalInductionVars, dynamicOffset);
+          if (it != originalInductionVars.end()) {
+            size_t index = std::distance(originalInductionVars.begin(), it);
+            offset = ivs[index];
+          }
+        }
+      }
+
+      SmallVector<OpFoldResult> sliceOpSizes = sliceOp.getMixedSizes();
+      for (OpFoldResult size : sliceOpSizes) {
+        if (size.is<Value>()) {
+          Value dynamicSize = size.get<Value>();
+          SmallVector<Value> originalInductionVars =
+              forallOp.getInductionVars();
+          auto *it = llvm::find(originalInductionVars, dynamicSize);
+          if (it != originalInductionVars.end()) {
+            size_t index = std::distance(originalInductionVars.begin(), it);
+            size = ivs[index];
+          }
+        }
+      }
+
+      SmallVector<OpFoldResult> sliceOpStrides = sliceOp.getMixedStrides();
+      for (OpFoldResult stride : sliceOpStrides) {
+        if (stride.is<Value>()) {
+          Value dynamicStride = stride.get<Value>();
+          SmallVector<Value> originalInductionVars =
+              forallOp.getInductionVars();
+          auto *it = llvm::find(originalInductionVars, dynamicStride);
+          if (it != originalInductionVars.end()) {
+            size_t index = std::distance(originalInductionVars.begin(), it);
+            stride = ivs[index];
+          }
+        }
+      }
+
+      res.push_back(rewriter.create<tensor::InsertSliceOp>(
+          sliceOp->getLoc(), sliceOp.getSource(), val, sliceOpOffsets,
+          sliceOpSizes, sliceOpStrides));
+    }
+    return res;
+  };
 
+  // Now we want to create our new loops with the innermost getting the tensor
+  // insert slices appropriately.
+  LoopNest loopNest =
+      scf::buildLoopNest(rewriter, loc, lbs, ubs, steps, iterArgs, build);
   SmallVector<Value> ivs = llvm::map_to_vector(
       loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
 
+  rewriter.replaceAllOpUsesWith(forallOp,
+                                {loopNest.loops.front()->getResults()});
+  // Erase the parallel inserts and associated shared outputs.
+  for (tensor::ParallelInsertSliceOp insertSlice :
+       llvm::make_early_inc_range(regionArgToSlice)) {
+    auto loopBlockArg = dyn_cast<BlockArgument>(insertSlice.getDest());
+    if (!loopBlockArg || loopBlockArg.getOwner()->getParentOp() != forallOp) {
+      insertSlice->emitOpError()
+          << "expected destination to be block argument in loop";
+    }
+    rewriter.eraseOp(insertSlice);
+    rewriter.modifyOpInPlace(forallOp, [&]() {
+      forallOp.getBody()->eraseArgument(loopBlockArg.getArgNumber());
+    });
+  }
+  rewriter.eraseOp(forallOp.getTerminator());
+
   Block *innermostBlock = loopNest.loops.back().getBody();
-  rewriter.eraseOp(forallOp.getBody()->getTerminator());
+
   rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
-                             innermostBlock->getTerminator()->getIterator(),
-                             ivs);
+                             innermostBlock->front().getIterator(), ivs);
   rewriter.eraseOp(forallOp);
 
   if (results) {
diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir
index e7d183fb9d2b5..4d8390f0b62c4 100644
--- a/mlir/test/Dialect/SCF/forall-to-for.mlir
+++ b/mlir/test/Dialect/SCF/forall-to-for.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for,canonicalize))' -split-input-file | FileCheck %s
 
 func.func private @callee(%i: index, %j: index)
 
@@ -55,3 +55,40 @@ func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
   }
   return
 }
+
+// -----
+
+func.func @nested_with_result() -> tensor<4x2xf32> {
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<4x2xf32>
+  %res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
+    %1 = tensor.empty() : tensor<1x1xf32>
+    %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] :
+        tensor<1x1xf32> into tensor<4x2xf32>
+    }
+  }
+  return %res: tensor<4x2xf32>
+}
+
+// CHECK-LABEL:   func.func @nested_with_result() -> tensor<4x2xf32> {
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[C4:.*]] = arith.constant 4 : index
+// CHECK:           %[[FILL:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[REDUCED_RES:.*]] = tensor.empty() : tensor<4x2xf32>
+// CHECK:           %[[OUTER:.*]] = scf.for %[[IV_OUTER:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[OUTER_RES:.*]] = %[[REDUCED_RES]]) -> (tensor<4x2xf32>) {
+// CHECK:             %[[INNER:.*]] = scf.for %[[IV_INNER:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[INNER_RES:.*]] = %[[OUTER_RES]]) -> (tensor<4x2xf32>) {
+// CHECK:               %[[ITERATION_TENS:.*]] = tensor.empty() : tensor<1x1xf32>
+// CHECK:               %[[ITERATION_RES:.*]] = linalg.fill ins(%[[FILL]] : f32) outs(%[[ITERATION_TENS]] : tensor<1x1xf32>) -> tensor<1x1xf32>
+// CHECK:               %[[UPDATED_RES:.*]] = tensor.insert_slice %[[ITERATION_RES]] into %[[INNER_RES]]{{\[}}%[[IV_OUTER]], %[[IV_INNER]]] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<4x2xf32>
+// CHECK:               scf.yield %[[UPDATED_RES]] : tensor<4x2xf32>
+// CHECK:             }
+// CHECK:             scf.yield %[[INNER]] : tensor<4x2xf32>
+// CHECK:           }
+// CHECK:           return %[[OUTER]] : tensor<4x2xf32>
+// CHECK:         }
\ No newline at end of file



More information about the Mlir-commits mailing list