[Mlir-commits] [mlir] b8a1f00 - [mlir][TilingInterface] Add support for interchange to tiling patterns that use the `TilingInterface`.

Mahesh Ravishankar llvmlistbot at llvm.org
Tue Jul 19 22:25:29 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-07-20T05:24:17Z
New Revision: b8a1f00d414ece5597423449d433fbb26a217626

URL: https://github.com/llvm/llvm-project/commit/b8a1f00d414ece5597423449d433fbb26a217626
DIFF: https://github.com/llvm/llvm-project/commit/b8a1f00d414ece5597423449d433fbb26a217626.diff

LOG: [mlir][TilingInterface] Add support for interchange to tiling patterns that use the `TilingInterface`.

Differential Revision: https://reviews.llvm.org/D129956

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
    mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f3ee8a5b27f6..52cc52325eb9a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -26,7 +26,7 @@ namespace mlir {
 namespace scf {
 
 using SCFTileSizeComputationFunction =
-    std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>;
+    std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
@@ -51,6 +51,13 @@ struct SCFTilingOptions {
   /// function that computes tile sizes at the point they are needed. Allows
   /// proper interaction with folding.
   SCFTilingOptions &setTileSizes(ArrayRef<int64_t> ts);
+
+  /// The interchange vector to reorder the tiled loops.
+  SmallVector<unsigned> interchangeVector = {};
+  SCFTilingOptions &setInterchange(ArrayRef<unsigned> interchange) {
+    interchangeVector = llvm::to_vector(interchange);
+    return *this;
+  }
 };
 
 struct SCFTilingResult {

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 5e0ea65dd20df..3bad54327e078 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -29,7 +29,7 @@ using namespace mlir;
 scf::SCFTilingOptions &
 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
   assert(!tileSizeComputationFunction && "tile sizes already set");
-  SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
+  SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
     OpBuilder::InsertionGuard guard(b);
     b.setInsertionPointToStart(
@@ -42,6 +42,49 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
   return *this;
 }
 
+/// Helper method to adjust the interchange vector to match the iteration
+/// domain.
+static SmallVector<unsigned>
+fillInterchangeVector(ArrayRef<unsigned> interchangeVector,
+                      size_t iterationDomainSize) {
+  SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector);
+  if (filledVector.size() < iterationDomainSize) {
+    auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize);
+    filledVector.append(range.begin(), range.end());
+  }
+  if (filledVector.size() > iterationDomainSize)
+    filledVector.resize(iterationDomainSize);
+  return filledVector;
+}
+
+/// Helper method to apply permutation to a vector
+template <typename T>
+static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector,
+                                               ArrayRef<unsigned> interchange) {
+  assert(interchange.size() == vector.size());
+  return llvm::to_vector(
+      llvm::map_range(interchange, [&](unsigned val) { return vector[val]; }));
+}
+/// Helper method to apply to invert a permutation.
+static SmallVector<unsigned>
+invertPermutationVector(ArrayRef<unsigned> interchange) {
+  SmallVector<unsigned> inversion(interchange.size());
+  for (auto pos : llvm::enumerate(interchange)) {
+    inversion[pos.value()] = pos.index();
+  }
+  return inversion;
+}
+/// Method to check if an interchange vector is a permutation.
+static bool isPermutation(ArrayRef<unsigned> interchange) {
+  llvm::SmallDenseSet<unsigned, 4> seenVals;
+  for (auto val : interchange) {
+    if (seenVals.count(val))
+      return false;
+    seenVals.insert(val);
+  }
+  return seenVals.size() == interchange.size();
+}
+
 //===----------------------------------------------------------------------===//
 // TileUsingSCFForOp pattern implementation.
 //===----------------------------------------------------------------------===//
@@ -137,7 +180,7 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
   // dimensions.
-  SmallVector<Value, 4> tileSizeVector =
+  SmallVector<Value> tileSizeVector =
       options.tileSizeComputationFunction(rewriter, op);
   if (tileSizeVector.size() < iterationDomain.size()) {
     auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
@@ -147,12 +190,38 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
   scf::SCFTilingResult tilingResult;
   SmallVector<OpFoldResult> offsets, sizes;
   {
+    // If there is an interchange specified, permute the iteration domain and
+    // the tile sizes.
+    SmallVector<unsigned> interchangeVector;
+    if (!options.interchangeVector.empty()) {
+      interchangeVector = fillInterchangeVector(options.interchangeVector,
+                                                iterationDomain.size());
+    }
+    if (!interchangeVector.empty()) {
+      if (!isPermutation(interchangeVector)) {
+        return rewriter.notifyMatchFailure(
+            op, "invalid intechange vector, not a permutation of the entire "
+                "iteration space");
+      }
+
+      iterationDomain =
+          applyPermutationToVector(iterationDomain, interchangeVector);
+      tileSizeVector =
+          applyPermutationToVector(tileSizeVector, interchangeVector);
+    }
+
     // 3. Materialize an empty loop nest that iterates over the tiles. These
     // loops for now do not return any values even if the original operation has
     // results.
     tilingResult.loops = generateTileLoopNest(
         rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
 
+    if (!interchangeVector.empty()) {
+      auto inversePermutation = invertPermutationVector(interchangeVector);
+      offsets = applyPermutationToVector(offsets, inversePermutation);
+      sizes = applyPermutationToVector(sizes, inversePermutation);
+    }
+
     LLVM_DEBUG({
       if (!tilingResult.loops.empty()) {
         llvm::errs() << "LoopNest shell :\n";

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index dd77211d8ccc6..81e2bfbe2d9ff 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -183,3 +183,50 @@ func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32
 // CHECK-SAME:           outs(%[[OUTS_TILE]] :
 //      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
 //      CHECK        scf.yield %[[INSERT]]
+
+// -----
+
+func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
+  %cst = arith.constant 0.0 : f32
+  %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %2 = linalg.matmul
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %3 = linalg.generic {
+      __internal_linalg_transform__ = "gemm_interchange_fusion",
+      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %4 = arith.addf %b0, %b0 : f32
+        linalg.yield %4 : f32
+      } -> tensor<?x?xf32>
+  return %3 : tensor<?x?xf32>
+}
+//      CHECK: func.func @interchange_matmul_fusion(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//      CHECK:   %[[INIT:.+]] = linalg.init_tensor
+//      CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG0:.+]] = %[[INIT]])
+//      CHECK:     scf.for %[[IV1:[a-zA-Z0-9]+]] =
+// CHECK-SAME:         iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
+//  CHECK-DAG:       %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
+//  CHECK-DAG:       %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
+//  CHECK-DAG:       %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]]
+//      CHECK:       %[[FILL_TILE:.+]] = linalg.fill
+// CHECK-SAME:           outs(%[[INIT_TILE]] :
+//      CHECK:       %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:           ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:           outs(%[[FILL_TILE]] :
+//      CHECK:       %[[INIT_TILE_2:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
+//      CHECK:       %[[GENERIC_TILE:.+]] = linalg.generic
+// CHECK-SAME:           ins(%[[GEMM_TILE]] :
+// CHECK-SAME:           outs(%[[INIT_TILE_2]] :
+//      CHECK:       %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]]
+//      CHECK        scf.yield %[[INSERT]]

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index d8ec2c56409e2..ad5307c43e265 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -226,3 +226,52 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
   } -> (tensor<?x?xf32>)
   return %0 : tensor<?x?xf32>
 }
+
+// -----
+
+func.func @interchange_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul {__internal_linalg_transform__ = "gemm_interchange"}
+      ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+//      CHECK: func.func @interchange_matmul(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C10:.+]] = arith.constant 10 : index
+//  CHECK-DAG:   %[[C20:.+]] = arith.constant 20 : index
+//  CHECK-DAG:   %[[C30:.+]] = arith.constant 30 : index
+//  CHECK-DAG:   %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//  CHECK-DAG:   %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+//      CHECK:   %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
+// CHECK-SAME:       iter_args(%[[INIT0:.+]] = %[[ARG2]])
+//      CHECK:     %[[TS_N:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C20]], %[[N]]]
+//      CHECK:     %[[INNER1:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
+// CHECK-SAME:         iter_args(%[[INIT1:.+]] = %[[INIT0]])
+//      CHECK:       %[[TS_K:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C30]], %[[K]]]
+//      CHECK:       %[[INNER2:[a-zA-Z0-9]+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
+// CHECK-SAME:           iter_args(%[[INIT2:.+]] = %[[INIT1]])
+//  CHECK-DAG:         %[[TS_M:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[C10]], %[[M]]]
+//  CHECK-DAG:         %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME:             [%[[IV2]], %[[IV1]]] [%[[TS_M]], %[[TS_K]]] [1, 1]
+//  CHECK-DAG:         %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
+// CHECK-SAME:             [%[[IV1]], %[[IV0]]] [%[[TS_K]], %[[TS_N]]] [1, 1]
+//  CHECK-DAG:         %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT2]]
+// CHECK-SAME:             [%[[IV2]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:         %[[GEMM_TILE:.+]] = linalg.matmul
+// CHECK-SAME:             ins(%[[LHS_TILE]], %[[RHS_TILE]] :
+// CHECK-SAME:             outs(%[[INIT_TILE]] :
+//      CHECK:         %[[UPDATE:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INIT2]]
+// CHECK-SAME:             [%[[IV2]], %[[IV0]]] [%[[TS_M]], %[[TS_N]]] [1, 1]
+//      CHECK:         scf.yield %[[UPDATE]]
+//      CHECK:       scf.yield %[[INNER2]]
+//      CHECK:     scf.yield %[[INNER1]]
+//      CHECK:   return %[[OUTER]]

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index cebe7b1086ee1..214a4053e232c 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -147,10 +147,11 @@ struct TestTilingInterfacePass
 
 template <class Pattern>
 static void
-addPatternForTiling(MLIRContext *context, ArrayRef<int64_t> tileSizes,
-                    StringRef filterName, RewritePatternSet &patterns) {
+addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns,
+                    StringRef filterName, ArrayRef<int64_t> tileSizes,
+                    ArrayRef<unsigned> interchange = {}) {
   scf::SCFTilingOptions tilingOptions;
-  tilingOptions.setTileSizes(tileSizes);
+  tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
   linalg::LinalgTransformationFilter filter(
       StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
   patterns.add<Pattern>(context, tilingOptions, filter);
@@ -161,29 +162,35 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
   if (testTiling) {
     // 1. Tiling M and N dims of `linalg.matmul` on tensors.
     addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, {10, 20}, "simple_gemm", patterns);
+        context, patterns, "simple_gemm", {10, 20});
     // 2. Tiling M, N and K of `linalg.matmul` on buffers.
     addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, {10, 20, 30}, "simple_gemm_memref", patterns);
+        context, patterns, "simple_gemm_memref", {10, 20, 30});
     // 3. Tiling 3D parallel generic op which implements a transpose
     addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, {10, 0, 20}, "parallel_generic_transpose", patterns);
+        context, patterns, "parallel_generic_transpose", {10, 0, 20});
     // 4. Tiling 2D conv op.
     addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns);
+        context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30});
     // 5. Tiling a simple op with `linalg.index` inside.
     addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
-        context, {10, 20}, "indexed_semantics", patterns);
+        context, patterns, "indexed_semantics", {10, 20});
+    // 6. Tiling + interchange of an operation
+    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+        context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0});
     return;
   }
   if (testTileConsumerAndFuseProducer) {
     // 1. Tile and fuse of gemm with bias-add operation.
     addPatternForTiling<
         TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
-        context, {10, 20}, "fusion", patterns);
+        context, patterns, "fusion", {10, 20});
+    addPatternForTiling<
+        TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
+        context, patterns, "gemm_fusion", {10});
     addPatternForTiling<
         TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
-        context, {10}, "gemm_fusion", patterns);
+        context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0});
     return;
   }
 }


        


More information about the Mlir-commits mailing list