[Mlir-commits] [mlir] [MLIR][Tensor] Add Destination style RewritePattern for DimOp. (PR #65780)

Amy Wang llvmlistbot at llvm.org
Fri Sep 8 14:59:01 PDT 2023


https://github.com/kaitingwang updated https://github.com/llvm/llvm-project/pull/65780:

>From d9ddbab0cff73ffbde4c93f4e380e3bb4ac48c67 Mon Sep 17 00:00:00 2001
From: Amy Wang <kai.ting.wang at huawei.com>
Date: Fri, 8 Sep 2023 09:50:46 -0400
Subject: [PATCH] [MLIR][Tensor] Add Destination style RewritePattern for
 DimOp.

Fold dim of a destination passing op with dim of the corresponding init.
This enables canonicalization to fold away unnecessary tensor.dim ops
which in turn enables folding away of other operations, as can be
seen in conv_tensors_dynamic where affine.min operations were
folded away.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 23 +++++++++++++++++-
 mlir/test/Dialect/Linalg/canonicalize.mlir    | 24 +++++++++++++++++--
 .../Dialect/Linalg/tile-and-fuse-tensors.mlir | 12 +++-------
 .../Linalg/transform-tile-reduction.mlir      | 14 +++++------
 4 files changed, 54 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 42d89cd5a76208a..40189b444a8aafa 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -579,11 +579,32 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
     return success();
   }
 };
+
+/// Fold dim of a destination passing style op into the dim of the corresponding
+/// init
+struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    auto source = dimOp.getSource();
+    auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
+    if (!destOp)
+      return failure();
+
+    auto resultIndex = source.cast<OpResult>().getResultNumber();
+    auto initOperand = destOp.getDpsInitOperand(resultIndex);
+
+    rewriter.updateRootInPlace(
+        dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
+    return success();
+  }
+};
 } // namespace
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<DimOfCastOp>(context);
+  results.add<DimOfCastOp, DimOfDestStyleOp>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 783660727ce1638..297b5c4e332c811 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -397,9 +397,8 @@ func.func @fold_static_pad_fill() -> tensor<412x276xf32> {
 
 //  CHECK-DAG:   %[[I1:.+]] = arith.constant 1 : index
 //  CHECK-DAG:   %[[F0:.+]] = arith.constant 0.000000e+00 : f32
-//      CHECK:   %[[OF:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[SRC]] : tensor<8x?x16x32xf32>)
 //      CHECK:   %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
-//      CHECK:   %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32>
+//      CHECK:   %[[DIM1:.+]] = tensor.dim %[[SRC]], %[[I1]] : tensor<8x?x16x32xf32>
 //      CHECK:   %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
 //      CHECK:   %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
 //      CHECK:   %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
@@ -908,3 +907,24 @@ func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
     ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
   return %arg0 : tensor<16x64x256xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_dim_of_dest_style_op
+//       CHECK: tensor.dim
+//       CHECK: tensor.dim
+//   CHECK-NOT: tensor.dim
+//       CHECK: return
+func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0_0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %dim1_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %0 = tensor.empty(%dim0_0, %dim1_0) : tensor<?x?xf32>
+  %1 = linalg.copy ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %dim0_1 = tensor.dim %1, %c0 : tensor<?x?xf32>
+  %dim1_1 = tensor.dim %1, %c1 : tensor<?x?xf32>
+  %2 = tensor.empty(%dim0_1, %dim1_1) : tensor<?x?xf32>
+  %3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %3: tensor<?x?xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 6f21e1e20c3d4e6..0f27a92c119cf42 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -197,10 +197,8 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
 // CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
 // CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
 // CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * -2 + s0 * 2 + s1 - 2, d1 * 2 + s1 - 2)>
-// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)>
 // CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
 // CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
-// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 4)>
 // CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, -d1 + s1, 2)>
 
 //      CHECK: func @conv_tensors_dynamic
@@ -225,8 +223,6 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
 //  CHECK-DAG:   %[[FILTER_OC:.+]] = tensor.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
 //  CHECK-DAG:   %[[INPUT_N:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x?x?xf32>
 //  CHECK-DAG:   %[[INPUT_C:.+]] = tensor.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
-//  CHECK-DAG:   %[[FILL_H:.+]] = tensor.dim %[[FILL]], %[[C1]] : tensor<?x?x?x?xf32>
-//  CHECK-DAG:   %[[FILL_W:.+]] = tensor.dim %[[FILL]], %[[C2]] : tensor<?x?x?x?xf32>
 
 //      CHECK:   scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
 // CHECK-NEXT:     %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
@@ -234,14 +230,12 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
 // CHECK-NEXT:     scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]]
 // CHECK-NEXT:       %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
 // CHECK-NEXT:       %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
-// CHECK-NEXT:       %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[FILL_H]], %[[FILTER_H]]]
-// CHECK-NEXT:       %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]]
+// CHECK-NEXT:       %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[ELEM_OH]], %[[FILTER_H]]]
 // CHECK-NEXT:       scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]]
 // CHECK-NEXT:         %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
 // CHECK-NEXT:         %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
 // CHECK-NEXT:         %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]])
-// CHECK-NEXT:         %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[FILL_W]], %[[FILTER_W]]]
-// CHECK-NEXT:         %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
+// CHECK-NEXT:         %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[ELEM_OW]], %[[FILTER_W]]]
 // CHECK-NEXT:         %[[ST_INPUT:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
 // CHECK-SAME:               [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
 // CHECK-NEXT:         scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
@@ -253,7 +247,7 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
 // CHECK-NEXT:           %[[ST_FILTER:.+]] = tensor.extract_slice %[[FILTER]][0, 0, 0, %[[IV3]]]
 // CHECK-SAME:                 [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
 // CHECK-NEXT:           %[[ST_FILL:.+]] = tensor.extract_slice %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
-// CHECK-SAME:                 [%[[SIZE_INPUT_N]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_2]]]
+// CHECK-SAME:                 [%[[SIZE_INPUT_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC_2]]]
 // CHECK-NEXT:           %[[ST_CONV:.+]] = linalg.conv_2d_nhwc_hwcf
 // CHECK-SAME:                 ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
 // CHECK-SAME:                 outs(%[[ST_FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 70e535b74f055bc..934be889cecb20f 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -43,9 +43,7 @@ transform.sequence failures(propagate) {
 //     CHECK:       arith.addf
 //     CHECK:       linalg.yield
 //     CHECK:     } -> tensor<?x?xf32>
-//     CHECK:     %[[D3:.*]] = tensor.dim %[[PR]], %[[C0]] : tensor<?x?xf32>
-//     CHECK:     %[[D4:.*]] = tensor.dim %[[PR]], %[[C1]] : tensor<?x?xf32>
-//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
+//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
 //     CHECK:     scf.yield %[[INS]] : tensor<?x5xf32>
 //     CHECK:   }
 //     CHECK:   %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
@@ -76,14 +74,16 @@ transform.sequence failures(propagate) {
     by tile_sizes = [5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 }
 
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
 //     CHECK: func @reduction_tile_transpose
 //     CHECK:   tensor.empty(%{{.*}}) : tensor<5x?xf32>
 //     CHECK:   linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
 //     CHECK:   scf.for
-//     CHECK:     linalg.generic
-//     CHECK:     %[[D3:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x?xf32>
-//     CHECK:     %[[D4:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<?x?xf32>
-//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
+//     CHECK:     %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
+//     CHECK:     %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
+//     CHECK:     %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
 //     CHECK:     scf.yield {{.*}} : tensor<5x?xf32>
 //     CHECK:   }
 //     CHECK:   linalg.generic



More information about the Mlir-commits mailing list