[Mlir-commits] [mlir] 32c4324 - [mlir][linalg] Always generate an extract/insert slice pair when tiling output tensors.

Tobias Gysi llvmlistbot at llvm.org
Mon Nov 22 05:13:01 PST 2021


Author: Tobias Gysi
Date: 2021-11-22T13:12:43Z
New Revision: 32c43241e716280d3443d684416826b1e7e5781b

URL: https://github.com/llvm/llvm-project/commit/32c43241e716280d3443d684416826b1e7e5781b
DIFF: https://github.com/llvm/llvm-project/commit/32c43241e716280d3443d684416826b1e7e5781b.diff

LOG: [mlir][linalg] Always generate an extract/insert slice pair when tiling output tensors.

Adapt tiling to always generate an extract/insert slice pair for output tensors even if the tensor is not tiled. Having an explicit extract/insert slice pair simplifies followup transformations such as padding and bufferization. In particular, it makes read and written iteration argument slices explicit.

Depends On D114067

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114085

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Utils/Utils.cpp
    mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
    mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 60a500e2a12cf..cf0aee6bd2138 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -820,8 +820,12 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
     Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
     AffineMap map = linalgOp.getTiedIndexingMap(opOperand);
-    // If the shape is not tiled, we can use it as is.
-    if (!isTiled(map, tileSizes)) {
+    // Use `opOperand` as is if it is not tiled and not an output tensor. Having
+    // an extract/insert slice pair for all output tensors simplifies follow up
+    // transformations such as padding and bufferization since the
+    // extract/insert slice pairs make the accessed iteration argument
+    // subdomains explicit.
+    if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) {
       tiledShapes.push_back(shapedOp);
       LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
                               << opOperand->get().getType() << "\n");

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
index 8359f5cf79ab9..9eb0e35860f8f 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir
@@ -263,7 +263,9 @@ module {
 //       CHECK:     %[[ST_FILL:.*]] = linalg.fill(%[[C0]], %[[ST]]) {__internal_linalg_transform__ = "after_out_fusion_producer"} : f32, tensor<?x?xf32> -> tensor<?x?xf32>
 //       CHECK:     %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor<?x?xf32>) {
 //   CHECK-NOT:       fill
-//       CHECK:       %[[ST_MM:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[BB]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+//       CHECK:       %[[ST_FILL_SUB:.*]] = tensor.extract_slice %[[BB]][0, 0]
+//       CHECK:       %[[ST_MM_SUB:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ST_FILL_SUB]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+//       CHECK:       %[[ST_MM:.*]] = tensor.insert_slice %[[ST_MM_SUB]] into %[[BB]]
 //       CHECK:       scf.yield %[[ST_MM]] : tensor<?x?xf32>
 //       CHECK:     %[[MM:.*]] = tensor.insert_slice %[[ST_MM_RES]] into {{.*}}
 //       CHECK:     scf.yield %[[MM]] : tensor<?x?xf32>
@@ -307,11 +309,13 @@ module {
 
 // TLOOP:      %[[A_SUB_SUB:.*]] = tensor.extract_slice %[[A_SUB_]][0, %[[K]]]
 // TLOOP:      %[[B_SUB_SUB:.*]] = tensor.extract_slice %[[B_SUB_]][%[[K]], 0]
+// TLOOP:      %[[INIT_SUB_SUB:.*]] = tensor.extract_slice %[[INIT_SUB_]][0, 0]
 
 // TLOOP:      %[[AB_SUB_SUB:.*]] = linalg.matmul
 // TLOOP-SAME:   ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
-// TLOOP-SAME:   outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
-// TLOOP:      linalg.yield %[[AB_SUB_SUB]] : [[TY]]
+// TLOOP-SAME:   outs(%[[INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
+// TLOOP:      %[[AB_SUB_:.*]] = tensor.insert_slice %[[AB_SUB_SUB]] into %[[INIT_SUB_]]
+// TLOOP:      linalg.yield %[[AB_SUB_]] : [[TY]]
 // TLOOP:    }
 // TLOOP:    %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]]
 // TLOOP-SAME:  into %[[OUT_]][%[[I]], %[[J]]]
@@ -380,11 +384,13 @@ module {
 
 // TLOOP:      %[[A_SUB_SUB:.*]] = tensor.extract_slice %[[A_SUB_]][0, %[[K]]]
 // TLOOP:      %[[B_SUB_SUB:.*]] = tensor.extract_slice %[[B_SUB_]][%[[K]], 0]
+// TLOOP:      %[[INIT_SUB_SUB:.*]] = tensor.extract_slice %[[INIT_SUB_]][0, 0]
 
 // TLOOP:      %[[AB_SUB_SUB:.*]] = linalg.matmul
 // TLOOP-SAME:   ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]])
-// TLOOP-SAME:   outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]]
-// TLOOP:      linalg.yield %[[AB_SUB_SUB]] : [[TY]]
+// TLOOP-SAME:   outs(%[[INIT_SUB_SUB]] : [[TY]]) -> [[TY]]
+// TLOOP:      %[[AB_SUB_:.*]] = tensor.insert_slice %[[AB_SUB_SUB]] into %[[INIT_SUB_]]
+// TLOOP:      linalg.yield %[[AB_SUB_]] : [[TY]]
 // TLOOP:    }
 // TLOOP:    %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]]
 // TLOOP-SAME:  into %[[OUT_]][%[[I]], %[[J]]]

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
index e46218995ffb8..ec8d43b23de7c 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
@@ -47,6 +47,8 @@ builtin.func @fuse_input(%arg0: tensor<24x12xf32>,
 builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
                           %arg1: tensor<12x25xf32>,
                           %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> {
+  //  MATMUL-DAG:  %[[C0:.*]] = arith.constant 0 : index
+  //  MATMUL-DAG:  %[[C1:.*]] = arith.constant 1 : index
   %c0 = arith.constant 0 : index
   %c12 = arith.constant 12 : index
   %c25 = arith.constant 25 : index
@@ -67,7 +69,17 @@ builtin.func @fuse_output(%arg0: tensor<24x12xf32>,
   // MATMUL-SAME:                                        %[[TS1]], %[[TS0]]
   //      MATMUL:      %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
   //      MATMUL:        scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]]
-  //      MATMUL:          %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]]
+
+  // Check there is an extract/insert slice pair for the output operand.
+  //  MATMUL-DAG:          %[[D0:.*]] = tensor.dim %[[ARG5]], %[[C0]]
+  //  MATMUL-DAG:          %[[D1:.*]] = tensor.dim %[[ARG5]], %[[C1]]
+  //      MATMUL:          %[[T2:.*]] = tensor.extract_slice %[[ARG5]]
+  // MATMUL-SAME:                                            0, 0
+  // MATMUL-SAME:                                            %[[D0]], %[[D1]]
+  //      MATMUL:          %[[T3:.*]] = linalg.matmul {{.*}} outs(%[[T2]]
+  //      MATMUL:          %{{.*}} = tensor.insert_slice %[[T3]] into %[[ARG5]]
+  // MATMUL-SAME:                                            0, 0
+  // MATMUL-SAME:                                            %[[D0]], %[[D1]]
   %1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32>
   return %1 : tensor<24x25xf32>
 }
@@ -185,7 +197,8 @@ builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>,
   //      MATMUL:          %[[T2:.*]] = tensor.extract_slice %[[ARG0]]
   // MATMUL-SAME:                                            %[[IV1]], %[[IV2]]
   //      MATMUL:          %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]])
-  //      MATMUL:          %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]]
+  //      MATMUL:          %[[T4:.*]] = tensor.extract_slice %[[ARG5]]
+  //      MATMUL:          %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T4]]
   %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32>
   return %2 : tensor<24x25xf32>
 }


        


More information about the Mlir-commits mailing list