[Mlir-commits] [mlir] 14d073b - Revert "[mlir][transform] Allow arbitrary indices to be scalable"
Alexander Belyaev
llvmlistbot at llvm.org
Tue Jul 4 00:41:27 PDT 2023
Author: Alexander Belyaev
Date: 2023-07-04T09:40:52+02:00
New Revision: 14d073b50f960674a62ef8ad2c34f6fc1e9b0061
URL: https://github.com/llvm/llvm-project/commit/14d073b50f960674a62ef8ad2c34f6fc1e9b0061
DIFF: https://github.com/llvm/llvm-project/commit/14d073b50f960674a62ef8ad2c34f6fc1e9b0061.diff
LOG: Revert "[mlir][transform] Allow arbitrary indices to be scalable"
This reverts commit 048764f23a380fd6f8cc562a0008dcc6095fb594.
Breaks https://lab.llvm.org/buildbot/#/builders/61/builds/45451
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Transform/Utils/Utils.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/Dialect/Linalg/transform-op-tile.mlir
mlir/test/Dialect/Transform/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index c33be09f818e93..7caae2b480be2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1686,7 +1686,7 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
- DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$last_tile_size_scalable);
let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
Variadic<TransformHandleTypeInterface>:$loops);
let builders = [
@@ -2008,10 +2008,9 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
let arguments = (ins TransformHandleTypeInterface:$target,
Variadic<TransformHandleTypeInterface>:$vector_sizes,
UnitAttr:$vectorize_nd_extract,
- DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
- $scalable_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
- $static_vector_sizes);
+ $static_vector_sizes,
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$last_vector_size_scalable);
let results = (outs);
let assemblyFormat = [{
@@ -2019,7 +2018,7 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
`vector_sizes` custom<DynamicIndexList>($vector_sizes,
$static_vector_sizes,
type($vector_sizes),
- $scalable_sizes)
+ $last_vector_size_scalable)
attr-dict
`:` type($target)
}];
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 9c99c4ceec69e4..fad380d4005f1c 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -52,15 +52,13 @@ namespace mlir {
/// integer attributes in a list. E.g.
/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
///
-/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
-/// This notation is similar to how scalable dims are marked when defining
-/// Vectors. For each value in `integers`, the corresponding `bool` in
-/// `scalables` encodes whether it's a scalable index. If `scalables` is
-/// empty then assume that all indices are non-scalable.
+/// If `isTrailingIdxScalable` is true, then wrap the trailing index with
+/// square brackets, e.g. `[42]`, to denote scalability. This would normally be
+/// used for scalable tile or vector sizes.
void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
- ArrayRef<bool> scalables = {},
+ BoolAttr isTrailingIdxScalable = {},
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
/// Parser hook for custom directive in assemblyFormat.
@@ -80,43 +78,41 @@ void printDynamicIndexList(
/// `kDynamic`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
///
-/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
-/// This notation is similar to how scalable dims are marked when defining
-/// Vectors. For each value in `integers`, the corresponding `bool` in
-/// `scalables` encodes whether it's a scalable index.
+/// Trailing indices can be scalable. For example, "42" in "[7, [42]]" is
+/// scalable. This notation is similar to how scalable dims are marked when
+/// defining Vectors. If /p isTrailingIdxScalable is null, scalable indices are
+/// not allowed/expected. When it's not null, this hook will set the
+/// corresponding value to:
+/// * true if the trailing idx is scalable,
+/// * false otherwise.
ParseResult parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
+ DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable = nullptr,
SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
-inline ParseResult parseDynamicIndexList(
- OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
- AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
- DenseBoolArrayAttr scalables = {};
- return parseDynamicIndexList(parser, values, integers, scalables, valueTypes,
- delimiter);
-}
inline ParseResult parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
- DenseBoolArrayAttr scalables = {};
- return parseDynamicIndexList(parser, values, integers, scalables,
- &valueTypes, delimiter);
+ return parseDynamicIndexList(parser, values, integers,
+ /*isTrailingIdxScalable=*/nullptr, &valueTypes,
+ delimiter);
}
inline ParseResult parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
- DenseBoolArrayAttr &scalables,
+ BoolAttr &isTrailingIdxScalable,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
- return parseDynamicIndexList(parser, values, integers, scalables, &valueTypes,
- delimiter);
+ bool scalable = false;
+ auto res = parseDynamicIndexList(parser, values, integers, &scalable,
+ &valueTypes, delimiter);
+ auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
+ isTrailingIdxScalable = scalableAttr;
+ return res;
}
/// Verify that a the `values` has as many elements as the number of entries in
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 78f82c98726411..781e48a9824a1c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2434,7 +2434,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
SmallVector<Operation *> tiled;
SmallVector<SmallVector<Operation *, 4>, 4> loops;
loops.resize(getLoops().size());
- auto scalableSizes = getScalableSizes();
+ bool scalable = getLastTileSizeScalable();
for (auto [i, op] : llvm::enumerate(targets)) {
auto tilingInterface = dyn_cast<TilingInterface>(op);
auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
@@ -2453,10 +2453,12 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
SmallVector<Value, 4> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
+ unsigned trailingIdx = getMixedSizes().size() - 1;
for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
- if (scalableSizes[ofrIdx]) {
+ // Only the trailing tile size is allowed to be scalable atm.
+ if (scalable && (ofrIdx == trailingIdx)) {
auto val = b.create<arith::ConstantIndexOp>(
getLoc(), attr.cast<IntegerAttr>().getInt());
Value vscale =
@@ -2558,10 +2560,9 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
DenseI64ArrayAttr staticSizes;
FunctionType functionalType;
llvm::SMLoc operandLoc;
- DenseBoolArrayAttr scalableSizes;
-
+ bool scalable = false;
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
- parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableSizes) ||
+ parseDynamicIndexList(parser, dynamicSizes, staticSizes, &scalable) ||
parseOptionalInterchange(parser, result) ||
parser.parseColonType(functionalType))
return ParseResult::failure();
@@ -2584,7 +2585,9 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
return failure();
}
- result.addAttribute(getScalableSizesAttrName(result.name), scalableSizes);
+ auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
+ result.addAttribute(getLastTileSizeScalableAttrName(result.name),
+ scalableAttr);
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
result.addTypes(functionalType.getResults());
@@ -2594,7 +2597,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
void TileOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
- /*valueTypes=*/{}, getScalableSizesAttr(),
+ /*valueTypes=*/{}, getLastTileSizeScalableAttr(),
OpAsmParser::Delimiter::Square);
printOptionalInterchange(p, getInterchange());
p << " : ";
@@ -3141,14 +3144,15 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
}
// TODO: Check that the correct number of vectorSizes was provided.
+ SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
+ scalableVecDims.back() = getLastVectorSizeScalable();
for (Operation *target : targets) {
if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Unsupported Op, cannot vectorize";
}
- if (failed(linalg::vectorize(rewriter, target, vectorSizes,
- getScalableSizes(),
+ if (failed(linalg::vectorize(rewriter, target, vectorSizes, scalableVecDims,
getVectorizeNdExtract()))) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 97c086bf5b728c..4f805d692637ea 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1254,20 +1254,20 @@ void ForallOp::print(OpAsmPrinter &p) {
if (isNormalized()) {
p << ") in ";
printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
- /*valueTypes=*/{}, /*scalables=*/{},
+ /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
OpAsmParser::Delimiter::Paren);
} else {
p << ") = ";
printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
- /*valueTypes=*/{}, /*scalables=*/{},
+ /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
OpAsmParser::Delimiter::Paren);
p << " to ";
printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
- /*valueTypes=*/{}, /*scalables=*/{},
+ /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
OpAsmParser::Delimiter::Paren);
p << " step ";
printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
- /*valueTypes=*/{}, /*scalables=*/{},
+ /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
OpAsmParser::Delimiter::Paren);
}
printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
@@ -1299,9 +1299,9 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
dynamicSteps;
if (succeeded(parser.parseOptionalKeyword("in"))) {
// Parse upper bounds.
- if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
- /*valueTypes=*/nullptr,
- OpAsmParser::Delimiter::Paren) ||
+ if (parseDynamicIndexList(
+ parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
+ /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
return failure();
@@ -1311,26 +1311,26 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
} else {
// Parse lower bounds.
if (parser.parseEqual() ||
- parseDynamicIndexList(parser, dynamicLbs, staticLbs,
- /*valueTypes=*/nullptr,
- OpAsmParser::Delimiter::Paren) ||
+ parseDynamicIndexList(
+ parser, dynamicLbs, staticLbs, /*isTrailingIdxScalable=*/nullptr,
+ /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicLbs, indexType, result.operands))
return failure();
// Parse upper bounds.
if (parser.parseKeyword("to") ||
- parseDynamicIndexList(parser, dynamicUbs, staticUbs,
- /*valueTypes=*/nullptr,
- OpAsmParser::Delimiter::Paren) ||
+ parseDynamicIndexList(
+ parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
+ /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
return failure();
// Parse step values.
if (parser.parseKeyword("step") ||
- parseDynamicIndexList(parser, dynamicSteps, staticSteps,
- /*valueTypes=*/nullptr,
- OpAsmParser::Delimiter::Paren) ||
+ parseDynamicIndexList(
+ parser, dynamicSteps, staticSteps, /*scalable=*/nullptr,
+ /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicSteps, indexType, result.operands))
return failure();
}
diff --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
index d516a56feed478..e7516423fb58c7 100644
--- a/mlir/lib/Dialect/Transform/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
@@ -42,5 +42,6 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList(
return success();
}
- return parseDynamicIndexList(parser, values, integers, &valueTypes);
+ return parseDynamicIndexList(parser, values, integers,
+ /*isTrailingIdxScalable=*/nullptr, &valueTypes);
}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 667f66bb99610b..0f75cc10fc8234 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -102,7 +102,8 @@ static char getRightDelimiter(AsmParser::Delimiter delimiter) {
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
- TypeRange valueTypes, ArrayRef<bool> scalables,
+ TypeRange valueTypes,
+ BoolAttr isTrailingIdxScalable,
AsmParser::Delimiter delimiter) {
char leftDelimiter = getLeftDelimiter(delimiter);
char rightDelimiter = getRightDelimiter(delimiter);
@@ -112,42 +113,59 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
return;
}
- unsigned dynamicValIdx = 0;
- unsigned scalableIndexIdx = 0;
+ int64_t trailingScalableInteger;
+ if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
+ // ATM only the trailing idx can be scalable
+ trailingScalableInteger = integers.back();
+ integers = integers.drop_back();
+ }
+
+ unsigned idx = 0;
llvm::interleaveComma(integers, printer, [&](int64_t integer) {
- if (not scalables.empty() && scalables[scalableIndexIdx])
- printer << "[";
if (ShapedType::isDynamic(integer)) {
- printer << values[dynamicValIdx];
+ printer << values[idx];
if (!valueTypes.empty())
- printer << " : " << valueTypes[dynamicValIdx];
- ++dynamicValIdx;
+ printer << " : " << valueTypes[idx];
+ ++idx;
} else {
printer << integer;
}
- if (!scalables.empty() && scalables[scalableIndexIdx])
- printer << "]";
-
- scalableIndexIdx++;
});
+ // Print the trailing scalable index
+ if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
+ if (!integers.empty())
+ printer << ", ";
+ printer << "[";
+ printer << trailingScalableInteger;
+ printer << "]";
+ }
+
printer << rightDelimiter;
}
ParseResult mlir::parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
+ DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable,
SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
SmallVector<int64_t, 4> integerVals;
- SmallVector<bool, 4> scalableVals;
+ bool foundScalable = false;
auto parseIntegerOrValue = [&]() {
OpAsmParser::UnresolvedOperand operand;
auto res = parser.parseOptionalOperand(operand);
- // When encountering `[`, assume that this is a scalable index.
- scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
+ // If `foundScalable` has already been set to `true` then a non-trailing
+ // index was identified as scalable.
+ if (foundScalable) {
+ parser.emitError(parser.getNameLoc())
+ << "non-trailing index cannot be scalable";
+ return failure();
+ }
+
+ if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded())
+ foundScalable = true;
if (res.has_value() && succeeded(res.value())) {
values.push_back(operand);
@@ -160,10 +178,7 @@ ParseResult mlir::parseDynamicIndexList(
return failure();
integerVals.push_back(integer);
}
-
- // If this is assumed to be a scalable index, verify that there's a closing
- // `]`.
- if (scalableVals.back() && parser.parseOptionalRSquare().failed())
+ if (foundScalable && parser.parseOptionalRSquare().failed())
return failure();
return success();
};
@@ -172,7 +187,8 @@ ParseResult mlir::parseDynamicIndexList(
return parser.emitError(parser.getNameLoc())
<< "expected SSA value or integer";
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
- scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
+ if (isTrailingIdxScalable)
+ *isTrailingIdxScalable = foundScalable;
return success();
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index 8b449770ee8a1b..3300e869979780 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -220,3 +220,25 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
}
+
+// -----
+
+// TODO: Add support for for specyfying more than one scalable tile size
+
+func.func @scalable_and_fixed_length_tile(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ return %0 : tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{non-trailing index cannot be scalable}}
+ // expected-error @below {{expected SSA value or integer}}
+ %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index dc35a9a6c9032d..7ddfcc60718730 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -105,11 +105,3 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
}
-
-// CHECK: transform.sequence
-// CHECK: transform.structured.tile %0{{\[}}[2], 4, 8]
-transform.sequence failures(propagate) {
-^bb0(%arg1: !transform.any_op):
- %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.structured.tile %0 [[2], 4, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-}
More information about the Mlir-commits
mailing list