[Mlir-commits] [mlir] [mlir][TilingInterface] Make the tiling set tile sizes function use `OpFoldResult`. (PR #66566)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 18 13:13:53 PDT 2023
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/66566
>From dea0b83f6069cb0f4a70df646769f97f14049ce1 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Fri, 15 Sep 2023 18:24:23 -0700
Subject: [PATCH] [mlir][TilingInterface] Make the tiling set tile sizes
function use `OpFoldResult`.
---
.../SCF/Transforms/TileUsingInterface.h | 11 +---
.../TransformOps/LinalgTransformOps.cpp | 24 ++++-----
.../SCF/Transforms/TileUsingInterface.cpp | 50 ++++++++-----------
.../Dialect/Linalg/transform-op-tile.mlir | 4 +-
.../TilingInterface/TestTilingInterface.cpp | 14 ++++--
5 files changed, 47 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index e7bcd062d96525d..ca641c596c7b7bb 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -26,7 +26,7 @@ namespace mlir {
namespace scf {
using SCFTileSizeComputationFunction =
- std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
+ std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
/// Options to use to control tiling.
struct SCFTilingOptions {
@@ -40,17 +40,10 @@ struct SCFTilingOptions {
tileSizeComputationFunction = std::move(fun);
return *this;
}
- /// Set the `tileSizeComputationFunction` to return the values `ts`. The
- /// values must not fold away when tiling. Otherwise, use a more robust
- /// `tileSizeComputationFunction`.
- SCFTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) {
- tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; };
- return *this;
- }
/// Convenience function to set the `tileSizeComputationFunction` to a
/// function that computes tile sizes at the point they are needed. Allows
/// proper interaction with folding.
- SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+ SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
/// The interchange vector to reorder the tiled loops.
SmallVector<int64_t> interchangeVector = {};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4270ab38004a1..1819ca614a060fd 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -473,7 +473,9 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
scf::SCFTilingOptions tilingOptions;
tilingOptions.interchangeVector = tileInterchange;
- tilingOptions = tilingOptions.setTileSizes(tileSizes);
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
+ tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions = tilingOptions;
LogicalResult result = applyTilingToAll(
@@ -923,7 +925,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
auto diag = mlir::emitSilenceableFailure(getLoc())
- << "could not find next producer to fuse into container";
+ << "could not find next producer to fuse into container";
diag.attachNote(containingOp->getLoc()) << "containing op";
return diag;
}
@@ -1999,7 +2001,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
- SmallVector<Value, 4> tileSizes;
+ SmallVector<OpFoldResult> tileSizes;
Location loc = target.getLoc();
SmallVector<OpFoldResult> allShapeSizes =
target.createFlatListOfOperandDims(b, loc);
@@ -2012,9 +2014,8 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
// If the shape size is dynamic, tile by 1.
// Otherwise, do not tile (i.e. tile size 0).
for (OpFoldResult shapeSize : shapeSizes) {
- tileSizes.push_back(getConstantIntValue(shapeSize)
- ? b.create<arith::ConstantIndexOp>(loc, 0)
- : b.create<arith::ConstantIndexOp>(loc, 1));
+ tileSizes.push_back(getConstantIntValue(shapeSize) ? b.getIndexAttr(0)
+ : b.getIndexAttr(1));
}
return tileSizes;
});
@@ -2549,7 +2550,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
if (!tileSizes.empty()) {
tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
Operation *) {
- SmallVector<Value, 4> sizes;
+ SmallVector<OpFoldResult> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
@@ -2560,10 +2561,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
getLoc(), attr.cast<IntegerAttr>().getInt());
Value vscale =
b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
- sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
+ sizes.push_back(
+ b.create<arith::MulIOp>(getLoc(), val, vscale).getResult());
} else {
- sizes.push_back(b.create<arith::ConstantIndexOp>(
- getLoc(), cast<IntegerAttr>(attr).getInt()));
+ sizes.push_back(attr);
}
continue;
}
@@ -2573,8 +2574,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
assert((dynamicSizes.empty() ^ params.empty()) &&
"expected either dynamic sizes or parameters");
if (!params.empty()) {
- sizes.push_back(
- b.create<arith::ConstantIndexOp>(getLoc(), params[index]));
+ sizes.push_back(b.getIndexAttr(params[index]));
} else {
sizes.push_back(dynamicSizes[index]->getResult(0));
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1ce25565edcaf61..c782583c32eb6a0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -31,19 +31,11 @@
using namespace mlir;
scf::SCFTilingOptions &
-scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
+scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
assert(!tileSizeComputationFunction && "tile sizes already set");
- SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
+ auto tileSizes = llvm::to_vector(ts);
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
- OpBuilder::InsertionGuard guard(b);
- b.setInsertionPointToStart(
- &op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>()
- ->getRegion(0)
- .front());
- return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
- Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
- return v;
- }));
+ return tileSizes;
};
return *this;
}
@@ -108,17 +100,16 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
/// Generate an empty loop nest that represents the tiled loop nest shell.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
-/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
+/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
/// the
/// tile processed within the inner most loop.
-static SmallVector<scf::ForOp>
-generateTileLoopNest(OpBuilder &builder, Location loc,
- ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
- SmallVector<OpFoldResult> &offsets,
- SmallVector<OpFoldResult> &sizes) {
+static SmallVector<scf::ForOp> generateTileLoopNest(
+ OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
+ ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
+ SmallVector<OpFoldResult> &sizes) {
assert(!loopRanges.empty() && "expected at least one loop range");
- assert(loopRanges.size() == tileSizeVals.size() &&
+ assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(builder);
SmallVector<scf::ForOp> loops;
@@ -130,7 +121,8 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
Value size =
getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
- Value tileSize = tileSizeVals[loopRange.index()];
+ Value tileSize = getValueOrCreateConstantIndexOp(
+ builder, loc, tileSizes[loopRange.index()]);
// No loops if tile size is zero. Set offset and size to the loop
// offset and size.
if (matchPattern(tileSize, m_Zero())) {
@@ -296,10 +288,10 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
- SmallVector<Value> tileSizeVector =
+ SmallVector<OpFoldResult> tileSizeVector =
options.tileSizeComputationFunction(rewriter, op);
if (tileSizeVector.size() < iterationDomain.size()) {
- auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+ auto zero = rewriter.getIndexAttr(0);
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}
@@ -402,17 +394,17 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
FailureOr<scf::SCFReductionTilingResult>
mlir::scf::tileReductionUsingScf(RewriterBase &b,
PartialReductionOpInterface op,
- ArrayRef<OpFoldResult> tileSize) {
+ ArrayRef<OpFoldResult> tileSizes) {
Location loc = op.getLoc();
// Ops implementing PartialReductionOpInterface are expected to implement
// TilingInterface.
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
- SmallVector<Value> tileSizeVector =
- getValueOrCreateConstantIndexOp(b, loc, tileSize);
- if (tileSizeVector.size() < iterationDomain.size()) {
- auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
- tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
+ auto tileSizesVector = llvm::to_vector(tileSizes);
+ if (tileSizesVector.size() < iterationDomain.size()) {
+ auto zero = b.getIndexAttr(0);
+ tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
+ zero);
}
if (op->getNumResults() != 1)
return b.notifyMatchFailure(
@@ -429,7 +421,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// 1. create the inital tensor value.
FailureOr<Operation *> identityTensor =
- op.generateInitialTensorForPartialReduction(b, loc, tileSize,
+ op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
reductionDims);
if (failed(identityTensor))
return b.notifyMatchFailure(op,
@@ -437,7 +429,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// 2. Create the nested loops.
SmallVector<OpFoldResult> offsets, sizes;
SmallVector<scf::ForOp> loops = generateTileLoopNest(
- b, loc, iterationDomain, tileSizeVector, offsets, sizes);
+ b, loc, iterationDomain, tileSizesVector, offsets, sizes);
// 3. Generate the tiled implementation within the inner most loop.
b.setInsertionPoint(loops.back().getBody()->getTerminator());
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index ce2a3d6ca9c58da..9df19632506a73c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -190,16 +190,16 @@ transform.sequence failures(propagate) {
// -----
// CHECK-LABEL: func.func @scalable_and_fixed_length_tile
-// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
-// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[VS:.*]] = vector.vscale
// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
// CHECK: %[[C128_1:.*]] = arith.constant 128 : index
+// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
// CHECK: %[[C128_2:.*]] = arith.constant 128 : index
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 752c885e0b87bdb..2fcc7bcadb60450 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -450,7 +450,9 @@ static void addPatternForTiling(MLIRContext *context,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
- tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(context, tileSizes);
+ tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
@@ -462,7 +464,9 @@ static void addPatternForTileFuseAndYield(MLIRContext *context,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
- tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(context, tileSizes);
+ tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
@@ -475,8 +479,10 @@ static void addPatternForTileAndFuse(MLIRContext *context,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> interchange = {}) {
scf::SCFTileAndFuseOptions tileAndFuseOptions;
- tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
- interchange);
+ SmallVector<OpFoldResult> tileSizesOfr =
+ getAsIndexOpFoldResult(context, tileSizes);
+ tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr)
+ .setInterchange(interchange);
LinalgTransformationFilter filter(StringAttr::get(context, filterName),
StringAttr::get(context, "tiled"));
patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
More information about the Mlir-commits
mailing list