[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 30 08:28:16 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

<details>
<summary>Changes</summary>

In https://github.com/llvm/llvm-project/pull/140583 more shape_cast ops will appear. Specifically broadcasts that just prepend ones become shape_cast ops (i.e. volume preserving broadcasts are canonicalized to shape_casts). This PR ensures that broadcast-like shape_cast ops fold at least as well as broadcast ops. 

---
Full diff: https://github.com/llvm/llvm-project/pull/146368.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+59-64) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+42-4) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a11dbe2589205..e4da65252c6e3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1696,59 +1696,68 @@ static bool hasZeroDimVectors(Operation *op) {
          llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
 }
 
+/// All BroadcastOps and SplatOps, and ShapeCastOps that only prepends 1s, are
+/// considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+  if (isa<BroadcastOp, SplatOp>(op))
+    return true;
+
+  auto shapeCast = dyn_cast<ShapeCastOp>(op);
+  if (!shapeCast)
+    return false;
+
+  VectorType srcType = shapeCast.getSourceVectorType();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  uint64_t srcRank = srcType.getRank();
+  ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+  return dstShape.size() <= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  Operation *defOp = extractOp.getVector().getDefiningOp();
-  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+
+  Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+  if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
     return Value();
 
-  Value source = defOp->getOperand(0);
-  if (extractOp.getType() == source.getType())
-    return source;
-  auto getRank = [](Type type) {
-    return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
-                                       : 0;
-  };
+  Value src = broadcastLikeOp->getOperand(0);
+
+  // Replace extract(broadcast(X)) with X
+  if (extractOp.getType() == src.getType())
+    return src;
 
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
-    return source;
+  // Get required types and ranks in the chain
+  //    src -> broadcastDst -> dst
+  auto srcType = llvm::dyn_cast<VectorType>(src.getType());
+  auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
+  unsigned srcRank = srcType ? srcType.getRank() : 0;
+  unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
+  unsigned dstRank = dstType ? dstType.getRank() : 0;
 
-  unsigned extractResultRank = getRank(extractOp.getType());
-  if (extractResultRank > broadcastSrcRank)
+  // Cannot do without the broadcast if overall the rank increases.
+  if (dstRank > srcRank)
     return Value();
-  // Check that the dimension of the result haven't been broadcasted.
-  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
-  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
-  if (extractVecType && broadcastVecType &&
-      extractVecType.getShape() !=
-          broadcastVecType.getShape().take_back(extractResultRank))
+
+  assert(srcType && "src must be a vector type because of previous checks");
+
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
     return Value();
 
-  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
-  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
+  // Replace extract(broadcast(X)) with extract(X).
+  // First, determine the new extraction position.
+  unsigned deltaOverall = srcRank - dstRank;
+  unsigned deltaBroadcast = broadcastDstRank - srcRank;
 
-  // Detect all the positions that come from "dim-1" broadcasting.
-  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
-  // extract position to `0` when extracting from the source operand.
-  llvm::SetVector<int64_t> broadcastedUnitDims =
-      broadcastOp.computeBroadcastedUnitDims();
-  SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
-  OpBuilder b(extractOp.getContext());
-  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
-  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
-    if (broadcastedUnitDims.contains(i))
-      extractPos[i] = b.getIndexAttr(0);
-  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
-  // matching extract position when extracting from the source operand.
-  int64_t rankDiff = broadcastSrcRank - extractResultRank;
-  extractPos.erase(extractPos.begin(),
-                   std::next(extractPos.begin(), extractPos.size() - rankDiff));
-  // OpBuilder is only used as a helper to build an I64ArrayAttr.
-  auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
+  SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
+  SmallVector<OpFoldResult> newPositions(deltaOverall);
+  IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
+  for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
+    newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
+  }
+  auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
   extractOp->setOperands(
-      llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
+      llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
   extractOp.setStaticPosition(staticPos);
   return extractOp.getResult();
 }
@@ -2193,32 +2202,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
 
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    Operation *defOp = extractOp.getVector().getDefiningOp();
-    if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
-      return failure();
 
-    Value source = defOp->getOperand(0);
-    if (extractOp.getType() == source.getType())
+    Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
       return failure();
-    auto getRank = [](Type type) {
-      return llvm::isa<VectorType>(type)
-                 ? llvm::cast<VectorType>(type).getRank()
-                 : 0;
-    };
-    unsigned broadcastSrcRank = getRank(source.getType());
-    unsigned extractResultRank = getRank(extractOp.getType());
-    // We only consider the case where the rank of the source is less than or
-    // equal to the rank of the extract dst. The other cases are handled in the
-    // folding patterns.
-    if (extractResultRank < broadcastSrcRank)
-      return failure();
-    // For scalar result, the input can only be a rank-0 vector, which will
-    // be handled by the folder.
-    if (extractResultRank == 0)
+
+    Value source = broadcastLikeOp->getOperand(0);
+    if (isBroadcastableTo(source.getType(), outType) !=
+        BroadcastableToResult::Success)
       return failure();
 
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        extractOp, extractOp.getType(), source);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..350233d1f7969 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -764,10 +764,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
 
 // -----
 
-// CHECK-LABEL: fold_extract_splat
+// CHECK-LABEL: fold_extract_scalar_from_splat
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.splat %a : vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
@@ -775,6 +775,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
 
 // -----
 
+// CHECK-LABEL: fold_extract_vector_from_splat
+//       CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
+func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
+  %b = vector.splat %a : vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
 //  CHECK-SAME:   %[[A:.*]]: vector<2x1xf32>
 //  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@@ -804,6 +814,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
 
 // -----
 
+// Test where the shape_cast is broadcast-like.
+// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
+//  CHECK-SAME:   %[[A:.*]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
+//       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
+//       CHECK:   return %[[B]] : vector<4xf32>
+func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
+  %idx0 : index, %idx1 : index) -> vector<4xf32> {
+  %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
@@ -831,6 +856,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
+//  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
+//       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+//       CHECK:   return %[[R]] : vector<1x1xf32>
+func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
+  -> vector<1x1xf32> {
+  %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
+  %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
+  return %r : vector<1x1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_extract_shuffle
 //  CHECK-SAME:   %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
 //   CHECK-NOT:   vector.shuffle
@@ -1549,7 +1587,7 @@ func.func @negative_store_to_load_tensor_memref(
     %arg0 : tensor<?x?xf32>,
     %arg1 : memref<?x?xf32>,
     %v0 : vector<4x2xf32>
-  ) -> vector<4x2xf32> 
+  ) -> vector<4x2xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
@@ -1606,7 +1644,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
 //       CHECK:   vector.transfer_read
 func.func @negative_store_to_load_tensor_broadcast_masked(
     %arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
-  -> vector<4x2x6xf32> 
+  -> vector<4x2x6xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32

``````````

</details>


https://github.com/llvm/llvm-project/pull/146368


More information about the Mlir-commits mailing list