[Mlir-commits] [mlir] 57b101b - [mlir][vector] Handle scalars in extract_strided_slice(broadcast)

Lei Zhang llvmlistbot at llvm.org
Fri Apr 1 09:11:25 PDT 2022


Author: Lei Zhang
Date: 2022-04-01T12:07:47-04:00
New Revision: 57b101bdec15c3ed421972ddf1d10f1de3c1f8c1

URL: https://github.com/llvm/llvm-project/commit/57b101bdec15c3ed421972ddf1d10f1de3c1f8c1
DIFF: https://github.com/llvm/llvm-project/commit/57b101bdec15c3ed421972ddf1d10f1de3c1f8c1.diff

LOG: [mlir][vector] Handle scalars in extract_strided_slice(broadcast)

For such cases we cannot generate extract_strided_slice ops.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D122902

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9fd71222d8ecd..db8d40ae19daf 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2501,24 +2501,27 @@ class StridedSliceBroadcast final
     if (!broadcast)
       return failure();
     auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>();
-    unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0;
+    unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
     auto dstVecType = op.getType().cast<VectorType>();
     unsigned dstRank = dstVecType.getRank();
-    unsigned rankDiff = dstRank - srcRrank;
+    unsigned rankDiff = dstRank - srcRank;
     // Check if the most inner dimensions of the source of the broadcast are the
     // same as the destination of the extract. If this is the case we can just
     // use a broadcast as the original dimensions are untouched.
     bool lowerDimMatch = true;
-    for (unsigned i = 0; i < srcRrank; i++) {
+    for (unsigned i = 0; i < srcRank; i++) {
       if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
         lowerDimMatch = false;
         break;
       }
     }
     Value source = broadcast.getSource();
-    if (!lowerDimMatch) {
-      // The inner dimensions don't match, it means we need to extract from the
-      // source of the orignal broadcast and then broadcast the extracted value.
+    // If the inner dimensions don't match, it means we need to extract from the
+    // source of the orignal broadcast and then broadcast the extracted value.
+    // We also need to handle degenerated cases where the source is effectively
+    // just a single scalar.
+    bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
+    if (!lowerDimMatch && !isScalarSrc) {
       source = rewriter.create<ExtractStridedSliceOp>(
           op->getLoc(), source,
           getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9bfd08b827c62..2b02dc143b4e3 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -762,6 +762,34 @@ func @extract_strided_broadcast2(%arg0: vector<4xf16>) -> vector<2x2xf16> {
 
 // -----
 
+// CHECK-LABEL: func @extract_strided_broadcast3
+//  CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>)
+//       CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x4xf32>
+//       CHECK: return %[[V]]
+func @extract_strided_broadcast3(%arg0: vector<1xf32>) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x8xf32>
+ %1 = vector.extract_strided_slice %0
+      {offsets = [0, 4], sizes = [1, 4], strides = [1, 1]}
+      : vector<1x8xf32> to vector<1x4xf32>
+  return %1 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_strided_broadcast4
+//  CHECK-SAME: (%[[ARG:.+]]: f32)
+//       CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x4xf32>
+//       CHECK: return %[[V]]
+func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
+ %0 = vector.broadcast %arg0 : f32 to vector<1x8xf32>
+ %1 = vector.extract_strided_slice %0
+      {offsets = [0, 4], sizes = [1, 4], strides = [1, 1]}
+      : vector<1x8xf32> to vector<1x4xf32>
+  return %1 : vector<1x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: consecutive_shape_cast
 //       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
 //  CHECK-NEXT:   return %[[C]] : vector<4x4xf16>


        


More information about the Mlir-commits mailing list