[Mlir-commits] [mlir] a5b3677 - [mlir][transform] Add support for expressing scalable tile sizes
Andrzej Warzynski
llvmlistbot at llvm.org
Thu Jun 1 01:28:31 PDT 2023
Author: Andrzej Warzynski
Date: 2023-06-01T09:28:03+01:00
New Revision: a5b3677ddc4eb0a080f9b80ac82a56d39f952350
URL: https://github.com/llvm/llvm-project/commit/a5b3677ddc4eb0a080f9b80ac82a56d39f952350
DIFF: https://github.com/llvm/llvm-project/commit/a5b3677ddc4eb0a080f9b80ac82a56d39f952350.diff
LOG: [mlir][transform] Add support for expressing scalable tile sizes
This patch enables specifying scalable tile sizes when using the
Transform dialect to drive tiling, e.g.:
```
%1, %loop = transform.structured.tile %0 [[4]]
```
This is implemented by extending the TileOp with a dedicated attribute
for "scalability" and by updating various parsing hooks. At the moment,
only the trailing tile size can be scalable. The following is not yet
supported:
```
%1, %loop = transform.structured.tile %0 [[4], [4]]
```
This change is a part of larger effort to enable scalable vectorisation
in Linalg. See this RFC for more context:
* https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/
Differential Revision: https://reviews.llvm.org/D150944
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/Transform/Utils/Utils.cpp
mlir/lib/Interfaces/ViewLikeInterface.cpp
mlir/test/Dialect/Linalg/transform-op-tile.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 06ef84b43f04b..856eac88b36e9 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -1528,7 +1528,8 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
let arguments = (ins TransformHandleTypeInterface:$target,
Variadic<TransformParamTypeOrAnyHandle>:$dynamic_sizes,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange);
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
+ DefaultValuedOptionalAttr<BoolAttr, "false">:$last_tile_size_scalable);
let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
Variadic<TransformHandleTypeInterface>:$loops);
let builders = [
diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
index 87113197524ff..cab2a0bcc11b1 100644
--- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h
@@ -72,17 +72,27 @@ void printDynamicIndexList(
/// 1. `result` is filled with the i64 ArrayAttr "[`kDynamic`, 7, 42,
/// `kDynamic`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
+///
+/// Trailing indices can be scalable. For example, "42" in "[7, [42]]" is
+/// scalable. This notation is similar to how scalable dims are marked when
+/// defining Vectors. If /p isTrailingIdxScalable is null, scalable indices are
+/// not allowed/expected. When it's not null, this hook will set the
+/// corresponding value to:
+/// * true if the trailing idx is scalable,
+/// * false otherwise.
ParseResult parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
+ DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable = nullptr,
+ SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
inline ParseResult parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> &valueTypes,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
- return parseDynamicIndexList(parser, values, integers, &valueTypes,
+ return parseDynamicIndexList(parser, values, integers,
+ /*isTrailingIdxScalable=*/nullptr, &valueTypes,
delimiter);
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a6a3fbb2e23b8..51dcd7e17c0f5 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2391,6 +2391,7 @@ transform::TileOp::apply(TransformResults &transformResults,
SmallVector<Operation *> tiled;
SmallVector<SmallVector<Operation *, 4>, 4> loops;
loops.resize(getLoops().size());
+ bool scalable = getLastTileSizeScalable();
for (auto [i, op] : llvm::enumerate(targets)) {
auto tilingInterface = dyn_cast<TilingInterface>(op);
auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
@@ -2409,10 +2410,21 @@ transform::TileOp::apply(TransformResults &transformResults,
SmallVector<Value, 4> sizes;
sizes.reserve(tileSizes.size());
unsigned dynamicIdx = 0;
- for (OpFoldResult ofr : getMixedSizes()) {
+ unsigned trailingIdx = getMixedSizes().size() - 1;
+
+ for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
- sizes.push_back(b.create<arith::ConstantIndexOp>(
- getLoc(), cast<IntegerAttr>(attr).getInt()));
+ // Only the trailing tile size is allowed to be scalable atm.
+ if (scalable && (ofrIdx == trailingIdx)) {
+ auto val = b.create<arith::ConstantIndexOp>(
+ getLoc(), attr.cast<IntegerAttr>().getInt());
+ Value vscale =
+ b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType());
+ sizes.push_back(b.create<arith::MulIOp>(getLoc(), val, vscale));
+ } else {
+ sizes.push_back(b.create<arith::ConstantIndexOp>(
+ getLoc(), cast<IntegerAttr>(attr).getInt()));
+ }
continue;
}
ArrayRef<Operation *> dynamicSizes = dynamicSizeProducers[dynamicIdx];
@@ -2507,8 +2519,9 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
DenseI64ArrayAttr staticSizes;
FunctionType functionalType;
llvm::SMLoc operandLoc;
+ bool scalable = false;
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
- parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
+ parseDynamicIndexList(parser, dynamicSizes, staticSizes, &scalable) ||
parseOptionalInterchange(parser, result) ||
parser.parseColonType(functionalType))
return ParseResult::failure();
@@ -2531,6 +2544,10 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
return failure();
}
+ auto scalableAttr = parser.getBuilder().getBoolAttr(scalable);
+ result.addAttribute(getLastTileSizeScalableAttrName(result.name),
+ scalableAttr);
+
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
result.addTypes(functionalType.getResults());
return success();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index acfeb0f1e205d..c8d64201cb2a2 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1261,9 +1261,9 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
dynamicSteps;
if (succeeded(parser.parseOptionalKeyword("in"))) {
// Parse upper bounds.
- if (parseDynamicIndexList(parser, dynamicUbs, staticUbs,
- /*valueTypes=*/nullptr,
- OpAsmParser::Delimiter::Paren) ||
+ if (parseDynamicIndexList(
+ parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
+ /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
return failure();
@@ -1273,26 +1273,26 @@ ParseResult ForallOp::parse(OpAsmParser &parser, OperationState &result) {
} else {
// Parse lower bounds.
if (parser.parseEqual() ||
- parseDynamicIndexList(parser, dynamicLbs, staticLbs,
- /*valueTypes=*/nullptr,
- OpAsmParser::Delimiter::Paren) ||
+ parseDynamicIndexList(
+ parser, dynamicLbs, staticLbs, /*scalable=*/nullptr,
+ /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicLbs, indexType, result.operands))
return failure();
// Parse upper bounds.
if (parser.parseKeyword("to") ||
- parseDynamicIndexList(parser, dynamicUbs, staticUbs,
- /*valueTypes=*/nullptr,
- OpAsmParser::Delimiter::Paren) ||
+ parseDynamicIndexList(
+ parser, dynamicUbs, staticUbs, /*scalable=*/nullptr,
+ /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicUbs, indexType, result.operands))
return failure();
// Parse step values.
if (parser.parseKeyword("step") ||
- parseDynamicIndexList(parser, dynamicSteps, staticSteps,
- /*valueTypes=*/nullptr,
- OpAsmParser::Delimiter::Paren) ||
+ parseDynamicIndexList(
+ parser, dynamicSteps, staticSteps, /*scalable=*/nullptr,
+ /*valueTypes=*/nullptr, OpAsmParser::Delimiter::Paren) ||
parser.resolveOperands(dynamicSteps, indexType, result.operands))
return failure();
}
diff --git a/mlir/lib/Dialect/Transform/Utils/Utils.cpp b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
index d516a56feed47..b50a7660e2bf2 100644
--- a/mlir/lib/Dialect/Transform/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Transform/Utils/Utils.cpp
@@ -42,5 +42,6 @@ ParseResult mlir::transform::parsePackedOrDynamicIndexList(
return success();
}
- return parseDynamicIndexList(parser, values, integers, &valueTypes);
+ return parseDynamicIndexList(parser, values, integers, /*scalable=*/nullptr,
+ &valueTypes);
}
diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp
index 4f48f0a57c307..13cca8131b682 100644
--- a/mlir/lib/Interfaces/ViewLikeInterface.cpp
+++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp
@@ -128,13 +128,26 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
ParseResult mlir::parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes,
- AsmParser::Delimiter delimiter) {
+ DenseI64ArrayAttr &integers, bool *isTrailingIdxScalable,
+ SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
SmallVector<int64_t, 4> integerVals;
+ bool foundScalable = false;
auto parseIntegerOrValue = [&]() {
OpAsmParser::UnresolvedOperand operand;
auto res = parser.parseOptionalOperand(operand);
+
+ // If `foundScalable` has already been set to `true` then a non-trailing
+ // tile size was identified as scalable.
+ if (foundScalable) {
+ parser.emitError(parser.getNameLoc())
+ << "non-trailing tile size cannot be scalable";
+ return failure();
+ }
+
+ if (isTrailingIdxScalable && parser.parseOptionalLSquare().succeeded())
+ foundScalable = true;
+
if (res.has_value() && succeeded(res.value())) {
values.push_back(operand);
integerVals.push_back(ShapedType::kDynamic);
@@ -146,6 +159,8 @@ ParseResult mlir::parseDynamicIndexList(
return failure();
integerVals.push_back(integer);
}
+ if (foundScalable && parser.parseOptionalRSquare().failed())
+ return failure();
return success();
};
if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue,
@@ -153,6 +168,8 @@ ParseResult mlir::parseDynamicIndexList(
return parser.emitError(parser.getNameLoc())
<< "expected SSA value or integer";
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
+ if (isTrailingIdxScalable)
+ *isTrailingIdxScalable = foundScalable;
return success();
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index f005752bfc034..e00a48429ed56 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --test-transform-dialect-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics %s | FileCheck %s
transform.sequence failures(propagate) {
^bb0(%arg1: !transform.any_op):
@@ -149,3 +149,96 @@ transform.sequence failures(propagate) {
transform.structured.tile_to_forall_op %0 tile_sizes[1, 1]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
}
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+
+module {
+ func.func @scalable_tile(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: f32) -> tensor<?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) outs(%arg2 : tensor<?xf32>) {
+ ^bb0(%in_1: f32, %in_2: f32, %out: f32):
+ %1 = arith.addf %in_1, %in_2 : f32
+ %2 = arith.mulf %arg3, %1 : f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+ }
+}
+
+// CHECK-LABEL: func.func @scalable_tile(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>,
+// CHECK: %[[C4:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C4]] : tensor<?xf32>
+// CHECK: %[[VEC_SIZE:.*]] = arith.constant 4 : index
+// CHECK: %[[VS:.*]] = vector.vscale
+// CHECK: %[[STEP:.*]] = arith.muli %[[VEC_SIZE]], %[[VS]] : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[STEP]] iter_args(%[[VAL:.*]] = %[[ARG_2]]) -> (tensor<?xf32>) {
+// CHECK: %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%[[IV]])[%[[STEP]], %[[DIM]]]
+// CHECK: %[[SLICE_ARG0:.*]] = tensor.extract_slice %[[ARG_0]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK: %[[SLICE_ARG1:.*]] = tensor.extract_slice %[[ARG_1]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK: %[[SLICE_ARG2:.*]] = tensor.extract_slice %[[VAL]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
+// CHECK: linalg.generic {indexing_maps = {{.*}}, iterator_types = ["parallel"]} ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] : tensor<?xf32>, tensor<?xf32>) outs(%[[SLICE_ARG2]] : tensor<?xf32>) {
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loop = transform.structured.tile %0 [[4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scalable_and_fixed_length_tile
+// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
+// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: %[[VS:.*]] = vector.vscale
+// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
+// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK: %[[C128_1:.*]] = arith.constant 128 : index
+// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
+// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+// CHECK: %[[C128_2:.*]] = arith.constant 128 : index
+// CHECK: scf.for %{{.*}} = %[[C0_2]] to %[[C128_2]] step %[[STEP_2]]
+
+func.func @scalable_and_fixed_length_tile(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ return %0 : tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1, %loops:3 = transform.structured.tile %0 [4, 4, [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}
+
+// -----
+
+// TODO: Add support for for specyfying more than one scalable tile size
+
+func.func @scalable_and_fixed_length_tile(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+
+ return %0 : tensor<128x128xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{non-trailing tile size cannot be scalable}}
+ // expected-error @below {{expected SSA value or integer}}
+ %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+}
More information about the Mlir-commits
mailing list