[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)

James Newling llvmlistbot at llvm.org
Mon Jun 30 11:31:07 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/146368

>From b3877b8d6c90569c89eddddca63303fa0bf3e28a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 30 Jun 2025 09:15:48 -0700
Subject: [PATCH 1/2] extend to broadcastlike, code simplifications

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 126 ++++++++++-----------
 mlir/test/Dialect/Vector/canonicalize.mlir |  46 +++++++-
 2 files changed, 104 insertions(+), 68 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a11dbe2589205..ed616fb4d343b 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1696,59 +1696,71 @@ static bool hasZeroDimVectors(Operation *op) {
          llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
 }
 
+/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
+/// 1s, are considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+  if (isa<BroadcastOp, SplatOp>(op))
+    return true;
+
+  auto shapeCast = dyn_cast<ShapeCastOp>(op);
+  if (!shapeCast)
+    return false;
+
+  // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
+  // Condition 1: dst has hight rank.
+  // Condition 2: src shape is a suffix of dst shape.
+  VectorType srcType = shapeCast.getSourceVectorType();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  uint64_t srcRank = srcType.getRank();
+  ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+  return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  Operation *defOp = extractOp.getVector().getDefiningOp();
-  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+
+  Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+  if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
     return Value();
 
-  Value source = defOp->getOperand(0);
-  if (extractOp.getType() == source.getType())
-    return source;
-  auto getRank = [](Type type) {
-    return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
-                                       : 0;
-  };
+  Value src = broadcastLikeOp->getOperand(0);
+
+  // Replace extract(broadcast(X)) with X
+  if (extractOp.getType() == src.getType())
+    return src;
 
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
-    return source;
+  // Get required types and ranks in the chain
+  //    src -> broadcastDst -> dst
+  auto srcType = llvm::dyn_cast<VectorType>(src.getType());
+  auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
+  unsigned srcRank = srcType ? srcType.getRank() : 0;
+  unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
+  unsigned dstRank = dstType ? dstType.getRank() : 0;
 
-  unsigned extractResultRank = getRank(extractOp.getType());
-  if (extractResultRank > broadcastSrcRank)
+  // Cannot do without the broadcast if overall the rank increases.
+  if (dstRank > srcRank)
     return Value();
-  // Check that the dimension of the result haven't been broadcasted.
-  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
-  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
-  if (extractVecType && broadcastVecType &&
-      extractVecType.getShape() !=
-          broadcastVecType.getShape().take_back(extractResultRank))
+
+  assert(srcType && "src must be a vector type because of previous checks");
+
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
     return Value();
 
-  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
-  int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
+  // Replace extract(broadcast(X)) with extract(X).
+  // First, determine the new extraction position.
+  unsigned deltaOverall = srcRank - dstRank;
+  unsigned deltaBroadcast = broadcastDstRank - srcRank;
 
-  // Detect all the positions that come from "dim-1" broadcasting.
-  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
-  // extract position to `0` when extracting from the source operand.
-  llvm::SetVector<int64_t> broadcastedUnitDims =
-      broadcastOp.computeBroadcastedUnitDims();
-  SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
-  OpBuilder b(extractOp.getContext());
-  int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
-  for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
-    if (broadcastedUnitDims.contains(i))
-      extractPos[i] = b.getIndexAttr(0);
-  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
-  // matching extract position when extracting from the source operand.
-  int64_t rankDiff = broadcastSrcRank - extractResultRank;
-  extractPos.erase(extractPos.begin(),
-                   std::next(extractPos.begin(), extractPos.size() - rankDiff));
-  // OpBuilder is only used as a helper to build an I64ArrayAttr.
-  auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
+  SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
+  SmallVector<OpFoldResult> newPositions(deltaOverall);
+  IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
+  for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
+    newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
+  }
+  auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
   extractOp->setOperands(
-      llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
+      llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
   extractOp.setStaticPosition(staticPos);
   return extractOp.getResult();
 }
@@ -2193,32 +2205,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
 
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
-    Operation *defOp = extractOp.getVector().getDefiningOp();
-    if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
-      return failure();
 
-    Value source = defOp->getOperand(0);
-    if (extractOp.getType() == source.getType())
+    Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
       return failure();
-    auto getRank = [](Type type) {
-      return llvm::isa<VectorType>(type)
-                 ? llvm::cast<VectorType>(type).getRank()
-                 : 0;
-    };
-    unsigned broadcastSrcRank = getRank(source.getType());
-    unsigned extractResultRank = getRank(extractOp.getType());
-    // We only consider the case where the rank of the source is less than or
-    // equal to the rank of the extract dst. The other cases are handled in the
-    // folding patterns.
-    if (extractResultRank < broadcastSrcRank)
-      return failure();
-    // For scalar result, the input can only be a rank-0 vector, which will
-    // be handled by the folder.
-    if (extractResultRank == 0)
+
+    Value source = broadcastLikeOp->getOperand(0);
+    if (isBroadcastableTo(source.getType(), outType) !=
+        BroadcastableToResult::Success)
       return failure();
 
-    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-        extractOp, extractOp.getType(), source);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..350233d1f7969 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -764,10 +764,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
 
 // -----
 
-// CHECK-LABEL: fold_extract_splat
+// CHECK-LABEL: fold_extract_scalar_from_splat
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.splat %a : vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
   return %r : f32
@@ -775,6 +775,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
 
 // -----
 
+// CHECK-LABEL: fold_extract_vector_from_splat
+//       CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
+func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
+  %b = vector.splat %a : vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
 //  CHECK-SAME:   %[[A:.*]]: vector<2x1xf32>
 //  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@@ -804,6 +814,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
 
 // -----
 
+// Test where the shape_cast is broadcast-like.
+// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
+//  CHECK-SAME:   %[[A:.*]]: vector<2x4xf32>
+//  CHECK-SAME:   %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
+//       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
+//       CHECK:   return %[[B]] : vector<4xf32>
+func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
+  %idx0 : index, %idx1 : index) -> vector<4xf32> {
+  %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
@@ -831,6 +856,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
+//  CHECK-SAME:   %[[A:.*]]: vector<1xf32>
+//       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+//       CHECK:   return %[[R]] : vector<1x1xf32>
+func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
+  -> vector<1x1xf32> {
+  %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
+  %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
+  return %r : vector<1x1xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_extract_shuffle
 //  CHECK-SAME:   %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
 //   CHECK-NOT:   vector.shuffle
@@ -1549,7 +1587,7 @@ func.func @negative_store_to_load_tensor_memref(
     %arg0 : tensor<?x?xf32>,
     %arg1 : memref<?x?xf32>,
     %v0 : vector<4x2xf32>
-  ) -> vector<4x2xf32> 
+  ) -> vector<4x2xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32
@@ -1606,7 +1644,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
 //       CHECK:   vector.transfer_read
 func.func @negative_store_to_load_tensor_broadcast_masked(
     %arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
-  -> vector<4x2x6xf32> 
+  -> vector<4x2x6xf32>
 {
   %c0 = arith.constant 0 : index
   %cf0 = arith.constant 0.0 : f32

>From da8e03a560fc0850b280b39cb82e0df8e86a4fce Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 30 Jun 2025 11:32:03 -0700
Subject: [PATCH 2/2] improve comments, add test

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 92 +++++++++++++++-------
 mlir/test/Dialect/Vector/canonicalize.mlir | 14 ++++
 2 files changed, 78 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ed616fb4d343b..9461ba02dd546 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1696,8 +1696,8 @@ static bool hasZeroDimVectors(Operation *op) {
          llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
 }
 
-/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
-/// 1s, are considered 'broadcastlike'.
+/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
+/// 1s, are considered to be 'broadcastlike'.
 static bool isBroadcastLike(Operation *op) {
   if (isa<BroadcastOp, SplatOp>(op))
     return true;
@@ -1706,9 +1706,12 @@ static bool isBroadcastLike(Operation *op) {
   if (!shapeCast)
     return false;
 
-  // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
+  // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
   // Condition 1: dst has hight rank.
   // Condition 2: src shape is a suffix of dst shape.
+  //
+  // Note that checking that dst shape has a prefix of 1s is not sufficient,
+  // for example (2,3) -> (1,3,2) is not broadcast-like.
   VectorType srcType = shapeCast.getSourceVectorType();
   ArrayRef<int64_t> srcShape = srcType.getShape();
   uint64_t srcRank = srcType.getRank();
@@ -1716,51 +1719,84 @@ static bool isBroadcastLike(Operation *op) {
   return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
 }
 
-/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
+/// Fold extract(broadcast(X)) to either extract(X) or just X.
+///
+/// Example:
+///
+///        broadcast           extract
+/// (3, 4) --------> (2, 3, 4) ------> (4)
+///
+/// becomes
+///                  extract
+/// (3,4) ---------------------------> (4)
+///
+///
+/// The variable names used in this implementation use names which correspond to
+/// the above shapes as,
+///
+/// - (3, 4) is `input` shape.
+/// - (2, 3, 4) is `broadcast` shape.
+/// - (4) is `extract` shape.
+///
+/// This folding is possible when the suffix of `input` shape is the same as
+/// `extract` shape.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
 
-  Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
-  if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
+  Operation *defOp = extractOp.getVector().getDefiningOp();
+  if (!defOp || !isBroadcastLike(defOp))
     return Value();
 
-  Value src = broadcastLikeOp->getOperand(0);
+  Value input = defOp->getOperand(0);
 
   // Replace extract(broadcast(X)) with X
-  if (extractOp.getType() == src.getType())
-    return src;
+  if (extractOp.getType() == input.getType())
+    return input;
 
   // Get required types and ranks in the chain
-  //    src -> broadcastDst -> dst
-  auto srcType = llvm::dyn_cast<VectorType>(src.getType());
-  auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
-  unsigned srcRank = srcType ? srcType.getRank() : 0;
-  unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
-  unsigned dstRank = dstType ? dstType.getRank() : 0;
+  //    input -> broadcast -> extract
+  auto inputType = llvm::dyn_cast<VectorType>(input.getType());
+  auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
+  unsigned inputRank = inputType ? inputType.getRank() : 0;
+  unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
+  unsigned extractRank = extractType ? extractType.getRank() : 0;
 
   // Cannot do without the broadcast if overall the rank increases.
-  if (dstRank > srcRank)
+  if (extractRank > inputRank)
     return Value();
 
-  assert(srcType && "src must be a vector type because of previous checks");
-
-  ArrayRef<int64_t> srcShape = srcType.getShape();
-  if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
+  // Proof by contradiction that, at this point, input is a vector.
+  //     Suppose input is a scalar.
+  // ==> inputRank is 0.
+  // ==> extractRank is 0 (because extractRank <= inputRank).
+  // ==> extract is scalar (because rank-0 extraction is always scalar).
+  // ==> input and extract are scalar, so same type.
+  // ==> returned early (check same type).
+  //     Contradiction!
+  assert(inputType && "input must be a vector type because of previous checks");
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+
+  // In the case where there is a broadcast dimension in the suffix, it is not
+  // possible to replace extract(broadcast(X)) with extract(X). Example:
+  //
+  //     broadcast       extract
+  // (1) --------> (3,4) ------> (4)
+  if (extractType &&
+      extractType.getShape() != inputShape.take_back(extractRank))
     return Value();
 
   // Replace extract(broadcast(X)) with extract(X).
   // First, determine the new extraction position.
-  unsigned deltaOverall = srcRank - dstRank;
-  unsigned deltaBroadcast = broadcastDstRank - srcRank;
-
+  unsigned deltaOverall = inputRank - extractRank;
+  unsigned deltaBroadcast = broadcastRank - inputRank;
   SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
   SmallVector<OpFoldResult> newPositions(deltaOverall);
   IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
-  for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
+  for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
     newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
   }
   auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
   extractOp->setOperands(
-      llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
+      llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
   extractOp.setStaticPosition(staticPos);
   return extractOp.getResult();
 }
@@ -2206,12 +2242,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
   LogicalResult matchAndRewrite(ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
 
-    Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+    Operation *defOp = extractOp.getVector().getDefiningOp();
     VectorType outType = dyn_cast<VectorType>(extractOp.getType());
-    if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
+    if (!defOp || !isBroadcastLike(defOp) || !outType)
       return failure();
 
-    Value source = broadcastLikeOp->getOperand(0);
+    Value source = defOp->getOperand(0);
     if (isBroadcastableTo(source.getType(), outType) !=
         BroadcastableToResult::Success)
       return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 350233d1f7969..c7d9074b853f9 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -829,6 +829,20 @@ func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
 
 // -----
 
+// Test where the shape_cast is not broadcast-like, even though it prepends 1s.
+// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank
+//  CHECK-NEXT: vector.shape_cast
+//  CHECK-NEXT: vector.extract
+//  CHECK-NEXT: return
+func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
+  %idx0 : index, %idx1 : index) -> vector<2xf32> {
+  %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32>
+  return %r : vector<2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_to_higher_rank
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>



More information about the Mlir-commits mailing list