[Mlir-commits] [mlir] 7a52f79 - [mlir][transform] Add support for expressing scalable vector sizes

Andrzej Warzynski llvmlistbot at llvm.org
Thu Jun 8 12:54:49 PDT 2023


Author: Andrzej Warzynski
Date: 2023-06-08T20:54:17+01:00
New Revision: 7a52f79126a59717012d8039ef875f68e3c637fd

URL: https://github.com/llvm/llvm-project/commit/7a52f79126a59717012d8039ef875f68e3c637fd
DIFF: https://github.com/llvm/llvm-project/commit/7a52f79126a59717012d8039ef875f68e3c637fd.diff

LOG: [mlir][transform] Add support for expressing scalable vector sizes

This patch enables specifying scalable vector sizes when using the
Transform dialect to drive vectorisation, e.g.:
```
transform.structured.masked_vectorize %0 vector_sizes [8, 16, [4]]
```
This is implemented by extending the MaskedVectorizeOp with a dedicated
attribute for "scalability" and by overloading `parseDynamicIndexList`
so that MaskedVectorizeOp can continue using the auto-generated parser
and printer.

At the moment, only the trailing vec size can be scalable. The following
is not yet supported:
```
transform.structured.masked_vectorize %0 vector_sizes [8, [16], [4]]
```

As the vectoriser does not support scalable vectorisation just yet, a
warning is issues when scalable vector sizes are used. You can also use
the debug output, `--debug-only=linalg-vectorization`, to check whether
scalable vectorisation has been switched on.

This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
  * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/
Similar patch for tiling: https://reviews.llvm.org/D150944

Differential Revision: https://reviews.llvm.org/D151892

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/include/mlir/Interfaces/ViewLikeInterface.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Interfaces/ViewLikeInterface.cpp
    mlir/test/Dialect/Linalg/transform-op-tile.mlir
    mlir/test/Dialect/Linalg/vectorization-masked.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6ba9f9aff42d1..3e2cb78bf9d12 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1897,13 +1897,16 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
                        Variadic<TransformHandleTypeInterface>:$vector_sizes,
                        UnitAttr:$vectorize_nd_extract,
                        DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
-                          $static_vector_sizes);
+                          $static_vector_sizes,
+                       DefaultValuedOptionalAttr<BoolAttr, "false">:$last_vector_size_scalable);
+
   let results = (outs);
   let assemblyFormat = [{
       $target
       `vector_sizes` custom<DynamicIndexList>($vector_sizes,
                                               $static_vector_sizes,
-                                              type($vector_sizes))
+                                              type($vector_sizes),
+                                              $last_vector_size_scalable)
       attr-dict
       `:` type($target)
   }];

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1d7f448ff180a..b2c5bd17a4793 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -592,7 +592,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
 /// dynamic shapes.
 LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
                         ArrayRef<int64_t> inputVectorSizes = {},
-                        bool vectorizeNDExtract = false);
+                        bool vectorizeNDExtract = false,
+                        bool lastVectorSizeScalable = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

diff  --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 82a563a03c3ac..fad380d4005f1 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -58,8 +58,8 @@ namespace mlir {
 void printDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
     ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
-    bool isTrailingIdxScalable = false);
+    BoolAttr isTrailingIdxScalable = {},
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
 
 /// Parser hook for custom directive in assemblyFormat.
 ///
@@ -100,6 +100,20 @@ inline ParseResult parseDynamicIndexList(
                                /*isTrailingIdxScalable=*/nullptr, &valueTypes,
                                delimiter);
 }
+inline ParseResult parseDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
+    BoolAttr &isTrailingIdxScalable,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+
+  bool scalable = false;
+  auto res = parseDynamicIndexList(parser, values, integers, &scalable,
+                                   &valueTypes, delimiter);
+  auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
+  isTrailingIdxScalable = scalableAttr;
+  return res;
+}
 
 /// Verify that a the `values` has as many elements as the number of entries in
 /// `attr` for which `isDynamic` evaluates to true.

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f7fa2f107754b..c534e78d3404a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2590,8 +2590,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
 void TileOp::print(OpAsmPrinter &p) {
   p << ' ' << getTarget();
   printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
-                        /*valueTypes=*/{}, OpAsmParser::Delimiter::Square,
-                        getLastTileSizeScalable());
+                        /*valueTypes=*/{}, getLastTileSizeScalableAttr(),
+                        OpAsmParser::Delimiter::Square);
   printOptionalInterchange(p, getInterchange());
   p << " : ";
   p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
@@ -3091,7 +3091,6 @@ transform::VectorizeOp::applyToOne(Operation *target,
 //===----------------------------------------------------------------------===//
 // MaskedVectorizeOp
 //===----------------------------------------------------------------------===//
-
 DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
     mlir::transform::TransformResults &transformResults,
     mlir::transform::TransformState &state) {
@@ -3146,7 +3145,8 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
     }
 
     if (failed(linalg::vectorize(rewriter, target, vectorSizes,
-                                 getVectorizeNdExtract()))) {
+                                 getVectorizeNdExtract(),
+                                 getLastVectorSizeScalable()))) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Attempted to vectorize, but failed";
     }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d081e1a90a0d2..90ab75cbcc910 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1529,11 +1529,16 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
 /// operations with dynamic shapes.
 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                       ArrayRef<int64_t> inputVectorSizes,
-                                      bool vectorizeNDExtract) {
+                                      bool vectorizeNDExtract,
+                                      bool lastVectorSizeScalable) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
   LLVM_DEBUG(llvm::dbgs() << "\n");
+  LDBG("Scalable vectorisation: " << lastVectorSizeScalable << "\n");
+
+  if (lastVectorSizeScalable)
+    op->emitWarning("Scalable vectorization is not supported yet");
 
   if (failed(
           vectorizeOpPrecondition(op, inputVectorSizes, vectorizeNDExtract))) {

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 3c531bc99cff2..db69195a7c704 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1220,17 +1220,21 @@ void ForallOp::print(OpAsmPrinter &p) {
   if (isNormalized()) {
     p << ") in ";
     printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
-                          /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren);
+                          /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+                          OpAsmParser::Delimiter::Paren);
   } else {
     p << ") = ";
     printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
-                          /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren);
+                          /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+                          OpAsmParser::Delimiter::Paren);
     p << " to ";
     printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
-                          /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren);
+                          /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+                          OpAsmParser::Delimiter::Paren);
     p << " step ";
     printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
-                          /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren);
+                          /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+                          OpAsmParser::Delimiter::Paren);
   }
   printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
   p << " ";

diff  --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index d0310730ca79b..0f75cc10fc823 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -103,8 +103,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
                                  OperandRange values,
                                  ArrayRef<int64_t> integers,
                                  TypeRange valueTypes,
-                                 AsmParser::Delimiter delimiter,
-                                 bool isTrailingIdxScalable) {
+                                 BoolAttr isTrailingIdxScalable,
+                                 AsmParser::Delimiter delimiter) {
   char leftDelimiter = getLeftDelimiter(delimiter);
   char rightDelimiter = getRightDelimiter(delimiter);
   printer << leftDelimiter;
@@ -114,7 +114,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
   }
 
   int64_t trailingScalableInteger;
-  if (isTrailingIdxScalable) {
+  if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
     // ATM only the trailing idx can be scalable
     trailingScalableInteger = integers.back();
     integers = integers.drop_back();
@@ -133,8 +133,9 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
   });
 
   // Print the trailing scalable index
-  if (isTrailingIdxScalable) {
-    printer << ", ";
+  if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
+    if (!integers.empty())
+      printer << ", ";
     printer << "[";
     printer << trailingScalableInteger;
     printer << "]";
@@ -156,10 +157,10 @@ ParseResult mlir::parseDynamicIndexList(
     auto res = parser.parseOptionalOperand(operand);
 
     // If `foundScalable` has already been set to `true` then a non-trailing
-    // tile size was identified as scalable.
+    // index was identified as scalable.
     if (foundScalable) {
       parser.emitError(parser.getNameLoc())
-          << "non-trailing tile size cannot be scalable";
+          << "non-trailing index cannot be scalable";
       return failure();
     }
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index e00a48429ed56..3300e86997978 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -238,7 +238,7 @@ func.func @scalable_and_fixed_length_tile(
 transform.sequence failures(propagate) {
 ^bb0(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  // expected-error @below {{non-trailing tile size cannot be scalable}}
+  // expected-error @below {{non-trailing index cannot be scalable}}
   // expected-error @below {{expected SSA value or integer}}
   %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
 }

diff  --git a/mlir/test/Dialect/Linalg/vectorization-masked.mlir b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
index 65b8b5b38461e..1b1202532ca27 100644
--- a/mlir/test/Dialect/Linalg/vectorization-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file --verify-diagnostics | FileCheck %s
 
 func.func @vectorize_dynamic_identity(%arg0: tensor<?xf32>,
                                       %arg1: tensor<?xf32>,
@@ -484,3 +484,18 @@ transform.sequence failures(propagate) {
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] : !transform.any_op
 }
+
+// -----
+
+func.func @vectorize_dynamic_matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
+  // expected-warning @+1 {{Scalable vectorization is not supported yet}}
+  linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+            outs(%C: memref<?x?xf32>)
+  return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  transform.structured.masked_vectorize %0 vector_sizes [8, 16, [4]] : !transform.any_op
+}


        


More information about the Mlir-commits mailing list