[Mlir-commits] [mlir] 7e5b10b - [mlir][Linalg] Add support for tiling tensor.pad to scf.forall
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Apr 12 04:43:55 PDT 2023
Author: Nicolas Vasilache
Date: 2023-04-12T04:43:47-07:00
New Revision: 7e5b10b9f74d4e34a0eb2db806b98ecb1c31daf1
URL: https://github.com/llvm/llvm-project/commit/7e5b10b9f74d4e34a0eb2db806b98ecb1c31daf1
DIFF: https://github.com/llvm/llvm-project/commit/7e5b10b9f74d4e34a0eb2db806b98ecb1c31daf1.diff
LOG: [mlir][Linalg] Add support for tiling tensor.pad to scf.forall
Also, properly propagate the nofold attribute.
Differential Revision: https://reviews.llvm.org/D148114
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
mlir/test/Dialect/Linalg/transform-op-tile.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 453fd85f19bb6..18d485ae5463b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -354,8 +354,6 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
return getValueOrCreateConstantIndexOp(b, loc, ofr);
}));
- Operation *tiledOp = nullptr;
-
// 1. Create the ForallOp. We don't use the lambda body-builder
// version because we require the use of RewriterBase in the body, so we
// manually move the insertion point to the body below.
@@ -371,6 +369,8 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
// 3. Clone the tileable op and update its destination operands to use the
// output bbArgs of the ForallOp.
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+ Operation *tiledOp = nullptr;
+ SmallVector<Value> tiledValues;
{
// 3.a. RAII guard, inserting within forallOp, before terminator.
OpBuilder::InsertionGuard g(b);
@@ -395,13 +395,12 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
assert(tilingResult->tiledOps.size() == 1 &&
"expected a single produced tiled op");
tiledOp = tilingResult->tiledOps.front();
+ tiledValues = tilingResult->tiledValues;
}
// 5. Parallel insert back into the result tensor.
- auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
- assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
- tilingInterfaceOp->getResults(), destBbArgs)) {
+ tiledValues, destBbArgs)) {
// 5.a. Partial subset information is inserted just before the terminator.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(forallOp.getTerminator());
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 1e8d4762f0f80..63f7a5af6f5e4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -617,7 +617,8 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
// Create pad(extract_slice(x)).
Value newSliceOp = b.create<tensor::ExtractSliceOp>(
loc, padOp.getSource(), newOffsets, newLengths, newStrides);
- auto newPadOp = b.create<PadOp>(loc, Type(), newSliceOp, newLows, newHighs);
+ auto newPadOp = b.create<PadOp>(loc, Type(), newSliceOp, newLows, newHighs,
+ /*nofold=*/padOp.getNofold());
// Copy region to new PadOp.
IRMapping bvm;
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index df48c2138fbd8..15e8f8c41eba7 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -123,3 +123,28 @@ func.func @tile_linalg_matmul(
-> tensor<128x128xf32>
return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32>
}
+
+// -----
+
+// CHECK-LABEL: tile_tensor_pad
+func.func @tile_tensor_pad(
+ %arg0 : tensor<?x?xf32>, %cst : f32, %low: index, %high: index)
+ -> tensor<20x40xf32>
+{
+ // CHECK: scf.forall
+ // CHECK: scf.if
+ // CHECK: tensor.generate
+ // CHECK: else
+ // CHECK: tensor.pad {{.*}} nofold
+ %0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] {
+ ^bb0(%arg9: index, %arg10: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<20x40xf32>
+ return %0 : tensor<20x40xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
+}
More information about the Mlir-commits
mailing list