[Mlir-commits] [mlir] [mlir][vector] Add pattern for dropping unit dims from for loops (PR #109585)
Quinn Dawkins
llvmlistbot at llvm.org
Sun Sep 22 11:02:54 PDT 2024
https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/109585
This adds a pattern for dropping unit dims from the iter_args of scf.for ops using vector.shape_cast. This composes with the other patterns for dropping unit dims from elementwise ops and transposes.
>From 8f3b21204806d74190ca3e43879e5b0e703306a4 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sun, 22 Sep 2024 13:27:04 -0400
Subject: [PATCH] [mlir][vector] Add pattern for dropping unit dims from for
loops
This adds a pattern for dropping unit dims from the iter_args of scf.for
ops using vector.shape_cast. This composes with the other patterns for
dropping unit dims from elementwise ops and transposes.
---
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 11 ++
mlir/lib/Dialect/SCF/IR/SCF.cpp | 142 +++++++++---------
.../Vector/Transforms/VectorTransforms.cpp | 60 +++++++-
.../drop-unit-dims-with-shape-cast.mlir | 75 +++++++++
4 files changed, 215 insertions(+), 73 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index 644118ca884c6b..d89d566ece62c1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -107,6 +107,17 @@ LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
function_ref<void(OpBuilder &, Location, ValueRange)>
bodyBuilder = nullptr);
+/// Perform a replacement of one iter OpOperand of an scf.for to the
+/// `replacement` value with a different type. A callback is used to insert
+/// cast ops inside the block to account for type differences.
+using ValueTypeCastFnTy =
+ std::function<Value(OpBuilder &, Location loc, Type, Value)>;
+SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
+ scf::ForOp forOp,
+ OpOperand &operand,
+ Value replacement,
+ const ValueTypeCastFnTy &castFn);
+
} // namespace scf
} // namespace mlir
#endif // MLIR_DIALECT_SCF_SCF_H
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 6d47ff3890977a..d1c9fd2d217dad 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -772,6 +772,70 @@ LoopNest mlir::scf::buildLoopNest(
});
}
+SmallVector<Value>
+mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
+ OpOperand &operand, Value replacement,
+ const ValueTypeCastFnTy &castFn) {
+ assert(operand.getOwner() == forOp);
+ Type oldType = operand.get().getType(), newType = replacement.getType();
+
+ // 1. Create new iter operands, exactly 1 is replaced.
+ assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
+ "expected an iter OpOperand");
+ assert(operand.get().getType() != replacement.getType() &&
+ "Expected a different type");
+ SmallVector<Value> newIterOperands;
+ for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
+ if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
+ newIterOperands.push_back(replacement);
+ continue;
+ }
+ newIterOperands.push_back(opOperand.get());
+ }
+
+ // 2. Create the new forOp shell.
+ scf::ForOp newForOp = rewriter.create<scf::ForOp>(
+ forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
+ forOp.getStep(), newIterOperands);
+ newForOp->setAttrs(forOp->getAttrs());
+ Block &newBlock = newForOp.getRegion().front();
+ SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
+ newBlock.getArguments().end());
+
+ // 3. Inject an incoming cast op at the beginning of the block for the bbArg
+ // corresponding to the `replacement` value.
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(&newBlock, newBlock.begin());
+ BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
+ &newForOp->getOpOperand(operand.getOperandNumber()));
+ Value castIn = castFn(rewriter, newForOp.getLoc(), oldType, newRegionIterArg);
+ newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
+
+ // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
+ Block &oldBlock = forOp.getRegion().front();
+ rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
+
+ // 5. Inject an outgoing cast op at the end of the block and yield it instead.
+ auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+ rewriter.setInsertionPoint(clonedYieldOp);
+ unsigned yieldIdx =
+ newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
+ Value castOut = castFn(rewriter, newForOp.getLoc(), newType,
+ clonedYieldOp.getOperand(yieldIdx));
+ SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
+ newYieldOperands[yieldIdx] = castOut;
+ rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
+ rewriter.eraseOp(clonedYieldOp);
+
+ // 6. Inject an outgoing cast op after the forOp.
+ rewriter.setInsertionPointAfter(newForOp);
+ SmallVector<Value> newResults = newForOp.getResults();
+ newResults[yieldIdx] =
+ castFn(rewriter, newForOp.getLoc(), oldType, newResults[yieldIdx]);
+
+ return newResults;
+}
+
namespace {
// Fold away ForOp iter arguments when:
// 1) The op yields the iter arguments.
@@ -973,76 +1037,6 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
}
};
-/// Perform a replacement of one iter OpOperand of an scf.for to the
-/// `replacement` value which is expected to be the source of a tensor.cast.
-/// tensor.cast ops are inserted inside the block to account for the type cast.
-static SmallVector<Value>
-replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
- Value replacement) {
- Type oldType = operand.get().getType(), newType = replacement.getType();
- assert(llvm::isa<RankedTensorType>(oldType) &&
- llvm::isa<RankedTensorType>(newType) &&
- "expected ranked tensor types");
-
- // 1. Create new iter operands, exactly 1 is replaced.
- ForOp forOp = cast<ForOp>(operand.getOwner());
- assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
- "expected an iter OpOperand");
- assert(operand.get().getType() != replacement.getType() &&
- "Expected a different type");
- SmallVector<Value> newIterOperands;
- for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
- if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
- newIterOperands.push_back(replacement);
- continue;
- }
- newIterOperands.push_back(opOperand.get());
- }
-
- // 2. Create the new forOp shell.
- scf::ForOp newForOp = rewriter.create<scf::ForOp>(
- forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newIterOperands);
- newForOp->setAttrs(forOp->getAttrs());
- Block &newBlock = newForOp.getRegion().front();
- SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
- newBlock.getArguments().end());
-
- // 3. Inject an incoming cast op at the beginning of the block for the bbArg
- // corresponding to the `replacement` value.
- OpBuilder::InsertionGuard g(rewriter);
- rewriter.setInsertionPoint(&newBlock, newBlock.begin());
- BlockArgument newRegionIterArg = newForOp.getTiedLoopRegionIterArg(
- &newForOp->getOpOperand(operand.getOperandNumber()));
- Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
- newRegionIterArg);
- newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
-
- // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
- Block &oldBlock = forOp.getRegion().front();
- rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
-
- // 5. Inject an outgoing cast op at the end of the block and yield it instead.
- auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
- rewriter.setInsertionPoint(clonedYieldOp);
- unsigned yieldIdx =
- newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
- Value castOut = rewriter.create<tensor::CastOp>(
- newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
- SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
- newYieldOperands[yieldIdx] = castOut;
- rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
- rewriter.eraseOp(clonedYieldOp);
-
- // 6. Inject an outgoing cast op after the forOp.
- rewriter.setInsertionPointAfter(newForOp);
- SmallVector<Value> newResults = newForOp.getResults();
- newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
- newForOp.getLoc(), oldType, newResults[yieldIdx]);
-
- return newResults;
-}
-
/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
///
@@ -1090,9 +1084,13 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
continue;
// Create a new ForOp with that iter operand replaced.
+ ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc, Type type,
+ Value source) {
+ return b.create<tensor::CastOp>(loc, type, source);
+ };
rewriter.replaceOp(
- op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
- incomingCast.getSource()));
+ op, replaceAndCastForOpIterArg(rewriter, op, iterOpOperand,
+ incomingCast.getSource(), castFn));
return success();
}
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index ad4e42b31962e1..ba32583fc3cdc4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1796,6 +1796,63 @@ struct DropUnitDimsFromTransposeOp final
}
};
+/// A pattern to drop unit dims from the iter_args of an scf.for.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %res = scf.for ... iter_args(%iter = %init) -> vector<[4]x1x1x4xf32> {
+/// ...
+/// scf.yield %
+/// }
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %drop = vector.shape_cast %init
+/// : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
+/// %new_loop = scf.for ... iter_args(%iter = %drop) -> vector<[4]x4xf32> {
+/// %new_iter = vector.shape_cast %iter
+/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+/// ...
+/// }
+/// %res = vector.shape_cast %new_loop
+/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+/// ```
+struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::ForOp forOp,
+ PatternRewriter &rewriter) const override {
+ /// Find the first iter_arg with droppable unit dims. Further applications
+ /// of this pattern will apply to later arguments.
+ for (OpOperand &operand : forOp.getInitArgsMutable()) {
+ auto vectorType = dyn_cast<VectorType>(operand.get().getType());
+ if (!vectorType)
+ continue;
+
+ VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType);
+ if (vectorType == newVectorType)
+ continue;
+
+ // Create a new ForOp with that iter operand replaced.
+ mlir::scf::ValueTypeCastFnTy castFn = [](OpBuilder &b, Location loc,
+ Type type, Value source) {
+ return b.create<vector::ShapeCastOp>(loc, type, source);
+ };
+
+ Value replacement =
+ castFn(rewriter, forOp.getLoc(), newVectorType, operand.get());
+ rewriter.replaceOp(forOp,
+ replaceAndCastForOpIterArg(rewriter, forOp, operand,
+ replacement, castFn));
+ return success();
+ }
+ return failure();
+ }
+};
+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -2001,7 +2058,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
- ShapeCastOpFolder>(patterns.getContext(), benefit);
+ ShapeCastOpFolder, DropUnitDimsFromScfForOp>(
+ patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
index af3fc924c1dbe7..8249400a43c757 100644
--- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir
@@ -207,3 +207,78 @@ func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vect
// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
// CHECK-NOT: vector.shape_cast
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: DropUnitDimsFromScfForOp]
+///----------------------------------------------------------------------------------------
+
+func.func @scf_for_with_internal_unit_dims(%vec: vector<4x1x1x[4]xf32>) -> vector<4x1x1x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<4x1x1x[4]xf32> {
+ %s = math.sqrt %iter : vector<4x1x1x[4]xf32>
+ scf.yield %s : vector<4x1x1x[4]xf32>
+ }
+ return %res : vector<4x1x1x[4]xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_internal_unit_dims
+// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<4x1x1x[4]xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<4x1x1x[4]xf32> to vector<4x[4]xf32>
+// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
+// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<4x[4]xf32>
+// CHECK: scf.yield %[[SQRT]]
+// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<4x[4]xf32> to vector<4x1x1x[4]xf32>
+// CHECK: return %[[CASTBACK]]
+
+// -----
+
+func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %res = scf.for %i = %c0 to %c4 step %c1 iter_args(%iter = %vec) -> vector<1x1xf32> {
+ %s = math.sqrt %iter : vector<1x1xf32>
+ scf.yield %s : vector<1x1xf32>
+ }
+ return %res : vector<1x1xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_all_unit_dims
+// CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32>
+// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]])
+// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32>
+// CHECK: scf.yield %[[SQRT]]
+// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: return %[[CASTBACK]]
+
+// -----
+
+func.func @scf_for_with_multiple_operands(%idx: index, %vec0: vector<1x4xf32>, %vec1: vector<1x4xf32>) -> vector<1x4xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %res:3 = scf.for %i = %c0 to %c4 step %c1
+ iter_args(%id = %idx, %iter0 = %vec0, %iter1 = %vec1) -> (index, vector<1x4xf32>, vector<1x4xf32>) {
+ %add = arith.addf %iter0, %iter1 : vector<1x4xf32>
+ scf.yield %id, %add, %add : index, vector<1x4xf32>, vector<1x4xf32>
+ }
+ return %res#1 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func.func @scf_for_with_multiple_operands
+// CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index
+// CHECK-SAME: %[[VEC0:[A-Za-z0-9]+]]: vector<1x4xf32>
+// CHECK-SAME: %[[VEC1:[A-Za-z0-9]+]]: vector<1x4xf32>
+// CHECK-DAG: %[[CAST0:.+]] = vector.shape_cast %[[VEC0]] : vector<1x4xf32> to vector<4xf32>
+// CHECK-DAG: %[[CAST1:.+]] = vector.shape_cast %[[VEC1]] : vector<1x4xf32> to vector<4xf32>
+// CHECK: %[[LOOP:.+]]:3 = scf.for
+// CHECK-SAME: iter_args(%{{.*}} = %[[IDX]], %[[ITER0:.+]] = %[[CAST0]], %[[ITER1:.+]] = %[[CAST1]])
+// CHECK: %[[ADD:.+]] = arith.addf %[[ITER0]], %[[ITER1]] : vector<4xf32>
+// CHECK: scf.yield %{{.*}}, %[[ADD]], %[[ADD]]
+// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]]#1 : vector<4xf32> to vector<1x4xf32>
+// CHECK: return %[[CASTBACK]]
More information about the Mlir-commits
mailing list