[Mlir-commits] [mlir] 318ce4a - [mlir][linalg] Improve codegen of ExtractSliceOfPadTensorSwapPattern
Matthias Springer
llvmlistbot at llvm.org
Wed Jul 14 19:23:57 PDT 2021
Author: Matthias Springer
Date: 2021-07-15T11:05:55+09:00
New Revision: 318ce4ad927d129a2bf96c2c872f4d107c45bdef
URL: https://github.com/llvm/llvm-project/commit/318ce4ad927d129a2bf96c2c872f4d107c45bdef
DIFF: https://github.com/llvm/llvm-project/commit/318ce4ad927d129a2bf96c2c872f4d107c45bdef.diff
LOG: [mlir][linalg] Improve codegen of ExtractSliceOfPadTensorSwapPattern
Generate simpler code in case low/high padding of the PadTensorOp is statically zero.
Differential Revision: https://reviews.llvm.org/D105529
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 4c3bb414bbcf..8d3ee8fe5566 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -866,6 +866,9 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
int64_t rank = padOp.getSourceType().getRank();
for (unsigned dim = 0; dim < rank; ++dim) {
auto low = asValue(rewriter, loc, padOp.getMixedLowPad()[dim]);
+ bool hasLowPad = getConstantIntValue(low) != static_cast<int64_t>(0);
+ auto high = asValue(rewriter, loc, padOp.getMixedHighPad()[dim]);
+ bool hasHighPad = getConstantIntValue(high) != static_cast<int64_t>(0);
auto offset = asValue(rewriter, loc, sliceOp.getMixedOffsets()[dim]);
auto length = asValue(rewriter, loc, sliceOp.getMixedSizes()[dim]);
auto srcSize =
@@ -874,7 +877,9 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
// The new amount of low padding is `low - offset`. Except for the case
// where none of the low padding is read. In that case, the new amount of
// low padding is zero.
- Value newLow = max(zero, sub(low, offset));
+ //
+ // Optimization: If low = 0, then newLow = 0.
+ Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero;
appendIndex(newLow, newLows, staticNewLows);
// Start reading the data from position `offset - low`. Since the original
@@ -887,7 +892,10 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
// In that case, set the offset to the end of source tensor. The new
// ExtractSliceOp length will be zero in that case. (Effectively reading no
// data from the source.)
- Value newOffset = min(max(sub(offset, low), zero), srcSize);
+ //
+ // Optimization: If low = 0, then the formula can be simplified.
+ Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize)
+ : min(offset, srcSize);
newOffsets.push_back(getAsOpFoldResult(newOffset));
// The original ExtractSliceOp was reading until position `offset + length`.
@@ -906,7 +914,11 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
// endLoc = min(max(offset - low + length, 0), srcSize)
//
// The new ExtractSliceOp length is `endLoc - newOffset`.
- Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize);
+ //
+ // Optimization: If low = 0, then the formula can be simplified.
+ Value endLoc = hasLowPad
+ ? min(max(add(sub(offset, low), length), zero), srcSize)
+ : min(add(offset, length), srcSize);
Value newLength = sub(endLoc, newOffset);
newLengths.push_back(getAsOpFoldResult(newLength));
@@ -925,7 +937,9 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
// The amount of high padding is simply the number of elements remaining,
// so that the result has the same length as the original ExtractSliceOp.
- Value newHigh = sub(sub(length, newLength), newLow);
+ // As an optimization, if the original high padding is zero, then the new
+ // high padding must also be zero.
+ Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
appendIndex(newHigh, newHighs, staticNewHighs);
// Only unit stride supported.
diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
index 362de8ef1300..13f12d83133b 100644
--- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
+++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir
@@ -177,3 +177,43 @@ func @dynamic_extract_size(%arg0 : tensor<?x5xf32>, %s1: index, %pad : f32) -> t
%1 = tensor.extract_slice %0[2, 4] [%s1, 4] [1, 1] : tensor<?x13xf32> to tensor<?x4xf32>
return %1 : tensor<?x4xf32>
}
+
+// -----
+
+// CHECK-LABEL: @dynamic_zero_low_padding
+// CHECK: scf.if
+// CHECK: tensor.generate
+// CHECK: else
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice
+// CHECK: linalg.pad_tensor %[[SLICE]] low[0, 0]
+func @dynamic_zero_low_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
+ %o1 : index, %o2 : index,
+ %s1 : index, %s2 : index)
+ -> tensor<?x?xf32> {
+ %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad : f32
+ } : tensor<?x?xf32> to tensor<?x?xf32>
+ %1 = tensor.extract_slice %0[%o1, %o2] [%s1, %s2] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dynamic_zero_high_padding
+// CHECK: scf.if
+// CHECK: tensor.generate
+// CHECK: else
+// CHECK: %[[SLICE:.*]] = tensor.extract_slice
+// CHECK: linalg.pad_tensor %[[SLICE]] low[%{{.*}}, %{{.*}}] high[0, 0]
+func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
+ %o1 : index, %o2 : index,
+ %s1 : index, %s2 : index)
+ -> tensor<?x?xf32> {
+ %0 = linalg.pad_tensor %arg0 low[7, 8] high[0, 0] {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad : f32
+ } : tensor<?x?xf32> to tensor<?x?xf32>
+ %1 = tensor.extract_slice %0[%o1, %o2] [%s1, %s2] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list