[Mlir-commits] [mlir] 4c7225d - [mlir][Transform] Fix implementation of the generic apply that is based on applyToOne.
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jun 23 05:29:28 PDT 2022
Author: Nicolas Vasilache
Date: 2022-06-23T05:28:09-07:00
New Revision: 4c7225d19a9d1ff62c0ae39049ca3afe2a46c571
URL: https://github.com/llvm/llvm-project/commit/4c7225d19a9d1ff62c0ae39049ca3afe2a46c571
DIFF: https://github.com/llvm/llvm-project/commit/4c7225d19a9d1ff62c0ae39049ca3afe2a46c571.diff
LOG: [mlir][Transform] Fix implementation of the generic apply that is based on applyToOne.
The result of applying an N-result producing transformation to M payload ops
is an M-wide result, each containing N result operations.
This requires a transposition of the results obtained by calling `applyToOne`.
This revision fixes the issue and adds more advanced tests that exercise the behavior.
Differential Revision: https://reviews.llvm.org/D128414
Added:
mlir/test/Dialect/Transform/selective-targeting.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
mlir/test/Dialect/Transform/test-interpreter.mlir
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8f0dc16d35ab7..f3e42cefb2d45 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -348,7 +348,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
does **not** fail when no ops were vectorized.
Note that this transformation is invalidating the handles to any payload IR
- operation that is contained inside the vectoriztaion target.
+ operation that is contained inside the vectorization target.
}];
let arguments = (ins PDL_Operation:$target,
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 390d6c6240657..e6fbfc88e31ee 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -794,6 +794,24 @@ applyTransformToEach(ArrayRef<Operation *> targets,
}
return DiagnosedSilenceableFailure::success();
}
+
+/// Helper function to transform M ops with N results into N results of M ops.
+static inline SmallVector<SmallVector<Operation *, 1>>
+transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
+ SmallVector<SmallVector<Operation *, 1>> res;
+ if (m.empty())
+ return res;
+ int64_t rows = m.size(), cols = m[0].size();
+ for (int64_t j = 0; j < cols; ++j)
+ res.push_back(SmallVector<Operation *, 1>(rows, nullptr));
+ for (int64_t i = 0; i < rows; ++i) {
+ assert(static_cast<int64_t>(m[i].size()) == cols);
+ for (int64_t j = 0; j < cols; ++j) {
+ res[j][i] = m[i][j];
+ }
+ }
+ return res;
+}
} // namespace detail
} // namespace transform
} // namespace mlir
@@ -815,27 +833,51 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
});
if (!result.succeeded())
return result;
- for (const SmallVector<Operation *> &oneTargetResults : results) {
- if (OpTy::template hasTrait<OpTrait::ZeroResults>())
- continue;
- if (OpTy::template hasTrait<OpTrait::OneResult>()) {
- transformResults.set(
- this->getOperation()->getResult(0).template cast<OpResult>(),
- oneTargetResults);
- continue;
- }
- if (this->getOperation()->getNumResults() != oneTargetResults.size()) {
- Diagnostic diag(this->getOperation()->getLoc(),
- DiagnosticSeverity::Error);
- diag << "unexpected number of results (got " << oneTargetResults.size()
- << " expected " << this->getOperation()->getNumResults() << ")";
- return DiagnosedSilenceableFailure::silencableFailure(std::move(diag));
- }
- for (const auto &it :
- llvm::zip(this->getOperation()->getResults(), oneTargetResults)) {
- transformResults.set(std::get<0>(it).template cast<OpResult>(),
- std::get<1>(it));
- }
+ if (results.empty())
+ return DiagnosedSilenceableFailure::success();
+
+ // Ensure all applications return the same number of results.
+ // Variadic cases are much trickier to handle in a generic fashion.
+ int64_t nRes = results[0].size();
+ if (llvm::any_of(results, [&](const auto &r) {
+ return static_cast<int64_t>(r.size()) != nRes;
+ })) {
+ return static_cast<OpTy *>(this)->emitSilenceableError()
+ << "expected all applications of " << OpTy::getOperationName()
+ << " to produce " << nRes
+ << " results.\n If you need variadic results, consider using a "
+ "generic `apply` instead of the specialized `applyToOne`";
+ }
+ // Ensure the number of results agrees with what the transform op expects.
+ if (this->getOperation()->getNumResults() != nRes) {
+ InFlightDiagnostic diag = static_cast<OpTy *>(this)->emitError()
+ << "unexpected number of results (got " << nRes
+ << " expected "
+ << this->getOperation()->getNumResults() << ")";
+ return DiagnosedSilenceableFailure::definiteFailure();
+ }
+
+ // If no results, bail early.
+ if (OpTy::template hasTrait<OpTrait::ZeroResults>())
+ return DiagnosedSilenceableFailure::success();
+
+ // Perform transposition of M applications producing N results each into N
+ // results for each of the M applications.
+ SmallVector<SmallVector<Operation *, 1>> transposedResults =
+ detail::transposeResults(results);
+ // Single result applies to M ops produces one single M-result.
+ if (OpTy::template hasTrait<OpTrait::OneResult>()) {
+ assert(transposedResults.size() == 1 && "Expected single result");
+ transformResults.set(
+ this->getOperation()->getResult(0).template cast<OpResult>(),
+ transposedResults[0]);
+ return DiagnosedSilenceableFailure::success();
+ }
+ // M ops, N results each.
+ for (const auto &it :
+ llvm::zip(this->getOperation()->getResults(), transposedResults)) {
+ transformResults.set(std::get<0>(it).template cast<OpResult>(),
+ std::get<1>(it));
}
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/test/Dialect/Transform/selective-targeting.mlir b/mlir/test/Dialect/Transform/selective-targeting.mlir
new file mode 100644
index 0000000000000..eee4c2bf54456
--- /dev/null
+++ b/mlir/test/Dialect/Transform/selective-targeting.mlir
@@ -0,0 +1,154 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @matmul_tensors_1(
+func.func @matmul_tensors_1(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+ %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // This operation is marked for tiling only.
+ // CHECK-COUNT-3: scf.for
+ // CHECK-COUNT-3: tensor.extract_slice
+ // CHECK: linalg.matmul
+ // CHECK-SAME: -> tensor<4x4xf32>
+ %0 = linalg.matmul { test.attrA }
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ func.return %0 : tensor<128x128xf32>
+}
+
+func.func @matmul_tensors_2(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+ %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // This operation is marked f
+ // This operation is marked for tiling and vectorization.
+ // CHECK-COUNT-3: scf.for
+ // CHECK-COUNT-3: vector.transfer_read
+ // CHECK: vector.contract
+ // CHECK-NOT: linalg.matmul
+ // CHECK: vector.transfer_write
+ %0 = linalg.matmul { test.attrA, test.attrC }
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ func.return %0 : tensor<128x128xf32>
+}
+
+func.func @matmul_tensors_3(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+ %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // This operation is marked for vectorization only.
+ // CHECK-NOT: scf.for
+ // CHECK-COUNT-3: vector.transfer_read
+ // CHECK: vector.contract
+ // CHECK-SAME: into vector<128x128xf32>
+ // CHECK: vector.transfer_write
+ %0 = linalg.matmul { test.attrC }
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ func.return %0 : tensor<128x128xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ // Match matmul operations inside @matmul_tensors with test.attrA set.
+ pdl.pattern @pdl_target_attrA : benefit(1) {
+ %args = operands
+ %results = types
+ %attr = attribute
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ // Match matmul operations inside @matmul_tensors with test.attrC set.
+ pdl.pattern @pdl_target_attrC : benefit(1) {
+ %args = operands
+ %results = types
+ %attr = attribute
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrC" = %attr}-> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target_attrA in %arg1
+ transform.structured.tile %0 {sizes = [4, 4, 4]}
+ %1 = pdl_match @pdl_target_attrC in %arg1
+ %2 = transform.get_closest_isolated_parent %1
+ transform.structured.vectorize %2
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @vectorize_one
+func.func @vectorize_one(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+ %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // CHECK: vector.contract
+ %0 = linalg.matmul {test.attrA}
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ func.return %0 : tensor<128x128xf32>
+}
+
+func.func @vectorize_none(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>,
+ %arg2: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // CHECK: linalg.matmul
+ %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ func.return %0 : tensor<128x128xf32>
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @pdl_target : benefit(1) {
+ %args = operands
+ %results = types
+ %attr = attribute
+ %0 = operation "linalg.matmul"(%args : !pdl.range<value>) {"test.attrA" = %attr}-> (%results : !pdl.range<type>)
+ // TODO: we don't want this, but it is the required terminator for pdl.pattern
+ rewrite %0 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = pdl_match @pdl_target in %arg1
+ %1 = get_closest_isolated_parent %0
+ transform.structured.vectorize %1
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @vectorize_all
+func.func @vectorize_all(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>,
+ %arg3: tensor<128x128xf32> {linalg.inplaceable = true})
+ -> tensor<128x128xf32> {
+ // CHECK: vector.contract
+ %0 = linalg.matmul {test.attrA}
+ ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ // CHECK: vector.contract
+ %1 = linalg.matmul ins(%arg0, %0: tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg3: tensor<128x128xf32>)
+ -> tensor<128x128xf32>
+ return %1 : tensor<128x128xf32>
+}
+
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+ transform.structured.vectorize %arg0
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index 6a54b0ce6f59e..a487d1dbef19c 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --verify-diagnostics
+// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
// expected-remark @below {{applying transformation}}
transform.test_transform_op
@@ -385,3 +385,54 @@ transform.sequence {
// expected-error @below {{unexpected number of results (got 0 expected 3)}}
transform.test_wrong_number_of_results %arg0
}
+
+// -----
+
+func.func @foo() {
+ "op" () : () -> ()
+ "op" () : () -> ()
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @some : benefit(1) {
+ %0 = pdl.operands
+ %1 = pdl.types
+ %2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ pdl.rewrite %2 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = pdl_match @some in %arg1
+ // expected-error @below {{expected all applications of transform.test_wrong_number_of_multi_results to produce 1 results}}
+ transform.test_wrong_number_of_multi_results %0
+ }
+}
+
+// -----
+
+func.func @foo() {
+ "op" () : () -> ()
+ "op" () : () -> ()
+ "op" () : () -> ()
+ return
+}
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ pdl.pattern @some : benefit(1) {
+ %0 = pdl.operands
+ %1 = pdl.types
+ %2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
+ pdl.rewrite %2 with "transform.dialect"
+ }
+
+ transform.sequence %arg0 {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = pdl_match @some in %arg1
+ // Transform matches 3 ops and produces 2 results.
+ %1:2 = transform.test_correct_number_of_multi_results %0
+ }
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index c48a7936adb12..9242a51dbb840 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -232,6 +232,24 @@ mlir::test::TestWrongNumberOfResultsOp::applyToOne(
return SmallVector<Operation *>{};
}
+FailureOr<SmallVector<Operation *>>
+mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne(
+ Operation *op, transform::TransformState &state) {
+ static int count = 0;
+ if (count++ > 0)
+ return SmallVector<Operation *>{};
+ OperationState opState(op->getLoc(), "foo");
+ return SmallVector<Operation *>{OpBuilder(op).create(opState)};
+}
+
+FailureOr<SmallVector<Operation *>>
+mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne(
+ Operation *op, transform::TransformState &state) {
+ OperationState opState(op->getLoc(), "foo");
+ return SmallVector<Operation *>{OpBuilder(op).create(opState),
+ OpBuilder(op).create(opState)};
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 1b8ddb9649c34..78eade06d0d65 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -144,4 +144,33 @@ def TestWrongNumberOfResultsOp
}];
}
+def TestWrongNumberOfMultiResultsOp
+ : Op<Transform_Dialect, "test_wrong_number_of_multi_results",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$result);
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "::mlir::test";
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+ ::mlir::Operation *target, transform::TransformState &state);
+ }];
+}
+
+def TestCorrectNumberOfMultiResultsOp
+ : Op<Transform_Dialect, "test_correct_number_of_multi_results",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let arguments = (ins PDL_Operation:$target);
+ let results = (outs PDL_Operation:$result1,
+ PDL_Operation:$result2);
+ let assemblyFormat = "$target attr-dict";
+ let cppNamespace = "::mlir::test";
+ let extraClassDeclaration = [{
+ ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
+ ::mlir::Operation *target, transform::TransformState &state);
+ }];
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
More information about the Mlir-commits
mailing list