[Mlir-commits] [mlir] 3a8f161 - [mlir] Add a pattern to fold single- and zero-iteration scf.forall ops.
Alexander Belyaev
llvmlistbot at llvm.org
Tue Mar 21 03:59:39 PDT 2023
Author: Alexander Belyaev
Date: 2023-03-21T11:59:25+01:00
New Revision: 3a8f161a3401edeb58e018e2d389dd2413a6417f
URL: https://github.com/llvm/llvm-project/commit/3a8f161a3401edeb58e018e2d389dd2413a6417f
DIFF: https://github.com/llvm/llvm-project/commit/3a8f161a3401edeb58e018e2d389dd2413a6417f.diff
LOG: [mlir] Add a pattern to fold single- and zero-iteration scf.forall ops.
Differential Revision: https://reviews.llvm.org/D145368
Added:
Modified:
mlir/include/mlir/Dialect/SCF/IR/SCF.h
mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
mlir/test/Dialect/SCF/canonicalize.mlir
mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 7f714d0a07646..cb399b78c406d 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -62,6 +62,14 @@ ForallOp getForallOpThreadIndexOwner(Value val);
// TODO: Consider moving this functionality to RegionBranchOpInterface.
bool insideMutuallyExclusiveBranches(Operation *a, Operation *b);
+/// Promotes the loop body of a scf::ForallOp to its containing block if the
+/// loop was known to have a single iteration.
+LogicalResult promoteIfSingleIteration(PatternRewriter &rewriter,
+ scf::ForallOp forallOp);
+
+/// Promotes the loop body of a scf::ForallOp to its containing block.
+void promote(PatternRewriter &rewriter, scf::ForallOp forallOp);
+
/// An owning vector of values, handy to return from functions.
using ValueVector = SmallVector<Value>;
using LoopVector = SmallVector<scf::ForOp>;
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 27c27756b3918..47910e2069761 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -128,6 +128,11 @@ SmallVector<int64_t>
getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
llvm::function_ref<bool(Attribute, Attribute)> compare);
+/// Return the number of iterations for a loop with a lower bound `lb`, upper
+/// bound `ub` and step `step`.
+std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
+ OpFoldResult step);
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 4e7bcc499be3d..e212159442844 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -534,6 +534,61 @@ void ForOp::getSuccessorRegions(std::optional<unsigned> index,
regions.push_back(RegionSuccessor(getResults()));
}
+/// Promotes the loop body of a forallOp to its containing block if it can be
+/// determined that the loop has a single iteration.
+LogicalResult mlir::scf::promoteIfSingleIteration(PatternRewriter &rewriter,
+ scf::ForallOp forallOp) {
+ for (auto [lb, ub, step] :
+ llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep())) {
+ auto tripCount = constantTripCount(lb, ub, step);
+ if (!tripCount.has_value() || *tripCount != 1)
+ return failure();
+ }
+
+ promote(rewriter, forallOp);
+ return success();
+}
+
+/// Promotes the loop body of a scf::ForallOp to its containing block.
+void mlir::scf::promote(PatternRewriter &rewriter, scf::ForallOp forallOp) {
+ IRMapping mapping;
+ mapping.map(forallOp.getInductionVars(), forallOp.getLowerBound(rewriter));
+ mapping.map(forallOp.getOutputBlockArguments(), forallOp.getOutputs());
+ for (auto &bodyOp : forallOp.getBody()->without_terminator())
+ rewriter.clone(bodyOp, mapping);
+
+ SmallVector<Value> results;
+ results.reserve(forallOp.getResults().size());
+ scf::InParallelOp terminator = forallOp.getTerminator();
+ for (auto &yieldingOp : terminator.getYieldingOps()) {
+ auto parallelInsertSliceOp =
+ cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+
+ Value dst = parallelInsertSliceOp.getDest();
+ Value src = parallelInsertSliceOp.getSource();
+
+ auto getMappedValues = [&](ValueRange values) {
+ return llvm::to_vector(llvm::map_range(
+ values, [&](Value value) { return mapping.lookupOrDefault(value); }));
+ };
+
+ Value srcVal = mapping.lookupOrDefault(src);
+ if (srcVal.getType().isa<TensorType>()) {
+ results.push_back(rewriter.create<tensor::InsertSliceOp>(
+ forallOp.getLoc(), dst.getType(), srcVal,
+ mapping.lookupOrDefault(dst),
+ getMappedValues(parallelInsertSliceOp.getOffsets()),
+ getMappedValues(parallelInsertSliceOp.getSizes()),
+ getMappedValues(parallelInsertSliceOp.getStrides()),
+ parallelInsertSliceOp.getStaticOffsets(),
+ parallelInsertSliceOp.getStaticSizes(),
+ parallelInsertSliceOp.getStaticStrides()));
+ }
+ }
+ rewriter.replaceOp(forallOp, results);
+}
+
LoopNest mlir::scf::buildLoopNest(
OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
ValueRange steps, ValueRange iterArgs,
@@ -1452,16 +1507,99 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep);
op.getDynamicStepMutable().assign(dynamicStep);
op.setStaticStep(staticStep);
+
+ op->setAttr(ForallOp::getOperandSegmentSizeAttr(),
+ rewriter.getDenseI32ArrayAttr(
+ {static_cast<int32_t>(dynamicLowerBound.size()),
+ static_cast<int32_t>(dynamicUpperBound.size()),
+ static_cast<int32_t>(dynamicStep.size()),
+ static_cast<int32_t>(op.getNumResults())}));
});
return success();
}
};
+struct ForallOpSingleOrZeroIterationDimsFolder
+ : public OpRewritePattern<ForallOp> {
+ using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForallOp op,
+ PatternRewriter &rewriter) const override {
+ // Do not fold dimensions if they are mapped to processing units.
+ if (op.getMapping().has_value())
+ return failure();
+ Location loc = op.getLoc();
+
+ // Compute new loop bounds that omit all single-iteration loop dimensions.
+ SmallVector<OpFoldResult> newMixedLowerBounds, newMixedUpperBounds,
+ newMixedSteps;
+ IRMapping mapping;
+ for (auto [lb, ub, step, iv] :
+ llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
+ op.getMixedStep(), op.getInductionVars())) {
+ auto numIterations = constantTripCount(lb, ub, step);
+ if (numIterations.has_value()) {
+ // Remove the loop if it performs zero iterations.
+ if (*numIterations == 0) {
+ rewriter.replaceOp(op, op.getOutputs());
+ return success();
+ }
+ // Replace the loop induction variable by the lower bound if the loop
+ // performs a single iteration. Otherwise, copy the loop bounds.
+ if (*numIterations == 1) {
+ mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+ continue;
+ }
+ }
+ newMixedLowerBounds.push_back(lb);
+ newMixedUpperBounds.push_back(ub);
+ newMixedSteps.push_back(step);
+ }
+ // Exit if none of the loop dimensions perform a single iteration.
+ if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
+ return rewriter.notifyMatchFailure(
+ op, "no dimensions have 0 or 1 iterations");
+ }
+
+ // All of the loop dimensions perform a single iteration. Inline loop body.
+ if (newMixedLowerBounds.empty()) {
+ promote(rewriter, op);
+ return success();
+ }
+
+ // Replace the loop by a lower-dimensional loop.
+ ForallOp newOp;
+ newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
+ newMixedUpperBounds, newMixedSteps,
+ op.getOutputs(), std::nullopt, nullptr);
+ newOp.getBodyRegion().getBlocks().clear();
+ // The new loop needs to keep all attributes from the old one, except for
+ // "operand_segment_sizes" and static loop bound attributes which capture
+ // the outdated information of the old iteration domain.
+ SmallVector<StringAttr> elidedAttrs{newOp.getOperandSegmentSizesAttrName(),
+ newOp.getStaticLowerBoundAttrName(),
+ newOp.getStaticUpperBoundAttrName(),
+ newOp.getStaticStepAttrName()};
+ for (const auto &namedAttr : op->getAttrs()) {
+ if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
+ continue;
+ rewriter.updateRootInPlace(newOp, [&]() {
+ newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
+ });
+ }
+ rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(),
+ newOp.getRegion().begin(), mapping);
+ rewriter.replaceOp(op, newOp.getResults());
+ return success();
+ }
+};
+
} // namespace
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfForallOp, ForallOpControlOperandsFolder>(context);
+ results.add<DimOfForallOp, ForallOpControlOperandsFolder,
+ ForallOpSingleOrZeroIterationDimsFolder>(context);
}
//===----------------------------------------------------------------------===//
@@ -2615,41 +2753,37 @@ ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
namespace {
// Collapse loop dimensions that perform a single iteration.
-struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
+struct ParallelOpSingleOrZeroIterationDimsFolder
+ : public OpRewritePattern<ParallelOp> {
using OpRewritePattern<ParallelOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ParallelOp op,
PatternRewriter &rewriter) const override {
- IRMapping mapping;
+ Location loc = op.getLoc();
+
// Compute new loop bounds that omit all single-iteration loop dimensions.
- SmallVector<Value, 2> newLowerBounds;
- SmallVector<Value, 2> newUpperBounds;
- SmallVector<Value, 2> newSteps;
- newLowerBounds.reserve(op.getLowerBound().size());
- newUpperBounds.reserve(op.getUpperBound().size());
- newSteps.reserve(op.getStep().size());
- for (auto [lowerBound, upperBound, step, iv] :
+ SmallVector<Value> newLowerBounds, newUpperBounds, newSteps;
+ IRMapping mapping;
+ for (auto [lb, ub, step, iv] :
llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(),
op.getInductionVars())) {
- // Collect the statically known loop bounds.
- auto lowerBoundConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp());
- auto upperBoundConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp());
- auto stepConstant =
- dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp());
- // Replace the loop induction variable by the lower bound if the loop
- // performs a single iteration. Otherwise, copy the loop bounds.
- if (lowerBoundConstant && upperBoundConstant && stepConstant &&
- (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 &&
- (upperBoundConstant.value() - lowerBoundConstant.value()) <=
- stepConstant.value()) {
- mapping.map(iv, lowerBound);
- } else {
- newLowerBounds.push_back(lowerBound);
- newUpperBounds.push_back(upperBound);
- newSteps.push_back(step);
+ auto numIterations = constantTripCount(lb, ub, step);
+ if (numIterations.has_value()) {
+ // Remove the loop if it performs zero iterations.
+ if (*numIterations == 0) {
+ rewriter.replaceOp(op, op.getInitVals());
+ return success();
+ }
+ // Replace the loop induction variable by the lower bound if the loop
+ // performs a single iteration. Otherwise, copy the loop bounds.
+ if (*numIterations == 1) {
+ mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb));
+ continue;
+ }
}
+ newLowerBounds.push_back(lb);
+ newUpperBounds.push_back(ub);
+ newSteps.push_back(step);
}
// Exit if none of the loop dimensions perform a single iteration.
if (newLowerBounds.size() == op.getLowerBound().size())
@@ -2694,23 +2828,6 @@ struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
}
};
-/// Removes parallel loops in which at least one lower/upper bound pair consists
-/// of the same values - such loops have an empty iteration domain.
-struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
- using OpRewritePattern<ParallelOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ParallelOp op,
- PatternRewriter &rewriter) const override {
- for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) {
- if (std::get<0>(dim) == std::get<1>(dim)) {
- rewriter.replaceOp(op, op.getInitVals());
- return success();
- }
- }
- return failure();
- }
-};
-
struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
using OpRewritePattern<ParallelOp>::OpRewritePattern;
@@ -2773,8 +2890,9 @@ struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
- MergeNestedParallelLoops>(context);
+ results
+ .add<ParallelOpSingleOrZeroIterationDimsFolder, MergeNestedParallelLoops>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6eca0ef9f69cf..e16e2881185a9 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -381,18 +381,12 @@ static void replaceIterArgsAndYieldResults(scf::ForOp forOp) {
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) {
- auto lbCstOp = forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
- auto ubCstOp = forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
- auto stepCstOp = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
- if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 ||
- ubCstOp.value() < 0 || stepCstOp.value() < 0)
- return failure();
- int64_t tripCount =
- mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value());
- if (tripCount != 1)
+ std::optional<int64_t> tripCount = constantTripCount(
+ forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
+ if (!tripCount.has_value() || tripCount != 1)
return failure();
auto iv = forOp.getInductionVar();
- iv.replaceAllUsesWith(lbCstOp);
+ iv.replaceAllUsesWith(forOp.getLowerBound());
replaceIterArgsAndYieldResults(forOp);
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index e646de95a76c9..45edd5f89ffed 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/APSInt.h"
namespace mlir {
@@ -228,4 +229,24 @@ getValuesSortedByKey(ArrayRef<Attribute> keys, ArrayRef<int64_t> values,
return getValuesSortedByKeyImpl(keys, values, compare);
}
+/// Return the number of iterations for a loop with a lower bound `lb`, upper
+/// bound `ub` and step `step`.
+std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
+ OpFoldResult step) {
+ if (lb == ub)
+ return 0;
+
+ std::optional<int64_t> lbConstant = getConstantIntValue(lb);
+ if (!lbConstant)
+ return std::nullopt;
+ std::optional<int64_t> ubConstant = getConstantIntValue(ub);
+ if (!ubConstant)
+ return std::nullopt;
+ std::optional<int64_t> stepConstant = getConstantIntValue(step);
+ if (!stepConstant)
+ return std::nullopt;
+
+ return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant);
+}
+
} // namespace mlir
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index a3ce8a63d4c9f..f69cf196597e2 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1544,3 +1544,110 @@ func.func @forall_fold_control_operands(
return %result : tensor<?x10xf32>
}
// CHECK: forall (%{{.*}}, %{{.*}}) in (%{{.*}}, 10)
+
+// -----
+
+func.func @inline_forall_loop(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<8x8xf32>
+ %1 = scf.forall (%i, %j) = (%c0, %c0) to (%c1, %c1)
+ step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+ %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
+ : tensor<8x8xf32> to tensor<2x3xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>)
+ -> tensor<2x3xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1]
+ : tensor<2x3xf32> into tensor<8x8xf32>
+ }
+ }
+ return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @inline_forall_loop
+// CHECK-NOT: scf.forall
+// CHECK: %[[OUT:.*]] = tensor.empty
+
+// CHECK-NEXT: %[[SLICE:.*]] = tensor.extract_slice %[[OUT]]
+// CHECK-SAME: : tensor<8x8xf32> to tensor<2x3xf32>
+
+// CHECK-NEXT: %[[FILL:.*]] = linalg.fill
+// CHECK-SAME: outs(%[[SLICE]]
+
+// CHECK-NEXT: tensor.insert_slice %[[FILL]]
+// CHECK-SAME: : tensor<2x3xf32> into tensor<8x8xf32>
+
+// -----
+
+func.func @do_not_inline_distributed_forall_loop(
+ %in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<8x8xf32>
+ %1 = scf.forall (%i, %j) = (0, 0) to (1, 1) step (8, 8)
+ shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+ %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1]
+ : tensor<8x8xf32> to tensor<2x3xf32>
+ %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>)
+ -> tensor<2x3xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1]
+ : tensor<2x3xf32> into tensor<8x8xf32>
+ }
+ }{ mapping = [#gpu.thread<y>, #gpu.thread<x>] }
+ return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @do_not_inline_distributed_forall_loop
+// CHECK: scf.forall
+
+// -----
+
+func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<8x8xf32>
+ %1 = scf.forall (%i, %j) = (0, %c0) to (1, %c16)
+ step (8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+ %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>)
+ -> tensor<8x8xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1]
+ : tensor<8x8xf32> into tensor<8x8xf32>
+ }
+ }
+ return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @collapse_one_dim_parallel
+// CHECK: scf.forall (%[[ARG:.*]]) = (0) to (16) step (8)
+// CHECK: linalg.fill
+// CHECK: tensor.parallel_insert_slice
+
+// -----
+
+func.func @remove_empty_forall(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %c8 = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<8x8xf32>
+ %1 = scf.forall (%i, %j) = (%c0, %c16) to (%c1, %c16)
+ step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) {
+ %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>)
+ -> tensor<8x8xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1]
+ : tensor<8x8xf32> into tensor<8x8xf32>
+ }
+ }
+ return %1 : tensor<8x8xf32>
+}
+// CHECK-LABEL: @remove_empty_forall
+// CHECK-NOT: scf.forall
+// CHECK: %[[EMPTY:.*]] = tensor.empty
+// CHECK: return %[[EMPTY]]
+
diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
index 2358ddeb5b01b..750a8d0edf0e2 100644
--- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
+++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
@@ -86,7 +86,7 @@ func.func @insert_slice_rank_reducing_dynamic_shape(
// CHECK-LABEL: func.func @parallel_insert_slice
// CHECK-NOT: tensor.insert_slice
-// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<1x2xf32>
+// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[0, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<1x2xf32>
func.func @parallel_insert_slice(%t0: tensor<1x2xf32>, %t1: tensor<f32>, %t2: tensor<1x1xf32>) -> tensor<1x2xf32> {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
More information about the Mlir-commits
mailing list