[Mlir-commits] [mlir] [mlir][linalg-transform] dyn_cast DestinationStyleOpInterface and early return (PR #166299)

Hsiang-Chieh Tsou llvmlistbot at llvm.org
Wed Nov 5 16:50:15 PST 2025


https://github.com/hsjts0u updated https://github.com/llvm/llvm-project/pull/166299

>From 2234d6ca78c84da61903019b306e459d39798d47 Mon Sep 17 00:00:00 2001
From: Jay Tsou <hsjts0u at gmail.com>
Date: Mon, 3 Nov 2025 20:21:08 -0800
Subject: [PATCH 1/2] dyn_cast DestinationStyleOpInterface and early return

---
 .../lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3a433825fd31a..59629c422a034 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -997,8 +997,11 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
       // Iterate over the outputs of the producer and over the loop bbArgs and
       // check if any bbArg points to the same value as the producer output. In
       // such case, make the producer output point to the bbArg directly.
-      for (OpOperand &initOperandPtr :
-           cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
+      auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
+      if (!dpsInterface)
+        return;
+
+      for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
         Value producerOperand =
             clone->getOperand(initOperandPtr.getOperandNumber());
         for (BlockArgument containerIterArg :

>From 9091c8f9d98f8a91d969319dc8b98456489d503d Mon Sep 17 00:00:00 2001
From: Jay Tsou <hsjts0u at gmail.com>
Date: Wed, 5 Nov 2025 16:48:11 -0800
Subject: [PATCH 2/2] Add test and address comments

---
 .../TransformOps/LinalgTransformOps.cpp       |  2 +-
 .../transform-op-fuse-into-containing.mlir    | 34 ++++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 78 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 15 ++++
 4 files changed, 128 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 59629c422a034..aa8206347e9b1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1063,7 +1063,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
       resultNumber, offsets, sizes);
 
   // Cleanup clone.
-  if (dyn_cast<LoopLikeOpInterface>(containingOp))
+  if (isa<LoopLikeOpInterface>(containingOp))
     rewriter.eraseOp(tileableProducer);
 
   return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index e5216089692b4..ab38f9f2f5943 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -253,6 +253,40 @@ module {
 
 // -----
 
+#map = affine_map<(d0) -> (d0 * 2)>
+#map1 = affine_map<(d0) -> (d0 * 4)>
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_op_no_dps
+  func.func @fuse_tileable_op_no_dps(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
+    %0 = "test.tiling_no_dps_op"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
+    %1 = tensor.empty() : tensor<4x4x4xf32>
+    // CHECK: scf.forall
+    %2 = scf.forall (%arg2, %arg3, %arg4) in (4, 2, 1) shared_outs(%arg5 = %1) -> (tensor<4x4x4xf32>) {
+      %3 = affine.apply #map(%arg3)
+      %4 = affine.apply #map1(%arg4)
+      // CHECK: "test.tiling_no_dps_op"
+      // CHECK: "test.unregistered_op"
+      %extracted_slice = tensor.extract_slice %0[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<4x4x4xf32> to tensor<1x2x4xf32>
+      %5 = "test.unregistered_op"(%extracted_slice, %extracted_slice) : (tensor<1x2x4xf32>, tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %5 into %arg5[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<1x2x4xf32> into tensor<4x4x4xf32>
+      }
+    }
+    return %2 : tensor<4x4x4xf32>
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %op = transform.structured.match ops{["test.tiling_no_dps_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %forall = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %fused, %new_containing = transform.structured.fuse_into_containing_op %op into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+}
+
+// -----
+
 module {
   // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
   //  CHECK-SAME:   %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 4d4ec02546bc7..403294ee423cd 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1051,6 +1051,84 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TilingNoDpsOp
+//===----------------------------------------------------------------------===//
+
+static Value getSlice(OpBuilder &builder, Location loc, Value source,
+                      ArrayRef<OpFoldResult> offsets,
+                      ArrayRef<OpFoldResult> sizes,
+                      ArrayRef<OpFoldResult> strides) {
+  auto staticOffsets = getConstantIntValues(offsets);
+  auto staticSizes = getConstantIntValues(sizes);
+  auto staticStrides = getConstantIntValues(strides);
+
+  auto sourceShape = cast<ShapedType>(source.getType()).getShape();
+  if (staticSizes && ArrayRef(*staticSizes) == sourceShape)
+    return source;
+
+  return {mlir::tensor::ExtractSliceOp::create(builder, loc, source, offsets,
+                                               sizes, strides)};
+}
+
+static ShapedType getSliceType(ShapedType type, ArrayRef<OpFoldResult> sizes) {
+  auto staticSizes = getConstantIntValues(sizes);
+  if (staticSizes.has_value())
+    return type.cloneWith(*staticSizes, type.getElementType());
+  return nullptr;
+}
+
+SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) {
+  auto shape = cast<ShapedType>(getResult().getType()).getShape();
+  auto zero = getAsIndexOpFoldResult(getContext(), 0);
+  auto one = getAsIndexOpFoldResult(getContext(), 1);
+  return llvm::map_to_vector(shape, [&](int64_t size) {
+    return Range{.offset = zero,
+                 .size = getAsIndexOpFoldResult(getContext(), size),
+                 .stride = one};
+  });
+}
+
+SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() {
+  auto tensorType = cast<ShapedType>(getResult().getType());
+  SmallVector<utils::IteratorType> types(
+      static_cast<size_t>(tensorType.getRank()), utils::IteratorType::parallel);
+  return types;
+}
+
+FailureOr<TilingResult>
+TilingNoDpsOp::getTiledImplementation(OpBuilder &builder,
+                                      ArrayRef<OpFoldResult> offsets,
+                                      ArrayRef<OpFoldResult> sizes) {
+  auto loc = getLoc();
+  auto strides = SmallVector<OpFoldResult>(
+      static_cast<size_t>(cast<ShapedType>(getOperand(0).getType()).getRank()),
+      getAsIndexOpFoldResult(getContext(), 1));
+  auto inputSlices = llvm::map_to_vector(getOperands(), [&](Value operand) {
+    return getSlice(builder, loc, operand, offsets, sizes, strides);
+  });
+  auto resultType =
+      getSliceType(cast<ShapedType>(getResult().getType()), sizes);
+  auto tiledOp = TilingNoDpsOp::create(builder, loc, TypeRange{resultType},
+                                       ValueRange(inputSlices));
+  return TilingResult{.tiledOps = {tiledOp},
+                      .tiledValues = SmallVector<Value>{tiledOp.getResult()},
+                      .generatedSlices =
+                          map_to_vector(inputSlices, [](Value val) {
+                            return val.getDefiningOp();
+                          })};
+}
+
+LogicalResult TilingNoDpsOp::getResultTilePosition(
+    OpBuilder & builder, unsigned resultNumber,
+    ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+    SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  resultOffsets.assign(offsets.begin(), offsets.end());
+  resultSizes.assign(sizes.begin(), sizes.end());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // OpWithShapedTypeInferTypeAdaptorInterfaceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a3430ba49a291..620d950c0d2af 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -30,6 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/TilingInterface.td"
 include "mlir/Interfaces/ValueBoundsOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
@@ -2887,6 +2888,20 @@ def TestLinalgFillOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test TilingInterface.
+//===----------------------------------------------------------------------===//
+
+def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op",
+    [Pure, DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
+  let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
+  let results = (outs AnyRankedTensor:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // Test NVVM RequiresSM trait.
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list