[Mlir-commits] [mlir] 57b101b - [mlir][vector] Handle scalars in extract_strided_slice(broadcast)
Lei Zhang
llvmlistbot at llvm.org
Fri Apr 1 09:11:25 PDT 2022
Author: Lei Zhang
Date: 2022-04-01T12:07:47-04:00
New Revision: 57b101bdec15c3ed421972ddf1d10f1de3c1f8c1
URL: https://github.com/llvm/llvm-project/commit/57b101bdec15c3ed421972ddf1d10f1de3c1f8c1
DIFF: https://github.com/llvm/llvm-project/commit/57b101bdec15c3ed421972ddf1d10f1de3c1f8c1.diff
LOG: [mlir][vector] Handle scalars in extract_strided_slice(broadcast)
For such cases we cannot generate extract_strided_slice ops.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D122902
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9fd71222d8ecd..db8d40ae19daf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2501,24 +2501,27 @@ class StridedSliceBroadcast final
if (!broadcast)
return failure();
auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>();
- unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
+ unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
auto dstVecType = op.getType().cast<VectorType>();
unsigned dstRank = dstVecType.getRank();
- unsigned rankDiff = dstRank - srcRrank;
+ unsigned rankDiff = dstRank - srcRank;
// Check if the most inner dimensions of the source of the broadcast are the
// same as the destination of the extract. If this is the case we can just
// use a broadcast as the original dimensions are untouched.
bool lowerDimMatch = true;
- for (unsigned i = 0; i < srcRrank; i++) {
+ for (unsigned i = 0; i < srcRank; i++) {
if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
lowerDimMatch = false;
break;
}
}
Value source = broadcast.getSource();
- if (!lowerDimMatch) {
- // The inner dimensions don't match, it means we need to extract from the
- // source of the orignal broadcast and then broadcast the extracted value.
+ // If the inner dimensions don't match, it means we need to extract from the
+ // source of the orignal broadcast and then broadcast the extracted value.
+ // We also need to handle degenerated cases where the source is effectively
+ // just a single scalar.
+ bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
+ if (!lowerDimMatch && !isScalarSrc) {
source = rewriter.create<ExtractStridedSliceOp>(
op->getLoc(), source,
getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9bfd08b827c62..2b02dc143b4e3 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -762,6 +762,34 @@ func @extract_strided_broadcast2(%arg0: vector<4xf16>) -> vector<2x2xf16> {
// -----
+// CHECK-LABEL: func @extract_strided_broadcast3
+// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
+// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x4xf32>
+// CHECK: return %[[V]]
+func @extract_strided_broadcast3(%arg0: vector<1xf32>) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x8xf32>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [0, 4], sizes = [1, 4], strides = [1, 1]}
+ : vector<1x8xf32> to vector<1x4xf32>
+ return %1 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_strided_broadcast4
+// CHECK-SAME: (%[[ARG:.+]]: f32)
+// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x4xf32>
+// CHECK: return %[[V]]
+func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1x8xf32>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [0, 4], sizes = [1, 4], strides = [1, 1]}
+ : vector<1x8xf32> to vector<1x4xf32>
+ return %1 : vector<1x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: consecutive_shape_cast
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
More information about the Mlir-commits
mailing list