[Mlir-commits] [mlir] faae4d5 - [mlir][vector][transform] Expose tensor slice -> transfer folding patterns
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 9 07:27:16 PDT 2023
Author: Matthias Springer
Date: 2023-06-09T16:23:25+02:00
New Revision: faae4d5d8127b999a0cd8d00cae6237aba407c06
URL: https://github.com/llvm/llvm-project/commit/faae4d5d8127b999a0cd8d00cae6237aba407c06
DIFF: https://github.com/llvm/llvm-project/commit/faae4d5d8127b999a0cd8d00cae6237aba407c06.diff
LOG: [mlir][vector][transform] Expose tensor slice -> transfer folding patterns
Add a new transform op to populate patterns: ApplyFoldTensorSliceIntoTransferPatternsOp.
Differential Revision: https://reviews.llvm.org/D152531
Added:
Modified:
mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 253aeedf15aba..806c3f9fca50d 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,4 +306,16 @@ def ApplyTransferToScfPatternsOp : Op<Transform_Dialect,
}];
}
+def ApplyFoldTensorSliceIntoTransferPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.fold_tensor_slice_into_transfer",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that tensor.extract_slice -> vector.transfer_read and
+ vector.transfer_write -> tensor.insert_slice op chains should be folded into
+ vector tranfer read and write ops
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
#endif // VECTOR_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index da99232ed6ab8..505cb5c11253a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -138,6 +138,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
}
+void transform::ApplyFoldTensorSliceIntoTransferPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ populateVectorTransferTensorSliceTransforms(patterns);
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
index cc17025fe0f1e..acd401704eb5b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
@@ -1,4 +1,11 @@
-// RUN: mlir-opt -split-input-file -test-vector-transfer-tensor-slice-patterns %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+ transform.apply_patterns to %module_op {
+ transform.apply_patterns.vector.fold_tensor_slice_into_transfer
+ } : !transform.any_op
+}
// CHECK-LABEL: func @transfer_read_of_extract_slice(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
@@ -16,16 +23,14 @@ func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2
return %1 : vector<5x6xf32>
}
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
+// CHECK-LABEL: func @transfer_read_of_extract_slice_1d(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
// CHECK: return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
+func.func @transfer_read_of_extract_slice_1d(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.0 : f32
@@ -34,8 +39,6 @@ func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2
return %1 : vector<6xf32>
}
-// -----
-
// CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
@@ -53,8 +56,6 @@ func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>,
return %1 : vector<5x6xf32>
}
-// -----
-
// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
// CHECK: extract_slice
// CHECK: vector.transfer_read
@@ -67,8 +68,6 @@ func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x
return %1 : vector<5x6xf32>
}
-// -----
-
// CHECK-LABEL: func @insert_slice_of_transfer_write(
// CHECK-SAME: %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
// CHECK: %[[c3:.*]] = arith.constant 3 : index
@@ -81,8 +80,6 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
return %1 : tensor<?x12xf32>
}
-// -----
-
// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
// CHECK: vector.transfer_write
// CHECK: insert_slice
@@ -93,8 +90,6 @@ func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x
return %1 : tensor<?x?x12xf32>
}
-// -----
-
// CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
// CHECK-SAME: %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 4fbddcee574a1..a5de1fd4de431 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -690,26 +690,6 @@ struct TestVectorGatherLowering
}
};
-struct TestVectorTransferTensorSlicePatterns
- : public PassWrapper<TestVectorTransferTensorSlicePatterns,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestVectorTransferTensorSlicePatterns)
-
- StringRef getArgument() const final {
- return "test-vector-transfer-tensor-slice-patterns";
- }
- StringRef getDescription() const final {
- return "Test patterns that fold vector transfer and tensor slice ops";
- }
-
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateVectorTransferTensorSliceTransforms(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- }
-};
-
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -771,8 +751,6 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorGatherLowering>();
- PassRegistration<TestVectorTransferTensorSlicePatterns>();
-
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
}
} // namespace test
More information about the Mlir-commits
mailing list