[Mlir-commits] [mlir] faae4d5 - [mlir][vector][transform] Expose tensor slice -> transfer folding patterns

Matthias Springer llvmlistbot at llvm.org
Fri Jun 9 07:27:16 PDT 2023


Author: Matthias Springer
Date: 2023-06-09T16:23:25+02:00
New Revision: faae4d5d8127b999a0cd8d00cae6237aba407c06

URL: https://github.com/llvm/llvm-project/commit/faae4d5d8127b999a0cd8d00cae6237aba407c06
DIFF: https://github.com/llvm/llvm-project/commit/faae4d5d8127b999a0cd8d00cae6237aba407c06.diff

LOG: [mlir][vector][transform] Expose tensor slice -> transfer folding patterns

Add a new transform op to populate patterns: ApplyFoldTensorSliceIntoTransferPatternsOp.

Differential Revision: https://reviews.llvm.org/D152531

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
    mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
    mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 253aeedf15aba..806c3f9fca50d 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,4 +306,16 @@ def ApplyTransferToScfPatternsOp : Op<Transform_Dialect,
   }];
 }
 
+def ApplyFoldTensorSliceIntoTransferPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.fold_tensor_slice_into_transfer",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that tensor.extract_slice -> vector.transfer_read and
+    vector.transfer_write -> tensor.insert_slice op chains should be folded into
+    vector tranfer read and write ops
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // VECTOR_TRANSFORM_OPS

diff  --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index da99232ed6ab8..505cb5c11253a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -138,6 +138,11 @@ void transform::ApplyTransferToScfPatternsOp::populatePatterns(
   populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions);
 }
 
+void transform::ApplyFoldTensorSliceIntoTransferPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  populateVectorTransferTensorSliceTransforms(patterns);
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
index cc17025fe0f1e..acd401704eb5b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir
@@ -1,4 +1,11 @@
-// RUN: mlir-opt -split-input-file -test-vector-transfer-tensor-slice-patterns %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  transform.apply_patterns to %module_op {
+    transform.apply_patterns.vector.fold_tensor_slice_into_transfer
+  } : !transform.any_op
+}
 
 // CHECK-LABEL: func @transfer_read_of_extract_slice(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
@@ -16,16 +23,14 @@ func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2
   return %1 : vector<5x6xf32>
 }
 
-// -----
-
-// CHECK-LABEL: func @transfer_read_of_extract_slice(
+// CHECK-LABEL: func @transfer_read_of_extract_slice_1d(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
 //   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
 //   CHECK-DAG:   %[[c8:.*]] = arith.constant 8 : index
 //       CHECK:   %[[add:.*]] = arith.addi %[[s1]], %[[c4]]
 //       CHECK:   %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<6xf32>
 //       CHECK:   return %[[r]]
-func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
+func.func @transfer_read_of_extract_slice_1d(%t : tensor<?x?xf32>, %s1 : index, %s2 : index) -> vector<6xf32> {
   %c3 = arith.constant 3 : index
   %c4 = arith.constant 4 : index
   %cst = arith.constant 0.0 : f32
@@ -34,8 +39,6 @@ func.func @transfer_read_of_extract_slice(%t : tensor<?x?xf32>, %s1 : index, %s2
   return %1 : vector<6xf32>
 }
 
-// -----
-
 // CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?xf32>, %[[s1:.*]]: index, %[[s2:.*]]: index
 //   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index
@@ -53,8 +56,6 @@ func.func @transfer_read_of_extract_slice_rank_reducing(%t : tensor<?x?x?xf32>,
   return %1 : vector<5x6xf32>
 }
 
-// -----
-
 // CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing(
 //       CHECK:   extract_slice
 //       CHECK:   vector.transfer_read
@@ -67,8 +68,6 @@ func.func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor<?x?x
   return %1 : vector<5x6xf32>
 }
 
-// -----
-
 // CHECK-LABEL: func @insert_slice_of_transfer_write(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
 //       CHECK:   %[[c3:.*]] = arith.constant 3 : index
@@ -81,8 +80,6 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
   return %1 : tensor<?x12xf32>
 }
 
-// -----
-
 // CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending(
 //       CHECK:   vector.transfer_write
 //       CHECK:   insert_slice
@@ -93,8 +90,6 @@ func.func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor<?x
   return %1 : tensor<?x?x12xf32>
 }
 
-// -----
-
 // CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending(
 //  CHECK-SAME:     %[[t1:.*]]: tensor<?x?x12xf32>, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index
 //   CHECK-DAG:   %[[c3:.*]] = arith.constant 3 : index

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 4fbddcee574a1..a5de1fd4de431 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -690,26 +690,6 @@ struct TestVectorGatherLowering
   }
 };
 
-struct TestVectorTransferTensorSlicePatterns
-    : public PassWrapper<TestVectorTransferTensorSlicePatterns,
-                         OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      TestVectorTransferTensorSlicePatterns)
-
-  StringRef getArgument() const final {
-    return "test-vector-transfer-tensor-slice-patterns";
-  }
-  StringRef getDescription() const final {
-    return "Test patterns that fold vector transfer and tensor slice ops";
-  }
-
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateVectorTransferTensorSliceTransforms(patterns);
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-  }
-};
-
 struct TestFoldArithExtensionIntoVectorContractPatterns
     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
                          OperationPass<func::FuncOp>> {
@@ -771,8 +751,6 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorGatherLowering>();
 
-  PassRegistration<TestVectorTransferTensorSlicePatterns>();
-
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 }
 } // namespace test


        


More information about the Mlir-commits mailing list