[Mlir-commits] [mlir] 9e7c97d - [mlir][vector] Fix bug in transfer op flattening
Thomas Raoux
llvmlistbot at llvm.org
Fri Sep 9 09:03:06 PDT 2022
Author: Thomas Raoux
Date: 2022-09-09T16:02:52Z
New Revision: 9e7c97d8ce12893f09fabbf6b54c8ee0297e7ed7
URL: https://github.com/llvm/llvm-project/commit/9e7c97d8ce12893f09fabbf6b54c8ee0297e7ed7
DIFF: https://github.com/llvm/llvm-project/commit/9e7c97d8ce12893f09fabbf6b54c8ee0297e7ed7.diff
LOG: [mlir][vector] Fix bug in transfer op flattening
The logic to figure out if a transfer op can be flattened wasn't
considering the shape being loaded therefore it was incorrectly assuming
some transfer ops were reading contigous data.
Differential Revision: https://reviews.llvm.org/D133544
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 5fe393b48b10f..92b103364ea27 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -339,34 +339,26 @@ class TransferWriteDropUnitDimsPattern
}
};
-/// Returns the position of the first inner dimension that has contiguous layout
-/// with at least `requiredContiguousSize` contiguous elements.
-/// When such a dimension is found, the return value satisfies:
-/// 0 <= return_value <= memrefType.getRank() - 1.
-/// When no such dimension is found, the return value is memrefType.getRank().
-static int64_t getContiguousInnerDim(MemRefType memrefType,
- int64_t requiredContiguousSize) {
+/// Return true if the memref type has its inner dimension matching the given
+/// shape. Otherwise return false.
+static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
+ ArrayRef<int64_t> targetShape) {
auto shape = memrefType.getShape();
SmallVector<int64_t> strides;
int64_t offset;
- int64_t innerDim = shape.size();
- if (succeeded(getStridesAndOffset(memrefType, strides, offset))) {
- int64_t innerSize = 1;
- while (true) {
- if (innerDim == 0)
- break;
- const int64_t nextDim = innerDim - 1;
- if (shape[nextDim] == ShapedType::kDynamicSize)
- break;
- if (strides[nextDim] != innerSize)
- break;
- innerSize *= shape[nextDim];
- innerDim = nextDim;
- if (innerSize >= requiredContiguousSize)
- break;
- }
+ if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
+ return false;
+ if (strides.back() != 1)
+ return false;
+ strides.pop_back();
+ int64_t flatDim = 1;
+ for (auto [targetDim, memrefDim, memrefStride] :
+ llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+ flatDim *= memrefDim;
+ if (flatDim != memrefStride || targetDim != memrefDim)
+ return false;
}
- return innerDim;
+ return true;
}
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
@@ -427,10 +419,12 @@ class FlattenContiguousRowMajorTransferReadPattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- int64_t firstContiguousInnerDim =
- getContiguousInnerDim(sourceType, vectorType.getNumElements());
- if (firstContiguousInnerDim >= sourceType.getRank() - 1)
+ if (!hasMatchingInnerContigousShape(
+ sourceType,
+ vectorType.getShape().take_back(vectorType.getRank() - 1)))
return failure();
+ int64_t firstContiguousInnerDim =
+ sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
@@ -485,10 +479,12 @@ class FlattenContiguousRowMajorTransferWritePattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- int64_t firstContiguousInnerDim =
- getContiguousInnerDim(sourceType, vectorType.getNumElements());
- if (firstContiguousInnerDim >= sourceType.getRank() - 1)
+ if (!hasMatchingInnerContigousShape(
+ sourceType,
+ vectorType.getShape().take_back(vectorType.getRank() - 1)))
return failure();
+ int64_t firstContiguousInnerDim =
+ sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 41e61887a6311..3c8e280212bed 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -12,9 +12,9 @@ func.func @transfer_read_flattenable_with_offset(
// CHECK-LABEL: func @transfer_read_flattenable_with_offset
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
-// C-HECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// C-HECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
-// C-HECK: return %[[VEC2D]]
+// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
+// CHECK: return %[[VEC2D]]
// -----
@@ -26,12 +26,12 @@ func.func @transfer_write_flattenable_with_offset(
return
}
-// C-HECK-LABEL: func @transfer_write_flattenable_with_offset
-// C-HECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
-// C-HECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
-// C-HECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
-// C-HECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
-// C-HECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+// CHECK-LABEL: func @transfer_write_flattenable_with_offset
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
+// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
+// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
// -----
@@ -104,3 +104,31 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
// CHECK-SAME: {in_bounds = [true]}
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+
+// -----
+
+func.func @transfer_read_flattenable_negative(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8>
+ return %v : vector<2x2x2x2xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_flattenable_negative
+// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8>
+
+// -----
+
+func.func @transfer_read_flattenable_negative2(
+ %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_flattenable_negative2
+// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
More information about the Mlir-commits
mailing list