[Mlir-commits] [mlir] 726835c - [mlir] Update how scalable indices are printed

Andrzej Warzynski llvmlistbot at llvm.org
Fri Jun 2 09:04:18 PDT 2023


Author: Andrzej Warzynski
Date: 2023-06-02T16:47:56+01:00
New Revision: 726835cd51503c3d287904ea2d4055c41f969e71

URL: https://github.com/llvm/llvm-project/commit/726835cd51503c3d287904ea2d4055c41f969e71
DIFF: https://github.com/llvm/llvm-project/commit/726835cd51503c3d287904ea2d4055c41f969e71.diff

LOG: [mlir] Update how scalable indices are printed

This patch makes sure that scalable indices (that would normally
represent scalable tile or vector sizes) are printed correctly, i.e.
with additional square brackets:
```
%1, %loop = transform.structured.tile %0 [2, 8, [4]]
```

This change complements https://reviews.llvm.org/D150944 and is a part
of a larger effort to enable scalable vectorisation in Linalg. See this
RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/

Differential Revision: https://reviews.llvm.org/D151978

Added: 
    

Modified: 
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/Transform/Utils/Utils.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/test/Dialect/Transform/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index cab2a0bcc11b1..82a563a03c3ac 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -51,10 +51,15 @@ namespace mlir {
 /// indicating their types. This allows idiomatic printing of mixed value and
 /// integer attributes in a list. E.g.
 /// `[%arg0 : index, 7, 42, %arg42 : i32]`.
+///
+/// If  `isTrailingIdxScalable` is true, then wrap the trailing index with
+/// square brackets, e.g. `[42]`, to denote scalability. This would normally be
+/// used for scalable tile or vector sizes.
 void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
+    bool isTrailingIdxScalable = false);
 
 /// Parser hook for custom directive in assemblyFormat.
 ///

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 51dcd7e17c0f5..133ce91bbcb84 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2555,7 +2555,9 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
 
 void TileOp::print(OpAsmPrinter &p) {
   p << ' ' << getTarget();
-  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
+  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
+                        /*valueTypes=*/{}, OpAsmParser::Delimiter::Square,
+                        getLastTileSizeScalable());
   printOptionalInterchange(p, getInterchange());
   p << " : ";
   p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c8d64201cb2a2..3c531bc99cff2 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1262,7 +1262,7 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
   if (succeeded(parser.parseOptionalKeyword("in"))) {
     // Parse upper bounds.
     if (parseDynamicIndexList(
-            parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
+            parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
             /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
         parser.resolveOperands(dynamicUbs, indexType, result.operands))
       return failure();
@@ -1274,7 +1274,7 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
     // Parse lower bounds.
     if (parser.parseEqual() ||
         parseDynamicIndexList(
-            parser, dynamicLbs, staticLbs, /*scalable=*/nullptr,
+            parser, dynamicLbs, staticLbs, /*isTrailingIdxScalable=*/nullptr,
             /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
 
         parser.resolveOperands(dynamicLbs, indexType, result.operands))
@@ -1283,7 +1283,7 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
     // Parse upper bounds.
     if (parser.parseKeyword("to") ||
         parseDynamicIndexList(
-            parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
+            parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
             /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
         parser.resolveOperands(dynamicUbs, indexType, result.operands))
       return failure();

diff  --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
index b50a7660e2bf2..e7516423fb58c 100644
--- a/mlir/lib/Dialect/Transform/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
@@ -42,6 +42,6 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList(
     return success();
   }
 
-  return parseDynamicIndexList(parser, values, integers, /*scalable=*/nullptr,
-                               &valueTypes);
+  return parseDynamicIndexList(parser, values, integers,
+                               /*isTrailingIdxScalable=*/nullptr, &valueTypes);
 }

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 13cca8131b682..d0310730ca79b 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -103,7 +103,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                  OperandRange values,
                                  ArrayRef<int64_t> integers,
                                  TypeRange valueTypes,
-                                 AsmParser::Delimiter delimiter) {
+                                 AsmParser::Delimiter delimiter,
+                                 bool isTrailingIdxScalable) {
   char leftDelimiter = getLeftDelimiter(delimiter);
   char rightDelimiter = getRightDelimiter(delimiter);
   printer << leftDelimiter;
@@ -111,6 +112,14 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
     printer << rightDelimiter;
     return;
   }
+
+  int64_t trailingScalableInteger;
+  if (isTrailingIdxScalable) {
+    // ATM only the trailing idx can be scalable
+    trailingScalableInteger = integers.back();
+    integers = integers.drop_back();
+  }
+
   unsigned idx = 0;
   llvm::interleaveComma(integers, printer, [&](int64_t integer) {
     if (ShapedType::isDynamic(integer)) {
@@ -122,6 +131,15 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
       printer << integer;
     }
   });
+
+  // Print the trailing scalable index
+  if (isTrailingIdxScalable) {
+    printer << ", ";
+    printer << "[";
+    printer << trailingScalableInteger;
+    printer << "]";
+  }
+
   printer << rightDelimiter;
 }
 

diff  --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index b85df428f5a8d..7ddfcc6071873 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -97,3 +97,11 @@ transform.sequence failures(propagate) {
   transform.print %arg0 {name = "test"} : !transform.any_op
   transform.print {name = "test"}
 }
+
+// CHECK: transform.sequence
+// CHECK: transform.structured.tile %0[4, 4, [4]]
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}


        


More information about the Mlir-commits mailing list