[Mlir-commits] [mlir] 9e7c97d - [mlir][vector] Fix bug in transfer op flattening

Thomas Raoux llvmlistbot at llvm.org
Fri Sep 9 09:03:06 PDT 2022


Author: Thomas Raoux
Date: 2022-09-09T16:02:52Z
New Revision: 9e7c97d8ce12893f09fabbf6b54c8ee0297e7ed7

URL: https://github.com/llvm/llvm-project/commit/9e7c97d8ce12893f09fabbf6b54c8ee0297e7ed7
DIFF: https://github.com/llvm/llvm-project/commit/9e7c97d8ce12893f09fabbf6b54c8ee0297e7ed7.diff

LOG: [mlir][vector] Fix bug in transfer op flattening

The logic to figure out if a transfer op can be flattened wasn't
considering the shape being loaded therefore it was incorrectly assuming
some transfer ops were reading contigous data.

Differential Revision: https://reviews.llvm.org/D133544

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 5fe393b48b10f..92b103364ea27 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -339,34 +339,26 @@ class TransferWriteDropUnitDimsPattern
   }
 };
 
-/// Returns the position of the first inner dimension that has contiguous layout
-/// with at least `requiredContiguousSize` contiguous elements.
-/// When such a dimension is found, the return value satisfies:
-///   0 <= return_value <= memrefType.getRank() - 1.
-/// When no such dimension is found, the return value is memrefType.getRank().
-static int64_t getContiguousInnerDim(MemRefType memrefType,
-                                     int64_t requiredContiguousSize) {
+/// Return true if the memref type has its inner dimension matching the given
+/// shape. Otherwise return false.
+static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
+                                              ArrayRef<int64_t> targetShape) {
   auto shape = memrefType.getShape();
   SmallVector<int64_t> strides;
   int64_t offset;
-  int64_t innerDim = shape.size();
-  if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
-    int64_t innerSize = 1;
-    while (true) {
-      if (innerDim == 0)
-        break;
-      const int64_t nextDim = innerDim - 1;
-      if (shape[nextDim] == ShapedType::kDynamicSize)
-        break;
-      if (strides[nextDim] != innerSize)
-        break;
-      innerSize *= shape[nextDim];
-      innerDim = nextDim;
-      if (innerSize >= requiredContiguousSize)
-        break;
-    }
+  if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
+    return false;
+  if (strides.back() != 1)
+    return false;
+  strides.pop_back();
+  int64_t flatDim = 1;
+  for (auto [targetDim, memrefDim, memrefStride] :
+       llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+    flatDim *= memrefDim;
+    if (flatDim != memrefStride || targetDim != memrefDim)
+      return false;
   }
-  return innerDim;
+  return true;
 }
 
 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
@@ -427,10 +419,12 @@ class FlattenContiguousRowMajorTransferReadPattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    int64_t firstContiguousInnerDim =
-        getContiguousInnerDim(sourceType, vectorType.getNumElements());
-    if (firstContiguousInnerDim >= sourceType.getRank() - 1)
+    if (!hasMatchingInnerContigousShape(
+            sourceType,
+            vectorType.getShape().take_back(vectorType.getRank() - 1)))
       return failure();
+    int64_t firstContiguousInnerDim =
+        sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferReadOp.hasOutOfBoundsDim())
       return failure();
@@ -485,10 +479,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    int64_t firstContiguousInnerDim =
-        getContiguousInnerDim(sourceType, vectorType.getNumElements());
-    if (firstContiguousInnerDim >= sourceType.getRank() - 1)
+    if (!hasMatchingInnerContigousShape(
+            sourceType,
+            vectorType.getShape().take_back(vectorType.getRank() - 1)))
       return failure();
+    int64_t firstContiguousInnerDim =
+        sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
       return failure();

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 41e61887a6311..3c8e280212bed 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -12,9 +12,9 @@ func.func @transfer_read_flattenable_with_offset(
 // CHECK-LABEL: func @transfer_read_flattenable_with_offset
 // CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
-// C-HECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// C-HECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
-// C-HECK:         return %[[VEC2D]]
+// CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
+// CHECK:         return %[[VEC2D]]
 
 // -----
 
@@ -26,12 +26,12 @@ func.func @transfer_write_flattenable_with_offset(
     return
 }
 
-// C-HECK-LABEL: func @transfer_write_flattenable_with_offset
-// C-HECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
-// C-HECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
-// C-HECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
-// C-HECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
-// C-HECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+// CHECK-LABEL: func @transfer_write_flattenable_with_offset
+// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
+// CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
+// CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
 
 // -----
 
@@ -104,3 +104,31 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
 // CHECK-SAME:    [%[[ARG2]], %[[ARG3]], %[[C0]]]
 // CHECK-SAME:    {in_bounds = [true]}
 // CHECK-SAME:    : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+
+// -----
+
+func.func @transfer_read_flattenable_negative(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : 
+      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8>
+    return %v : vector<2x2x2x2xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_flattenable_negative
+//       CHECK:   vector.transfer_read {{.*}} vector<2x2x2x2xi8>
+
+// -----
+
+func.func @transfer_read_flattenable_negative2(
+      %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : 
+      memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+    return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_flattenable_negative2
+//       CHECK:   vector.transfer_read {{.*}} vector<5x4x3x2xi8>


        


More information about the Mlir-commits mailing list