[Mlir-commits] [mlir] 9b11323 - [mlir][linalg][transform] Fix TileOp builder

Matthias Springer llvmlistbot at llvm.org
Thu Jul 6 02:44:50 PDT 2023


Author: Matthias Springer
Date: 2023-07-06T11:40:33+02:00
New Revision: 9b113239048a0195b7a6794b9b8a592b32e65c4e

URL: https://github.com/llvm/llvm-project/commit/9b113239048a0195b7a6794b9b8a592b32e65c4e
DIFF: https://github.com/llvm/llvm-project/commit/9b113239048a0195b7a6794b9b8a592b32e65c4e.diff

LOG: [mlir][linalg][transform] Fix TileOp builder

The TileOp builders did not set `scalable_sizes`, which produces invalid ops. `scalable_sizes` must contain as any booleans as there are sizes.

Differential Revision: https://reviews.llvm.org/D154585

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 5faaf32246d103..8be49d05c2f0e0 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1697,21 +1697,29 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
     OpBuilder<(ins "TypeRange":$loopTypes,
                    "Value":$target,
                    "ArrayRef<int64_t>":$staticTileSizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$interchange)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$interchange,
+                   CArg<"std::optional<ArrayRef<bool>>", "std::nullopt">:
+                      $scalableSizes)>,
     OpBuilder<(ins "TypeRange":$loopTypes,
                    "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedTileSizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$interchange)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$interchange,
+                   CArg<"std::optional<ArrayRef<bool>>", "std::nullopt">:
+                      $scalableSizes)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<int64_t>":$staticTileSizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$interchange)>,
+                   CArg<"ArrayRef<int64_t>", "{}">:$interchange,
+                   CArg<"std::optional<ArrayRef<bool>>", "std::nullopt">:
+                      $scalableSizes)>,
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedTileSizes,
-                   CArg<"ArrayRef<int64_t>", "{}">:$interchange)>
-
+                   CArg<"ArrayRef<int64_t>", "{}">:$interchange,
+                   CArg<"std::optional<ArrayRef<bool>>", "std::nullopt">:
+                      $scalableSizes)>,
   ];
 
   let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns the list of tile sizes, which may be static (Attribute) or

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 041f9b97e5a36b..a9e675b902e765 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2332,36 +2332,41 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
 void transform::TileOp::build(OpBuilder &builder, OperationState &result,
                               TypeRange loopTypes, Value target,
                               ArrayRef<int64_t> staticTileSizes,
-                              ArrayRef<int64_t> interchange) {
+                              ArrayRef<int64_t> interchange,
+                              std::optional<ArrayRef<bool>> scalableSizes) {
   return build(builder, result, loopTypes,
                /*target=*/target,
                /*mixedTileSizes=*/
                getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
-               interchange);
+               interchange, scalableSizes);
 }
 
 void transform::TileOp::build(OpBuilder &builder, OperationState &result,
                               Value target, ArrayRef<int64_t> staticTileSizes,
-                              ArrayRef<int64_t> interchange) {
+                              ArrayRef<int64_t> interchange,
+                              std::optional<ArrayRef<bool>> scalableSizes) {
   build(builder, result, target,
         getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
-        interchange);
+        interchange, scalableSizes);
 }
 
 void transform::TileOp::build(OpBuilder &builder, OperationState &result,
                               Value target,
                               ArrayRef<OpFoldResult> mixedTileSizes,
-                              ArrayRef<int64_t> interchange) {
+                              ArrayRef<int64_t> interchange,
+                              std::optional<ArrayRef<bool>> scalableSizes) {
   // Loop types are automaticaly splat by the callee, setting up one is
   // enough.
   SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
-  build(builder, result, loopTypes, target, mixedTileSizes, interchange);
+  build(builder, result, loopTypes, target, mixedTileSizes, interchange,
+        scalableSizes);
 }
 
 void transform::TileOp::build(OpBuilder &builder, OperationState &result,
                               TypeRange loopTypes, Value target,
                               ArrayRef<OpFoldResult> mixedTileSizes,
-                              ArrayRef<int64_t> interchange) {
+                              ArrayRef<int64_t> interchange,
+                              std::optional<ArrayRef<bool>> scalableSizes) {
   SmallVector<int64_t> staticTileSizes;
   SmallVector<Value> dynamicTileSizes;
   dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
@@ -2379,12 +2384,24 @@ void transform::TileOp::build(OpBuilder &builder, OperationState &result,
     resultTypes.append(numExpectedLoops, loopTypes[0]);
   else
     llvm::append_range(resultTypes, loopTypes);
+  SmallVector<bool> expandedScalableSizes(mixedTileSizes.size(), false);
+  if (scalableSizes.has_value())
+    expandedScalableSizes.assign(scalableSizes->begin(), scalableSizes->end());
   build(builder, result, /*tiled_linalg_op=*/target.getType(),
         /*loops=*/resultTypes,
         /*target=*/target,
         /*dynamic_sizes=*/dynamicTileSizes,
         /*static_sizes=*/staticTileSizesAttr,
-        /*interchange=*/builder.getDenseI64ArrayAttr(interchange));
+        /*interchange=*/builder.getDenseI64ArrayAttr(interchange),
+        /*scalable_sizes=*/expandedScalableSizes);
+}
+
+LogicalResult transform::TileOp::verify() {
+  if (getMixedSizes().size() != getScalableSizes().size())
+    return emitOpError("expected same number of sizes (")
+           << getMixedSizes().size() << ") and scalable sizes ()"
+           << getScalableSizes().size() << ")";
+  return success();
 }
 
 DiagnosedSilenceableFailure


        


More information about the Mlir-commits mailing list