[Mlir-commits] [mlir] 7a52f79 - [mlir][transform] Add support for expressing scalable vector sizes
Andrzej Warzynski
llvmlistbot at llvm.org
Thu Jun 8 12:54:49 PDT 2023
Author: Andrzej Warzynski
Date: 2023-06-08T20:54:17+01:00
New Revision: 7a52f79126a59717012d8039ef875f68e3c637fd
URL: https://github.com/llvm/llvm-project/commit/7a52f79126a59717012d8039ef875f68e3c637fd
DIFF: https://github.com/llvm/llvm-project/commit/7a52f79126a59717012d8039ef875f68e3c637fd.diff
LOG: [mlir][transform] Add support for expressing scalable vector sizes
This patch enables specifying scalable vector sizes when using the
Transform dialect to drive vectorisation, e.g.:
```
transform.structured.masked_vectorize %0 vector_sizes [8, 16, [4]]
```
This is implemented by extending the MaskedVectorizeOp with a dedicated
attribute for "scalability" and by overloading `parseDynamicIndexList`
so that MaskedVectorizeOp can continue using the auto-generated parser
and printer.
At the moment, only the trailing vec size can be scalable. The following
is not yet supported:
```
transform.structured.masked_vectorize %0 vector_sizes [8, [16], [4]]
```
As the vectoriser does not support scalable vectorisation just yet, a
warning is issues when scalable vector sizes are used. You can also use
the debug output, `--debug-only=linalg-vectorization`, to check whether
scalable vectorisation has been switched on.
This change is a part of a larger effort to enable scalable
vectorisation in Linalg. See this RFC for more context:
* https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/
Similar patch for tiling: https://reviews.llvm.org/D150944
Differential Revision: https://reviews.llvm.org/D151892
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/Dialect/Linalg/transform-op-tile.mlir
mlir/test/Dialect/Linalg/vectorization-masked.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 6ba9f9aff42d1..3e2cb78bf9d12 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1897,13 +1897,16 @@ def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
Variadic<TransformHandleTypeInterface>:$vector_sizes,
UnitAttr:$vectorize_nd_extract,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
- $static_vector_sizes);
+ $static_vector_sizes,
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$last_vector_size_scalable);
+
let results = (outs);
let assemblyFormat = [{
$target
`vector_sizes` custom<DynamicIndexList>($vector_sizes,
$static_vector_sizes,
- type($vector_sizes))
+ type($vector_sizes),
+ $last_vector_size_scalable)
attr-dict
`:` type($target)
}];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1d7f448ff180a..b2c5bd17a4793 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -592,7 +592,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
/// dynamic shapes.
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
- bool vectorizeNDExtract = false);
+ bool vectorizeNDExtract = false,
+ bool lastVectorSizeScalable = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 82a563a03c3ac..fad380d4005f1 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -58,8 +58,8 @@ namespace mlir {
void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
- AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square,
- bool isTrailingIdxScalable = false);
+ BoolAttr isTrailingIdxScalable = {},
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
/// Parser hook for custom directive in assemblyFormat.
///
@@ -100,6 +100,20 @@ inline ParseResult parseDynamicIndexList(
/*isTrailingIdxScalable=*/nullptr, &valueTypes,
delimiter);
}
+inline ParseResult parseDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
+ BoolAttr &isTrailingIdxScalable,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+
+ bool scalable = false;
+ auto res = parseDynamicIndexList(parser, values, integers, &scalable,
+ &valueTypes, delimiter);
+ auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
+ isTrailingIdxScalable = scalableAttr;
+ return res;
+}
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f7fa2f107754b..c534e78d3404a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2590,8 +2590,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
void TileOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
- /*valueTypes=*/{}, OpAsmParser::Delimiter::Square,
- getLastTileSizeScalable());
+ /*valueTypes=*/{}, getLastTileSizeScalableAttr(),
+ OpAsmParser::Delimiter::Square);
printOptionalInterchange(p, getInterchange());
p << " : ";
p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
@@ -3091,7 +3091,6 @@ transform::VectorizeOp::applyToOne(Operation *target,
//===----------------------------------------------------------------------===//
// MaskedVectorizeOp
//===----------------------------------------------------------------------===//
-
DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
mlir::transform::TransformResults &transformResults,
mlir::transform::TransformState &state) {
@@ -3146,7 +3145,8 @@ DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
}
if (failed(linalg::vectorize(rewriter, target, vectorSizes,
- getVectorizeNdExtract()))) {
+ getVectorizeNdExtract(),
+ getLastVectorSizeScalable()))) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index d081e1a90a0d2..90ab75cbcc910 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1529,11 +1529,16 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
/// operations with dynamic shapes.
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
- bool vectorizeNDExtract) {
+ bool vectorizeNDExtract,
+ bool lastVectorSizeScalable) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
+ LDBG("Scalable vectorisation: " << lastVectorSizeScalable << "\n");
+
+ if (lastVectorSizeScalable)
+ op->emitWarning("Scalable vectorization is not supported yet");
if (failed(
vectorizeOpPrecondition(op, inputVectorSizes, vectorizeNDExtract))) {
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 3c531bc99cff2..db69195a7c704 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1220,17 +1220,21 @@ void ForallOp::print(OpAsmPrinter &p) {
if (isNormalized()) {
p << ") in ";
printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
- /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren);
+ /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+ OpAsmParser::Delimiter::Paren);
} else {
p << ") = ";
printDynamicIndexList(p, op, getDynamicLowerBound(), getStaticLowerBound(),
- /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren);
+ /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+ OpAsmParser::Delimiter::Paren);
p << " to ";
printDynamicIndexList(p, op, getDynamicUpperBound(), getStaticUpperBound(),
- /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren);
+ /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+ OpAsmParser::Delimiter::Paren);
p << " step ";
printDynamicIndexList(p, op, getDynamicStep(), getStaticStep(),
- /*valueTypes=*/{}, OpAsmParser::Delimiter::Paren);
+ /*valueTypes=*/{}, /*=isTrailingIdxScalable=*/{},
+ OpAsmParser::Delimiter::Paren);
}
printInitializationList(p, getRegionOutArgs(), getOutputs(), " shared_outs");
p << " ";
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index d0310730ca79b..0f75cc10fc823 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -103,8 +103,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
TypeRange valueTypes,
- AsmParser::Delimiter delimiter,
- bool isTrailingIdxScalable) {
+ BoolAttr isTrailingIdxScalable,
+ AsmParser::Delimiter delimiter) {
char leftDelimiter = getLeftDelimiter(delimiter);
char rightDelimiter = getRightDelimiter(delimiter);
printer << leftDelimiter;
@@ -114,7 +114,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
}
int64_t trailingScalableInteger;
- if (isTrailingIdxScalable) {
+ if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
// ATM only the trailing idx can be scalable
trailingScalableInteger = integers.back();
integers = integers.drop_back();
@@ -133,8 +133,9 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
});
// Print the trailing scalable index
- if (isTrailingIdxScalable) {
- printer << ", ";
+ if (isTrailingIdxScalable && isTrailingIdxScalable.getValue()) {
+ if (!integers.empty())
+ printer << ", ";
printer << "[";
printer << trailingScalableInteger;
printer << "]";
@@ -156,10 +157,10 @@ ParseResult mlir::parseDynamicIndexList(
auto res = parser.parseOptionalOperand(operand);
// If `foundScalable` has already been set to `true` then a non-trailing
- // tile size was identified as scalable.
+ // index was identified as scalable.
if (foundScalable) {
parser.emitError(parser.getNameLoc())
- << "non-trailing tile size cannot be scalable";
+ << "non-trailing index cannot be scalable";
return failure();
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index e00a48429ed56..3300e86997978 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -238,7 +238,7 @@ func.func @scalable_and_fixed_length_tile(
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-error @below {{non-trailing tile size cannot be scalable}}
+ // expected-error @below {{non-trailing index cannot be scalable}}
// expected-error @below {{expected SSA value or integer}}
%1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
}
diff --git a/mlir/test/Dialect/Linalg/vectorization-masked.mlir b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
index 65b8b5b38461e..1b1202532ca27 100644
--- a/mlir/test/Dialect/Linalg/vectorization-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-masked.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file --verify-diagnostics | FileCheck %s
func.func @vectorize_dynamic_identity(%arg0: tensor<?xf32>,
%arg1: tensor<?xf32>,
@@ -484,3 +484,18 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.masked_vectorize %0 vector_sizes [8, 16, 4] : !transform.any_op
}
+
+// -----
+
+func.func @vectorize_dynamic_matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
+ // expected-warning @+1 {{Scalable vectorization is not supported yet}}
+ linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.masked_vectorize %0 vector_sizes [8, 16, [4]] : !transform.any_op
+}
More information about the Mlir-commits
mailing list