[Mlir-commits] [mlir] 269cb22 - [mlir][transform] extract a minimal DomainAndOperandsAffineMapT… (#145034)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 20 06:45:24 PDT 2025
Author: Nicolas Vasilache
Date: 2025-06-20T15:45:21+02:00
New Revision: 269cb22ae82fd83ecc7a7ef7f7a4110e4c7a43ec
URL: https://github.com/llvm/llvm-project/commit/269cb22ae82fd83ecc7a7ef7f7a4110e4c7a43ec
DIFF: https://github.com/llvm/llvm-project/commit/269cb22ae82fd83ecc7a7ef7f7a4110e4c7a43ec.diff
LOG: [mlir][transform] extract a minimal DomainAndOperandsAffineMapT… (#145034)
…ransferInterface out of LinalgStructuredInterface and use that for
PadTilingInterface
Along the way, a bug was found on the handling of scalar values, fix it
and add a test.
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index dbc1ac60e0973..74c4c0a8835f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -222,9 +222,59 @@ def LinalgFillOpInterface : OpInterface<"FillOpInterface"> {
];
}
+def IndexingMapOpInterface : OpInterface<"IndexingMapOpInterface"> {
+ let description = [{
+ Interface for operations that connect an iteration domain to operands via
+ affine maps. Provides methods to access indexing maps between iteration
+ domain and operand index spaces.
+ }];
+ let cppNamespace = "::mlir::linalg";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps attribute within the current operation.
+ }],
+ /*retTy=*/"ArrayAttr",
+ /*methodName=*/"getIndexingMaps"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the indexing maps within the current operation.
+ }],
+ /*retTy=*/"SmallVector<AffineMap>",
+ /*methodName=*/"getIndexingMapsArray",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto range = $_op.getIndexingMaps()
+ .template getAsValueRange<AffineMapAttr>();
+ return {range.begin(), range.end()};
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the input or output indexing map for `opOperand`.
+ }],
+ /*retTy=*/"AffineMap",
+ /*methodName=*/"getMatchingIndexingMap",
+ /*args=*/(ins "OpOperand*":$opOperand),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(opOperand->getOwner() == this->getOperation());
+ auto indexingMaps =
+ $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
+ return *(indexingMaps.begin() + opOperand->getOperandNumber());
+ }]
+ >,
+ ];
+}
+
// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
def LinalgStructuredInterface
- : OpInterface<"LinalgOp", [DestinationStyleOpInterface]> {
+ : OpInterface<"LinalgOp", [
+ DestinationStyleOpInterface,
+ IndexingMapOpInterface
+ ]> {
let cppNamespace = "::mlir::linalg";
let methods = [
//===------------------------------------------------------------------===//
@@ -465,21 +515,6 @@ def LinalgStructuredInterface
blockArgument.getArgNumber());
}]
>,
- InterfaceMethod<
- /*desc=*/[{
- Return the input or output indexing map for `opOperand`.
- }],
- /*retTy=*/"AffineMap",
- /*methodName=*/"getMatchingIndexingMap",
- /*args=*/(ins "OpOperand*":$opOperand),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(opOperand->getOwner() == this->getOperation());
- auto indexingMaps =
- $_op.getIndexingMaps().template getAsValueRange<AffineMapAttr>();
- return *(indexingMaps.begin() + opOperand->getOperandNumber());
- }]
- >,
InterfaceMethod<
/*desc=*/[{
Return the indexing map for a `result`.
@@ -576,27 +611,6 @@ def LinalgStructuredInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{ return success(); }]
>,
- InterfaceMethod<
- /*desc=*/[{
- Return the indexing maps attribute within the current operation.
- }],
- /*retTy=*/"ArrayAttr",
- /*methodName=*/"getIndexingMaps"
- >,
- InterfaceMethod<
- /*desc=*/[{
- Return the indexing maps within the current operation.
- }],
- /*retTy=*/"SmallVector<AffineMap>",
- /*methodName=*/"getIndexingMapsArray",
- /*args=*/(ins),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- auto range = $_op.getIndexingMaps()
- .template getAsValueRange<AffineMapAttr>();
- return {range.begin(), range.end()};
- }]
- >,
InterfaceMethod<
/*desc=*/[{
Return true if any of the operands has a dynamic shape.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 59b7fdeef10b3..a6dab03d6473f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -612,10 +612,9 @@ using PadSizeComputationFunction =
const PadTilingInterfaceOptions &)>;
/// Specific helper for Linalg ops.
-FailureOr<SmallVector<OpFoldResult>>
-computeLinalgPaddedShape(RewriterBase &rewriter, OpOperand &operandToPad,
- ArrayRef<Range> iterationDomain,
- const PadTilingInterfaceOptions &options);
+FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
+ RewriterBase &rewriter, OpOperand &operandToPad,
+ ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
///
@@ -632,7 +631,7 @@ rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
const PadTilingInterfaceOptions &constOptions,
SmallVector<tensor::PadOp> &padOps,
PadSizeComputationFunction computePaddingSizeFun =
- &computeLinalgPaddedShape);
+ &computeIndexingMapOpInterfacePaddedShape);
namespace detail {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index e627fc83f2ba7..5d55adbf46f36 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2229,10 +2229,12 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
return diag;
}
- // Only Linalg ops for now, until TilingInterface exposes a loopsToOperand
- // map / C++ APIs to compute the effect of padding on operands.
- if (!isa<LinalgOp>(targetOp.getOperation())) {
- auto diag = emitSilenceableError() << "only LinalgOp supported atm";
+ // Only IndexingMapOpInterface ops for now, until TilingInterface exposes a
+ // loopsToOperand map / C++ APIs to compute the effect of padding on
+ // operands.
+ if (!isa<IndexingMapOpInterface>(targetOp.getOperation())) {
+ auto diag = emitSilenceableError() << "only IndexingMapOpInterface ops "
+ "supported atm";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index a9d7bc64f2a6b..5383ae48aeb3a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -155,11 +155,13 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
return paddedShape;
}
-FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
+FailureOr<SmallVector<OpFoldResult>>
+linalg::computeIndexingMapOpInterfacePaddedShape(
RewriterBase &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
- auto linalgOp = llvm::dyn_cast<LinalgOp>(operandToPad.getOwner());
- if (!linalgOp)
+ auto transferOp =
+ llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
+ if (!transferOp)
return failure();
// clang-format off
@@ -173,7 +175,7 @@ FailureOr<SmallVector<OpFoldResult>> linalg::computeLinalgPaddedShape(
for (const Range &range : iterationDomain)
loopUpperBounds.push_back(range.size);
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&operandToPad);
+ AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
return computePaddedShape(
rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
indexingMap, loopUpperBounds, options);
@@ -255,7 +257,18 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
SmallVector<Value> newOperands;
newOperands.reserve(opToPad->getNumOperands());
for (OpOperand &opOperand : opToPad->getOpOperands()) {
- LLVM_DEBUG(DBGS() << "--start padding oprd: " << opOperand.get() << "\n");
+ Value operand = opOperand.get();
+ LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
+
+ // 2.a. Skip scalar-like operands.
+ Type operandType = operand.getType();
+ if (!isa<RankedTensorType>(operandType)) {
+ assert(!isa<ShapedType>(operandType) ||
+ isa<VectorType>(operandType) &&
+ "Unexpected non-vector ShapedType");
+ newOperands.push_back(operand);
+ continue;
+ }
// 2.a. Compute padded shape.
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
@@ -266,14 +279,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
// 2.b. Expect proper `paddingValues`.
// TODO: we may want to allow garbage padding in the future, in which case
// we would just not assert.
- assert(opOperand.getOperandNumber() < options.paddingValues.size() &&
- "--no padding value specified");
+ if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
+ return rewriter.notifyMatchFailure(opToPad,
+ "--no padding value specified");
+ }
Attribute paddingValueAttr =
options.paddingValues[opOperand.getOperandNumber()];
// 2.c. Perform actual padding.
Value paddedOperand = padOperand(
- rewriter, opToPad, cast<TypedValue<RankedTensorType>>(opOperand.get()),
+ rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
*maybePaddedShape, paddingValueAttr);
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
index c361885693cbc..f0a410fa4015f 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad-tiling-interface.mlir
@@ -1,5 +1,33 @@
// RUN: mlir-opt --transform-interpreter -canonicalize -split-input-file --verify-diagnostics %s | FileCheck %s
+// CHECK-LABEL: pad_fill
+// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<8x25xf32>) -> tensor<8x25xf32>
+func.func @pad_fill(%value: f32, %output: tensor<24x25xf32>) -> tensor<24x25xf32>
+{
+ %0 = linalg.fill ins(%value : f32) outs(%output : tensor<24x25xf32>) -> tensor<24x25xf32>
+ func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %fill = transform.structured.match ops{["linalg.fill"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+
+ // Tile to 5 then pad to 8
+ %fill_l1, %loops_l1 = transform.structured.tile_using_for %fill tile_sizes [5]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ %fill_padded, %_ = transform.structured.pad_tiling_interface %fill_l1 to padding_sizes [8] {
+ padding_values=[0.0 : f32, 0.0 : f32],
+ padding_dimensions=[0]
+ } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: pad_lhs
func.func @pad_lhs(
%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>)
More information about the Mlir-commits
mailing list