[Mlir-commits] [mlir] [mlir][TilingInterface] Move TilingInterface tests to use transform dialect ops. (PR #77204)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jan 6 10:08:17 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-scf

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

In the process a couple of test transform dialect ops are added just
for testing. These operations are not intended to use as full flushed
out of transformation ops, but are rather operations added for testing.

A separate operation is added to `LinalgTransformOps.td` to convert a
`TilingInterface` operation to loops using the
`generateScalarImplementation` method implemented by the
operation. Eventually this and other operations related to tiling
using the `TilingInterface` need to move to a better place (i.e. out
of `Linalg` dialect)

---

Patch is 111.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77204.diff


16 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+31-1) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+24) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+22-33) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+48-22) 
- (modified) mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir (+84-10) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir (+103-22) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir (+13-1) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir (+74-13) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir (+117-38) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir (+43-2) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt (+9-1) 
- (removed) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp (-650) 
- (added) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+267) 
- (added) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+70) 
- (added) mlir/test/lib/Interfaces/TilingInterface/lit.local.cfg (+1) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bc257d17483e3b..7d10ba0ae829e5 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -293,7 +293,10 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
   let results = (outs TransformHandleTypeInterface:$transformed,
                       Variadic<TransformHandleTypeInterface>:$loops);
 
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    $target ($tile_sizes^)? (`interchange` $tile_interchange^)?
+    attr-dict `:` functional-type(operands, results)
+  }];
   let hasVerifier = 1;
 }
 
@@ -1269,6 +1272,33 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    For operations that implement the `TilingInterface`, and implement
+    the `generateScalarImplementation` method, lowers the operation to
+    loops. This operation does not return any handles.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs);
+
+  let assemblyFormat = [{
+    $target attr-dict `:` type($target)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::TilingInterface target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // DecomposeInterfaceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 2f8f337bb8057c..5d2d78e6e6165b 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -97,6 +97,30 @@ struct SCFTileAndFuseOptions {
     tilingOptions = options;
     return *this;
   }
+
+  /// Control function to check if a slice needs to be fused or not,
+  /// The control function receives
+  /// 1) the slice along which fusion is to be done,
+  /// 2) the producer value that is to be fused
+  /// 3) a boolean value set to `true` if the fusion is from
+  ///    a destination operand.
+  /// It retuns two booleans
+  /// - returns `true` if the fusion should be done through the candidate slice
+  /// - returns `true` if a replacement for the fused producer needs to be
+  ///   yielded from within the tiled loop. Note that it is valid to return
+  ///   `true` only if the slice fused is disjoint across all iterations of the
+  ///   tiled loop. It is up to the caller to ensure that this is true for the
+  ///   fused producers.
+  using ControlFnTy = std::function<std::tuple<bool, bool>(
+      tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
+      bool isDestinationOperand)>;
+  ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) {
+    return std::make_tuple(true, false);
+  };
+  SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
+    fusionControlFn = controlFn;
+    return *this;
+  }
 };
 
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5254aac976f462..97d2b4a3be5c56 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -492,38 +492,6 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
                         : DiagnosedSilenceableFailure::success();
 }
 
-ParseResult transform::FuseOp::parse(OpAsmParser &parser,
-                                     OperationState &result) {
-  OpAsmParser::UnresolvedOperand targetOperand;
-  if (parser.parseOperand(targetOperand) ||
-      parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-
-  FunctionType trailingType;
-  SMLoc typeLoc;
-  if (parser.getCurrentLocation(&typeLoc) ||
-      parser.parseColonType(trailingType)) {
-    return failure();
-  }
-  if (trailingType.getNumInputs() != 1)
-    return parser.emitError(typeLoc) << "expected one input type";
-
-  result.addTypes(trailingType.getResults());
-  if (parser.resolveOperand(targetOperand, trailingType.getInput(0),
-                            result.operands))
-    return failure();
-  return success();
-}
-
-void transform::FuseOp::print(OpAsmPrinter &p) {
-  p << ' ';
-  p << getTarget();
-  p.printOptionalAttrDict((*this)->getAttrs());
-  p << " : ";
-  p.printFunctionalType(TypeRange(getOperand().getType()),
-                        getResults().getTypes());
-}
-
 LogicalResult transform::FuseOp::verify() {
   SmallVector<int64_t> permutation =
       extractFromIntegerArrayAttr<int64_t>(getTileInterchange());
@@ -2111,6 +2079,22 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertToLoopsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
+    transform::TransformRewriter &rewriter, TilingInterface target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  FailureOr<SmallVector<scf::ForOp>> loops =
+      scf::lowerToLoopsUsingSCFForOp(rewriter, target);
+  if (failed(loops))
+    return emitDefaultDefiniteFailure(target);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // RewriteInDestinationPassingStyleOp
 //===----------------------------------------------------------------------===//
@@ -2620,7 +2604,12 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
     }
 
     scf::SCFTilingOptions tilingOptions;
-    if (!tileSizes.empty()) {
+    if (tileSizes.empty()) {
+      tilingOptions.setTileSizeComputationFunction(
+          [](OpBuilder &, Operation *) -> SmallVector<OpFoldResult> {
+            return {};
+          });
+    } else {
       tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b,
                                                                   Operation *) {
         SmallVector<OpFoldResult> sizes;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1b6b4db9d20907..38e0625d7ce093 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -283,10 +283,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
   // 1. Get the range of the loops that are represented by the operation.
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
   size_t numLoops = iterationDomain.size();
-  if (numLoops == 0) {
-    return rewriter.notifyMatchFailure(
-        op, "unable to tile op with no iteration domain");
-  }
+
   // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
   // skips tiling a particular dimension. This convention is significantly
   // simpler to handle instead of adjusting affine maps to account for missing
@@ -728,32 +725,36 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
   }
 
   // 1. First tile the consumer.
-  SmallVector<scf::ForOp> forLoops;
   SetVector<Operation *> fusedProducers, tiledAndFusedOps;
-  DenseMap<Value, Value> replacements;
-  llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
-  {
-    FailureOr<scf::SCFTilingResult> tilingResult =
-        tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
-    if (failed(tilingResult))
-      return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
-    for (auto *tiledOp : tilingResult->tiledOps)
-      tiledAndFusedOps.insert(tiledOp);
-    forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
-    for (auto [index, origValue, replacement] :
-         llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
-      replacements[origValue] = replacement;
-      yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
-          index)] = index;
-    }
-  }
+  llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
+  FailureOr<scf::SCFTilingResult> tilingResult =
+      tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
+  if (failed(tilingResult))
+    return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
+  for (auto *tiledOp : tilingResult->tiledOps)
+    tiledAndFusedOps.insert(tiledOp);
+  SmallVector<scf::ForOp> forLoops =
+      castToTypedOperations<scf::ForOp>(tilingResult->loops);
 
   // If there are no loops generated, fusion is immaterial.
   if (forLoops.empty()) {
+    DenseMap<Value, Value> replacements;
+    for (auto [origVal, replacement] :
+         llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
+      replacements[origVal] = replacement;
+    }
     return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
                                      getAsOperations(forLoops), replacements};
   }
 
+  // To keep track of replacements for now just record the map from the original
+  // untiled value to the result number of the for loop. Since the loop gets
+  // potentially replaced during fusion, keeping the value directly wont work.
+  DenseMap<Value, size_t> origValToResultNumber;
+  for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
+    origValToResultNumber[result] = index;
+  }
+
   // 2. Typically, the operands of the tiled operation are slices of the
   //    operands of the untiled operation. These are expressed in IR using
   //    `tensor.extract_slice` operations with source being the operands of the
@@ -776,6 +777,18 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     tensor::ExtractSliceOp candidateSliceOp = candidates.front();
     candidates.pop_front();
 
+    // Find the original producer of the slice.
+    auto [fusableProducer, destinationInitArg] =
+        getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
+                                          forLoops);
+    if (!fusableProducer)
+      continue;
+
+    auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
+        candidateSliceOp, fusableProducer, destinationInitArg.has_value());
+    if (!fuseSlice)
+      continue;
+
     // The operands of the fused producer might themselved be slices of
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
@@ -784,6 +797,13 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     if (!fusedResult)
       continue;
 
+    if (yieldReplacement) {
+      yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
+                                       fusedResult.value(), forLoops);
+      origValToResultNumber[fusableProducer] =
+          forLoops.front().getNumResults() - 1;
+    }
+
     if (Operation *tiledAndFusedOp =
             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
@@ -791,6 +811,12 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
       addCandidateSlices(tiledAndFusedOp, candidates);
     }
   }
+
+  DenseMap<Value, Value> replacements;
+  for (auto [origVal, resultNumber] : origValToResultNumber) {
+    replacements[origVal] = forLoops.front()->getResult(resultNumber);
+  }
+
   return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
                                    getAsOperations(forLoops), replacements};
 }
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index c8199c325abfec..7245498f641ecf 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -transform-interpreter -split-input-file -canonicalize -cse %s | FileCheck %s
 
 func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
   %arg2 : memref<?x?xf32>) {
@@ -6,13 +6,22 @@ func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
       outs(%arg2 : memref<?x?xf32>)
   return
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %matmul : !transform.any_op
+    transform.yield
+  }
+}
 // CHECK-LABEL: func @gemm
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
 //  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
 //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
 //   CHECK-DAG:   %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
 //       CHECK:   scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
@@ -51,6 +60,15 @@ func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
     }
   return
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %generic = transform.structured.match ops{["linalg.generic"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %generic : !transform.any_op
+    transform.yield
+  }
+}
 // CHECK-LABEL: func @indexed_generic
 //  CHECK-SAME:     %[[ARG0:.+]]: memref<200x300xi32>
 //  CHECK-SAME:     %[[ARG1:.+]]: memref<300xi16>
@@ -87,8 +105,18 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
       outs(%arg2 : memref<?x?x?x?xf32>)
   return
 }
-//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)>
-//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %conv : !transform.any_op
+    transform.yield
+  }
+}
+
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
 //       CHECK: func @conv_strides_and_dilation(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
@@ -111,8 +139,8 @@ func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
 //       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
 //       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
 //       CHECK:               scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
-//   CHECK-DAG:                 %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
-//   CHECK-DAG:                 %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
+//   CHECK-DAG:                 %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
+//   CHECK-DAG:                 %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
 //   CHECK-DAG:                 %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]]
 //   CHECK-DAG:                 %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]]
 //   CHECK-DAG:                 %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
@@ -131,8 +159,18 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
       outs(%arg2 : memref<?x?x?x?xf32>)
   return
 }
-//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)>
-//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %pool : !transform.any_op
+    transform.yield
+  }
+}
+
+//  CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)>
+//  CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
 //       CHECK: func @pool_strides_and_dilation
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
@@ -153,8 +191,8 @@ func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref
 //       CHECK:         scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
 //       CHECK:           scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
 //       CHECK:             scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
-//   CHECK-DAG:               %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
-//   CHECK-DAG:               %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
+//   CHECK-DAG:               %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]])
+//   CHECK-DAG:               %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]])
 //   CHECK-DAG:               %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]]
 //   CHECK-DAG:               %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
 //       CHECK:               %[[T10:.+]] = arith.maximumf %[[T9]], %[[T8]]
@@ -172,6 +210,15 @@ func.func @map(%lhs: memref<64xf32>,
     }
   return
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %map = transform.structured.match ops{["linalg.map"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %map : !transform.any_op
+    transform.yield
+  }
+}
 // CHECK-LABEL: func.func @map(
 // CHECK-SAME:    %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>,
 // CHECK-SAME:    %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>,
@@ -195,6 +242,15 @@ func.func @transpose(%arg0: memref<16x32x64xf32>,
                    outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0]
   return
 }
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.convert_to_loops %transpose : !transform.any_op
+    transform.yield
+  }
+}
 // CHECK-LABEL: func.func @transpose(
 // CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]: me...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/77204


More information about the Mlir-commits mailing list