[Mlir-commits] [mlir] f275148 - [mlir][Linalg] Better builders for transform ops
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Dec 14 06:23:55 PST 2022
Author: Nicolas Vasilache
Date: 2022-12-14T06:22:52-08:00
New Revision: f27514800cc50677d640deae555bf999653a4c6f
URL: https://github.com/llvm/llvm-project/commit/f27514800cc50677d640deae555bf999653a4c6f
DIFF: https://github.com/llvm/llvm-project/commit/f27514800cc50677d640deae555bf999653a4c6f.diff
LOG: [mlir][Linalg] Better builders for transform ops
Also adopt DenseI64ArrayAttr in those transform ops.
Differential Revision: https://reviews.llvm.org/D140009
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 9c80f56332cef..1cac6b83a1e9c 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -721,14 +721,24 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
```
}];
+ // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
let results = (outs PDL_Operation:$for_op,
PDL_Operation:$fill_op,
PDL_Operation:$split_linalg_op,
PDL_Operation:$combining_linalg_op);
- let assemblyFormat = "$target attr-dict";
+ let builders = [
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$staticTileSizes)>
+ ];
+
+ let assemblyFormat = [{
+ $target
+ `by` `tile_sizes` `=` $tile_sizes
+ attr-dict
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -808,16 +818,31 @@ def TileReductionUsingForeachThreadOp :
```
}];
+ // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
let arguments = (ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$num_threads,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
+ DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
let results = (outs PDL_Operation:$foreach_thread_op,
PDL_Operation:$fill_op,
PDL_Operation:$split_linalg_op,
PDL_Operation:$combining_linalg_op);
- let assemblyFormat = "$target attr-dict";
+ let builders = [
+ OpBuilder<(ins "Value":$target,
+ "ArrayRef<int64_t>":$staticNumThreads,
+ "ArrayRef<int64_t>":$staticTileSizes,
+ CArg<"ArrayAttr", "{}">:$mapping)>
+ ];
+
+ let assemblyFormat = [{
+ $target
+ `by`
+ (`num_threads` `=` $num_threads^)?
+ (`,` `tile_sizes` `=` $tile_sizes^)?
+ (`,` `mapping` `=` $mapping^)?
+ attr-dict
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -825,6 +850,7 @@ def TileReductionUsingForeachThreadOp :
::llvm::SmallVectorImpl<::mlir::Operation *> &results,
::mlir::transform::TransformState &state);
}];
+
}
def TileOp : Op<Transform_Dialect, "structured.tile",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 853321f22dba7..c8995e609e778 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1200,20 +1200,31 @@ transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
// TileReductionUsingScfOp
//===----------------------------------------------------------------------===//
+void transform::TileReductionUsingScfOp::build(
+ OpBuilder &builder, OperationState &result, Value target,
+ ArrayRef<int64_t> staticTileSizes) {
+ // Call the default builder.
+ // This is future-proof re mixed static-dynamic and setting up the proper
+ // operands segment sizes attributes for multiple variadic operands.
+ // In the absence of this, horrible bugs ensue.
+ // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
+ MLIRContext *ctx = builder.getContext();
+ auto opTy = pdl::OperationType::get(ctx);
+ auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+ build(builder, result,
+ /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
+ /*target=*/target,
+ /*tile_sizes=*/staticTileSizesAttr);
+}
+
DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
- SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
- SmallVector<OpFoldResult> sizes;
- for (int64_t size : tileSizes) {
- sizes.push_back(rewriter.getIndexAttr(size));
- }
-
FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
- sizes);
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
if (failed(result))
return emitDefaultSilenceableFailure(target);
@@ -1228,14 +1239,37 @@ DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
// TileReductionUsingForeachThreadOp
//===----------------------------------------------------------------------===//
+void transform::TileReductionUsingForeachThreadOp::build(
+ OpBuilder &builder, OperationState &result, Value target,
+ ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
+ ArrayAttr mapping) {
+ // Call the default builder.
+ // This is future-proof re mixed static-dynamic and setting up the proper
+ // operands segment sizes attributes for multiple variadic operands.
+ // In the absence of this, horrible bugs ensue.
+ // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
+ MLIRContext *ctx = builder.getContext();
+ auto opTy = pdl::OperationType::get(ctx);
+ auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
+ auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
+ build(builder, result,
+ /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
+ /*target=*/target,
+ /*num_threads=*/staticNumThreadsAttr,
+ /*tile_sizes=*/staticTileSizesAttr,
+ /*mapping=*/mapping);
+}
+
DiagnosedSilenceableFailure
transform::TileReductionUsingForeachThreadOp::applyToOne(
linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
- SmallVector<OpFoldResult> numThreads = getAsOpFoldResult(getNumThreads());
- SmallVector<OpFoldResult> tileSizes = getAsOpFoldResult(getTileSizes());
+ SmallVector<OpFoldResult> numThreads =
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
+ SmallVector<OpFoldResult> tileSizes =
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
FailureOr<linalg::ForeachThreadReductionTilingResult> result =
linalg::tileReductionUsingForeachThread(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
index 13aec82b10a44..370930661bd20 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
@@ -17,7 +17,8 @@ func.func @reduction_tile(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>) -> tensor
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 5] }
+ %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0
+ by tile_sizes = [0, 5]
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
@@ -71,7 +72,8 @@ func.func @reduction_tile_transpose(%arg0: tensor<?x?xf32>, %out: tensor<?xf32>)
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [5, 0] }
+ %loop, %1, %2, %3 = transform.structured.tile_reduction_using_scf %0
+ by tile_sizes = [5, 0]
}
// CHECK: func @reduction_tile_transpose
@@ -107,7 +109,8 @@ func.func @reduction_tile_parallel(
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
- %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5] }
+ %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0
+ by num_threads = [0, 5], tile_sizes = []
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
@@ -159,7 +162,8 @@ func.func @matmul_tile_parallel(
transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
- %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 0, 5] }
+ %loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0
+ by num_threads = [0, 0, 5], tile_sizes = []
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)>
@@ -219,7 +223,7 @@ transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0
- { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
+ by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>]
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)>
@@ -285,7 +289,7 @@ transform.sequence failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
%loop, %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0
- { num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>] }
+ by num_threads = [0, 5], tile_sizes = [0, 3], mapping = [#gpu.thread<x>]
// CHECK: expecting fill
// CHECK-NEXT: linalg.fill
More information about the Mlir-commits
mailing list