[Mlir-commits] [mlir] b686fdb - [mlir][Linalg] Drop output tensor from `linalg.pad_tensor` op.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 31 11:12:51 PDT 2021
Author: MaheshRavishankar
Date: 2021-08-31T11:12:24-07:00
New Revision: b686fdbf92ea5b495804afdf1c7c4d4aab30ef33
URL: https://github.com/llvm/llvm-project/commit/b686fdbf92ea5b495804afdf1c7c4d4aab30ef33
DIFF: https://github.com/llvm/llvm-project/commit/b686fdbf92ea5b495804afdf1c7c4d4aab30ef33.diff
LOG: [mlir][Linalg] Drop output tensor from `linalg.pad_tensor` op.
The output tensor was added for tiling purposes. With use of
`TilingInterface` for tiling pad operations, there is no need for an
explicit operand for the shape of result of `linalg.pad_tensor`
op. The interface allows the tiling pattern to query the value that
can be used for the "init" needed for tiling dynamically.
Differential Revision: https://reviews.llvm.org/D108613
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 9b062f2ebd746..332993d8b0022 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -147,12 +147,11 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
dimension, i.e `low`.
* high: A list contains the padding along the end of each
dimension, i.e. `high`.
- * output: An optional output operand.
The result tensor dimensions are `low` + `dim` + `high` along that
dimension. The number of elements of `low` and `high` must match
- the rank of the input tensor (which is also the rank of the output
- tensor). They can be either a constant or a dynamic value.
+ the rank of the input tensor. They can be either a constant or a
+ dynamic value.
The region of the `pad_tensor` operation returns the value to use
for the padding. The arguments of the region represent the index
@@ -196,8 +195,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
Variadic<Index>:$low,
Variadic<Index>:$high,
I64ArrayAttr:$static_low,
- I64ArrayAttr:$static_high,
- Optional<AnyTensor>:$output);
+ I64ArrayAttr:$static_high);
let regions = (region SizedRegion<1>:$region);
@@ -208,9 +206,7 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
$source
`low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
`high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
- (`into` $output^ )?
$region attr-dict `:` type($source) `to` type($result)
- custom<InferType>(ref($output), type($output), ref(type($result)))
}];
let extraClassDeclaration = [{
@@ -300,11 +296,6 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
OpBuilder<(ins "Type":$resultType, "Value":$source,
"ArrayRef<OpFoldResult>":$low, "ArrayRef<OpFoldResult>":$high,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
- // Build a PadTensorOp with mixed static and dynamic entries and custom
- // result type.
- OpBuilder<(ins "Type":$resultType, "Value":$source,
- "ArrayRef<Value>":$low, "ArrayRef<Value>":$high, "ArrayAttr":$staticLow,
- "ArrayAttr":$staticHigh)>
];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 1e759baecf1f4..abb3c44158381 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1040,9 +1040,6 @@ static LogicalResult verify(PadTensorOp op) {
<< resultType << " does not match the inferred type "
<< expectedType;
}
- if (op.output() && op.output().getType() != op.getResultType()) {
- op.emitError("expected that output operand type equals result type");
- }
auto ®ion = op.region();
unsigned rank = resultType.getRank();
@@ -1089,7 +1086,7 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Value source,
auto sourceType = source.getType().cast<RankedTensorType>();
auto resultType = inferResultType(sourceType, staticLow, staticHigh);
build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow),
- b.getI64ArrayAttr(staticHigh), /*output=*/Value());
+ b.getI64ArrayAttr(staticHigh));
result.addAttributes(attrs);
}
@@ -1126,15 +1123,7 @@ void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
PadTensorOp::inferResultType(sourceType, staticLow, staticHigh);
}
build(b, result, resultType, source, dynamicLow, dynamicHigh,
- b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh),
- /*output=*/Value());
-}
-
-void PadTensorOp::build(OpBuilder &b, OperationState &result, Type resultType,
- Value source, ArrayRef<Value> low, ArrayRef<Value> high,
- ArrayAttr staticLow, ArrayAttr staticHigh) {
- build(b, result, resultType, source, low, high, staticLow, staticHigh,
- /*output=*/{});
+ b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh));
}
PadTensorOp PadTensorOp::createPadScalarOp(Type type, Value source, Value pad,
@@ -1221,7 +1210,8 @@ static Value getAsValue(OpBuilder &builder, Location loc, OpFoldResult ofr) {
SmallVector<Value> PadTensorOp::getDestinationOperands(OpBuilder &b) {
ReifiedRankedShapedTypeDims reifiedShapes;
(void)reifyResultShapes(b, reifiedShapes);
- Value initTensor = b.create<InitTensorOp>(getLoc(), reifiedShapes[0],
+ SmallVector<OpFoldResult> mixedSizes = getAsOpFoldResult(reifiedShapes[0]);
+ Value initTensor = b.create<InitTensorOp>(getLoc(), mixedSizes,
getResultType().getElementType());
return {initTensor};
}
@@ -1465,21 +1455,6 @@ struct FoldStaticZeroPadding : public OpRewritePattern<PadTensorOp> {
}
};
-// Fold tensor.dim(pad_tensor(%input, %output)) to tensor.dim(%output).
-struct FoldToDimOfOutputOperand : public OpRewritePattern<tensor::DimOp> {
- using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(tensor::DimOp dimOp,
- PatternRewriter &rewriter) const override {
- auto padTensorOp = dimOp.source().getDefiningOp<PadTensorOp>();
- if (!padTensorOp || !padTensorOp.output())
- return failure();
- rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, padTensorOp.output(),
- dimOp.index());
- return success();
- }
-};
-
// Fold CastOp into PadTensorOp when adding static information.
struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
@@ -1503,7 +1478,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
auto newOp = rewriter.create<PadTensorOp>(
padTensorOp->getLoc(), newResultType, padTensorOp.source(),
padTensorOp.low(), padTensorOp.high(), padTensorOp.static_low(),
- padTensorOp.static_high(), /*output=*/nullptr);
+ padTensorOp.static_high());
BlockAndValueMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
@@ -1517,8 +1492,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldStaticZeroPadding, FoldToDimOfOutputOperand,
- FoldSourceTensorCast>(context);
+ results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context);
}
/// Return the padding value of the PadTensorOp if it constant. In this context,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index dff35353405bb..acef26a281437 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -357,10 +357,6 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
PadTensorOp &newPadOp, LoopNest &loopNest,
const LinalgTilingOptions &options) {
- // Can tile only PadTensorOp that have an output operand.
- if (!op.output())
- return failure();
-
Location loc = op.getLoc();
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(op);
@@ -383,8 +379,9 @@ static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
}
}
// Generate loop nest: One loop per dimension.
+ SmallVector<Value> destOperand = op.getDestinationOperands(builder);
loopNest = mlir::scf::buildLoopNest(
- builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(op.output()),
+ builder, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand),
[&](OpBuilder &b, Location loc, ValueRange localIvs,
ValueRange iterArgs) -> scf::ValueVector {
// Compute offsets and sizes of ExtractSliceOp.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 41a4bfe9c9800..9b6be7920a999 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -904,24 +904,6 @@ func @rank_reducing_init_extract(%sz : index, %idx : index) -> tensor<2xf32> {
// -----
-// CHECK-LABEL: func @dim_of_pad_tensor(
-// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>
-// CHECK: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG1]], %[[C0]]
-// CHECK: return %[[RESULT]]
-func @dim_of_pad_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
- %pad_value: f32) -> index {
- %c0 = constant 0 : index
- %0 = linalg.pad_tensor %arg0 low[2, 3] high[4, 5] into %arg1 {
- ^bb0(%arg2: index, %arg3: index):
- linalg.yield %pad_value : f32
- } : tensor<?x?xf32> to tensor<?x?xf32>
- %r = tensor.dim %0, %c0 : tensor<?x?xf32>
- return %r : index
-}
-
-// -----
-
// CHECK-LABEL: func @dim_of_tiled_loop_input(
// CHECK-SAME: %[[arg0:.*]]: tensor<?x?xf32>, %[[arg1:.*]]: tensor<?x?xf32>, %[[arg2:.*]]: tensor<?x?xf32>
// CHECK: %[[c0:.*]] = constant 0 : index
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 36860272a80c8..3592d592acc3a 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -459,18 +459,6 @@ func @pad_result_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32) -> t
// -----
-// expected-note at +1 {{prior use here}}
-func @pad_output_type(%arg0: tensor<?x2x3x4xi32>, %arg1: index, %arg2: i32, %output: tensor<?x6x6x7xf32>) -> tensor<?x?x?x8xf32> {
- // expected-error @+1 {{use of value '%output' expects
diff erent type than prior uses: 'tensor<?x5x6x7xf32>' vs 'tensor<?x6x6x7xf32>'}}
- %0 = linalg.pad_tensor %arg0 low[1, 1, 1, 1] high[2, 2, 2, 2] into %output {
- ^bb0(%arg3: index, %arg4: index): // no predecessors
- linalg.yield %arg2 : i32
- } : tensor<?x2x3x4xi32> to tensor<?x5x6x7xf32>
- return %0 : tensor<?x5x6x7xf32>
-}
-
-// -----
-
func @pad_number_of_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> {
// expected-error @+1 {{expected the block to have 2 arguments}}
%0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] {
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index e0d7ab2dfb24f..23e29e0ab082a 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -49,24 +49,6 @@ func @pad_static(%arg0: tensor<3x4xf32>, %pad_value: f32) -> tensor<6x9xf32> {
// -----
-func @pad_static_with_output(%arg0: tensor<3x4xf32>,
- %out_tensor : tensor<6x9xf32>,
- %pad_value: f32)
- -> tensor<6x9xf32> {
- %0 = linalg.pad_tensor %arg0 low[1, 2] high[2, 3] into %out_tensor {
- ^bb0(%arg1 : index, %arg2 : index):
- linalg.yield %pad_value : f32
- } : tensor<3x4xf32> to tensor<6x9xf32>
- return %0 : tensor<6x9xf32>
-}
-// CHECK-LABEL: func @pad_static
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: tensor<3x4xf32>,
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: tensor<6x9xf32>,
-// CHECK: linalg.pad_tensor %[[ARG0]] low[1, 2] high[2, 3] into %[[ARG1]]
-// CHECK: : tensor<3x4xf32> to tensor<6x9xf32>
-
-// -----
-
func @pad_asymmetrical(%arg0: tensor<2x3xf32>, %ub0: index, %ub1: index,
%pad_value: f32) -> tensor<?x?xf32> {
%0 = linalg.pad_tensor %arg0 low[0, 0] high[%ub0, %ub1] {
diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
index 67be544db78ad..e2a22fa104b22 100644
--- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
+++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
@@ -1,12 +1,12 @@
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -cse -split-input-file | \
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
// RUN: FileCheck %s -check-prefix=TILE2
-// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -cse -split-input-file | \
+// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -resolve-shaped-type-result-dims -cse -split-input-file | \
// RUN: FileCheck %s -check-prefix=TILE1
// TILE2-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
// TILE2-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>
// TILE2: func @dynamic_pad_tensor(
-// TILE2-SAME: %[[IN:.*]]: tensor<?x?xf32>, %[[OUT:.*]]: tensor<?x?xf32>
+// TILE2-SAME: %[[IN:.*]]: tensor<?x?xf32>
// TILE2-DAG: %[[C0:.*]] = constant 0 : index
// TILE2-DAG: %[[C1:.*]] = constant 1 : index
// TILE2-DAG: %[[C2:.*]] = constant 2 : index
@@ -25,16 +25,18 @@
// TILE2: tensor.insert_slice %[[SWAP_RESULT]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
// TILE2: return %[[RESULT]]
-// TILE1-DAG: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 7)>
+// TILE1-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 7)>
+// TILE1-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 8)>
// TILE1: func @dynamic_pad_tensor(
-// TILE1-SAME: %[[IN:.*]]: tensor<?x?xf32>, %[[OUT:.*]]: tensor<?x?xf32>
+// TILE1-SAME: %[[IN:.*]]: tensor<?x?xf32>
// TILE1-DAG: %[[C0:.*]] = constant 0 : index
// TILE1-DAG: %[[C1:.*]] = constant 1 : index
// TILE1-DAG: %[[C3:.*]] = constant 3 : index
// TILE1: %[[DIM_IN1:.*]] = tensor.dim %[[IN]], %[[C1]]
-// TILE1: %[[DIM1:.*]] = affine.apply #[[MAP]]()[%[[DIM_IN1]]]
+// TILE1: %[[DIM1:.*]] = affine.apply #[[MAP0]]()[%[[DIM_IN1]]]
+// TILE1: %[[DIM_IN0:.*]] = tensor.dim %[[IN]], %[[C0]]
+// TILE1: %[[DIM0:.*]] = affine.apply #[[MAP1]]()[%[[DIM_IN0]]]
// TILE1: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
-// TILE1: %[[DIM0:.*]] = tensor.dim %[[OUT]], %[[C0]]
// TILE1: %[[SWAP_RESULT:.*]] = scf.if
// TILE1: tensor.generate
// TILE1: else
@@ -44,10 +46,8 @@
// TILE1: return %[[RESULT]]
func @dynamic_pad_tensor(%input_tensor: tensor<?x?xf32>,
- %output_tensor: tensor<?x?xf32>,
%pad_value: f32) -> tensor<?x?xf32> {
- %0 = linalg.pad_tensor %input_tensor
- low[3, 4] high[5, 3] into %output_tensor{
+ %0 = linalg.pad_tensor %input_tensor low[3, 4] high[5, 3] {
^bb0(%arg1: index, %arg2: index):
linalg.yield %pad_value : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
@@ -57,7 +57,7 @@ func @dynamic_pad_tensor(%input_tensor: tensor<?x?xf32>,
// -----
// TILE2-LABEL: func @static_pad_tensor(
-// TILE2-SAME: %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<15x16xf32>
+// TILE2-SAME: %[[IN:.*]]: tensor<7x9xf32>
// TILE2-DAG: %[[C0:.*]] = constant 0 : index
// TILE2-DAG: %[[C2:.*]] = constant 2 : index
// TILE2-DAG: %[[C3:.*]] = constant 3 : index
@@ -75,7 +75,7 @@ func @dynamic_pad_tensor(%input_tensor: tensor<?x?xf32>,
// TILE1-LABEL: func @static_pad_tensor(
-// TILE1-SAME: %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<15x16xf32>
+// TILE1-SAME: %[[IN:.*]]: tensor<7x9xf32>
// TILE1-DAG: %[[C0:.*]] = constant 0 : index
// TILE1-DAG: %[[C3:.*]] = constant 3 : index
// TILE1-DAG: %[[C16:.*]] = constant 16 : index
@@ -89,10 +89,8 @@ func @dynamic_pad_tensor(%input_tensor: tensor<?x?xf32>,
// TILE1: return %[[RESULT]]
func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
- %output_tensor: tensor<15x16xf32>,
%pad_value: f32) -> tensor<15x16xf32> {
- %0 = linalg.pad_tensor %input_tensor
- low[3, 4] high[5, 3] into %output_tensor {
+ %0 = linalg.pad_tensor %input_tensor low[3, 4] high[5, 3] {
^bb0(%arg1: index, %arg2: index):
linalg.yield %pad_value : f32
} : tensor<7x9xf32> to tensor<15x16xf32>
@@ -112,7 +110,7 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
// TILE1: scf.yield %[[GEN]] : tensor<14x3xf32>
// TILE1: else
// TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
-// TILE1: %[[PAD:.*]] = linalg.pad_tensor %8 low[0, 0] high[7, %{{.*}}]
+// TILE1: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[0, 0] high[7, %{{.*}}]
// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32>
// TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
@@ -121,8 +119,7 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
func @static_pad_tile_evenly(%input_tensor: tensor<7x9xf32>,
%output_tensor: tensor<14x15xf32>,
%pad_value: f32) -> tensor<14x15xf32> {
- %0 = linalg.pad_tensor %input_tensor
- low[0, 0] high[7, 6] into %output_tensor {
+ %0 = linalg.pad_tensor %input_tensor low[0, 0] high[7, 6] {
^bb0(%arg1: index, %arg2: index):
linalg.yield %pad_value : f32
} : tensor<7x9xf32> to tensor<14x15xf32>
More information about the Mlir-commits
mailing list