[Mlir-commits] [mlir] 93c4229 - [mlir][TilingInterface] NFC code changes separated out from introduction of `scf::tileUsingSCFForallop`. (#67081)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 26 13:42:31 PDT 2023


Author: MaheshRavishankar
Date: 2023-09-26T13:42:27-07:00
New Revision: 93c42299bdb1ef094857ef2d065670af0695c26b

URL: https://github.com/llvm/llvm-project/commit/93c42299bdb1ef094857ef2d065670af0695c26b
DIFF: https://github.com/llvm/llvm-project/commit/93c42299bdb1ef094857ef2d065670af0695c26b.diff

LOG: [mlir][TilingInterface] NFC code changes separated out from introduction of `scf::tileUsingSCFForallop`. (#67081)

This patch contains NFC changes that are precursor to the introduction
of `scf::tileUsingSCFForallOp` method introduced in
https://github.com/llvm/llvm-project/pull/67083.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
    mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
    mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
    mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
    mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index ca641c596c7b7bb..9f49d97e141e0c8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -60,7 +60,7 @@ struct SCFTilingResult {
   /// of the last op.
   SmallVector<Operation *> tiledOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<scf::ForOp> loops;
+  SmallVector<Operation *> loops;
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
@@ -160,7 +160,7 @@ struct SCFTileAndFuseResult {
   /// generated operation.
   llvm::SetVector<Operation *> tiledAndFusedOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<scf::ForOp> loops;
+  SmallVector<Operation *> loops;
   /// The replacement values to use for the tiled and fused operations.
   llvm::DenseMap<Value, Value> replacements;
 };

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 04b0d3a27eef325..9ce780d3d249cfb 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -434,16 +434,12 @@ static LogicalResult applyTilingToAll(
     SmallVector<Operation *> opsToReplace{target};
     llvm::append_range(opsToReplace, tiledResults->fusedProducers);
     for (Operation *toReplace : opsToReplace) {
-      SmallVector<Value> replacements;
-      replacements.reserve(toReplace->getNumResults());
-      for (OpResult res : toReplace->getResults()) {
-        auto it = tiledResults->replacements.find(res);
-        if (it == tiledResults->replacements.end())
-          replacements.push_back(res);
-        else
-          replacements.push_back(it->getSecond());
+      for (OpResult res : toReplace->getResults())
+        if (auto replacement = tiledResults->replacements.lookup(res))
+          rewriter.replaceAllUsesWith(res, replacement);
+      if (toReplace->use_empty()) {
+        rewriter.eraseOp(toReplace);
       }
-      rewriter.replaceOp(toReplace, replacements);
     }
 
     // Report back the relevant handles to the transform op.

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index ab59eac2ac4d6f8..bc913e94a2837b2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
   return filledVector;
 }
 
+/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
+template <typename SrcOpTy>
+static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
+  return llvm::to_vector(
+      llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
+}
+template <typename SrcOpTy>
+static SmallVector<Operation *>
+getAsOperations(const SmallVector<SrcOpTy> &ops) {
+  return getAsOperations(ArrayRef<SrcOpTy>(ops));
+}
+
+/// Convert a list of `Operation *` to a list of `DstOpTy.
+template <typename DstOpTy>
+static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
+  return llvm::to_vector(
+      llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
+}
+template <typename DstOpTy>
+static SmallVector<DstOpTy>
+castToTypedOperations(const SmallVector<Operation *> &ops) {
+  return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
+}
+
 //===----------------------------------------------------------------------===//
 // tileUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
@@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
                                        Range loopRange, Value iv,
-                                       Value tileSize) {
-  std::optional<int64_t> ts = getConstantIntValue(tileSize);
-  if (ts && ts.value() == 1)
-    return getAsOpFoldResult(tileSize);
+                                       OpFoldResult tileSize) {
+  if (isConstantIntValue(tileSize, 1))
+    return tileSize;
 
   if (tileDividesIterationDomain(
           Range{loopRange.offset, loopRange.size, tileSize}))
@@ -296,8 +319,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
     tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
   }
 
-  scf::SCFTilingResult tilingResult;
   SmallVector<OpFoldResult> offsets, sizes;
+  SmallVector<scf::ForOp> forLoops;
   {
     // If there is an interchange specified, permute the iteration domain and
     // the tile sizes.
@@ -320,8 +343,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
     // 3. Materialize an empty loop nest that iterates over the tiles. These
     // loops for now do not return any values even if the original operation has
     // results.
-    tilingResult.loops = generateTileLoopNest(
-        rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
+    forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
+                                    tileSizeVector, offsets, sizes);
 
     if (!interchangeVector.empty()) {
       auto inversePermutation = invertPermutationVector(interchangeVector);
@@ -331,30 +354,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   }
 
   LLVM_DEBUG({
-    if (!tilingResult.loops.empty()) {
+    if (!forLoops.empty()) {
       llvm::dbgs() << "LoopNest shell :\n";
-      tilingResult.loops.front().dump();
+      forLoops.front().dump();
       llvm::dbgs() << "\n";
     }
   });
 
   // 4. Generate the tiled implementation within the inner most loop.
-  if (!tilingResult.loops.empty())
-    rewriter.setInsertionPoint(
-        tilingResult.loops.back().getBody()->getTerminator());
+  if (!forLoops.empty())
+    rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
   FailureOr<TilingResult> tiledImplementation =
       op.getTiledImplementation(rewriter, offsets, sizes);
-  tilingResult.tiledOps.append(tiledImplementation->tiledOps);
+
   if (op->getNumResults() == 0) {
-    // nothing more to do.
-    return tilingResult;
+    return scf::SCFTilingResult{
+        tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
   }
 
   // If loops are empty, the tiled op is used as the replacement for the untiled
   // op.
-  if (tilingResult.loops.empty()) {
-    tilingResult.replacements = tiledImplementation->tiledValues;
-    return tilingResult;
+  if (forLoops.empty()) {
+    return scf::SCFTilingResult{tiledImplementation->tiledOps,
+                                getAsOperations(forLoops),
+                                tiledImplementation->tiledValues};
   }
 
   // 5. Yield all the results of the tiled operation. The surrounding loop
@@ -378,18 +401,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
                                              destinationTensors)))
     return rewriter.notifyMatchFailure(op, "failed to get destinations");
 
-  tilingResult.replacements = yieldTiledValues(
+  SmallVector<Value> replacements = yieldTiledValues(
       rewriter, destinationTensors, tiledImplementation.value(),
-      resultOffsetsList, resultSizesList, tilingResult.loops);
-
+      resultOffsetsList, resultSizesList, forLoops);
   LLVM_DEBUG({
-    if (!tilingResult.loops.empty()) {
+    if (!forLoops.empty()) {
       llvm::dbgs() << "After tiled implementation :\n";
-      tilingResult.loops.front().dump();
+      forLoops.front().dump();
       llvm::dbgs() << "\n";
     }
   });
-  return tilingResult;
+  return scf::SCFTilingResult{tiledImplementation->tiledOps,
+                              getAsOperations(forLoops), replacements};
 }
 
 FailureOr<scf::SCFReductionTilingResult>
@@ -467,6 +490,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
   results.mergeOp = mergeOp;
   return results;
 }
+
 //===----------------------------------------------------------------------===//
 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
 //===----------------------------------------------------------------------===//
@@ -637,7 +661,9 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   }
 
   // 1. First tile the consumer.
-  scf::SCFTileAndFuseResult tileAndFuseResult;
+  SmallVector<scf::ForOp> forLoops;
+  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
+  DenseMap<Value, Value> replacements;
   llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
   {
     FailureOr<scf::SCFTilingResult> tilingResult =
@@ -645,20 +671,21 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     if (failed(tilingResult))
       return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
     for (auto *tiledOp : tilingResult->tiledOps)
-      tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
-    tileAndFuseResult.loops = std::move(tilingResult->loops);
-    for (const auto &result : llvm::enumerate(
-             llvm::zip(consumer->getResults(), tilingResult->replacements))) {
-      tileAndFuseResult.replacements[std::get<0>(result.value())] =
-          std::get<1>(result.value());
+      tiledAndFusedOps.insert(tiledOp);
+    forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
+    for (auto [index, origValue, replacement] :
+         llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
+      replacements[origValue] = replacement;
       yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
-          result.index())] = result.index();
+          index)] = index;
     }
   }
 
   // If there are no loops generated, fusion is immaterial.
-  if (tileAndFuseResult.loops.empty())
-    return tileAndFuseResult;
+  if (forLoops.empty()) {
+    return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
+                                     getAsOperations(forLoops), replacements};
+  }
 
   // 2. Typically, the operands of the tiled operation are slices of the
   //    operands of the untiled operation. These are expressed in IR using
@@ -675,7 +702,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   };
 
   std::deque<tensor::ExtractSliceOp> candidates;
-  addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
+  addCandidateSlices(tiledAndFusedOps.back(), candidates);
   OpBuilder::InsertionGuard g(rewriter);
   while (!candidates.empty()) {
     // Traverse the slices in BFS fashion.
@@ -685,19 +712,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     // The operands of the fused producer might themselved be slices of
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
-    std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
-        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
-                                   tileAndFuseResult.loops);
-    if (!fusedProducer)
+    std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
+        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
+    if (!fusedResult)
       continue;
 
     if (Operation *tiledAndFusedOp =
-            fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
-      tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
+            fusedResult->tiledAndFusedProducer.getDefiningOp()) {
+      fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
+      tiledAndFusedOps.insert(tiledAndFusedOp);
       addCandidateSlices(tiledAndFusedOp, candidates);
     }
   }
-  return tileAndFuseResult;
+  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
+                                   getAsOperations(forLoops), replacements};
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 4f5900fda3e76bd..cf5a1b828f95b75 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -8,7 +8,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
   %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
   %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
   %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
-  %gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"}
+  %gemm = linalg.matmul {__internal_transform__ = "fusion"}
       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %gemm : tensor<?x?xf32>
@@ -47,7 +47,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   %generic = linalg.generic {
-      __internal_linalg_transform__ = "fusion",
+      __internal_transform__ = "fusion",
       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel"]}
       ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
@@ -97,7 +97,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
   %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
   %init1 = tensor.empty(%d0, %d2) : tensor<?x?xf32>
   %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
-  %gemm1 = linalg.matmul  {__internal_linalg_transform__ = "gemm_fusion"}
+  %gemm1 = linalg.matmul  {__internal_transform__ = "gemm_fusion"}
       ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %gemm1 : tensor<?x?xf32>
 }
@@ -147,7 +147,7 @@ func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32
       outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
   %init1 = tensor.empty(%d1, %d0) : tensor<?x?xf32>
   %transpose = linalg.generic {
-      __internal_linalg_transform__ = "fusion",
+      __internal_transform__ = "fusion",
       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
       iterator_types = ["parallel", "parallel"]}
       ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
@@ -198,7 +198,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
   %3 = linalg.generic {
-      __internal_linalg_transform__ = "gemm_interchange_fusion",
+      __internal_transform__ = "gemm_interchange_fusion",
       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel"]}
       ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
@@ -249,7 +249,7 @@ func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
                       affine_map<(d0, d1) -> (d0, d1)>,
                       affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"],
-     __internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
+     __internal_transform__ = "gemm_plus_gemm_fusion"}
     ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
     outs(%5 : tensor<?x?xf32>) {
     ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
@@ -302,7 +302,7 @@ func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x
                       affine_map<(d0, d1) -> (d1, d0)>,
                       affine_map<(d0, d1) -> (d0, d1)>],
      iterator_types = ["parallel", "parallel"],
-     __internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
+     __internal_transform__ = "gemm_plus_gemm_fusion"}
     ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
     outs(%5 : tensor<?x?xf32>) {
     ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
@@ -352,7 +352,7 @@ func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
   %1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
     outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
   %2 = linalg.matmul
-    {__internal_linalg_transform__ = "gemm_sequence_fusion"}
+    {__internal_transform__ = "gemm_sequence_fusion"}
     ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
     outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
   return %2 : tensor<?x?xf32>
@@ -425,7 +425,7 @@ func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
       linalg.yield %10, %9 : f32, f32
     } -> (tensor<30xf32>, tensor<30x3xf32>)
   %6 = linalg.generic {
-      __internal_linalg_transform__ = "reduction_sequence_fusion",
+      __internal_transform__ = "reduction_sequence_fusion",
       indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
                        affine_map<(d0, d1) -> (d0, d1)>],
       iterator_types = ["parallel", "parallel"]}

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index f47850a5cb6d229..f725d19e14a0c5b 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -13,7 +13,7 @@ func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?
       ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
   %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
   %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
-  %gemm1 = linalg.matmul  {__internal_linalg_transform__ = "gemm_sequence_fusion_and_yield"}
+  %gemm1 = linalg.matmul  {__internal_transform__ = "gemm_sequence_fusion_and_yield"}
       ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
 }

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
index 2d6069973c8bf78..cbc5d6c186d6d34 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
@@ -6,7 +6,7 @@ func.func @dynamic_2d_pad_tensor(%input_tensor: tensor<?x?xf32>,
   %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
     ^bb0(%arg1: index, %arg2: index):
       tensor.yield %pad_value : f32
-    } {__internal_linalg_transform__ = "pad_2dtiling"}: tensor<?x?xf32> to tensor<?x?xf32>
+    } {__internal_transform__ = "pad_2dtiling"}: tensor<?x?xf32> to tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 //  CHECK-DAG:  #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 8)>
@@ -38,7 +38,7 @@ func.func @dynamic_2d_pad_tensor_inner_tiling(%input_tensor: tensor<?x?xf32>,
   %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
     ^bb0(%arg1: index, %arg2: index):
       tensor.yield %pad_value : f32
-    } {__internal_linalg_transform__ = "pad_inner_tiling"}: tensor<?x?xf32> to tensor<?x?xf32>
+    } {__internal_transform__ = "pad_inner_tiling"}: tensor<?x?xf32> to tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 //   CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
@@ -68,7 +68,7 @@ func.func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
   %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
     ^bb0(%arg1: index, %arg2: index):
       tensor.yield %pad_value : f32
-    } {__internal_linalg_transform__ = "pad_2dtiling"} : tensor<7x9xf32> to tensor<15x16xf32>
+    } {__internal_transform__ = "pad_2dtiling"} : tensor<7x9xf32> to tensor<15x16xf32>
   return %0 : tensor<15x16xf32>
 }
 // CHECK-LABEL: func @static_pad_tensor(
@@ -95,7 +95,7 @@ func.func @static_pad_tensor_inner_tiling(%input_tensor: tensor<7x9xf32>,
   %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
     ^bb0(%arg1: index, %arg2: index):
       tensor.yield %pad_value : f32
-    } {__internal_linalg_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32>
+    } {__internal_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32>
   return %0 : tensor<15x16xf32>
 }
 // CHECK-LABEL: func @static_pad_tensor_inner_tiling(
@@ -122,7 +122,7 @@ func.func @dynamic_2d_pad_tensor_outer_tiling(%input_tensor: tensor<?x?xf32>,
   %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
     ^bb0(%arg1: index, %arg2: index):
       tensor.yield %pad_value : f32
-    } {__internal_linalg_transform__ = "pad_outer_tiling"}: tensor<?x?xf32> to tensor<?x?xf32>
+    } {__internal_transform__ = "pad_outer_tiling"}: tensor<?x?xf32> to tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
 // CHECK-LABEL: func @dynamic_2d_pad_tensor_outer_tiling
@@ -134,7 +134,7 @@ func.func @static_pad_tensor_outer_tiling(%input_tensor: tensor<7x9xf32>,
   %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
     ^bb0(%arg1: index, %arg2: index):
       tensor.yield %pad_value : f32
-    } {__internal_linalg_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32>
+    } {__internal_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32>
   return %0 : tensor<15x16xf32>
 }
 // CHECK-LABEL: func @static_pad_tensor_outer_tiling

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index cacef3c47b5e1cd..2153eb6f237fcfd 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -2,7 +2,7 @@
 
 func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = linalg.matmul {__internal_linalg_transform__ = "simple_gemm"}
+  %0 = linalg.matmul {__internal_transform__ = "simple_gemm"}
       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
@@ -45,7 +45,7 @@ func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 
 func.func @simple_matmul_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
     %arg2 : memref<?x?xf32>) {
-  linalg.matmul {__internal_linalg_transform__ = "simple_gemm_memref"}
+  linalg.matmul {__internal_transform__ = "simple_gemm_memref"}
       ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
       outs(%arg2 : memref<?x?xf32>)
   return
@@ -92,7 +92,7 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x
   %0:2 = linalg.generic {
       indexing_maps = [#map0, #map1, #map2],
       iterator_types = ["parallel", "parallel", "parallel"]}
-      {__internal_linalg_transform__ = "parallel_generic_transpose"}
+      {__internal_transform__ = "parallel_generic_transpose"}
       ins(%arg0 : tensor<128x200x300xf32>)
       outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) {
     ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
@@ -139,7 +139,7 @@ func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
   %0 = linalg.conv_2d_nhwc_hwcf {
       strides = dense<[2, 3]> : tensor<2xi64>,
       dilation = dense<[4, 5]> : tensor<2xi64>,
-      __internal_linalg_transform__ = "simple_conv"}
+      __internal_transform__ = "simple_conv"}
       ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
       outs(%arg2 : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
   return %0 : tensor<?x?x?x?xf32>
@@ -205,7 +205,7 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
                      affine_map<(d0, d1) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel"]}
-    {__internal_linalg_transform__ = "indexed_semantics"}
+    {__internal_transform__ = "indexed_semantics"}
     ins(%arg0: tensor<?x?xf32>)
     outs(%arg1: tensor<?x?xf32>) {
   ^bb0(%arg2: f32, %arg3: f32):
@@ -229,7 +229,7 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
 
 func.func @interchange_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
     %arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %0 = linalg.matmul {__internal_linalg_transform__ = "gemm_interchange"}
+  %0 = linalg.matmul {__internal_transform__ = "gemm_interchange"}
       ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
@@ -283,7 +283,7 @@ func.func @interchange_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
 //       CHECK:       memref.subview
 //       CHECK:       linalg.copy
 func.func @linalg_copy_matmul(%a: memref<?x?xf32>, %b: memref<?x?xf32>) {
-  linalg.copy {__internal_linalg_transform__ = "simple_copy_memref"}
+  linalg.copy {__internal_transform__ = "simple_copy_memref"}
       ins(%a : memref<?x?xf32>) outs(%b : memref<?x?xf32>)
   return
 }

diff  --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index 2fcc7bcadb60450..2573e11979dbc47 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -37,51 +37,51 @@ using namespace mlir;
 namespace {
 
 /// Marker used as attribute name in generated Linalg rewriting transformations.
-const StringLiteral kLinalgTransformMarker = "__internal_linalg_transform__";
+const StringLiteral kTransformMarker = "__internal_transform__";
 
 /// Helper class to control application of linalg transformation patterns.
 /// Control comes in 2 forms:
 ///   1. attribute matching and setting behavior using the attribute named
-///      `kLinalgTransformMarker`. This can be used to build a state machine
+///      `kTransformMarker`. This can be used to build a state machine
 ///      using attributes and incrementally applying patterns to advance states.
 ///   2. filter function, which is a simple lambda on the Operation* that
 ///      returns a LogicalResult.
-struct LinalgTransformationFilter {
+struct TransformationFilter {
   using FilterFunction = std::function<LogicalResult(Operation *)>;
 
-  explicit LinalgTransformationFilter(
+  explicit TransformationFilter(
       ArrayRef<StringAttr> matchDisjunction = {},
       std::optional<StringAttr> replacement = std::nullopt);
 
-  explicit LinalgTransformationFilter(
+  explicit TransformationFilter(
       const FilterFunction &f, ArrayRef<StringAttr> matchDisjunction = {},
       std::optional<StringAttr> replacement = std::nullopt);
 
-  LinalgTransformationFilter(LinalgTransformationFilter &&) = default;
-  LinalgTransformationFilter(const LinalgTransformationFilter &) = default;
+  TransformationFilter(TransformationFilter &&) = default;
+  TransformationFilter(const TransformationFilter &) = default;
   LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const;
-  void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
-                                         Operation *op) const;
+  void replaceTransformationFilter(PatternRewriter &rewriter,
+                                   Operation *op) const;
 
-  LinalgTransformationFilter &addFilter(const FilterFunction &f) {
+  TransformationFilter &addFilter(const FilterFunction &f) {
     if (f)
       filters.push_back(f);
     return *this;
   }
 
   template <typename... OpTypes>
-  LinalgTransformationFilter &addOpFilter() {
+  TransformationFilter &addOpFilter() {
     return addFilter(
         [](Operation *op) { return success(isa<OpTypes...>(op)); });
   }
 
-  LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
+  TransformationFilter &addOpNameFilter(StringRef opName) {
     return addFilter([opName](Operation *op) {
       return success(op->getName().getStringRef() == opName);
     });
   }
 
-  LinalgTransformationFilter &setMatchByDefault() {
+  TransformationFilter &setMatchByDefault() {
     matchByDefault = true;
     return *this;
   }
@@ -95,20 +95,19 @@ struct LinalgTransformationFilter {
   bool matchByDefault;
 };
 
-LinalgTransformationFilter::LinalgTransformationFilter(
+TransformationFilter::TransformationFilter(
     ArrayRef<StringAttr> matchDisjunction,
     std::optional<StringAttr> replacement)
     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
       replacement(replacement), matchByDefault(false) {}
 
-LogicalResult
-LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
-                                           Operation *op) const {
+LogicalResult TransformationFilter::checkAndNotify(PatternRewriter &rewriter,
+                                                   Operation *op) const {
   if (llvm::any_of(filters,
                    [&](const FilterFunction &f) { return failed(f(op)); }))
     return failure();
 
-  auto attr = op->template getAttrOfType<StringAttr>(kLinalgTransformMarker);
+  auto attr = op->template getAttrOfType<StringAttr>(kTransformMarker);
 
   if (!attr) {
     // 1. Has no filter case and matchDisjunction is empty.
@@ -134,12 +133,12 @@ LinalgTransformationFilter::checkAndNotify(PatternRewriter &rewriter,
   });
 }
 
-void LinalgTransformationFilter::replaceLinalgTransformationFilter(
+void TransformationFilter::replaceTransformationFilter(
     PatternRewriter &rewriter, Operation *op) const {
   if (replacement.has_value())
-    op->setAttr(kLinalgTransformMarker, *replacement);
+    op->setAttr(kTransformMarker, *replacement);
   else
-    op->removeAttr(rewriter.getStringAttr(kLinalgTransformMarker));
+    op->removeAttr(rewriter.getStringAttr(kTransformMarker));
 }
 
 /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
@@ -147,18 +146,17 @@ void LinalgTransformationFilter::replaceLinalgTransformationFilter(
 /// using a `filter` to avoid recursive application.
 struct TestTileUsingSCFForOp
     : public OpInterfaceRewritePattern<TilingInterface> {
-  TestTileUsingSCFForOp(
-      MLIRContext *context, scf::SCFTilingOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
+  TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options,
+                        TransformationFilter filter = TransformationFilter(),
+                        PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
         options(std::move(options)), filter(std::move(filter)) {}
 
   /// Construct a generic pattern applied to `opName`.
-  TestTileUsingSCFForOp(
-      StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
+  TestTileUsingSCFForOp(StringRef opName, MLIRContext *context,
+                        scf::SCFTilingOptions options,
+                        TransformationFilter filter = TransformationFilter(),
+                        PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
         options(std::move(options)), filter(std::move(filter)) {}
 
@@ -179,13 +177,13 @@ struct TestTileUsingSCFForOp
     }
 
     for (auto *tiledOp : tilingResult->tiledOps)
-      filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
+      filter.replaceTransformationFilter(rewriter, tiledOp);
     return success();
   }
 
 private:
   scf::SCFTilingOptions options;
-  LinalgTransformationFilter filter;
+  TransformationFilter filter;
 };
 
 /// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
@@ -196,7 +194,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
     : public OpInterfaceRewritePattern<TilingInterface> {
   TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
       MLIRContext *context, scf::SCFTileAndFuseOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      TransformationFilter filter = TransformationFilter(),
       PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
         options(std::move(options)), filter(std::move(filter)) {}
@@ -205,7 +203,7 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
   TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
       StringRef opName, MLIRContext *context,
       scf::SCFTileAndFuseOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      TransformationFilter filter = TransformationFilter(),
       PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
         options(std::move(options)), filter(std::move(filter)) {}
@@ -229,14 +227,14 @@ struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
     }
     rewriter.replaceOp(op, replacements);
 
-    filter.replaceLinalgTransformationFilter(
+    filter.replaceTransformationFilter(
         rewriter, tileAndFuseResult->tiledAndFusedOps.front());
     return success();
   }
 
 private:
   scf::SCFTileAndFuseOptions options;
-  LinalgTransformationFilter filter;
+  TransformationFilter filter;
 };
 
 /// Pattern to tile a consumer and fuse producer with it
@@ -254,7 +252,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
 
   TestTileConsumerFuseAndYieldProducerUsingSCFForOp(
       MLIRContext *context, scf::SCFTilingOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      TransformationFilter filter = TransformationFilter(),
       PatternBenefit benefit = 1)
       : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
         options(std::move(options)), filter(std::move(filter)) {}
@@ -302,6 +300,8 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
     std::deque<tensor::ExtractSliceOp> candidates;
     addCandidateSlices(tilingResult->tiledOps.back(), candidates);
     OpBuilder::InsertionGuard g(rewriter);
+    auto forLoops = llvm::to_vector(llvm::map_range(
+        tilingResult->loops, [](auto op) { return cast<scf::ForOp>(op); }));
     while (!candidates.empty()) {
       // Traverse the slices in BFS fashion.
       tensor::ExtractSliceOp candidateSliceOp = candidates.front();
@@ -309,8 +309,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
 
       // Materialize the slice of the producer in place.
       std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
-          tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
-                                     tilingResult->loops);
+          tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
       if (!fusedProducer)
         continue;
 
@@ -318,11 +317,10 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
       // to be yielded from within the tiled loop.
       OpResult untiledProducer = fusedProducer->origProducer;
       if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
-            return !isIgnoredUser(user, tilingResult->loops.front());
+            return !isIgnoredUser(user, forLoops.front());
           })) {
         yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
-                                         fusedProducer.value(),
-                                         tilingResult->loops);
+                                         fusedProducer.value(), forLoops);
         yieldedValuesToOrigValues.push_back(untiledProducer);
       }
 
@@ -332,7 +330,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
         addCandidateSlices(fusedProducerOp, candidates);
     }
 
-    scf::ForOp outermostLoop = tilingResult->loops.front();
+    scf::ForOp outermostLoop = forLoops.front();
     for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
       Value replacement = outermostLoop.getResult(index);
       rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) {
@@ -340,8 +338,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
       });
     }
     rewriter.eraseOp(rootOp);
-    filter.replaceLinalgTransformationFilter(rewriter,
-                                             tilingResult->tiledOps.back());
+    filter.replaceTransformationFilter(rewriter, tilingResult->tiledOps.back());
     return success();
   }
 
@@ -370,7 +367,7 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
   }
 
   scf::SCFTilingOptions options;
-  LinalgTransformationFilter filter;
+  TransformationFilter filter;
 };
 
 /// Pattern to lower operations that implement the `TilingInterface` to
@@ -453,8 +450,8 @@ static void addPatternForTiling(MLIRContext *context,
   SmallVector<OpFoldResult> tileSizesOfr =
       getAsIndexOpFoldResult(context, tileSizes);
   tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
-  LinalgTransformationFilter filter(StringAttr::get(context, filterName),
-                                    StringAttr::get(context, "tiled"));
+  TransformationFilter filter(StringAttr::get(context, filterName),
+                              StringAttr::get(context, "tiled"));
   patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
 }
 
@@ -467,8 +464,8 @@ static void addPatternForTileFuseAndYield(MLIRContext *context,
   SmallVector<OpFoldResult> tileSizesOfr =
       getAsIndexOpFoldResult(context, tileSizes);
   tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
-  LinalgTransformationFilter filter(StringAttr::get(context, filterName),
-                                    StringAttr::get(context, "tiled"));
+  TransformationFilter filter(StringAttr::get(context, filterName),
+                              StringAttr::get(context, "tiled"));
   patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
       context, tilingOptions, filter);
 }
@@ -483,8 +480,8 @@ static void addPatternForTileAndFuse(MLIRContext *context,
       getAsIndexOpFoldResult(context, tileSizes);
   tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr)
       .setInterchange(interchange);
-  LinalgTransformationFilter filter(StringAttr::get(context, filterName),
-                                    StringAttr::get(context, "tiled"));
+  TransformationFilter filter(StringAttr::get(context, filterName),
+                              StringAttr::get(context, "tiled"));
   patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
       context, tileAndFuseOptions, filter);
 }


        


More information about the Mlir-commits mailing list