[Mlir-commits] [mlir] 24199f5 - [mlir][linalg] Lower subtensor(pad_tensor) to pad_tensor(subtensor)
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 18 21:45:11 PDT 2021
Author: Matthias Springer
Date: 2021-06-19T13:44:47+09:00
New Revision: 24199f534f61d9ac7d2d9dcde7b9cac93c84d4f0
URL: https://github.com/llvm/llvm-project/commit/24199f534f61d9ac7d2d9dcde7b9cac93c84d4f0
DIFF: https://github.com/llvm/llvm-project/commit/24199f534f61d9ac7d2d9dcde7b9cac93c84d4f0.diff
LOG: [mlir][linalg] Lower subtensor(pad_tensor) to pad_tensor(subtensor)
Only high padding is supported at the moment. Low padding will be added in a separate commit.
Differential Revision: https://reviews.llvm.org/D104357
Added:
mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 9df8fbb2e4693..8841af104a360 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1076,6 +1076,15 @@ LogicalResult applyStagedPatterns(
const FrozenRewritePatternSet &stage2Patterns,
function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
+/// Rewrite subtensor(pad_tensor(x)) into pad_tensor(subtensor(x)).
+struct SubTensorOfPadTensorSwapPattern
+ : public OpRewritePattern<SubTensorOp> {
+ using OpRewritePattern<SubTensorOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
+ PatternRewriter &rewriter) const override;
+};
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index efd0c3b2079d1..4c2df05f52cb1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -700,3 +700,225 @@ LogicalResult PadTensorOpTransformationPattern::matchAndRewrite(
return success();
}
+
+/// Given an OpFoldResult, return a Value. If the OpFoldResult is an Attribute,
+/// it must be of type Integer.
+static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) {
+ if (auto val = ofr.dyn_cast<Value>())
+ return val;
+ auto intVal = getConstantIntValue(ofr);
+ assert(intVal && "expected Value or IntegerAttr");
+ return builder.create<ConstantIndexOp>(loc, *intVal);
+}
+
+/// Given a value, try to extract a constant index-type integer as an Attribute.
+/// If this fails, return the original value.
+static OpFoldResult asOpFoldResult(OpBuilder &builder, Value val) {
+ if (auto constInt = getConstantIntValue(val))
+ return builder.getIndexAttr(*constInt);
+ return val;
+}
+
+LogicalResult SubTensorOfPadTensorSwapPattern::matchAndRewrite(
+ SubTensorOp subTensorOp, PatternRewriter &rewriter) const {
+ auto padOp = subTensorOp.source().getDefiningOp<PadTensorOp>();
+ if (!padOp)
+ return failure();
+ // Only unit stride supported.
+ if (!subTensorOp.hasUnitStride())
+ return failure();
+ // Only constant padding value supported.
+ Value padValue = padOp.getConstantPaddingValue();
+ if (!padValue)
+ return failure();
+ // Only zero low padding supported at the moment.
+ if (!padOp.hasZeroLowPad())
+ return failure();
+
+ // Helper variables and functions for various arithmetic operations. These are
+ // used extensively for computing new offset/length and padding values.
+ Location loc = subTensorOp.getLoc();
+ AffineExpr dim0, dim1;
+ bindDims(rewriter.getContext(), dim0, dim1);
+ // Add two integers.
+ auto addMap = AffineMap::get(2, 0, {dim0 + dim1});
+ auto add = [&](Value v1, Value v2) {
+ return rewriter.createOrFold<AffineApplyOp>(loc, addMap,
+ ValueRange{v1, v2});
+ };
+ // Subtract two integers.
+ auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+ auto sub = [&](Value v1, Value v2) {
+ return rewriter.createOrFold<AffineApplyOp>(loc, subMap,
+ ValueRange{v1, v2});
+ };
+ // Take the minimum of two integers.
+ auto idMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
+ auto min = [&](Value v1, Value v2) {
+ return rewriter.createOrFold<AffineMinOp>(loc, idMap, ValueRange{v1, v2});
+ };
+ // Take the maximum of two integers.
+ auto max = [&](Value v1, Value v2) {
+ return rewriter.createOrFold<AffineMaxOp>(loc, idMap, ValueRange{v1, v2});
+ };
+ // Zero index-typed integer.
+ auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
+
+ // Helper function for filling static/dynamic low/high padding indices vectors
+ // of PadTensorOp.
+ auto appendIndex = [&](Value val, SmallVector<Value> &dynIndices,
+ SmallVector<int64_t> &staticIndices) {
+ if (auto constInt = getConstantIntValue(val)) {
+ staticIndices.push_back(*constInt);
+ } else {
+ staticIndices.push_back(ShapedType::kDynamicSize);
+ dynIndices.push_back(val);
+ }
+ };
+
+ // Compute new offsets, lengths, low padding, high padding.
+ SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
+ SmallVector<Value> newLows, newHighs;
+ SmallVector<int64_t> staticNewLows, staticNewHighs;
+ // Set to true if the original data source is not read at all.
+ bool hasZeroLen = false;
+ // Same as hasZeroLen, but for dynamic dimension sizes. This condition
+ // is true if the original data source turns out to be unused at runtime.
+ Value dynHasZeroLenCond;
+
+ int64_t rank = padOp.getSourceType().getRank();
+ for (unsigned dim = 0; dim < rank; ++dim) {
+ auto offset = asValue(rewriter, loc, subTensorOp.getMixedOffsets()[dim]);
+ auto length = asValue(rewriter, loc, subTensorOp.getMixedSizes()[dim]);
+ auto srcSize = rewriter.createOrFold<memref::DimOp>(
+ loc, padOp.source(), dim);
+
+ // Existing low padding is zero, so new low padding is also zero.
+ Value newLow = zero;
+ appendIndex(newLow, newLows, staticNewLows);
+
+ // There is no low padding, so the offset remains unchanged. Except for the
+ // case where the SubTensorOp starts reading from a position within the high
+ // padding. In that case, set the offset to the end of source tensor. The
+ // new SubTensorOp length will be zero in that case. (Effectively reading no
+ // data from the source.)
+ Value newOffset = min(offset, srcSize);
+ newOffsets.push_back(asOpFoldResult(rewriter, newOffset));
+
+ // The new SubTensorOp starts reading at `newOffset` and reads until
+ // `offset + length`. This position may be outside of the source (i.e.,
+ // within the high padding). In that case, read only until the end of the
+ // source. In mathematical terms:
+ //
+ // endLoc = min(offset + length, srcSize)
+ //
+ // The new SubTensorOp length is `endLoc - newOffset`.
+ Value newLength = sub(min(add(offset, length), srcSize), newOffset);
+ newLengths.push_back(asOpFoldResult(rewriter, newLength));
+ if (auto newLengthInt = getConstantIntValue(newLength)) {
+ hasZeroLen |= *newLengthInt == 0;
+ } else {
+ Value check = rewriter.create<CmpIOp>(
+ loc, CmpIPredicate::eq, newLength, zero);
+ dynHasZeroLenCond = dynHasZeroLenCond
+ ? rewriter.create<AndOp>(loc, check, dynHasZeroLenCond) : check;
+ }
+
+ // The number of elements available to read from the source (starting from
+ // the new offset) is `maxRead = srcSize - newOffset`. The original
+ // SubTensorOp may have read a larger number of elements `length > maxRead`.
+ // In that case, the missing number of elements `length - maxRead` must be
+ // paddded. (If `maxRead > length`, more than enough data is available to
+ // read and no high padding is needed.)
+ Value newHigh = max(zero, add(sub(newOffset, srcSize), length));
+ appendIndex(newHigh, newHighs, staticNewHighs);
+
+ // Only unit stride supported.
+ newStrides.push_back(rewriter.getIndexAttr(1));
+ }
+
+ // Insert cast to ensure that types match. (May be folded away.)
+ auto castResult = [&](Value val) -> Value {
+ auto castOp = rewriter.create<tensor::CastOp>(
+ loc, subTensorOp.getType(), val);
+ return castOp;
+ };
+
+ // In cases where the original data source is unused: Emit a GenerateOp and
+ // do not generate a SubTensorOp. (The result shape of the SubTensorOp would
+ // have a dimension of size 0, the semantics of which is unclear.)
+ auto createGenerateOp = [&]() {
+ // The shape of the GenerateOp is the same as the existing SubTensorOp.
+ RankedTensorType type = subTensorOp.getType();
+ SmallVector<Value> dynDims;
+ for (unsigned i = 0; i < type.getRank(); ++i) {
+ if (type.isDynamicDim(i))
+ dynDims.push_back(
+ asValue(rewriter, loc, subTensorOp.getMixedOffsets()[i]));
+ }
+
+ // Create GenerateOp.
+ auto generateOp = rewriter.create<tensor::GenerateOp>(loc, type, dynDims);
+
+ // Copy region to new op.
+ BlockAndValueMapping bvm;
+ padOp.region().cloneInto(&generateOp.getRegion(), bvm);
+ // Rewrite linalg::YieldOp to tensor::YieldOp.
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto yieldOp = dyn_cast<linalg::YieldOp>(
+ generateOp.getRegion().front().getTerminator());
+ assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator");
+ assert(yieldOp.values().size() == 1);
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<tensor::YieldOp>(
+ yieldOp, yieldOp.values()[0]);
+ }
+
+ return castResult(generateOp);
+ };
+
+ // Emit a SubTensorOp and a PadTensorOp. Should not be used in cases where
+ // the result shape of the new SubTensorOp has a zero dimension.
+ auto createPadTensorOfSubTensor = [&]() {
+ // Create pad_tensor(subtensor(x)).
+ auto newSubTensorOp = rewriter.create<SubTensorOp>(
+ loc, padOp.source(), newOffsets, newLengths, newStrides);
+ auto newPadTensorOp = rewriter.create<PadTensorOp>(
+ loc, newSubTensorOp, staticNewLows, staticNewHighs, newLows, newHighs);
+
+ // Copy region to new PadTensorOp.
+ BlockAndValueMapping bvm;
+ padOp.region().cloneInto(&newPadTensorOp.getRegion(), bvm);
+
+ // Cast result and return.
+ return castResult(newPadTensorOp);
+ };
+
+ // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known
+ // that the original data source x is not used.
+ if (hasZeroLen) {
+ rewriter.replaceOp(subTensorOp, createGenerateOp());
+ return success();
+ }
+
+ // If there are dynamic dimensions: Generate an scf.if check to avoid creating
+ // SubTensorOps with result dimensions of size 0 at runtime.
+ if (dynHasZeroLenCond) {
+ auto result = rewriter.create<scf::IfOp>(
+ loc, subTensorOp.getType(), dynHasZeroLenCond,
+ /*thenBuilder=*/[&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, createGenerateOp());
+ },
+ /*elseBuilder=*/[&](OpBuilder &b, Location loc) {
+ b.create<scf::YieldOp>(loc, createPadTensorOfSubTensor());
+ });
+ rewriter.replaceOp(subTensorOp, result.getResult(0));
+ return success();
+ }
+
+ // All shapes are static and the data source is actually used. Rewrite into
+ // pad_tensor(subtensor(x)).
+ rewriter.replaceOp(subTensorOp, createPadTensorOfSubTensor());
+ return success();
+}
diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
new file mode 100644
index 0000000000000..7d9c770946e63
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
@@ -0,0 +1,80 @@
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-swap-subtensor-padtensor -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @static_data_only(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>
+// CHECK: %[[RESULT:.*]] = subtensor %[[ARG0]][1, 2] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32>
+// CHECK: return %[[RESULT]]
+func @static_data_only(%arg0 : tensor<4x5xf32>, %pad : f32)
+ -> tensor<2x1xf32> {
+ %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad : f32
+ } : tensor<4x5xf32> to tensor<11x13xf32>
+ %1 = subtensor %0[1, 2] [2, 1] [1, 1] : tensor<11x13xf32> to tensor<2x1xf32>
+ return %1 : tensor<2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @static_high_pad_only
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
+// CHECK-NOT: linalg.pad_tensor
+// CHECK-NOT: subtensor
+// CHECK: %[[RESULT:.*]] = tensor.generate
+// CHECK: tensor.yield %[[PAD]]
+// CHECK: return %[[RESULT]] : tensor<2x4xf32>
+func @static_high_pad_only(%arg0 : tensor<4x5xf32>, %pad : f32)
+ -> tensor<2x4xf32> {
+ %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad : f32
+ } : tensor<4x5xf32> to tensor<11x13xf32>
+ %1 = subtensor %0[4, 5] [2, 4] [1, 1] : tensor<11x13xf32> to tensor<2x4xf32>
+ return %1 : tensor<2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @static_mixed_data_high_pad
+// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32
+// CHECK-NOT: linalg.pad_tensor
+// CHECK: %[[SUBTENSOR:.*]] = subtensor %[[ARG0]][2, 4] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32>
+// CHECK: %[[RESULT:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[1, 3]
+// CHECK: linalg.yield %[[PAD]]
+// CHECK: return %[[RESULT]] : tensor<3x4xf32>
+func @static_mixed_data_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
+ -> tensor<3x4xf32> {
+ %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad : f32
+ } : tensor<4x5xf32> to tensor<11x13xf32>
+ %1 = subtensor %0[2, 4] [3, 4] [1, 1] : tensor<11x13xf32> to tensor<3x4xf32>
+ return %1 : tensor<3x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dynamic_high_pad
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>
+// CHECK-NOT: linalg.pad_tensor
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[RESULT:.*]] = scf.if %{{.*}} -> (tensor<3x4xf32>) {
+// CHECK: %[[GEN:.*]] = tensor.generate
+// CHECK: scf.yield %[[GEN]]
+// CHECK: } else {
+// CHECK: %[[SUBTENSOR:.*]] = subtensor %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor<?x5xf32> to tensor<?x1xf32>
+// CHECK: %[[PADTENSOR:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3]
+// CHECK: %[[CAST:.*]] = tensor.cast %[[PADTENSOR]] : tensor<?x4xf32> to tensor<3x4xf32>
+// CHECK: scf.yield %[[CAST]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+func @dynamic_high_pad(%arg0 : tensor<?x5xf32>, %h1: index, %pad : f32) -> tensor<3x4xf32> {
+ %0 = linalg.pad_tensor %arg0 low[0, 0] high[%h1, 8] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad : f32
+ } : tensor<?x5xf32> to tensor<?x13xf32>
+ %1 = subtensor %0[2, 4] [3, 4] [1, 1] : tensor<?x13xf32> to tensor<3x4xf32>
+ return %1 : tensor<3x4xf32>
+}
+
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 0037db27610ef..402a26475a503 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -97,6 +97,11 @@ struct TestLinalgTransforms
*this, "test-transform-pad-tensor",
llvm::cl::desc("Test transform pad tensor by copying with generic ops"),
llvm::cl::init(false)};
+ Option<bool> testSwapSubTensorPadTensor{
+ *this, "test-swap-subtensor-padtensor",
+ llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into "
+ "pad_tensor(subtensor)"),
+ llvm::cl::init(false)};
ListOption<int64_t> tileSizesForPadding{
*this, "tile-sizes-for-padding",
llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore,
@@ -524,6 +529,12 @@ static void applyPadTensorToGenericPatterns(FuncOp funcOp) {
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
+static void applySubTensorOfPadTensorSwapPattern(FuncOp funcOp) {
+ RewritePatternSet patterns(funcOp.getContext());
+ patterns.add<SubTensorOfPadTensorSwapPattern>(funcOp.getContext());
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+}
+
static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
RewritePatternSet foldPattern(funcOp.getContext());
foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
@@ -602,6 +613,8 @@ void TestLinalgTransforms::runOnFunction() {
return applyLinalgToVectorPatterns(getFunction());
if (testTransformPadTensor)
return applyPadTensorToGenericPatterns(getFunction());
+ if (testSwapSubTensorPadTensor)
+ return applySubTensorOfPadTensorSwapPattern(getFunction());
if (testAffineMinSCFCanonicalizationPatterns)
return applyAffineMinSCFCanonicalizationPatterns(getFunction());
if (testTileAndPadPattern)
More information about the Mlir-commits
mailing list