[Mlir-commits] [mlir] 56796ae - [mlir][linalg] Fix tensor tiling together with interchange

Lei Zhang llvmlistbot at llvm.org
Fri Jul 15 10:54:54 PDT 2022


Author: Lei Zhang
Date: 2022-07-15T13:54:47-04:00
New Revision: 56796ae1a8db4c85dada28676f8303a5a3609c63

URL: https://github.com/llvm/llvm-project/commit/56796ae1a8db4c85dada28676f8303a5a3609c63
DIFF: https://github.com/llvm/llvm-project/commit/56796ae1a8db4c85dada28676f8303a5a3609c63.diff

LOG: [mlir][linalg] Fix tensor tiling together with interchange

In `linalg::tileConsumerAndFuseProducers`, there are two levels of
tiling and fusion; we partition the tile sizes and only use one
half for each of them. The partition is using the first non-parallel
dimension *after* interchange as the boundary. However, concrete
tiling happens *together with* loop interchange, so we still need
to provide the partial tile sizes *before* the interchange.
Otherwise, there will be inconsistency, which is what this patch
is to fix.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D129804

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index d968b374c736b..66a558ce8cfaf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -454,19 +454,24 @@ FailureOr<TileLoopNest> mlir::linalg::tileConsumerAndFuseProducers(
     }
   };
 
+  // Perform tiling and fusion in two steps. We need to respect the loop
+  // interchange here; filter parellel dimensions based on their order *after*
+  // permutation but pass in the original configuration *before* permuation,
+  // given the tiling and interchange happen together.
+  SmallVector<int64_t> outerTileSizes(tileSizes.size(), 0);
+  SmallVector<int64_t> innerTileSizes(tileSizes.size(), 0);
+  for (int64_t i : tileInterchange.take_front(split))
+    outerTileSizes[i] = tileSizes[i];
+  for (int64_t i : tileInterchange.drop_front(split))
+    innerTileSizes[i] = tileSizes[i];
+
   // Tile the outer parallel loops and fuse the output operands.
-  SmallVector<int64_t> outerTileSizes;
-  outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
-  outerTileSizes.append(tileSizes.size() - split, 0);
   if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange,
                                      tileDistribution)))
     return failure();
   fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
 
   // Tile the remaining loops and fuse the input operands.
-  SmallVector<int64_t> innerTileSizes;
-  innerTileSizes.append(split, 0);
-  innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
   if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange,
                                      tileDistribution)))
     return failure();

diff  --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index af6da5d7eaebb..1d4d62045c769 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -68,3 +68,54 @@ transform.with_pdl_patterns {
     transform.loop.peel %loops#0
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @interchange_reduction
+//  CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>)
+func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> {
+  %five = arith.constant 5.0 : f32
+  %init = linalg.init_tensor [12, 25] : tensor<12x25xf32>
+
+//   CHECK-DAG:  %[[C4:.+]] = arith.constant 4 : index
+//   CHECK-DAG:  %[[C5:.+]] = arith.constant 5 : index
+//   CHECK-DAG:  %[[C7:.+]] = arith.constant 7 : index
+
+//       CHECK: %[[INIT:.+]] = linalg.init_tensor [12, 25]
+//       CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]])
+//       CHECK:   scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C7]] iter_args(%[[FOR_ARG1:.+]] = %[[FOR_ARG0]])
+//       CHECK:     %[[OUT_SLICE0:.+]] = tensor.extract_slice %[[FOR_ARG1]][%[[IV0]], %[[IV1]]]
+//       CHECK:     %[[FILL:.+]] = linalg.fill {{.+}} outs(%[[OUT_SLICE0]] : tensor<?x?xf32>)
+//       CHECK:     scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %[[C4]] iter_args(%[[FOR_ARG2:.+]] = %[[FILL]])
+//       CHECK:       %[[IN_SLICE:.+]] = tensor.extract_slice %[[INPUT]]
+//       CHECK:       %[[OUT_SLICE2:.+]] = tensor.extract_slice %[[FOR_ARG2]][0, 0]
+//       CHECK:       linalg.generic {{.+}} ins(%[[IN_SLICE]] : tensor<?x?x?xf32>) outs(%[[OUT_SLICE2]] : tensor<?x?xf32>)
+
+  %fill = linalg.fill ins(%five : f32) outs(%init : tensor<12x25xf32>) -> tensor<12x25xf32>
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>],
+    iterator_types = ["parallel", "reduction", "parallel"]
+  } ins(%input : tensor<12x7x25xf32>) outs(%fill : tensor<12x25xf32>) {
+  ^bb0(%arg0: f32, %arg1: f32):
+    %2 = arith.addf %arg0, %arg1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<12x25xf32>
+  func.return %0 : tensor<12x25xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  pdl.pattern @pdl_target : benefit(1) {
+    %args = operands
+    %results = types
+    %0 = pdl.operation "linalg.generic"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
+    // TODO: we don't want this, but it is the required terminator for pdl.pattern
+    rewrite %0 with "transform.dialect"
+  }
+
+  transform.sequence %arg0 {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = pdl_match @pdl_target in %arg1
+    %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [5, 4, 7], tile_interchange = [0, 2, 1]}
+  }
+}


        


More information about the Mlir-commits mailing list