[Mlir-commits] [mlir] a0e0201 - [mlir][linalg] Improve codegen when tiling PadTensor evenly
Matthias Springer
llvmlistbot at llvm.org
Wed Jul 14 19:42:19 PDT 2021
Author: Matthias Springer
Date: 2021-07-15T11:29:21+09:00
New Revision: a0e02018beb81946397f577f14df09e4b3b675da
URL: https://github.com/llvm/llvm-project/commit/a0e02018beb81946397f577f14df09e4b3b675da
DIFF: https://github.com/llvm/llvm-project/commit/a0e02018beb81946397f577f14df09e4b3b675da.diff
LOG: [mlir][linalg] Improve codegen when tiling PadTensor evenly
Produce simpler IR with more static type information and fewer affine expressions.
Differential Revision: https://reviews.llvm.org/D105530
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
mlir/test/Dialect/Linalg/tile.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 4aa7792eca90..5418bc3e3855 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -494,6 +494,7 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
+ tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
@@ -513,7 +514,15 @@ static void insertTilingPatterns(RewritePatternSet &patterns,
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>::insert(patterns, options);
patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options);
+}
+
+static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
+ MLIRContext *ctx = funcOp.getContext();
+ RewritePatternSet patterns(ctx);
patterns.add<ExtractSliceOfPadTensorSwapPattern>(patterns.getContext());
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(
+ funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
}
static void
@@ -527,6 +536,7 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(ctx);
insertTilingPatterns(patterns, options);
+ patterns.add<AffineMinSCFCanonicalizationPattern>(patterns.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsAndFoldGreedily(
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
@@ -534,6 +544,10 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
funcOp.walk([](LinalgOp op) {
op->removeAttr(LinalgTransforms::kLinalgTransformMarker);
});
+
+ // Apply swap pattern after generating loop nest and running
+ // canonicalizations.
+ applyExtractSliceOfPadTensorSwapPattern(funcOp);
}
namespace {
diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
index 10f4dc3e34c3..36dc34a29657 100644
--- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
+++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir
@@ -92,3 +92,33 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
} : tensor<7x9xf32> to tensor<15x16xf32>
return %0 : tensor<15x16xf32>
}
+
+// -----
+
+// TILE1-LABEL: func @static_pad_tile_evenly(
+// TILE1-SAME: %[[IN:.*]]: tensor<7x9xf32>, %[[OUT:.*]]: tensor<14x15xf32>
+// TILE1-DAG: %[[C0:.*]] = constant 0 : index
+// TILE1-DAG: %[[C3:.*]] = constant 3 : index
+// TILE1-DAG: %[[C15:.*]] = constant 15 : index
+// TILE1: %[[RESULT:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C15]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+// TILE1: %[[R2:.*]] = scf.if
+// TILE1: %[[GEN:.*]] = tensor.generate
+// TILE1: scf.yield %[[GEN]] : tensor<14x3xf32>
+// TILE1: else
+// TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
+// TILE1: %[[PAD:.*]] = linalg.pad_tensor %8 low[0, 0] high[7, %{{.*}}]
+// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
+// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32>
+// TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
+// TILE1: scf.yield %[[R3]] : tensor<14x15xf32>
+// TILE1: return %[[RESULT]] : tensor<14x15xf32>
+func @static_pad_tile_evenly(%input_tensor: tensor<7x9xf32>,
+ %output_tensor: tensor<14x15xf32>,
+ %pad_value: f32) -> tensor<14x15xf32> {
+ %0 = linalg.pad_tensor %input_tensor
+ low[0, 0] high[7, 6] into %output_tensor {
+ ^bb0(%arg1: index, %arg2: index):
+ linalg.yield %pad_value : f32
+ } : tensor<7x9xf32> to tensor<14x15xf32>
+ return %0 : tensor<14x15xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index 47d6dc1c3896..97b17ebb7bdc 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -20,10 +20,6 @@
// TILE-234-DAG: #[[$bound_map_3:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)>
// TILE-234-DAG: #[[$bound_map_4:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
-// TILE-2-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 10)>
-// TILE-02-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 12)>
-// TILE-002-DAG: #[[$bound_map_static:.*]] = affine_map<(d0) -> (2, -d0 + 16)>
-
// TILE-2-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
// TILE-02-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
// TILE-234-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)>
@@ -132,10 +128,8 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-2-DAG: %[[C2:.*]] = constant 2 : index
// TILE-2-DAG: %[[M:.*]] = constant 10 : index
// TILE-2: scf.for %[[I:.*]] = %{{.*}} to %[[M]] step %{{.*}} {
-// TILE-2: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[I]])
-// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<?x16xf32, #[[$strided2D]]>
-// TILE-2: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]])
-// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
+// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x16xf32, #[[$strided2D]]>
+// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]>
// TILE-2: linalg.matmul ins(%[[sAi]], %{{.*}}{{.*}} outs(%[[sCi]]
// TILE-02-LABEL: func @matmul_static(
@@ -143,10 +137,8 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-02-DAG: %[[C2:.*]] = constant 2 : index
// TILE-02-DAG: %[[N:.*]] = constant 12 : index
// TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} {
-// TILE-02: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[J]])
-// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]>
-// TILE-02: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]])
-// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
+// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, 2] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x2xf32, #[[$strided2D]]>
+// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, 2] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]>
// TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]]
// TILE-002-LABEL: func @matmul_static(
@@ -154,10 +146,8 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-002-DAG: %[[C2:.*]] = constant 2 : index
// TILE-002-DAG: %[[C16:.*]] = constant 16 : index
// TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} {
-// TILE-002: %[[MIN2:.*]] = affine.min #[[$bound_map_static]](%[[K]])
-// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
-// TILE-002: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]])
-// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x12xf32, #[[$strided2D]]>
+// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, 2] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]>
+// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [2, 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]>
// TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
// TILE-234-LABEL: func @matmul_static(
@@ -171,9 +161,9 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
// TILE-234: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[C10]] step %{{.*}} {
// TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[C12]] step %{{.*}} {
// TILE-234: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} {
-// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
-// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<?x?xf32, #[[$strided2D]]>
+// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [2, 4] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x4xf32, #[[$strided2D]]>
+// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [4, 3] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<4x3xf32, #[[$strided2D]]>
+// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x3xf32, #[[$strided2D]]>
//
// TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]]
@@ -312,7 +302,7 @@ func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) {
// TILE-234: for
// TILE-234-NOT: for
// TILE-234: memref.subview{{.*}} : memref<127x99xf32>
-// TILE-234: linalg.fill{{.*}} : f32, memref<?x?xf32, #[[$stride_99_1_layout_map]]>
+// TILE-234: linalg.fill{{.*}} : f32, memref<?x3xf32, #[[$stride_99_1_layout_map]]>
func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {
More information about the Mlir-commits
mailing list