[Mlir-commits] [mlir] 834fcfe - Reland "[mlir][Vector] Extend xfer drop unit dim patterns"

Diego Caballero llvmlistbot at llvm.org
Thu Jun 1 15:23:30 PDT 2023


Author: Diego Caballero
Date: 2023-06-01T22:22:16Z
New Revision: 834fcfed248dc1cd0fe68158dbd1e5f9a9e19e3d

URL: https://github.com/llvm/llvm-project/commit/834fcfed248dc1cd0fe68158dbd1e5f9a9e19e3d
DIFF: https://github.com/llvm/llvm-project/commit/834fcfed248dc1cd0fe68158dbd1e5f9a9e19e3d.diff

LOG: Reland "[mlir][Vector] Extend xfer drop unit dim patterns"

This reverts commit 76d71f3792b2b1864992446f7b1028b026dccd11.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 8d97bbfb7257..fa901d068a75 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -63,6 +63,7 @@ class TransferOptimization {
   std::vector<Operation *> opToErase;
 };
 
+} // namespace
 /// Return true if there is a path from start operation to dest operation,
 /// otherwise return false. The operations have to be in the same region.
 bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
@@ -288,14 +289,25 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
 }
 
+/// Returns a copy of `shape` without unit dims.
+static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
+  SmallVector<int64_t> reducedShape;
+  llvm::copy_if(shape, std::back_inserter(reducedShape),
+                [](int64_t dimSize) { return dimSize != 1; });
+  return reducedShape;
+}
+
 /// Returns true if all values are `arith.constant 0 : index`
 static bool isZero(Value v) {
   auto cst = v.getDefiningOp<arith::ConstantIndexOp>();
   return cst && cst.value() == 0;
 }
 
-/// Rewrites vector.transfer_read ops where the source has unit dims, by
-/// inserting a memref.subview dropping those unit dims.
+namespace {
+
+/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
+/// inserting a memref.subview dropping those unit dims. The vector shapes are
+/// also reduced accordingly.
 class TransferReadDropUnitDimsPattern
     : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -317,12 +329,15 @@ class TransferReadDropUnitDimsPattern
       return failure();
     if (!transferReadOp.getPermutationMap().isMinorIdentity())
       return failure();
+    // Check if the source shape can be further reduced.
     int reducedRank = getReducedRank(sourceType.getShape());
     if (reducedRank == sourceType.getRank())
-      return failure(); // The source shape can't be further reduced.
-    if (reducedRank != vectorType.getRank())
-      return failure(); // This pattern requires the vector shape to match the
-                        // reduced source shape.
+      return failure();
+    // Check if the reduced vector shape matches the reduced source shape.
+    // Otherwise, this case is not supported yet.
+    int vectorReducedRank = getReducedRank(vectorType.getShape());
+    if (reducedRank != vectorReducedRank)
+      return failure();
     if (llvm::any_of(transferReadOp.getIndices(),
                      [](Value v) { return !isZero(v); }))
       return failure();
@@ -331,14 +346,22 @@ class TransferReadDropUnitDimsPattern
     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
-    rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-        transferReadOp, vectorType, reducedShapeSource, zeros, identityMap);
+    auto reducedVectorType = VectorType::get(
+        getReducedShape(vectorType.getShape()), vectorType.getElementType());
+
+    auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
+        loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
+    auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+        loc, vectorType, newTransferReadOp);
+    rewriter.replaceOp(transferReadOp, shapeCast);
+
     return success();
   }
 };
 
-/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has
-/// unit dims, by inserting a memref.subview dropping those unit dims.
+/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
+/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
+/// vector shapes are also reduced accordingly.
 class TransferWriteDropUnitDimsPattern
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -360,12 +383,15 @@ class TransferWriteDropUnitDimsPattern
       return failure();
     if (!transferWriteOp.getPermutationMap().isMinorIdentity())
       return failure();
+    // Check if the destination shape can be further reduced.
     int reducedRank = getReducedRank(sourceType.getShape());
     if (reducedRank == sourceType.getRank())
-      return failure(); // The source shape can't be further reduced.
-    if (reducedRank != vectorType.getRank())
-      return failure(); // This pattern requires the vector shape to match the
-                        // reduced source shape.
+      return failure();
+    // Check if the reduced vector shape matches the reduced destination shape.
+    // Otherwise, this case is not supported yet.
+    int vectorReducedRank = getReducedRank(vectorType.getShape());
+    if (reducedRank != vectorReducedRank)
+      return failure();
     if (llvm::any_of(transferWriteOp.getIndices(),
                      [](Value v) { return !isZero(v); }))
       return failure();
@@ -374,12 +400,20 @@ class TransferWriteDropUnitDimsPattern
     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
+    VectorType reducedVectorType = VectorType::get(
+        getReducedShape(vectorType.getShape()), vectorType.getElementType());
+
+    auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+        loc, reducedVectorType, vector);
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        transferWriteOp, vector, reducedShapeSource, zeros, identityMap);
+        transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
+
     return success();
   }
 };
 
+} // namespace
+
 /// Return true if the memref type has its inner dimension matching the given
 /// shape. Otherwise return false.
 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
@@ -439,6 +473,8 @@ checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
   return success();
 }
 
+namespace {
+
 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -736,6 +772,7 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
     return success();
   }
 };
+
 } // namespace
 
 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e4e2e3b69c67..3efa06948f54 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -15,6 +15,14 @@ func.func @transfer_read_rank_reducing(
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]
 
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  transform.vector.apply_rank_reducing_subview_patterns %module_op
+      : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
 func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -28,6 +36,97 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
 
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  transform.vector.apply_rank_reducing_subview_patterns %module_op
+      : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+func.func @transfer_read_and_vector_rank_reducing(
+      %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f32
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
+      memref<1x1x3x2x1xf32>, vector<3x2x1xf32>
+    return %v : vector<3x2x1xf32>
+}
+
+// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2x1xf32>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2x1xf32> to memref<3x2xf32>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  transform.vector.apply_rank_reducing_subview_patterns %module_op
+      : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+func.func @transfer_write_and_vector_rank_reducing(
+      %arg : memref<1x1x3x2x1xf32>,
+      %vec : vector<3x2x1xf32>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
+      vector<3x2x1xf32>, memref<1x1x3x2x1xf32>
+    return
+}
+
+// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2x1xf32>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2x1xf32> to memref<3x2xf32>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  transform.vector.apply_rank_reducing_subview_patterns %module_op
+      : (!transform.any_op) -> !transform.any_op
+}
+
+// -----
+
+func.func @transfer_read_and_vector_rank_reducing_to_0d(
+      %arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0.0 : f32
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
+      memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
+    return %v : vector<1x1x1xf32>
+}
+
+// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d
+//  CHECK-SAME:     %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
+//       CHECK:   %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
+//       CHECK:   vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !pdl.operation):
+  transform.vector.apply_rank_reducing_subview_patterns %module_op
+      : (!pdl.operation) -> !pdl.operation
+}
+
+// -----
+
+func.func @transfer_write_and_vector_rank_reducing_to_0d(
+      %arg : memref<1x1x1x1x1xf32>,
+      %vec : vector<1x1x1xf32>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
+      vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
+    return
+}
+
+// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d
+//  CHECK-SAME:     %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
+//       CHECK:   %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
+//       CHECK:   vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
 
 transform.sequence failures(propagate) {
 ^bb1(%module_op: !transform.any_op):


        


More information about the Mlir-commits mailing list