[Mlir-commits] [mlir] 1e790b7 - [mlir][Linalg] Replace SimplePad with PadTensor in hoist-padding
Hanhan Wang
llvmlistbot at llvm.org
Thu Jan 28 11:11:00 PST 2021
Author: Hanhan Wang
Date: 2021-01-28T11:09:57-08:00
New Revision: 1e790b745d7e3b0c79deec2de202a4de7e7a66c3
URL: https://github.com/llvm/llvm-project/commit/1e790b745d7e3b0c79deec2de202a4de7e7a66c3
DIFF: https://github.com/llvm/llvm-project/commit/1e790b745d7e3b0c79deec2de202a4de7e7a66c3.diff
LOG: [mlir][Linalg] Replace SimplePad with PadTensor in hoist-padding
This is the last revision to migrate using SimplePadOp to PadTensorOp, and the
SimplePadOp is removed in the patch. Update a bit in SliceAnalysis because the
PadTensorOp takes a region different from SimplePadOp. This is not covered by
LinalgOp because it is not a structured op.
Also, remove a duplicated comment from cpp file, which is already described in a
header file. And update the pseudo-mlir in the comment.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D95615
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
mlir/lib/Analysis/SliceAnalysis.cpp
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/test/Dialect/Linalg/hoist-padding.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 67c0615ddd88..4f2a8afcdbc8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -194,6 +194,13 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
return "static_high";
}
+ RankedTensorType getSourceType() {
+ return source().getType().cast<RankedTensorType>();
+ }
+ RankedTensorType getResultType() {
+ return getResult().getType().cast<RankedTensorType>();
+ }
+
// Infer the shape of the result tensor given the static shapes
// and element type of the result tensor.
static RankedTensorType inferResultType(RankedTensorType sourceType,
@@ -487,38 +494,6 @@ def Linalg_SliceOp : Linalg_Op<"slice", [
let hasFolder = 1;
}
-def Linalg_SimplePadOp : Linalg_Op<"simple_pad", [NoSideEffect]> {
- let summary = "TODO: replace with pad_tensors when ready.";
-
- let description = [{
- `linalg.simple_pad` is a tmp placeholder for padding and packing on tensors.
- Its semantics are to pad a partially dynamic tensor to a fully static tensor
- where the static sizes are assumed to be greater than the dynamic sizes. The
- op perforrms "high" padding (i.e. it adds trailing padding values until the
- desired size is met).
- }];
-
- let arguments = (ins AnyRankedTensor:$tensor, AnyType:$padding);
- let results = (outs AnyRankedTensor:$result);
-
- // TODO: verify all static result, some dynamic input, static shapes match,
- // element types match, ranks match etc. Use pad_tensors when ready but for
- // now just let it ne fully specified by traits.
- let verifier = ?;
-
- let extraClassDeclaration = [{
- RankedTensorType getSourceType() {
- return tensor().getType().cast<RankedTensorType>(); }
- RankedTensorType getResultType() {
- return getResult().getType().cast<RankedTensorType>(); }
- }];
-
- let assemblyFormat = [{
- $tensor `pad` $padding attr-dict `:`
- type($tensor) `to` type($result) `pad` type($padding)
- }];
-}
-
def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg yield operation";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 4d44b3717991..de604c972dca 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -14,7 +14,7 @@ class FuncOp;
struct LogicalResult;
namespace linalg {
-class SimplePadOp;
+class PadTensorOp;
/// Hoist alloc/dealloc pairs and alloca op out of immediately enclosing
/// scf::ForOp if both conditions are true:
@@ -44,7 +44,7 @@ void hoistRedundantVectorTransfersOnTensor(FuncOp func);
/// Mechanically hoist padding operations on tensors by `nLoops` into a new,
/// generally larger tensor. This achieves packing of multiple padding ops into
-/// a larger tensor. On success, `simplePadOp` is replaced by the cloned version
+/// a larger tensor. On success, `padTensorOp` is replaced by the cloned version
/// in the packing loop so the caller can continue reasoning about the padding
/// operation.
///
@@ -55,8 +55,10 @@ void hoistRedundantVectorTransfersOnTensor(FuncOp func);
/// ```
/// scf.for (%i, %j, %k)
/// %st0 = subtensor f(%i, %k) : ... to tensor<?x?xf32>
-/// %0 = linalg.simple_pad %st0 pad %pad :
-/// tensor<?x?xf32> to tensor<4x8xf32>
+/// %0 = linalg.pad_tensor %st0 low[0, 0] high[...] {
+/// ^bb0( ... ):
+/// linalg.yield %pad
+/// } : tensor<?x?xf32> to tensor<4x8xf32>
/// compute(%0)
/// ```
///
@@ -65,10 +67,13 @@ void hoistRedundantVectorTransfersOnTensor(FuncOp func);
/// ```
/// scf.for (%i) {
/// %packed_init = linalg.init_tensor range(%j) : tensor<?x4x8xf32>
-/// %packed = scf.for (%k) iter_args(%p : %packed_init)
+/// %packed = scf.for (%k) iter_args(%p : %packed_init) {
/// %st0 = subtensor f(%i, %k) : ... to tensor<?x?xf32>
-/// %0 = linalg.simple_pad %st0 pad %pad :
-/// tensor<?x?xf32> to tensor<4x8xf32>
+/// %0 = linalg.pad_tensor %st0 low[0, 0] high[...] {
+/// ^bb0( ... ):
+/// linalg.yield %pad
+/// } : tensor<?x?xf32> to tensor<4x8xf32>
+/// %1 = subtensor_insert %0 ... : tensor<4x8xf32> to tensor<?x4x8xf32>
/// scf.yield %1: tensor<?x4x8xf32>
/// } -> tensor<?x4x8xf32>
/// scf.for (%j, %k) {
@@ -78,7 +83,7 @@ void hoistRedundantVectorTransfersOnTensor(FuncOp func);
/// }
/// }
/// ```
-LogicalResult hoistPaddingOnTensors(SimplePadOp &simplePadOp, unsigned nLoops);
+LogicalResult hoistPaddingOnTensors(PadTensorOp &padTensorOp, unsigned nLoops);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 55122b83d585..07cbca8298c4 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -86,7 +86,8 @@ static void getBackwardSliceImpl(Operation *op,
return;
assert((op->getNumRegions() == 0 ||
- isa<AffineForOp, scf::ForOp, linalg::LinalgOp>(op)) &&
+ isa<AffineForOp, scf::ForOp, linalg::LinalgOp, linalg::PadTensorOp>(
+ op)) &&
"unexpected generic op with regions");
// Evaluate whether we should keep this def.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 5c67c8e61829..cf546152da8d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -337,7 +337,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
/// Ensure prerequisites that guarantee pad op hoisting can occur.
/// Return failure in the cases when we cannot perform hoisting; i.e. if either:
-/// 1. There exists a use of `simplePadOp` that is not a linalg input operand.
+/// 1. There exists a use of `padTensorOp` that is not a linalg input operand.
/// 2. There isn't an enclosing `outermostEnclosingForOp` loop.
/// 3. There exists an op with a region that is dominated by
/// `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a
@@ -353,12 +353,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
/// remain in `backwardSlice` but that are not in `packingLoops` are
/// dimensions of reuse.
static LogicalResult
-hoistPaddingOnTensorsPrerequisites(linalg::SimplePadOp simplePadOp, int nLevels,
+hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
llvm::SetVector<Operation *> &backwardSlice,
llvm::SetVector<Operation *> &packingLoops) {
// Bail on any use that isn't an input of a Linalg op.
// Hoisting of inplace updates happens after vectorization.
- for (OpOperand &use : simplePadOp.result().getUses()) {
+ for (OpOperand &use : padTensorOp.result().getUses()) {
auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
if (!linalgUser || !linalgUser.isInputTensor(&use))
return failure();
@@ -368,7 +368,7 @@ hoistPaddingOnTensorsPrerequisites(linalg::SimplePadOp simplePadOp, int nLevels,
SmallVector<LoopLikeOpInterface> reverseEnclosingLoops;
Operation *outermostEnclosingForOp = nullptr,
*nextEnclosingForOp =
- simplePadOp->getParentOfType<LoopLikeOpInterface>();
+ padTensorOp->getParentOfType<LoopLikeOpInterface>();
while (nLevels-- > 0 && nextEnclosingForOp) {
outermostEnclosingForOp = nextEnclosingForOp;
reverseEnclosingLoops.push_back(outermostEnclosingForOp);
@@ -378,28 +378,13 @@ hoistPaddingOnTensorsPrerequisites(linalg::SimplePadOp simplePadOp, int nLevels,
if (!outermostEnclosingForOp)
return failure();
- // Get the backwards slice from `simplePadOp` that is dominated by the
+ // Get the backwards slice from `padTensorOp` that is dominated by the
// outermost enclosing loop.
DominanceInfo domInfo(outermostEnclosingForOp);
- getBackwardSlice(simplePadOp, &backwardSlice, [&](Operation *op) {
+ getBackwardSlice(padTensorOp, &backwardSlice, [&](Operation *op) {
return domInfo.dominates(outermostEnclosingForOp, op);
});
- #if 0
-
- // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp.
- // Bail on any op with side effects that is not a LoopLikeInterface.
- if (llvm::any_of(backwardSlice, [](Operation *op) {
- if (isa<LoopLikeOpInterface>(op))
- return false;
- if (!MemoryEffectOpInterface::hasNoEffect(op))
- return true;
- return op->getNumRegions() > 0 && !isa<LinalgOp>(op);
- }))
- return failure();
-
- #else
-
// Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp.
if (llvm::any_of(backwardSlice, [](Operation *op) {
return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) &&
@@ -407,8 +392,6 @@ hoistPaddingOnTensorsPrerequisites(linalg::SimplePadOp simplePadOp, int nLevels,
}))
return failure();
- #endif
-
// Filter out the loops whose induction variable is not used to compute the
// padded result. As a first approximation, just look for IVs that have no use
// in the backwardSlice.
@@ -444,54 +427,18 @@ static Value buildLoopTripCount(OpBuilder &b, Operation *op) {
ValueRange{forOp.lowerBound(), forOp.upperBound(), forOp.step()});
}
-/// Mechanically hoist padding operations on tensors by at most `nLoops` into a
-/// new, generally larger tensor. This achieves packing of multiple padding ops
-/// into a larger tensor. On success, `simplePadOp` is replaced by the cloned
-/// version in the packing loop so the caller can continue reasoning about the
-/// padding operation.
-///
-/// Example in pseudo-mlir:
-/// =======================
-///
-/// If hoistPaddingOnTensors is called with `nLoops` = 2 on the following IR.
-/// ```
-/// scf.for (%i, %j, %k)
-/// %st0 = subtensor f(%i, %k) : ... to tensor<?x?xf32>
-/// %0 = linalg.simple_pad %st0 pad %pad :
-/// tensor<?x?xf32> to tensor<4x8xf32>
-/// compute(%0)
-/// ```
-///
-/// IR resembling the following is produced:
-///
-/// ```
-/// scf.for (%i) {
-/// %packed_init = linalg.init_tensor range(%j) : tensor<?x4x8xf32>
-/// %packed = scf.for (%k) iter_args(%p : %packed_init)
-/// %st0 = subtensor f(%i, %k) : ... to tensor<?x?xf32>
-/// %0 = linalg.simple_pad %st0 pad %pad :
-/// tensor<?x?xf32> to tensor<4x8xf32>
-/// scf.yield %1: tensor<?x4x8xf32>
-/// } -> tensor<?x4x8xf32>
-/// scf.for (%j, %k) {
-/// %st0 = subtensor %packed [%k, 0, 0][1, 4, 8][1, 1, 1] :
-/// tensor<?x4x8xf32> to tensor<4x8xf32>
-/// compute(%st0)
-/// }
-/// }
-/// ```
-LogicalResult mlir::linalg::hoistPaddingOnTensors(SimplePadOp &simplePadOp,
+LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp,
unsigned nLoops) {
llvm::SetVector<Operation *> backwardSlice, packingLoops;
- if (failed(hoistPaddingOnTensorsPrerequisites(simplePadOp, nLoops,
+ if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops,
backwardSlice, packingLoops)))
return failure();
// Update actual number of loops, which may be smaller.
nLoops = packingLoops.size();
- Location loc = simplePadOp->getLoc();
- RankedTensorType paddedTensorType = simplePadOp.getResultType();
+ Location loc = padTensorOp->getLoc();
+ RankedTensorType paddedTensorType = padTensorOp.getResultType();
unsigned paddedRank = paddedTensorType.getRank();
// Backward slice is a topologically sorted list of ops starting at
@@ -503,7 +450,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(SimplePadOp &simplePadOp,
// Create the packed tensor<?x?x..?xpadded_shape> into which we amortize
// padding.
SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize);
- // TODO: go grab dims when necessary, for now SimplePadOp returns a static
+ // TODO: go grab dims when necessary, for now PadTensorOp returns a static
// tensor.
llvm::append_range(packedShape, paddedTensorType.getShape());
auto packedTensorType =
@@ -526,10 +473,10 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(SimplePadOp &simplePadOp,
clonedLoopIvs.reserve(nLoops);
BlockAndValueMapping bvm;
// Stack step 1. iteratively clone loops and push `packedTensor`.
- // Insert `simplePadOp` into the backwardSlice so we clone it too.
- backwardSlice.insert(simplePadOp);
+ // Insert `padTensorOp` into the backwardSlice so we clone it too.
+ backwardSlice.insert(padTensorOp);
for (Operation *op : backwardSlice) {
- if (op->getNumRegions() == 0) {
+ if (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op)) {
b.clone(*op, bvm);
continue;
}
@@ -556,7 +503,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(SimplePadOp &simplePadOp,
// sizes = [1 .. 1, paddedShape].
SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1));
for (int64_t sz : paddedTensorType.getShape()) {
- // TODO: go grab dims when necessary, for now SimplePadOp returns a static
+ // TODO: go grab dims when necessary, for now PadTensorOp returns a static
// tensor.
assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes");
sizes.push_back(b.getIndexAttr(sz));
@@ -565,7 +512,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(SimplePadOp &simplePadOp,
SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1));
Value inserted =
- b.create<SubTensorInsertOp>(loc, bvm.lookup(simplePadOp.result()),
+ b.create<SubTensorInsertOp>(loc, bvm.lookup(padTensorOp.result()),
packedTensor, offsets, sizes, strides);
// Stack step 3. iteratively pop the stack and propagate the yield.
@@ -579,7 +526,7 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(SimplePadOp &simplePadOp,
// Now the packed tensor is ready, replace the original padding op by a
// 1x..x1 SubTensor [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1].
- b.setInsertionPoint(simplePadOp);
+ b.setInsertionPoint(padTensorOp);
SmallVector<Value> originalLoopIvs =
llvm::to_vector<4>(llvm::map_range(packingLoops, [](Operation *loop) {
return cast<scf::ForOp>(loop).getInductionVar();
@@ -591,16 +538,16 @@ LogicalResult mlir::linalg::hoistPaddingOnTensors(SimplePadOp &simplePadOp,
// strides = [1 .. 1] (defined above)
packedTensor =
scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0);
- simplePadOp.replaceAllUsesWith(
- b.create<SubTensorOp>(loc, simplePadOp.getResultType(), packedTensor,
+ padTensorOp.replaceAllUsesWith(
+ b.create<SubTensorOp>(loc, padTensorOp.getResultType(), packedTensor,
offsets, sizes, strides)
->getResult(0));
- Operation *toErase = simplePadOp;
+ Operation *toErase = padTensorOp;
- // Make the newly cloned `simplePadOp` available to the caller.
- simplePadOp =
- cast<SimplePadOp>(bvm.lookup(simplePadOp.result()).getDefiningOp());
+ // Make the newly cloned `padTensorOp` available to the caller.
+ padTensorOp =
+ cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp());
toErase->erase();
diff --git a/mlir/test/Dialect/Linalg/hoist-padding.mlir b/mlir/test/Dialect/Linalg/hoist-padding.mlir
index 27750ea8a024..8685df44db3f 100644
--- a/mlir/test/Dialect/Linalg/hoist-padding.mlir
+++ b/mlir/test/Dialect/Linalg/hoist-padding.mlir
@@ -27,7 +27,8 @@ func @matmul_tensors(
// CHECK: %[[A:.*]] = scf.for
// CHECK-NOT: scf.for
// CHECK: subtensor %{{.*}} [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- // CHECK: linalg.simple_pad %{{.*}} : tensor<?x?xf32> to tensor<2x4xf32> pad f32
+ // CHECK: linalg.pad_tensor %{{.*}}
+ // CHECK: : tensor<?x?xf32> to tensor<2x4xf32>
// CHECK: subtensor_insert %{{.*}} into %{{.*}}[%{{.*}}, 0, 0]
// CHECK-SAME: [1, 2, 4] [1, 1, 1] : tensor<2x4xf32> into tensor<?x2x4xf32>
// 2-D loop
@@ -36,7 +37,8 @@ func @matmul_tensors(
// CHECK: scf.for
// CHECK-NOT: scf.for
// CHECK: subtensor %{{.*}} [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- // CHECK: linalg.simple_pad %{{.*}} : tensor<?x?xf32> to tensor<4x3xf32> pad f32
+ // CHECK: linalg.pad_tensor %{{.*}}
+ // CHECK: : tensor<?x?xf32> to tensor<4x3xf32>
// CHECK: subtensor_insert %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0]
// CHECK-SAME: [1, 1, 4, 3] [1, 1, 1, 1] : tensor<4x3xf32> into tensor<?x?x4x3xf32>
// 2-D loop
@@ -47,8 +49,8 @@ func @matmul_tensors(
// CHECK-SAME: tensor<?x2x4xf32> to tensor<2x4xf32>
// CHECK: %[[stB:.*]] = subtensor %[[B]][%[[K]], %[[J]], 0, 0] [1, 1, 4, 3] [1, 1, 1, 1] :
// CHECK-SAME: tensor<?x?x4x3xf32> to tensor<4x3xf32>
- // CHECK: %[[stC:.*]] = linalg.simple_pad %{{.*}} pad %{{.*}} :
- // CHECK-SAME: tensor<?x?xf32> to tensor<2x3xf32> pad f32
+ // CHECK: %[[stC:.*]] = linalg.pad_tensor %{{.*}}
+ // CHECK: : tensor<?x?xf32> to tensor<2x3xf32>
// CHECK: linalg.matmul ins(%[[stA]], %[[stB]] : tensor<2x4xf32>, tensor<4x3xf32>)
// CHECK-SAME: outs(%[[stC]] : tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = scf.for %arg3 = %c0 to %0 step %c2 iter_args(%arg4 = %arg2) -> (tensor<?x?xf32>) {
@@ -69,13 +71,28 @@ func @matmul_tensors(
%18 = dim %arg8, %c1 : tensor<?x?xf32>
%19 = affine.min #map4(%18, %arg5)
%20 = subtensor %arg8[%arg3, %arg5] [%17, %19] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
- %21 = linalg.simple_pad %10 pad %cst : tensor<?x?xf32> to tensor<2x4xf32> pad f32
- %22 = linalg.simple_pad %15 pad %cst : tensor<?x?xf32> to tensor<4x3xf32> pad f32
- %23 = linalg.simple_pad %20 pad %cst : tensor<?x?xf32> to tensor<2x3xf32> pad f32
- %24 = linalg.matmul ins(%21, %22 : tensor<2x4xf32>, tensor<4x3xf32>) outs(%23 : tensor<2x3xf32>) -> tensor<2x3xf32>
- %25 = subtensor %24[0, 0] [%7, %14] [1, 1] : tensor<2x3xf32> to tensor<?x?xf32>
- %26 = subtensor_insert %25 into %arg8[%arg3, %arg5] [%17, %19] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
- scf.yield %26 : tensor<?x?xf32>
+ %21 = subi %c2, %7 : index
+ %22 = subi %c4, %9 : index
+ %23 = linalg.pad_tensor %10 low[%c0, %c0] high[%21, %22] {
+ ^bb0(%arg9: index, %arg10: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<2x4xf32>
+ %24 = subi %c4, %12 : index
+ %25 = subi %c3, %14 : index
+ %26 = linalg.pad_tensor %15 low[%c0, %c0] high[%24, %25] {
+ ^bb0(%arg9: index, %arg10: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<4x3xf32>
+ %27 = subi %c2, %17 : index
+ %28 = subi %c3, %19 : index
+ %29 = linalg.pad_tensor %20 low[%c0, %c0] high[%27, %28] {
+ ^bb0(%arg9: index, %arg10: index): // no predecessors
+ linalg.yield %cst : f32
+ } : tensor<?x?xf32> to tensor<2x3xf32>
+ %30 = linalg.matmul ins(%23, %26 : tensor<2x4xf32>, tensor<4x3xf32>) outs(%29 : tensor<2x3xf32>) -> tensor<2x3xf32>
+ %31 = subtensor %30[0, 0] [%7, %14] [1, 1] : tensor<2x3xf32> to tensor<?x?xf32>
+ %32 = subtensor_insert %31 into %arg8[%arg3, %arg5] [%17, %19] [%c1, %c1] : tensor<?x?xf32> into tensor<?x?xf32>
+ scf.yield %32 : tensor<?x?xf32>
}
scf.yield %5 : tensor<?x?xf32>
}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 879dfa3ed08e..0b3cc28d79d0 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -833,13 +833,3 @@ func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
-
-// -----
-
-// TODO: this op should disappear once pad_tensors is available and connected.
-// CHECK-LABEL: func @simple_pad
-func @simple_pad(%0: tensor<?x4x?xf32>, %pad: f32) {
-// CHECK: linalg.simple_pad %{{.+}} pad %{{.+}}: tensor<?x4x?xf32> to tensor<8x4x8xf32>
- %1 = linalg.simple_pad %0 pad %pad: tensor<?x4x?xf32> to tensor<8x4x8xf32> pad f32
- return
-}
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index db05d60ad8c7..126bbc3639af 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -572,8 +572,8 @@ void TestLinalgTransforms::runOnFunction() {
if (testTileAndPadPattern)
return applyTileAndPadPattern(getFunction());
if (testHoistPadding2Levels) {
- getFunction().walk([](linalg::SimplePadOp simplePadOp) {
- linalg::hoistPaddingOnTensors(simplePadOp, 2);
+ getFunction().walk([](linalg::PadTensorOp padTensorOp) {
+ linalg::hoistPaddingOnTensors(padTensorOp, 2);
});
}
}
More information about the Mlir-commits
mailing list