[Mlir-commits] [mlir] 726835c - [mlir] Update how scalable indices are printed
Andrzej Warzynski
llvmlistbot at llvm.org
Fri Jun 2 09:04:18 PDT 2023
Author: Andrzej Warzynski
Date: 2023-06-02T16:47:56+01:00
New Revision: 726835cd51503c3d287904ea2d4055c41f969e71
URL: https://github.com/llvm/llvm-project/commit/726835cd51503c3d287904ea2d4055c41f969e71
DIFF: https://github.com/llvm/llvm-project/commit/726835cd51503c3d287904ea2d4055c41f969e71.diff
LOG: [mlir] Update how scalable indices are printed
This patch makes sure that scalable indices (that would normally
represent scalable tile or vector sizes) are printed correctly, i.e.
with additional square brackets:
```
%1, %loop = transform.structured.tile %0 [2, 8, [4]]
```
This change complements https://reviews.llvm.org/D150944 and is a part
of a larger effort to enable scalable vectorisation in Linalg. See this
RFC for more context:
* https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/
Differential Revision: https://reviews.llvm.org/D151978
Added:
Modified:
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Transform/Utils/Utils.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/Dialect/Transform/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index cab2a0bcc11b1..82a563a03c3ac 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -51,10 +51,15 @@ namespace mlir {
/// indicating their types. This allows idiomatic printing of mixed value and
/// integer attributes in a list. E.g.
/// `[%arg0 : index, 7, 42, %arg42 : i32]`.
+///
+/// If `isTrailingIdxScalable` is true, then wrap the trailing index with
+/// square brackets, e.g. `[42]`, to denote scalability. This would normally be
+/// used for scalable tile or vector sizes.
void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
- AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+ bool isTrailingIdxScalable = false);
/// Parser hook for custom directive in assemblyFormat.
///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 51dcd7e17c0f5..133ce91bbcb84 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2555,7 +2555,9 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
void TileOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
- printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
+ printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
+ /*valueTypes=*/{}, OpAsmParser::Delimiter::Square,
+ getLastTileSizeScalable());
printOptionalInterchange(p, getInterchange());
p << " : ";
p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c8d64201cb2a2..3c531bc99cff2 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1262,7 +1262,7 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
if (succeeded(parser.parseOptionalKeyword("in"))) {
// Parse upper bounds.
if (parseDynamicIndexList(
- parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
+ parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
return failure();
@@ -1274,7 +1274,7 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse lower bounds.
if (parser.parseEqual() ||
parseDynamicIndexList(
- parser, dynamicLbs, staticLbs, /*scalable=*/nullptr,
+ parser, dynamicLbs, staticLbs, /*isTrailingIdxScalable=*/nullptr,
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicLbs, indexType, result.operands))
@@ -1283,7 +1283,7 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse upper bounds.
if (parser.parseKeyword("to") ||
parseDynamicIndexList(
- parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
+ parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
/*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
return failure();
diff --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
index b50a7660e2bf2..e7516423fb58c 100644
--- a/mlir/lib/Dialect/Transform/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
@@ -42,6 +42,6 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList(
return success();
}
- return parseDynamicIndexList(parser, values, integers, /*scalable=*/nullptr,
- &valueTypes);
+ return parseDynamicIndexList(parser, values, integers,
+ /*isTrailingIdxScalable=*/nullptr, &valueTypes);
}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 13cca8131b682..d0310730ca79b 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -103,7 +103,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
TypeRange valueTypes,
- AsmParser::Delimiter delimiter) {
+ AsmParser::Delimiter delimiter,
+ bool isTrailingIdxScalable) {
char leftDelimiter = getLeftDelimiter(delimiter);
char rightDelimiter = getRightDelimiter(delimiter);
printer << leftDelimiter;
@@ -111,6 +112,14 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
printer << rightDelimiter;
return;
}
+
+ int64_t trailingScalableInteger;
+ if (isTrailingIdxScalable) {
+ // ATM only the trailing idx can be scalable
+ trailingScalableInteger = integers.back();
+ integers = integers.drop_back();
+ }
+
unsigned idx = 0;
llvm::interleaveComma(integers, printer, [&](int64_t integer) {
if (ShapedType::isDynamic(integer)) {
@@ -122,6 +131,15 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
printer << integer;
}
});
+
+ // Print the trailing scalable index
+ if (isTrailingIdxScalable) {
+ printer << ", ";
+ printer << "[";
+ printer << trailingScalableInteger;
+ printer << "]";
+ }
+
printer << rightDelimiter;
}
diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index b85df428f5a8d..7ddfcc6071873 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -97,3 +97,11 @@ transform.sequence failures(propagate) {
transform.print %arg0 {name = "test"} : !transform.any_op
transform.print {name = "test"}
}
+
+// CHECK: transform.sequence
+// CHECK: transform.structured.tile %0[4, 4, [4]]
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}
More information about the Mlir-commits
mailing list