[Mlir-commits] [mlir] 0ccc44c - [mlir][linalg] Fix tile and fuse for outermost reduction.

Tobias Gysi llvmlistbot at llvm.org
Mon Nov 22 02:45:25 PST 2021


Author: Tobias Gysi
Date: 2021-11-22T10:44:15Z
New Revision: 0ccc44cec067abbc702d5d3afb44e0395c55820d

URL: https://github.com/llvm/llvm-project/commit/0ccc44cec067abbc702d5d3afb44e0395c55820d
DIFF: https://github.com/llvm/llvm-project/commit/0ccc44cec067abbc702d5d3afb44e0395c55820d.diff

LOG: [mlir][linalg] Fix tile and fuse for outermost reduction.

Tile and fuse failed if the outermost tile loop is a reduction dimension. Add the necessary check to handle outermost reductions and introduce a test case to verify the change.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114012

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 74f42d3d26509..25aee5f23a513 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -225,9 +225,9 @@ class TileLoopNest {
   LogicalResult tileRootOp(OpBuilder &b, ArrayRef<int64_t> tileSizes,
                            ArrayRef<int64_t> tileInterchange);
 
-  /// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the
-  /// fused producer of fails if fusion is not possible.
-  FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *rootOpOperand);
+  /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns
+  /// the fused producer or fails if fusion is not possible.
+  FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand);
 
   /// Returns the replacement results for the original untiled root operation.
   ValueRange getRootOpReplacementResults();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 7156515cedae0..00904a48712b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -317,8 +317,11 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
 
 FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
                                                OpOperand *consumerOpOperand) {
-  assert(tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) != 0 &&
-         "expect the operand owner is the root operation or a fused producer");
+  // Check if the consumer has been tiled before. For example, it may not have
+  // been tiled if the outermost tile loop is a reduction loop.
+  if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0)
+    return failure();
+
   assert(this->isValid() &&
          "expect the tile loop nest to satisfy all invariants");
 

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
index edc14c42b6fb4..90b9ad60b97f3 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir
@@ -232,6 +232,41 @@ builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
 
 // -----
 
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1) -> (d0)>
+
+//      CHECK:  fuse_outermost_reduction
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32>
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32>
+func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>,
+                               %arg1: tensor<10xf32>) -> tensor<10xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32>
+
+  // Cannot fuse the output fill since the reduction loop is the outermost loop.
+  //      CHECK:      %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]])
+  %1 = linalg.fill(%cst, %arg1) : f32, tensor<10xf32> -> tensor<10xf32>
+
+  //      CHECK:  scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]]
+  //      CHECK:    scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]]
+
+  // Check the input fill has been fused.
+  //      CHECK:      %[[T1:.*]] = tensor.extract_slice %[[ARG0]]
+  // CHECK-SAME:                                        %[[IV1]], %[[IV0]]
+  //      CHECK:      %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]])
+  //      CHECK:      %[[T3:.*]] = tensor.extract_slice %[[ARG3]]
+  // CHECK-SAME:                                        %[[IV1]]
+  //      CHECK:  linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]]
+  %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<10x17xf32>) outs(%1 : tensor<10xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):  // no predecessors
+    %3 = arith.addf %arg2, %arg3 : f32
+    linalg.yield %3 : f32
+  } -> tensor<10xf32>
+  return %2 : tensor<10xf32>
+}
+
+// -----
+
 //  CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
 //  CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 17)>
 //  CHECK-DAG:  #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 17)>


        


More information about the Mlir-commits mailing list