[Mlir-commits] [mlir] ad7ef19 - [mlir][transform] Allow arbitrary indices to be scalable

Andrzej Warzynski llvmlistbot at llvm.org
Wed Jul 5 01:55:39 PDT 2023


Author: Andrzej Warzynski
Date: 2023-07-05T09:53:26+01:00
New Revision: ad7ef1923fe582a95f16a877dd75889eb347c774

URL: https://github.com/llvm/llvm-project/commit/ad7ef1923fe582a95f16a877dd75889eb347c774
DIFF: https://github.com/llvm/llvm-project/commit/ad7ef1923fe582a95f16a877dd75889eb347c774.diff

LOG: [mlir][transform] Allow arbitrary indices to be scalable

This change lifts the limitation that only the trailing dimensions/sizes
in dynamic index lists can be scalable. It allows us to extend
`MaskedVectorizeOp` and `TileOp` from the Transform dialect so that the
following is allowed:

  %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]]

This is also a follow up for https://reviews.llvm.org/D153372
that will enable the following (middle vector dimension is scalable):

  transform.structured.masked_vectorize %0 vector_sizes [2, [4], 8]

To facilate this change, the hooks for parsing and printing dynamic
index lists are updated accordingly (`printDynamicIndexList` and
`parseDynamicIndexList`, respectively). `MaskedVectorizeOp` and `TileOp`
are updated to include an array of attribute of bools that captures
whether the corresponding vector dimension/tile size, respectively, are
scalable or not.

NOTE 1: I am re-landing this after the initial version was reverted. To
fix the regression and in addition to the original patch, this revision
updates the Python bindings for the transform dialect

NOTE 2: This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/

This relands 048764f23a380fd6f8cc562a0008dcc6095fb594 with fixes.

Differential Revision: https://reviews.llvm.org/D154336

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Dialect/Transform/Utils/Utils.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/python/mlir/dialects/_structured_transform_ops_ext.py
    mlir/python/mlir/ir.py
    mlir/test/Dialect/Linalg/transform-op-tile.mlir
    mlir/test/Dialect/Transform/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 3830d65b99e38e..5faaf32246d103 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1690,7 +1690,7 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
                    Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
-                   DefaultValuedOptionalAttr<BoolAttr, "false">:$last_tile_size_scalable);
+                   DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
   let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
                       Variadic<TransformHandleTypeInterface>:$loops);
   let builders = [
@@ -2012,9 +2012,10 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
   let arguments = (ins TransformHandleTypeInterface:$target,
                        Variadic<TransformHandleTypeInterface>:$vector_sizes,
                        UnitAttr:$vectorize_nd_extract,
+                       DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
+                          $scalable_sizes,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
-                          $static_vector_sizes,
-                       DefaultValuedOptionalAttr<BoolAttr, "false">:$last_vector_size_scalable);
+                          $static_vector_sizes);
 
   let results = (outs);
   let assemblyFormat = [{
@@ -2022,7 +2023,7 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
       `vector_sizes` custom<DynamicIndexList>($vector_sizes,
                                               $static_vector_sizes,
                                               type($vector_sizes),
-                                              $last_vector_size_scalable)
+                                              $scalable_sizes)
       attr-dict
       `:` type($target)
   }];

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index fad380d4005f1c..65ef514908d181 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -52,13 +52,15 @@ namespace mlir {
 /// integer attributes in a list. E.g.
 /// `[%arg0 : index, 7, 42, %arg42 : i32]`.
 ///
-/// If  `isTrailingIdxScalable` is true, then wrap the trailing index with
-/// square brackets, e.g. `[42]`, to denote scalability. This would normally be
-/// used for scalable tile or vector sizes.
+/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
+/// This notation is similar to how scalable dims are marked when defining
+/// Vectors. For each value in `integers`, the corresponding `bool` in
+/// `scalables` encodes whether it's a scalable index. If `scalableVals` is
+/// empty then assume that all indices are non-scalable.
 void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
-    BoolAttr isTrailingIdxScalable = {},
+    ArrayRef<bool> scalables = {},
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
 
 /// Parser hook for custom directive in assemblyFormat.
@@ -78,41 +80,43 @@ void printDynamicIndexList(
 ///   `kDynamic`]"
 ///   2. `ssa` is filled with "[%arg0, %arg1]".
 ///
-/// Trailing indices can be scalable. For example, "42" in "[7, [42]]" is
-/// scalable. This notation is similar to how scalable dims are marked when
-/// defining Vectors. If /p isTrailingIdxScalable is null, scalable indices are
-/// not allowed/expected. When it's not null, this hook will set the
-/// corresponding value to:
-///   * true if the trailing idx is scalable,
-///   * false otherwise.
+/// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable.
+/// This notation is similar to how scalable dims are marked when defining
+/// Vectors. For each value in `integers`, the corresponding `bool` in
+/// `scalableVals` encodes whether it's a scalable index.
 ParseResult parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable = nullptr,
+    DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals,
     SmallVectorImpl<Type> *valueTypes = nullptr,
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
+inline ParseResult parseDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals,
+                               valueTypes, delimiter);
+}
 inline ParseResult parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
-  return parseDynamicIndexList(parser, values, integers,
-                               /*isTrailingIdxScalable=*/nullptr, &valueTypes,
-                               delimiter);
+  DenseBoolArrayAttr scalableVals = {};
+  return parseDynamicIndexList(parser, values, integers, scalableVals,
+                               &valueTypes, delimiter);
 }
 inline ParseResult parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
     DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
-    BoolAttr &isTrailingIdxScalable,
+    DenseBoolArrayAttr &scalableVals,
     AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
 
-  bool scalable = false;
-  auto res = parseDynamicIndexList(parser, values, integers, &scalable,
-                                   &valueTypes, delimiter);
-  auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
-  isTrailingIdxScalable = scalableAttr;
-  return res;
+  return parseDynamicIndexList(parser, values, integers, scalableVals,
+                               &valueTypes, delimiter);
 }
 
 /// Verify that a the `values` has as many elements as the number of entries in

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 28c1a019652072..041f9b97e5a36b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2451,7 +2451,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
   SmallVector<Operation *> tiled;
   SmallVector<SmallVector<Operation *, 4>, 4> loops;
   loops.resize(getLoops().size());
-  bool scalable = getLastTileSizeScalable();
+  auto scalableSizes = getScalableSizes();
   for (auto [i, op] : llvm::enumerate(targets)) {
     auto tilingInterface = dyn_cast<TilingInterface>(op);
     auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
@@ -2470,12 +2470,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
         SmallVector<Value, 4> sizes;
         sizes.reserve(tileSizes.size());
         unsigned dynamicIdx = 0;
-        unsigned trailingIdx = getMixedSizes().size() - 1;
 
         for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
           if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
-            // Only the trailing tile size is allowed to be scalable atm.
-            if (scalable && (ofrIdx == trailingIdx)) {
+            if (scalableSizes[ofrIdx]) {
               auto val = b.create<arith::ConstantIndexOp>(
                   getLoc(), attr.cast<IntegerAttr>().getInt());
               Value vscale =
@@ -2577,9 +2575,10 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
   DenseI64ArrayAttr staticSizes;
   FunctionType functionalType;
   llvm::SMLoc operandLoc;
-  bool scalable = false;
+  DenseBoolArrayAttr scalableVals;
+
   if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
-      parseDynamicIndexList(parser, dynamicSizes, staticSizes, &scalable) ||
+      parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
       parseOptionalInterchange(parser, result) ||
       parser.parseColonType(functionalType))
     return ParseResult::failure();
@@ -2602,9 +2601,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
     return failure();
   }
 
-  auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
-  result.addAttribute(getLastTileSizeScalableAttrName(result.name),
-                      scalableAttr);
+  result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
 
   result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
   result.addTypes(functionalType.getResults());
@@ -2614,7 +2611,7 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
 void TileOp::print(OpAsmPrinter &p) {
   p << ' ' << getTarget();
   printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
-                        /*valueTypes=*/{}, getLastTileSizeScalableAttr(),
+                        /*valueTypes=*/{}, getScalableSizesAttr(),
                         OpAsmParser::Delimiter::Square);
   printOptionalInterchange(p, getInterchange());
   p << " : ";
@@ -3161,15 +3158,14 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
   }
 
   // TODO: Check that the correct number of vectorSizes was provided.
-  SmallVector<bool> scalableVecDims(vectorSizes.size(), false);
-  scalableVecDims.back() = getLastVectorSizeScalable();
   for (Operation *target : targets) {
     if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
 
-    if (failed(linalg::vectorize(rewriter, target, vectorSizes, scalableVecDims,
+    if (failed(linalg::vectorize(rewriter, target, vectorSizes,
+                                 getScalableSizes(),
                                  getVectorizeNdExtract()))) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 01cfa679a77511..ddcdffebed1392 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1254,20 +1254,20 @@ void ForallOp::print(OpAsmPrinter &p) {
   if (isNormalized()) {
     p << ") in ";
     printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
-                          /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+                          /*valueTypes=*/{}, /*scalables=*/{},
                           OpAsmParser::Delimiter::Paren);
   } else {
     p << ") = ";
     printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
-                          /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+                          /*valueTypes=*/{}, /*scalables=*/{},
                           OpAsmParser::Delimiter::Paren);
     p << " to ";
     printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
-                          /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+                          /*valueTypes=*/{}, /*scalables=*/{},
                           OpAsmParser::Delimiter::Paren);
     p << " step ";
     printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
-                          /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+                          /*valueTypes=*/{}, /*scalable=*/{},
                           OpAsmParser::Delimiter::Paren);
   }
   printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
@@ -1299,9 +1299,9 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
       dynamicSteps;
   if (succeeded(parser.parseOptionalKeyword("in"))) {
     // Parse upper bounds.
-    if (parseDynamicIndexList(
-            parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
-            /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
+    if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
+                              /*valueTypes=*/nullptr,
+                              OpAsmParser::Delimiter::Paren) ||
         parser.resolveOperands(dynamicUbs, indexType, result.operands))
       return failure();
 
@@ -1311,26 +1311,26 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
   } else {
     // Parse lower bounds.
     if (parser.parseEqual() ||
-        parseDynamicIndexList(
-            parser, dynamicLbs, staticLbs, /*isTrailingIdxScalable=*/nullptr,
-            /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
+        parseDynamicIndexList(parser, dynamicLbs, staticLbs,
+                              /*valueTypes=*/nullptr,
+                              OpAsmParser::Delimiter::Paren) ||
 
         parser.resolveOperands(dynamicLbs, indexType, result.operands))
       return failure();
 
     // Parse upper bounds.
     if (parser.parseKeyword("to") ||
-        parseDynamicIndexList(
-            parser, dynamicUbs, staticUbs, /*isTrailingIdxScalable=*/nullptr,
-            /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
+        parseDynamicIndexList(parser, dynamicUbs, staticUbs,
+                              /*valueTypes=*/nullptr,
+                              OpAsmParser::Delimiter::Paren) ||
         parser.resolveOperands(dynamicUbs, indexType, result.operands))
       return failure();
 
     // Parse step values.
     if (parser.parseKeyword("step") ||
-        parseDynamicIndexList(
-            parser, dynamicSteps, staticSteps, /*scalable=*/nullptr,
-            /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
+        parseDynamicIndexList(parser, dynamicSteps, staticSteps,
+                              /*valueTypes=*/nullptr,
+                              OpAsmParser::Delimiter::Paren) ||
         parser.resolveOperands(dynamicSteps, indexType, result.operands))
       return failure();
   }

diff  --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
index e7516423fb58c7..d516a56feed478 100644
--- a/mlir/lib/Dialect/Transform/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
@@ -42,6 +42,5 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList(
     return success();
   }
 
-  return parseDynamicIndexList(parser, values, integers,
-                               /*isTrailingIdxScalable=*/nullptr, &valueTypes);
+  return parseDynamicIndexList(parser, values, integers, &valueTypes);
 }

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 0f75cc10fc8234..667f66bb99610b 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -102,8 +102,7 @@ static char getRightDelimiter(AsmParser::Delimiter delimiter) {
 void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                  OperandRange values,
                                  ArrayRef<int64_t> integers,
-                                 TypeRange valueTypes,
-                                 BoolAttr isTrailingIdxScalable,
+                                 TypeRange valueTypes, ArrayRef<bool> scalables,
                                  AsmParser::Delimiter delimiter) {
   char leftDelimiter = getLeftDelimiter(delimiter);
   char rightDelimiter = getRightDelimiter(delimiter);
@@ -113,33 +112,24 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
     return;
   }
 
-  int64_t trailingScalableInteger;
-  if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
-    // ATM only the trailing idx can be scalable
-    trailingScalableInteger = integers.back();
-    integers = integers.drop_back();
-  }
-
-  unsigned idx = 0;
+  unsigned dynamicValIdx = 0;
+  unsigned scalableIndexIdx = 0;
   llvm::interleaveComma(integers, printer, [&](int64_t integer) {
+    if (not scalables.empty() && scalables[scalableIndexIdx])
+      printer << "[";
     if (ShapedType::isDynamic(integer)) {
-      printer << values[idx];
+      printer << values[dynamicValIdx];
       if (!valueTypes.empty())
-        printer << " : " << valueTypes[idx];
-      ++idx;
+        printer << " : " << valueTypes[dynamicValIdx];
+      ++dynamicValIdx;
     } else {
       printer << integer;
     }
-  });
+    if (!scalables.empty() && scalables[scalableIndexIdx])
+      printer << "]";
 
-  // Print the trailing scalable index
-  if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
-    if (!integers.empty())
-      printer << ", ";
-    printer << "[";
-    printer << trailingScalableInteger;
-    printer << "]";
-  }
+    scalableIndexIdx++;
+  });
 
   printer << rightDelimiter;
 }
@@ -147,25 +137,17 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
 ParseResult mlir::parseDynamicIndexList(
     OpAsmParser &parser,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable,
+    DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
     SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
 
   SmallVector<int64_t, 4> integerVals;
-  bool foundScalable = false;
+  SmallVector<bool, 4> scalableVals;
   auto parseIntegerOrValue = [&]() {
     OpAsmParser::UnresolvedOperand operand;
     auto res = parser.parseOptionalOperand(operand);
 
-    // If `foundScalable` has already been set to `true` then a non-trailing
-    // index was identified as scalable.
-    if (foundScalable) {
-      parser.emitError(parser.getNameLoc())
-          << "non-trailing index cannot be scalable";
-      return failure();
-    }
-
-    if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded())
-      foundScalable = true;
+    // When encountering `[`, assume that this is a scalable index.
+    scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
 
     if (res.has_value() && succeeded(res.value())) {
       values.push_back(operand);
@@ -178,7 +160,10 @@ ParseResult mlir::parseDynamicIndexList(
         return failure();
       integerVals.push_back(integer);
     }
-    if (foundScalable && parser.parseOptionalRSquare().failed())
+
+    // If this is assumed to be a scalable index, verify that there's a closing
+    // `]`.
+    if (scalableVals.back() && parser.parseOptionalRSquare().failed())
       return failure();
     return success();
   };
@@ -187,8 +172,7 @@ ParseResult mlir::parseDynamicIndexList(
     return parser.emitError(parser.getNameLoc())
            << "expected SSA value or integer";
   integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
-  if (isTrailingIdxScalable)
-    *isTrailingIdxScalable = foundScalable;
+  scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
   return success();
 }
 

diff  --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
index 47c1bbb31c0b7a..190b3bc9124ea3 100644
--- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py
@@ -14,6 +14,9 @@
 IntOrAttrList = Sequence[Union[IntegerAttr, int]]
 OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
 
+BoolOrAttrList = Sequence[Union[BoolAttr, bool]]
+OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]]
+
 
 def _get_int_int_array_attr(
     values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]]
@@ -226,6 +229,7 @@ def __init__(
             Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
         ] = None,
         interchange: OptionalIntList = None,
+        scalable_sizes: OptionalBoolList = None,
         loc=None,
         ip=None,
     ):
@@ -240,6 +244,7 @@ def __init__(
             Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
         ] = None,
         interchange: OptionalIntList = None,
+        scalable_sizes: OptionalBoolList = None,
         loc=None,
         ip=None,
     ):
@@ -254,6 +259,7 @@ def __init__(
             Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
         ] = None,
         interchange: OptionalIntList = None,
+        scalable_sizes: OptionalBoolList = None,
         loc=None,
         ip=None,
     ):
@@ -261,6 +267,8 @@ def __init__(
             interchange = []
         if sizes is None:
             sizes = []
+        if scalable_sizes is None:
+            scalable_sizes = []
 
         static_sizes = []
         dynamic_sizes = []
@@ -298,6 +306,7 @@ def __init__(
             dynamic_sizes=dynamic_sizes,
             static_sizes=sizes_attr,
             interchange=interchange,
+            scalable_sizes=scalable_sizes,
             loc=loc,
             ip=ip,
         )

diff  --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 10a0f5bd2c6b95..76077acb6a579c 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -105,6 +105,10 @@ def _f64ArrayAttr(x, context):
 def _denseI64ArrayAttr(x, context):
     return DenseI64ArrayAttr.get(x, context=context)
 
+ at register_attribute_builder("DenseBoolArrayAttr")
+def _denseBoolArrayAttr(x, context):
+    return DenseBoolArrayAttr.get(x, context=context)
+
 
 @register_attribute_builder("TypeAttr")
 def _typeAttr(x, context):

diff  --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index 3300e869979780..8b449770ee8a1b 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -220,25 +220,3 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 }
-
-// -----
-
-// TODO: Add support for for specyfying more than one scalable tile size
-
-func.func @scalable_and_fixed_length_tile(
-  %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
-    -> tensor<128x128xf32> {
-  %0 = linalg.matmul  ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
-                     outs(%arg2: tensor<128x128xf32>)
-    -> tensor<128x128xf32>
-
-  return %0 : tensor<128x128xf32>
-}
-
-transform.sequence failures(propagate) {
-^bb0(%arg1: !transform.any_op):
-  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{non-trailing index cannot be scalable}}
-  // expected-error @below {{expected SSA value or integer}}
-  %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-}

diff  --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir
index 7ddfcc60718730..dc35a9a6c9032d 100644
--- a/mlir/test/Dialect/Transform/ops.mlir
+++ b/mlir/test/Dialect/Transform/ops.mlir
@@ -105,3 +105,11 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 }
+
+// CHECK: transform.sequence
+// CHECK: transform.structured.tile %0{{\[}}[2], 4, 8]
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.structured.tile %0 [[2], 4, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}


        


More information about the Mlir-commits mailing list