[Mlir-commits] [mlir] 269cb22 - [mlir][transform] extract a minimal DomainAndOperandsAffineMapT… (#145034)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 20 06:45:24 PDT 2025


Author: Nicolas Vasilache
Date: 2025-06-20T15:45:21+02:00
New Revision: 269cb22ae82fd83ecc7a7ef7f7a4110e4c7a43ec

URL: https://github.com/llvm/llvm-project/commit/269cb22ae82fd83ecc7a7ef7f7a4110e4c7a43ec
DIFF: https://github.com/llvm/llvm-project/commit/269cb22ae82fd83ecc7a7ef7f7a4110e4c7a43ec.diff

LOG: [mlir][transform] extract a minimal DomainAndOperandsAffineMapT… (#145034)

…ransferInterface out of LinalgStructuredInterface and use that for
PadTilingInterface

Along the way, a bug was found on the handling of scalar values, fix it
and add a test.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
    mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index dbc1ac60e0973..74c4c0a8835f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -222,9 +222,59 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
   ];
 }
 
+def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
+  let description = [{
+    Interface for operations that connect an iteration domain to operands via
+    affine maps. Provides methods to access indexing maps between iteration
+    domain and operand index spaces.
+  }];
+  let cppNamespace = "::mlir::linalg";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the indexing maps attribute within the current operation.
+      }],
+      /*retTy=*/"ArrayAttr",
+      /*methodName=*/"getIndexingMaps"
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the indexing maps within the current operation.
+      }],
+      /*retTy=*/"SmallVector<AffineMap>",
+      /*methodName=*/"getIndexingMapsArray",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto range = $_op.getIndexingMaps()
+          .template getAsValueRange<AffineMapAttr>();
+        return {range.begin(), range.end()};
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the input or output indexing map for `opOperand`.
+      }],
+      /*retTy=*/"AffineMap",
+      /*methodName=*/"getMatchingIndexingMap",
+      /*args=*/(ins "OpOperand*":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(opOperand->getOwner() == this->getOperation());
+        auto indexingMaps =
+          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
+        return *(indexingMaps.begin() + opOperand->getOperandNumber());
+      }]
+    >,
+  ];
+}
+
 // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
 def LinalgStructuredInterface
-    : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
+    : OpInterface<"LinalgOp", [
+      DestinationStyleOpInterface,
+      IndexingMapOpInterface
+  ]> {
   let cppNamespace = "::mlir::linalg";
   let methods = [
     //===------------------------------------------------------------------===//
@@ -465,21 +515,6 @@ def LinalgStructuredInterface
             blockArgument.getArgNumber());
       }]
     >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the input or output indexing map for `opOperand`.
-      }],
-      /*retTy=*/"AffineMap",
-      /*methodName=*/"getMatchingIndexingMap",
-      /*args=*/(ins "OpOperand*":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(opOperand->getOwner() == this->getOperation());
-        auto indexingMaps =
-          $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
-        return *(indexingMaps.begin() + opOperand->getOperandNumber());
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/[{
         Return the indexing map for a `result`.
@@ -576,27 +611,6 @@ def LinalgStructuredInterface
       /*methodBody=*/"",
       /*defaultImplementation=*/[{ return success(); }]
     >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the indexing maps attribute within the current operation.
-      }],
-      /*retTy=*/"ArrayAttr",
-      /*methodName=*/"getIndexingMaps"
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the indexing maps within the current operation.
-      }],
-      /*retTy=*/"SmallVector<AffineMap>",
-      /*methodName=*/"getIndexingMapsArray",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto range = $_op.getIndexingMaps()
-          .template getAsValueRange<AffineMapAttr>();
-        return {range.begin(), range.end()};
-      }]
-    >,
     InterfaceMethod<
       /*desc=*/[{
         Return true if any of the operands has a dynamic shape.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 59b7fdeef10b3..a6dab03d6473f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -612,10 +612,9 @@ using PadSizeComputationFunction =
         const PadTilingInterfaceOptions &)>;
 
 /// Specific helper for Linalg ops.
-FailureOr<SmallVector<OpFoldResult>>
-computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
-                         ArrayRef<Range> iterationDomain,
-                         const PadTilingInterfaceOptions &options);
+FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
+    RewriterBase &rewriter, OpOperand &operandToPad,
+    ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
 
 /// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
 ///
@@ -632,7 +631,7 @@ rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
                   const PadTilingInterfaceOptions &constOptions,
                   SmallVector<tensor::PadOp> &padOps,
                   PadSizeComputationFunction computePaddingSizeFun =
-                      &computeLinalgPaddedShape);
+                      &computeIndexingMapOpInterfacePaddedShape);
 
 namespace detail {
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e627fc83f2ba7..5d55adbf46f36 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2229,10 +2229,12 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
       return diag;
     }
 
-    // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
-    // map / C++ APIs to compute the effect of padding on operands.
-    if (!isa<LinalgOp>(targetOp.getOperation())) {
-      auto diag = emitSilenceableError() << "only LinalgOp supported atm";
+    // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
+    // loopsToOperand map / C++ APIs to compute the effect of padding on
+    // operands.
+    if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
+      auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
+                                            "supported atm";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
     }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index a9d7bc64f2a6b..5383ae48aeb3a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -155,11 +155,13 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
   return paddedShape;
 }
 
-FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
+FailureOr<SmallVector<OpFoldResult>>
+linalg::computeIndexingMapOpInterfacePaddedShape(
     RewriterBase &rewriter, OpOperand &operandToPad,
     ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
-  auto linalgOp = llvm::dyn_cast<LinalgOp>(operandToPad.getOwner());
-  if (!linalgOp)
+  auto transferOp =
+      llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
+  if (!transferOp)
     return failure();
 
   // clang-format off
@@ -173,7 +175,7 @@ FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
   for (const Range &range : iterationDomain)
     loopUpperBounds.push_back(range.size);
 
-  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&operandToPad);
+  AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
   return computePaddedShape(
       rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
       indexingMap, loopUpperBounds, options);
@@ -255,7 +257,18 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
   SmallVector<Value> newOperands;
   newOperands.reserve(opToPad->getNumOperands());
   for (OpOperand &opOperand : opToPad->getOpOperands()) {
-    LLVM_DEBUG(DBGS() << "--start padding oprd: " << opOperand.get() << "\n");
+    Value operand = opOperand.get();
+    LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
+
+    // 2.a. Skip scalar-like operands.
+    Type operandType = operand.getType();
+    if (!isa<RankedTensorType>(operandType)) {
+      assert(!isa<ShapedType>(operandType) ||
+             isa<VectorType>(operandType) &&
+                 "Unexpected non-vector ShapedType");
+      newOperands.push_back(operand);
+      continue;
+    }
     // 2.a. Compute padded shape.
     FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
         computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
@@ -266,14 +279,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
     // 2.b. Expect proper `paddingValues`.
     // TODO: we may want to allow garbage padding in the future, in which case
     // we would just not assert.
-    assert(opOperand.getOperandNumber() < options.paddingValues.size() &&
-           "--no padding value specified");
+    if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
+      return rewriter.notifyMatchFailure(opToPad,
+                                         "--no padding value specified");
+    }
     Attribute paddingValueAttr =
         options.paddingValues[opOperand.getOperandNumber()];
 
     // 2.c. Perform actual padding.
     Value paddedOperand = padOperand(
-        rewriter, opToPad, cast<TypedValue<RankedTensorType>>(opOperand.get()),
+        rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
         *maybePaddedShape, paddingValueAttr);
     LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index c361885693cbc..f0a410fa4015f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -1,5 +1,33 @@
 // RUN: mlir-opt --transform-interpreter -canonicalize -split-input-file --verify-diagnostics %s | FileCheck %s
 
+//     CHECK-LABEL: pad_fill
+//           CHECK:   linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
+func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
+{
+  %0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
+  func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %fill = transform.structured.match ops{["linalg.fill"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+
+    // Tile to 5 then pad to 8
+    %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5] 
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
+      padding_values=[0.0 : f32, 0.0 : f32],
+      padding_dimensions=[0]
+    } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+    transform.yield
+  }
+}
+
+// -----
+
 //     CHECK-LABEL: pad_lhs
 func.func @pad_lhs(
   %arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)


        


More information about the Mlir-commits mailing list