[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)
James Newling
llvmlistbot at llvm.org
Mon Jun 30 08:27:40 PDT 2025
https://github.com/newling created https://github.com/llvm/llvm-project/pull/146368
In https://github.com/llvm/llvm-project/pull/140583 more shape_cast ops will appear. Specifically broadcasts that just prepend ones become shape_cast ops (i.e. volume preserving broadcasts are canonicalized to shape_casts). This PR ensures that broadcast-like shape_cast ops fold at least as well as broadcast ops.
>From c129bffd0f8a1e47f49f8e8a72401160cb6a69f5 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 30 Jun 2025 08:26:21 -0700
Subject: [PATCH] extend to broadcastlike, code simplifications
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 123 ++++++++++-----------
mlir/test/Dialect/Vector/canonicalize.mlir | 46 +++++++-
2 files changed, 101 insertions(+), 68 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a11dbe2589205..e4da65252c6e3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1696,59 +1696,68 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
+/// All BroadcastOps and SplatOps, and ShapeCastOps that only prepends 1s, are
+/// considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+ if (isa<BroadcastOp, SplatOp>(op))
+ return true;
+
+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
+ if (!shapeCast)
+ return false;
+
+ VectorType srcType = shapeCast.getSourceVectorType();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ uint64_t srcRank = srcType.getRank();
+ ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+ return dstShape.size() <= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- Operation *defOp = extractOp.getVector().getDefiningOp();
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+
+ Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+ if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
return Value();
- Value source = defOp->getOperand(0);
- if (extractOp.getType() == source.getType())
- return source;
- auto getRank = [](Type type) {
- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
- : 0;
- };
+ Value src = broadcastLikeOp->getOperand(0);
+
+ // Replace extract(broadcast(X)) with X
+ if (extractOp.getType() == src.getType())
+ return src;
- // If splat or broadcast from a scalar, just return the source scalar.
- unsigned broadcastSrcRank = getRank(source.getType());
- if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
- return source;
+ // Get required types and ranks in the chain
+ // src -> broadcastDst -> dst
+ auto srcType = llvm::dyn_cast<VectorType>(src.getType());
+ auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
+ unsigned srcRank = srcType ? srcType.getRank() : 0;
+ unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
+ unsigned dstRank = dstType ? dstType.getRank() : 0;
- unsigned extractResultRank = getRank(extractOp.getType());
- if (extractResultRank > broadcastSrcRank)
+ // Cannot do without the broadcast if overall the rank increases.
+ if (dstRank > srcRank)
return Value();
- // Check that the dimension of the result haven't been broadcasted.
- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
- if (extractVecType && broadcastVecType &&
- extractVecType.getShape() !=
- broadcastVecType.getShape().take_back(extractResultRank))
+
+ assert(srcType && "src must be a vector type because of previous checks");
+
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
return Value();
- auto broadcastOp = cast<vector::BroadcastOp>(defOp);
- int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
+ // Replace extract(broadcast(X)) with extract(X).
+ // First, determine the new extraction position.
+ unsigned deltaOverall = srcRank - dstRank;
+ unsigned deltaBroadcast = broadcastDstRank - srcRank;
- // Detect all the positions that come from "dim-1" broadcasting.
- // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
- // extract position to `0` when extracting from the source operand.
- llvm::SetVector<int64_t> broadcastedUnitDims =
- broadcastOp.computeBroadcastedUnitDims();
- SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
- OpBuilder b(extractOp.getContext());
- int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
- for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
- if (broadcastedUnitDims.contains(i))
- extractPos[i] = b.getIndexAttr(0);
- // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
- // matching extract position when extracting from the source operand.
- int64_t rankDiff = broadcastSrcRank - extractResultRank;
- extractPos.erase(extractPos.begin(),
- std::next(extractPos.begin(), extractPos.size() - rankDiff));
- // OpBuilder is only used as a helper to build an I64ArrayAttr.
- auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
+ SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
+ SmallVector<OpFoldResult> newPositions(deltaOverall);
+ IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
+ for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
+ newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
+ }
+ auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
extractOp->setOperands(
- llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
+ llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
extractOp.setStaticPosition(staticPos);
return extractOp.getResult();
}
@@ -2193,32 +2202,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- Operation *defOp = extractOp.getVector().getDefiningOp();
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
- return failure();
- Value source = defOp->getOperand(0);
- if (extractOp.getType() == source.getType())
+ Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
return failure();
- auto getRank = [](Type type) {
- return llvm::isa<VectorType>(type)
- ? llvm::cast<VectorType>(type).getRank()
- : 0;
- };
- unsigned broadcastSrcRank = getRank(source.getType());
- unsigned extractResultRank = getRank(extractOp.getType());
- // We only consider the case where the rank of the source is less than or
- // equal to the rank of the extract dst. The other cases are handled in the
- // folding patterns.
- if (extractResultRank < broadcastSrcRank)
- return failure();
- // For scalar result, the input can only be a rank-0 vector, which will
- // be handled by the folder.
- if (extractResultRank == 0)
+
+ Value source = broadcastLikeOp->getOperand(0);
+ if (isBroadcastableTo(source.getType(), outType) !=
+ BroadcastableToResult::Success)
return failure();
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- extractOp, extractOp.getType(), source);
+ rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..350233d1f7969 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -764,10 +764,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
// -----
-// CHECK-LABEL: fold_extract_splat
+// CHECK-LABEL: fold_extract_scalar_from_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
@@ -775,6 +775,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
// -----
+// CHECK-LABEL: fold_extract_vector_from_splat
+// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
+func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
+ %b = vector.splat %a : vector<1x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@@ -804,6 +814,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
// -----
+// Test where the shape_cast is broadcast-like.
+// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
+// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
+// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
+// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
+// CHECK: return %[[B]] : vector<4xf32>
+func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
+ %idx0 : index, %idx1 : index) -> vector<4xf32> {
+ %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
@@ -831,6 +856,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
// -----
+// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>
+// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: return %[[R]] : vector<1x1xf32>
+func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
+ -> vector<1x1xf32> {
+ %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
+ %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
+ return %r : vector<1x1xf32>
+}
+
+// -----
+
// CHECK-LABEL: @fold_extract_shuffle
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
// CHECK-NOT: vector.shuffle
@@ -1549,7 +1587,7 @@ func.func @negative_store_to_load_tensor_memref(
%arg0 : tensor<?x?xf32>,
%arg1 : memref<?x?xf32>,
%v0 : vector<4x2xf32>
- ) -> vector<4x2xf32>
+ ) -> vector<4x2xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
@@ -1606,7 +1644,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_masked(
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
- -> vector<4x2x6xf32>
+ -> vector<4x2x6xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
More information about the Mlir-commits
mailing list