[Mlir-commits] [mlir] [vector][mlir] Canonicalize to shape_cast where possible (PR #140583)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 26 09:04:24 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sme

Author: James Newling (newling)

<details>
<summary>Changes</summary>

Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see https://github.com/llvm/llvm-project/pull/138777) 

For example these can all be expressed as shape casts:

```mlir
%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
```

This PR adds canonicalizes to convert the above 3 examples to shape_casts. 

I've added some more comments as review comments. 

I'm happy to split this PR up and add the new patterns separately. 

---

Patch is 41.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140583.diff


10 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+84-53) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (-61) 
- (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+4-4) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+26-41) 
- (modified) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+2-2) 
- (added) mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir (+162) 
- (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+60) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+5-7) 
- (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (-85) 
- (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+4-4) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..08cc4af158e10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,11 +2351,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    // Negative values in `position` indicates poison, which cannot be
+    // represented with a shape_cast
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -2867,13 +2897,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, BroadcastToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -5991,10 +6044,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///   i) Y = ShapeCast(X), or
-///  ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
     bool srcIsScalar = !srcVectorType;
 
-    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
-    // Example:
-    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
-    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
-    // to
-    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
-    if (srcVectorType) {
-      if (srcVectorType.getNumElements() ==
-          shapeCastOp.getResultVectorType().getNumElements()) {
-        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-            shapeCastOp, shapeCastOp.getResultVectorType(),
-            broadcastOp.getSource());
-        return success();
-      }
-    }
-
     // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
     // Example
     // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6233,7 +6267,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
   // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
   //
-  // Example of what NOT to fold:
+  // Example of what not to fold:
   // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
   //
   if (getSourceVectorType() == getResultVectorType() &&
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransposeOp transposeOp,
-                                PatternRewriter &rewriter) const override {
-    auto shapeCastOp =
-        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
-    if (!shapeCastOp)
-      return failure();
-    if (!isOrderPreserving(transposeOp))
-      return failure();
-
-    VectorType resultType = transposeOp.getType();
-
-    // We don't need to check isValidShapeCast at this point, because it is
-    // guaranteed that merging the transpose into the the shape_cast is a valid
-    // shape_cast, because the transpose just inserts/removes ones.
-
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
-                                                     shapeCastOp.getSource());
-    return success();
-  }
-};
-
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6488,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
+/// BEFORE:
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+struct TransposeToShapeCast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+
+    if (!isOrderPreserving(transpose)) {
+      return rewriter.notifyMatchFailure(
+          transpose, "not order preserving, so not semantically a 'copy'");
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transpose, transpose.getType(), transpose.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+  results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
+              FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..71410eda28297 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,7 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
-/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
-/// to 2D vectors with at least one unit dim. For example:
-///
-/// Replace:
-///   vector.transpose %0, [1, 0] : vector<4x1xi32>> to
-///                                 vector<1x4xi32>
-/// with:
-///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dim can be scalable.
-///
-/// TODO: This pattern was introduced specifically to help lower scalable
-/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
-/// to cancel out) would be preferable:
-///
-///  BEFORE:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
-///     %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  AFTER:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
-///
-/// Given the context above, we may want to consider (re-)moving this pattern
-/// at some later time. I am leaving it for now in case there are other users
-/// that I am not aware of.
-class Transpose2DWithUnitDimToShapeCast
-    : public OpRewritePattern<vector::TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
-                                    PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
-
-  LogicalResult matchAndRewrite(vector::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getVector();
-    VectorType resType = op.getResultVectorType();
-
-    // Set up convenience transposition table.
-    ArrayRef<int64_t> transp = op.getPermutation();
-
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -511,8 +452,6 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
-  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
-                                                  benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       vectorTransposeLowering, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6cdf576272ebc..a9a2fdccdd82f 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
 
 // -----
 
-// The pass should do nothing (and not crash).
-// CHECK-LABEL: @illegal_transpose_no_defining_source_op
-func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
+// CHECK-LABEL: @transpose_no_defining_source_op
+func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
 {
-  // CHECK: vector.transpose
+  // CHECK:      vector.shape_cast
+  // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..374c71c814e89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
 // -----
 
 // CHECK-LABEL: transpose_3D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-NEXT: return [[ARG]]
 func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
-  // CHECK-NOT: transpose
   %0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
-  // CHECK-NEXT: return [[ARG]]
   return %0 : vector<4x3x2xf32>
 }
 
 // -----
 
+// CHECK-LABEL: transpose_0D_identity
+//  CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
+//  CHECK-NEXT: return [[ARG]]
+func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
+  %0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
+  return %0 : vector<i8>
+}
+
+// -----
+
 // CHECK-LABEL: transpose_2D_sequence
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
 func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // -----
 
+
 // CHECK-LABEL: negative_fold_extract_broadcast
-//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
 func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
-  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -797,8 +807,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 // rank(extract_output) < rank(broadcast_input)
 func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
 
 // -----
 
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
-  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
-  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
-  return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
-    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
-    return %1 : vector<1x1xf32>
-}
-
-// -----
-
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -1920,12 +1906,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
 
 // -----
 
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
 //       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
   %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
   %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
   %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2277,7 +2263,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
 
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
-  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
   %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
   return %shuffle : vector<1xi32>
 }
@@ -2764,9 +2750,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
 // CHECK-LABEL: func.func @extract_from_broadcast
 func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
-  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
-  //  CHECK-NEXT:   return %0 : vector<1xf32>
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+  //  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
   %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
   return %1: vector<1xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..d5f96a8928770 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
-//       CHECK:       return %[[EXTRACT]] : vector<2xi8>
+//       CHECK:       %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[SC]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
diff --git...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/140583


More information about the Mlir-commits mailing list