[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)
James Newling
llvmlistbot at llvm.org
Fri Jul 18 10:11:46 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/146368
>From 09ba159afe75e1ff476ff82d51668471699d40ed Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 30 Jun 2025 09:15:48 -0700
Subject: [PATCH 1/6] extend to broadcastlike, code simplifications
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 126 ++++++++++-----------
mlir/test/Dialect/Vector/canonicalize.mlir | 46 +++++++-
2 files changed, 104 insertions(+), 68 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7d615bfc12984..cfad95a7aee79 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1707,59 +1707,71 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
+/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
+/// 1s, are considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+ if (isa<BroadcastOp, SplatOp>(op))
+ return true;
+
+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
+ if (!shapeCast)
+ return false;
+
+ // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
+ // Condition 1: dst has hight rank.
+ // Condition 2: src shape is a suffix of dst shape.
+ VectorType srcType = shapeCast.getSourceVectorType();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ uint64_t srcRank = srcType.getRank();
+ ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+ return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- Operation *defOp = extractOp.getVector().getDefiningOp();
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+
+ Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+ if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
return Value();
- Value source = defOp->getOperand(0);
- if (extractOp.getType() == source.getType())
- return source;
- auto getRank = [](Type type) {
- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
- : 0;
- };
+ Value src = broadcastLikeOp->getOperand(0);
+
+ // Replace extract(broadcast(X)) with X
+ if (extractOp.getType() == src.getType())
+ return src;
- // If splat or broadcast from a scalar, just return the source scalar.
- unsigned broadcastSrcRank = getRank(source.getType());
- if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
- return source;
+ // Get required types and ranks in the chain
+ // src -> broadcastDst -> dst
+ auto srcType = llvm::dyn_cast<VectorType>(src.getType());
+ auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
+ unsigned srcRank = srcType ? srcType.getRank() : 0;
+ unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
+ unsigned dstRank = dstType ? dstType.getRank() : 0;
- unsigned extractResultRank = getRank(extractOp.getType());
- if (extractResultRank > broadcastSrcRank)
+ // Cannot do without the broadcast if overall the rank increases.
+ if (dstRank > srcRank)
return Value();
- // Check that the dimension of the result haven't been broadcasted.
- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
- if (extractVecType && broadcastVecType &&
- extractVecType.getShape() !=
- broadcastVecType.getShape().take_back(extractResultRank))
+
+ assert(srcType && "src must be a vector type because of previous checks");
+
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
return Value();
- auto broadcastOp = cast<vector::BroadcastOp>(defOp);
- int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
+ // Replace extract(broadcast(X)) with extract(X).
+ // First, determine the new extraction position.
+ unsigned deltaOverall = srcRank - dstRank;
+ unsigned deltaBroadcast = broadcastDstRank - srcRank;
- // Detect all the positions that come from "dim-1" broadcasting.
- // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
- // extract position to `0` when extracting from the source operand.
- llvm::SetVector<int64_t> broadcastedUnitDims =
- broadcastOp.computeBroadcastedUnitDims();
- SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
- OpBuilder b(extractOp.getContext());
- int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
- for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
- if (broadcastedUnitDims.contains(i))
- extractPos[i] = b.getIndexAttr(0);
- // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
- // matching extract position when extracting from the source operand.
- int64_t rankDiff = broadcastSrcRank - extractResultRank;
- extractPos.erase(extractPos.begin(),
- std::next(extractPos.begin(), extractPos.size() - rankDiff));
- // OpBuilder is only used as a helper to build an I64ArrayAttr.
- auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
+ SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
+ SmallVector<OpFoldResult> newPositions(deltaOverall);
+ IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
+ for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
+ newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
+ }
+ auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
extractOp->setOperands(
- llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
+ llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
extractOp.setStaticPosition(staticPos);
return extractOp.getResult();
}
@@ -2204,32 +2216,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- Operation *defOp = extractOp.getVector().getDefiningOp();
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
- return failure();
- Value source = defOp->getOperand(0);
- if (extractOp.getType() == source.getType())
+ Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
return failure();
- auto getRank = [](Type type) {
- return llvm::isa<VectorType>(type)
- ? llvm::cast<VectorType>(type).getRank()
- : 0;
- };
- unsigned broadcastSrcRank = getRank(source.getType());
- unsigned extractResultRank = getRank(extractOp.getType());
- // We only consider the case where the rank of the source is less than or
- // equal to the rank of the extract dst. The other cases are handled in the
- // folding patterns.
- if (extractResultRank < broadcastSrcRank)
- return failure();
- // For scalar result, the input can only be a rank-0 vector, which will
- // be handled by the folder.
- if (extractResultRank == 0)
+
+ Value source = broadcastLikeOp->getOperand(0);
+ if (isBroadcastableTo(source.getType(), outType) !=
+ BroadcastableToResult::Success)
return failure();
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- extractOp, extractOp.getType(), source);
+ rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index ea2343efd246e..6ed64cb8313c2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -823,10 +823,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
// -----
-// CHECK-LABEL: fold_extract_splat
+// CHECK-LABEL: fold_extract_scalar_from_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
@@ -834,6 +834,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
// -----
+// CHECK-LABEL: fold_extract_vector_from_splat
+// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
+func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
+ %b = vector.splat %a : vector<1x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@@ -863,6 +873,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
// -----
+// Test where the shape_cast is broadcast-like.
+// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
+// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
+// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
+// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
+// CHECK: return %[[B]] : vector<4xf32>
+func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
+ %idx0 : index, %idx1 : index) -> vector<4xf32> {
+ %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ return %r : vector<4xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
@@ -890,6 +915,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
// -----
+// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
+// CHECK-SAME: %[[A:.*]]: vector<1xf32>
+// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
+// CHECK: return %[[R]] : vector<1x1xf32>
+func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
+ -> vector<1x1xf32> {
+ %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
+ %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
+ return %r : vector<1x1xf32>
+}
+
+// -----
+
// CHECK-LABEL: @fold_extract_shuffle
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
// CHECK-NOT: vector.shuffle
@@ -1623,7 +1661,7 @@ func.func @negative_store_to_load_tensor_memref(
%arg0 : tensor<?x?xf32>,
%arg1 : memref<?x?xf32>,
%v0 : vector<4x2xf32>
- ) -> vector<4x2xf32>
+ ) -> vector<4x2xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
@@ -1680,7 +1718,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
// CHECK: vector.transfer_read
func.func @negative_store_to_load_tensor_broadcast_masked(
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
- -> vector<4x2x6xf32>
+ -> vector<4x2x6xf32>
{
%c0 = arith.constant 0 : index
%cf0 = arith.constant 0.0 : f32
>From 1c46b4eab4b8d1cc6000e0e78de13a5fa7ec9153 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 30 Jun 2025 11:32:03 -0700
Subject: [PATCH 2/6] improve comments, add test
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 92 +++++++++++++++-------
mlir/test/Dialect/Vector/canonicalize.mlir | 14 ++++
2 files changed, 78 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cfad95a7aee79..3ea8d0eb784c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1707,8 +1707,8 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
-/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
-/// 1s, are considered 'broadcastlike'.
+/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
+/// 1s, are considered to be 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
if (isa<BroadcastOp, SplatOp>(op))
return true;
@@ -1717,9 +1717,12 @@ static bool isBroadcastLike(Operation *op) {
if (!shapeCast)
return false;
- // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
+ // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
// Condition 1: dst has hight rank.
// Condition 2: src shape is a suffix of dst shape.
+ //
+ // Note that checking that dst shape has a prefix of 1s is not sufficient,
+ // for example (2,3) -> (1,3,2) is not broadcast-like.
VectorType srcType = shapeCast.getSourceVectorType();
ArrayRef<int64_t> srcShape = srcType.getShape();
uint64_t srcRank = srcType.getRank();
@@ -1727,51 +1730,84 @@ static bool isBroadcastLike(Operation *op) {
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
}
-/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
+/// Fold extract(broadcast(X)) to either extract(X) or just X.
+///
+/// Example:
+///
+/// broadcast extract
+/// (3, 4) --------> (2, 3, 4) ------> (4)
+///
+/// becomes
+/// extract
+/// (3,4) ---------------------------> (4)
+///
+///
+/// The variable names used in this implementation use names which correspond to
+/// the above shapes as,
+///
+/// - (3, 4) is `input` shape.
+/// - (2, 3, 4) is `broadcast` shape.
+/// - (4) is `extract` shape.
+///
+/// This folding is possible when the suffix of `input` shape is the same as
+/// `extract` shape.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
- if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
+ Operation *defOp = extractOp.getVector().getDefiningOp();
+ if (!defOp || !isBroadcastLike(defOp))
return Value();
- Value src = broadcastLikeOp->getOperand(0);
+ Value input = defOp->getOperand(0);
// Replace extract(broadcast(X)) with X
- if (extractOp.getType() == src.getType())
- return src;
+ if (extractOp.getType() == input.getType())
+ return input;
// Get required types and ranks in the chain
- // src -> broadcastDst -> dst
- auto srcType = llvm::dyn_cast<VectorType>(src.getType());
- auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
- unsigned srcRank = srcType ? srcType.getRank() : 0;
- unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
- unsigned dstRank = dstType ? dstType.getRank() : 0;
+ // input -> broadcast -> extract
+ auto inputType = llvm::dyn_cast<VectorType>(input.getType());
+ auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
+ unsigned inputRank = inputType ? inputType.getRank() : 0;
+ unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
+ unsigned extractRank = extractType ? extractType.getRank() : 0;
// Cannot do without the broadcast if overall the rank increases.
- if (dstRank > srcRank)
+ if (extractRank > inputRank)
return Value();
- assert(srcType && "src must be a vector type because of previous checks");
-
- ArrayRef<int64_t> srcShape = srcType.getShape();
- if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
+ // Proof by contradiction that, at this point, input is a vector.
+ // Suppose input is a scalar.
+ // ==> inputRank is 0.
+ // ==> extractRank is 0 (because extractRank <= inputRank).
+ // ==> extract is scalar (because rank-0 extraction is always scalar).
+ // ==> input and extract are scalar, so same type.
+ // ==> returned early (check same type).
+ // Contradiction!
+ assert(inputType && "input must be a vector type because of previous checks");
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+
+ // In the case where there is a broadcast dimension in the suffix, it is not
+ // possible to replace extract(broadcast(X)) with extract(X). Example:
+ //
+ // broadcast extract
+ // (1) --------> (3,4) ------> (4)
+ if (extractType &&
+ extractType.getShape() != inputShape.take_back(extractRank))
return Value();
// Replace extract(broadcast(X)) with extract(X).
// First, determine the new extraction position.
- unsigned deltaOverall = srcRank - dstRank;
- unsigned deltaBroadcast = broadcastDstRank - srcRank;
-
+ unsigned deltaOverall = inputRank - extractRank;
+ unsigned deltaBroadcast = broadcastRank - inputRank;
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
SmallVector<OpFoldResult> newPositions(deltaOverall);
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
- for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
+ for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
}
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
extractOp->setOperands(
- llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
+ llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
extractOp.setStaticPosition(staticPos);
return extractOp.getResult();
}
@@ -2217,12 +2253,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
- Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+ Operation *defOp = extractOp.getVector().getDefiningOp();
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
- if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
+ if (!defOp || !isBroadcastLike(defOp) || !outType)
return failure();
- Value source = broadcastLikeOp->getOperand(0);
+ Value source = defOp->getOperand(0);
if (isBroadcastableTo(source.getType(), outType) !=
BroadcastableToResult::Success)
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6ed64cb8313c2..6809122974545 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -888,6 +888,20 @@ func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
// -----
+// Test where the shape_cast is not broadcast-like, even though it prepends 1s.
+// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank
+// CHECK-NEXT: vector.shape_cast
+// CHECK-NEXT: vector.extract
+// CHECK-NEXT: return
+func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
+ %idx0 : index, %idx1 : index) -> vector<2xf32> {
+ %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32>
+ return %r : vector<2xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
>From 302cb34913dc99f668b98742799f172e6292bb80 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 18 Jul 2025 09:34:11 -0700
Subject: [PATCH 3/6] comment improvements
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 35 ++++++++++++------------
1 file changed, 17 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3ea8d0eb784c1..31dba8781745f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1718,11 +1718,9 @@ static bool isBroadcastLike(Operation *op) {
return false;
// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
- // Condition 1: dst has hight rank.
- // Condition 2: src shape is a suffix of dst shape.
- //
// Note that checking that dst shape has a prefix of 1s is not sufficient,
- // for example (2,3) -> (1,3,2) is not broadcast-like.
+ // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition
+ // is that the source shape is a suffix of the destination shape.
VectorType srcType = shapeCast.getSourceVectorType();
ArrayRef<int64_t> srcShape = srcType.getShape();
uint64_t srcRank = srcType.getRank();
@@ -1734,16 +1732,16 @@ static bool isBroadcastLike(Operation *op) {
///
/// Example:
///
-/// broadcast extract
-/// (3, 4) --------> (2, 3, 4) ------> (4)
+/// broadcast extract [1][2]
+/// (3, 4) --------> (2, 3, 4) ----------------> (4)
///
/// becomes
-/// extract
-/// (3,4) ---------------------------> (4)
+/// extract [1]
+/// (3,4) -------------------------------------> (4)
///
///
-/// The variable names used in this implementation use names which correspond to
-/// the above shapes as,
+/// The variable names used in this implementation correspond to the above
+/// shapes as,
///
/// - (3, 4) is `input` shape.
/// - (2, 3, 4) is `broadcast` shape.
@@ -1775,14 +1773,15 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
if (extractRank > inputRank)
return Value();
- // Proof by contradiction that, at this point, input is a vector.
- // Suppose input is a scalar.
- // ==> inputRank is 0.
- // ==> extractRank is 0 (because extractRank <= inputRank).
- // ==> extract is scalar (because rank-0 extraction is always scalar).
- // ==> input and extract are scalar, so same type.
- // ==> returned early (check same type).
- // Contradiction!
+ // The above condition guarantees that input is a vector:
+ //
+ // If input is a scalar:
+ // 1) inputRank is 0, so
+ // 2) extractRank is 0 (because extractRank <= inputRank), so
+ // 3) extract is scalar (because rank-0 extraction is always scalar), s0
+ // 4) input and extract are scalar, so same type.
+ // But then we should have returned earlier when the types were compared for
+ // equivalence. So input is not a scalar at this point.
assert(inputType && "input must be a vector type because of previous checks");
ArrayRef<int64_t> inputShape = inputType.getShape();
>From 8c85bc7a0959c9cb67819e6251e0edb230ef2c05 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 18 Jul 2025 09:40:07 -0700
Subject: [PATCH 4/6] remove lengthy explanation
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++----------
1 file changed, 2 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 31dba8781745f..7723665926295 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1718,7 +1718,7 @@ static bool isBroadcastLike(Operation *op) {
return false;
// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
- // Note that checking that dst shape has a prefix of 1s is not sufficient,
+ // Checking that the destination shape has a prefix of 1s is not sufficient,
// for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition
// is that the source shape is a suffix of the destination shape.
VectorType srcType = shapeCast.getSourceVectorType();
@@ -1773,15 +1773,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
if (extractRank > inputRank)
return Value();
- // The above condition guarantees that input is a vector:
- //
- // If input is a scalar:
- // 1) inputRank is 0, so
- // 2) extractRank is 0 (because extractRank <= inputRank), so
- // 3) extract is scalar (because rank-0 extraction is always scalar), s0
- // 4) input and extract are scalar, so same type.
- // But then we should have returned earlier when the types were compared for
- // equivalence. So input is not a scalar at this point.
+ // The above condition guarantees that input is a vector.
assert(inputType && "input must be a vector type because of previous checks");
ArrayRef<int64_t> inputShape = inputType.getShape();
>From cb306132fa2e1f2ae46978a269358627d64966f3 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 18 Jul 2025 09:42:18 -0700
Subject: [PATCH 5/6] broadcastlike vs broadcast-like
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 7723665926295..01eedceafb275 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1719,7 +1719,7 @@ static bool isBroadcastLike(Operation *op) {
// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
// Checking that the destination shape has a prefix of 1s is not sufficient,
- // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition
+ // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
// is that the source shape is a suffix of the destination shape.
VectorType srcType = shapeCast.getSourceVectorType();
ArrayRef<int64_t> srcShape = srcType.getShape();
>From 15f78d32c09644d28b453e4d1e3dab52fba4ddd9 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 18 Jul 2025 10:12:45 -0700
Subject: [PATCH 6/6] test simplification
---
mlir/test/Conversion/VectorToSCF/funk.mlir | 755 ++++++++++++++++++
.../Conversion/VectorToSCF/vector-to-scf.mlir | 3 +-
2 files changed, 756 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Conversion/VectorToSCF/funk.mlir
diff --git a/mlir/test/Conversion/VectorToSCF/funk.mlir b/mlir/test/Conversion/VectorToSCF/funk.mlir
new file mode 100644
index 0000000000000..556814cd04792
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSCF/funk.mlir
@@ -0,0 +1,755 @@
+module {
+ func.func @vector_transfer_ops_0d(%arg0: memref<f32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.transfer_read %arg0[], %cst : memref<f32>, vector<f32>
+ vector.transfer_write %0, %arg0[] : vector<f32>, memref<f32>
+ return
+ }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1)>
+#map1 = affine_map<(d0) -> (d0 + 1)>
+#map2 = affine_map<(d0) -> (d0 + 2)>
+#map3 = affine_map<(d0) -> (d0 + 3)>
+module {
+ func.func @materialize_read_1d() {
+ %c7 = arith.constant 7 : index
+ %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() : memref<7x42xf32>
+ affine.for %arg0 = 0 to 7 step 4 {
+ affine.for %arg1 = 0 to 42 step 4 {
+ %0 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %cst) -> (vector<4xf32>) {
+ %7 = affine.apply #map(%arg0, %arg2)
+ %8 = affine.apply #map(%arg0, %arg2)
+ %9 = arith.cmpi slt, %8, %c7 : index
+ %10 = scf.if %9 -> (vector<4xf32>) {
+ %11 = memref.load %alloc[%7, %arg1] : memref<7x42xf32>
+ %12 = vector.insert %11, %arg3 [%arg2] : f32 into vector<4xf32>
+ scf.yield %12 : vector<4xf32>
+ } else {
+ scf.yield %arg3 : vector<4xf32>
+ }
+ scf.yield %10 : vector<4xf32>
+ }
+ %1 = affine.apply #map1(%arg1)
+ %2 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %cst) -> (vector<4xf32>) {
+ %7 = affine.apply #map(%arg0, %arg2)
+ %8 = affine.apply #map(%arg0, %arg2)
+ %9 = arith.cmpi slt, %8, %c7 : index
+ %10 = scf.if %9 -> (vector<4xf32>) {
+ %11 = memref.load %alloc[%7, %1] : memref<7x42xf32>
+ %12 = vector.insert %11, %arg3 [%arg2] : f32 into vector<4xf32>
+ scf.yield %12 : vector<4xf32>
+ } else {
+ scf.yield %arg3 : vector<4xf32>
+ }
+ scf.yield %10 : vector<4xf32>
+ }
+ %3 = affine.apply #map2(%arg1)
+ %4 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %cst) -> (vector<4xf32>) {
+ %7 = affine.apply #map(%arg0, %arg2)
+ %8 = affine.apply #map(%arg0, %arg2)
+ %9 = arith.cmpi slt, %8, %c7 : index
+ %10 = scf.if %9 -> (vector<4xf32>) {
+ %11 = memref.load %alloc[%7, %3] : memref<7x42xf32>
+ %12 = vector.insert %11, %arg3 [%arg2] : f32 into vector<4xf32>
+ scf.yield %12 : vector<4xf32>
+ } else {
+ scf.yield %arg3 : vector<4xf32>
+ }
+ scf.yield %10 : vector<4xf32>
+ }
+ %5 = affine.apply #map3(%arg1)
+ %6 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %cst) -> (vector<4xf32>) {
+ %7 = affine.apply #map(%arg0, %arg2)
+ %8 = affine.apply #map(%arg0, %arg2)
+ %9 = arith.cmpi slt, %8, %c7 : index
+ %10 = scf.if %9 -> (vector<4xf32>) {
+ %11 = memref.load %alloc[%7, %5] : memref<7x42xf32>
+ %12 = vector.insert %11, %arg3 [%arg2] : f32 into vector<4xf32>
+ scf.yield %12 : vector<4xf32>
+ } else {
+ scf.yield %arg3 : vector<4xf32>
+ }
+ scf.yield %10 : vector<4xf32>
+ }
+ "dummy_use"(%0, %2, %4, %6) : (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) -> ()
+ }
+ }
+ return
+ }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1)>
+#map1 = affine_map<(d0, d1) -> (d0 + d1 + 1)>
+module {
+ func.func @materialize_read_1d_partially_specialized(%arg0: index, %arg1: index, %arg2: index) {
+ %c42 = arith.constant 42 : index
+ %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc(%arg0, %arg1, %arg2) : memref<7x?x?x42x?xf32>
+ affine.for %arg3 = 0 to 7 {
+ affine.for %arg4 = 0 to %arg0 {
+ affine.for %arg5 = 0 to %arg1 {
+ affine.for %arg6 = 0 to 42 step 2 {
+ affine.for %arg7 = 0 to %arg2 {
+ %0 = scf.for %arg8 = %c0 to %c4 step %c1 iter_args(%arg9 = %cst) -> (vector<4xf32>) {
+ %2 = affine.apply #map(%arg6, %arg8)
+ %3 = affine.apply #map(%arg6, %arg8)
+ %4 = arith.cmpi slt, %3, %c42 : index
+ %5 = scf.if %4 -> (vector<4xf32>) {
+ %6 = memref.load %alloc[%arg3, %arg4, %arg5, %2, %arg7] : memref<7x?x?x42x?xf32>
+ %7 = vector.insert %6, %arg9 [%arg8] : f32 into vector<4xf32>
+ scf.yield %7 : vector<4xf32>
+ } else {
+ scf.yield %arg9 : vector<4xf32>
+ }
+ scf.yield %5 : vector<4xf32>
+ }
+ %1 = scf.for %arg8 = %c0 to %c4 step %c1 iter_args(%arg9 = %cst) -> (vector<4xf32>) {
+ %2 = affine.apply #map1(%arg8, %arg6)
+ %3 = affine.apply #map1(%arg8, %arg6)
+ %4 = arith.cmpi slt, %3, %c42 : index
+ %5 = scf.if %4 -> (vector<4xf32>) {
+ %6 = memref.load %alloc[%arg3, %arg4, %arg5, %2, %arg7] : memref<7x?x?x42x?xf32>
+ %7 = vector.insert %6, %arg9 [%arg8] : f32 into vector<4xf32>
+ scf.yield %7 : vector<4xf32>
+ } else {
+ scf.yield %arg9 : vector<4xf32>
+ }
+ scf.yield %5 : vector<4xf32>
+ }
+ "dummy_use"(%0, %1) : (vector<4xf32>, vector<4xf32>) -> ()
+ }
+ }
+ }
+ }
+ }
+ return
+ }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1)>
+module {
+ func.func @materialize_read(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
+ %cst = arith.constant dense<0.000000e+00> : vector<3xf32>
+ %c3 = arith.constant 3 : index
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<4x3xf32>
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ %c5 = arith.constant 5 : index
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
+ affine.for %arg4 = 0 to %arg0 step 3 {
+ affine.for %arg5 = 0 to %arg1 {
+ affine.for %arg6 = 0 to %arg2 {
+ affine.for %arg7 = 0 to %arg3 step 5 {
+ %alloca = memref.alloca() : memref<vector<5x4x3xf32>>
+ %0 = vector.type_cast %alloca : memref<vector<5x4x3xf32>> to memref<5xvector<4x3xf32>>
+ scf.for %arg8 = %c0 to %c5 step %c1 {
+ %2 = affine.apply #map(%arg7, %arg8)
+ %3 = arith.cmpi sgt, %arg3, %2 : index
+ scf.if %3 {
+ %4 = affine.apply #map(%arg7, %arg8)
+ %5 = vector.type_cast %0 : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
+ scf.for %arg9 = %c0 to %c4 step %c1 {
+ %6 = scf.for %arg10 = %c0 to %c3 step %c1 iter_args(%arg11 = %cst) -> (vector<3xf32>) {
+ %7 = affine.apply #map(%arg4, %arg10)
+ %8 = affine.apply #map(%arg4, %arg10)
+ %9 = arith.cmpi sgt, %arg0, %8 : index
+ %10 = scf.if %9 -> (vector<3xf32>) {
+ %11 = memref.load %alloc[%7, %arg5, %arg6, %4] : memref<?x?x?x?xf32>
+ %12 = vector.insert %11, %arg11 [%arg10] : f32 into vector<3xf32>
+ scf.yield %12 : vector<3xf32>
+ } else {
+ scf.yield %arg11 : vector<3xf32>
+ }
+ scf.yield %10 : vector<3xf32>
+ }
+ memref.store %6, %5[%arg8, %arg9] : memref<5x4xvector<3xf32>>
+ }
+ } else {
+ memref.store %cst_0, %0[%arg8] : memref<5xvector<4x3xf32>>
+ }
+ }
+ %1 = memref.load %alloca[] : memref<vector<5x4x3xf32>>
+ "dummy_use"(%1) : (vector<5x4x3xf32>) -> ()
+ }
+ }
+ }
+ }
+ return
+ }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1)>
+module {
+ func.func @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<1.000000e+00> : vector<3x4x1x5xf32>
+ %alloc = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
+ affine.for %arg4 = 0 to %arg0 step 3 {
+ affine.for %arg5 = 0 to %arg1 step 4 {
+ affine.for %arg6 = 0 to %arg2 {
+ affine.for %arg7 = 0 to %arg3 step 5 {
+ %alloca = memref.alloca() : memref<vector<3x4x1x5xf32>>
+ memref.store %cst, %alloca[] : memref<vector<3x4x1x5xf32>>
+ %0 = vector.type_cast %alloca : memref<vector<3x4x1x5xf32>> to memref<3xvector<4x1x5xf32>>
+ scf.for %arg8 = %c0 to %c3 step %c1 {
+ %1 = affine.apply #map(%arg4, %arg8)
+ %2 = arith.cmpi sgt, %arg0, %1 : index
+ scf.if %2 {
+ %3 = affine.apply #map(%arg4, %arg8)
+ %4 = vector.type_cast %0 : memref<3xvector<4x1x5xf32>> to memref<3x4xvector<1x5xf32>>
+ scf.for %arg9 = %c0 to %c4 step %c1 {
+ %5 = affine.apply #map(%arg5, %arg9)
+ %6 = arith.cmpi sgt, %arg1, %5 : index
+ scf.if %6 {
+ %7 = affine.apply #map(%arg5, %arg9)
+ %8 = vector.type_cast %4 : memref<3x4xvector<1x5xf32>> to memref<3x4x1xvector<5xf32>>
+ scf.for %arg10 = %c0 to %c1 step %c1 {
+ %9 = affine.apply #map(%arg6, %arg10)
+ %10 = memref.load %8[%arg8, %arg9, %arg10] : memref<3x4x1xvector<5xf32>>
+ vector.transfer_write %10, %alloc[%3, %7, %9, %arg7] : vector<5xf32>, memref<?x?x?x?xf32>
+ }
+ } else {
+ }
+ }
+ } else {
+ }
+ }
+ }
+ }
+ }
+ }
+ return
+ }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+ func.func @transfer_read_progressive(%arg0: memref<?x?xf32>, %arg1: index) -> vector<3x15xf32> {
+ %cst = arith.constant dense<7.000000e+00> : vector<15xf32>
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c0 = arith.constant 0 : index
+ %cst_0 = arith.constant 7.000000e+00 : f32
+ %alloca = memref.alloca() : memref<vector<3x15xf32>>
+ %0 = vector.type_cast %alloca : memref<vector<3x15xf32>> to memref<3xvector<15xf32>>
+ scf.for %arg2 = %c0 to %c3 step %c1 {
+ %dim = memref.dim %arg0, %c0 : memref<?x?xf32>
+ %2 = affine.apply #map(%arg2)[%arg1]
+ %3 = arith.cmpi sgt, %dim, %2 : index
+ scf.if %3 {
+ %4 = affine.apply #map(%arg2)[%arg1]
+ %5 = vector.transfer_read %arg0[%4, %arg1], %cst_0 : memref<?x?xf32>, vector<15xf32>
+ memref.store %5, %0[%arg2] : memref<3xvector<15xf32>>
+ } else {
+ memref.store %cst, %0[%arg2] : memref<3xvector<15xf32>>
+ }
+ }
+ %1 = memref.load %alloca[] : memref<vector<3x15xf32>>
+ return %1 : vector<3x15xf32>
+ }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+ func.func @transfer_write_progressive(%arg0: memref<?x?xf32>, %arg1: index, %arg2: vector<3x15xf32>) {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<3x15xf32>>
+ memref.store %arg2, %alloca[] : memref<vector<3x15xf32>>
+ %0 = vector.type_cast %alloca : memref<vector<3x15xf32>> to memref<3xvector<15xf32>>
+ scf.for %arg3 = %c0 to %c3 step %c1 {
+ %dim = memref.dim %arg0, %c0 : memref<?x?xf32>
+ %1 = affine.apply #map(%arg3)[%arg1]
+ %2 = arith.cmpi sgt, %dim, %1 : index
+ scf.if %2 {
+ %3 = affine.apply #map(%arg3)[%arg1]
+ %4 = memref.load %0[%arg3] : memref<3xvector<15xf32>>
+ vector.transfer_write %4, %arg0[%3, %arg1] : vector<15xf32>, memref<?x?xf32>
+ } else {
+ }
+ }
+ return
+ }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+ func.func @transfer_write_progressive_inbounds(%arg0: memref<?x?xf32>, %arg1: index, %arg2: vector<3x15xf32>) {
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<3x15xf32>>
+ memref.store %arg2, %alloca[] : memref<vector<3x15xf32>>
+ %0 = vector.type_cast %alloca : memref<vector<3x15xf32>> to memref<3xvector<15xf32>>
+ scf.for %arg3 = %c0 to %c3 step %c1 {
+ %1 = affine.apply #map(%arg3)[%arg1]
+ %2 = memref.load %0[%arg3] : memref<3xvector<15xf32>>
+ vector.transfer_write %2, %arg0[%1, %arg1] {in_bounds = [true]} : vector<15xf32>, memref<?x?xf32>
+ }
+ return
+ }
+}
+
+// -----
+module {
+ func.func @transfer_read_simple(%arg0: memref<2x2xf32>) -> vector<2x2xf32> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<2x2xf32>>
+ %0 = vector.type_cast %alloca : memref<vector<2x2xf32>> to memref<2xvector<2xf32>>
+ scf.for %arg1 = %c0 to %c2 step %c1 {
+ %2 = vector.transfer_read %arg0[%arg1, %c0], %cst {in_bounds = [true]} : memref<2x2xf32>, vector<2xf32>
+ memref.store %2, %0[%arg1] : memref<2xvector<2xf32>>
+ }
+ %1 = memref.load %alloca[] : memref<vector<2x2xf32>>
+ return %1 : vector<2x2xf32>
+ }
+ func.func @transfer_read_minor_identity(%arg0: memref<?x?x?x?xf32>) -> vector<3x3xf32> {
+ %cst = arith.constant dense<0.000000e+00> : vector<3xf32>
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<3x3xf32>>
+ %0 = vector.type_cast %alloca : memref<vector<3x3xf32>> to memref<3xvector<3xf32>>
+ scf.for %arg1 = %c0 to %c3 step %c1 {
+ %dim = memref.dim %arg0, %c2 : memref<?x?x?x?xf32>
+ %2 = arith.cmpi sgt, %dim, %arg1 : index
+ scf.if %2 {
+ %3 = vector.transfer_read %arg0[%c0, %c0, %arg1, %c0], %cst_0 : memref<?x?x?x?xf32>, vector<3xf32>
+ memref.store %3, %0[%arg1] : memref<3xvector<3xf32>>
+ } else {
+ memref.store %cst, %0[%arg1] : memref<3xvector<3xf32>>
+ }
+ }
+ %1 = memref.load %alloca[] : memref<vector<3x3xf32>>
+ return %1 : vector<3x3xf32>
+ }
+ func.func @transfer_write_minor_identity(%arg0: vector<3x3xf32>, %arg1: memref<?x?x?x?xf32>) {
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ %c3 = arith.constant 3 : index
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<3x3xf32>>
+ memref.store %arg0, %alloca[] : memref<vector<3x3xf32>>
+ %0 = vector.type_cast %alloca : memref<vector<3x3xf32>> to memref<3xvector<3xf32>>
+ scf.for %arg2 = %c0 to %c3 step %c1 {
+ %dim = memref.dim %arg1, %c2 : memref<?x?x?x?xf32>
+ %1 = arith.cmpi sgt, %dim, %arg2 : index
+ scf.if %1 {
+ %2 = memref.load %0[%arg2] : memref<3xvector<3xf32>>
+ vector.transfer_write %2, %arg1[%c0, %c0, %arg2, %c0] : vector<3xf32>, memref<?x?x?x?xf32>
+ } else {
+ }
+ }
+ return
+ }
+}
+
+// -----
+#map = affine_map<(d0, d1) -> (d0 + d1 * 8)>
+module {
+ func.func @transfer_read_strided(%arg0: memref<8x4xf32, #map>) -> vector<4xf32> {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+ %0 = scf.for %arg1 = %c0 to %c4 step %c1 iter_args(%arg2 = %cst) -> (vector<4xf32>) {
+ %1 = memref.load %arg0[%c0, %arg1] : memref<8x4xf32, #map>
+ %2 = vector.insert %1, %arg2 [%arg1] : f32 into vector<4xf32>
+ scf.yield %2 : vector<4xf32>
+ }
+ return %0 : vector<4xf32>
+ }
+ func.func @transfer_write_strided(%arg0: vector<4xf32>, %arg1: memref<8x4xf32, #map>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ scf.for %arg2 = %c0 to %c4 step %c1 {
+ %0 = vector.extract %arg0[%arg2] : f32 from vector<4xf32>
+ memref.store %0, %arg1[%c0, %arg2] : memref<8x4xf32, #map>
+ }
+ return
+ }
+}
+
+// -----
+module {
+ func.func private @fake_side_effecting_fun(vector<2x2xf32>)
+ func.func @transfer_read_within_async_execute(%arg0: memref<2x2xf32>) -> !async.token {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %token = async.execute {
+ %alloca = memref.alloca() : memref<vector<2x2xf32>>
+ %0 = vector.type_cast %alloca : memref<vector<2x2xf32>> to memref<2xvector<2xf32>>
+ scf.for %arg1 = %c0 to %c2 step %c1 {
+ %2 = vector.transfer_read %arg0[%arg1, %c0], %cst {in_bounds = [true]} : memref<2x2xf32>, vector<2xf32>
+ memref.store %2, %0[%arg1] : memref<2xvector<2xf32>>
+ }
+ %1 = memref.load %alloca[] : memref<vector<2x2xf32>>
+ func.call @fake_side_effecting_fun(%1) : (vector<2x2xf32>) -> ()
+ async.yield
+ }
+ return %token : !async.token
+ }
+}
+
+// -----
+module {
+ func.func @transfer_read_with_tensor(%arg0: tensor<f32>) -> vector<1xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.transfer_read %arg0[], %cst : tensor<f32>, vector<f32>
+ %1 = vector.broadcast %0 : vector<f32> to vector<1xf32>
+ return %1 : vector<1xf32>
+ }
+}
+
+// -----
+module {
+ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>, %arg1: f32) {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ %0 = llvm.mlir.undef : vector<[16]xf32>
+ %1 = llvm.mlir.undef : vector<[16]xi32>
+ %2 = llvm.mlir.constant(0 : i32) : i32
+ %c0 = arith.constant 0 : index
+ %dim = memref.dim %arg0, %c0 : memref<?xf32, strided<[?], offset: ?>>
+ %3 = llvm.intr.stepvector : vector<[16]xi32>
+ %4 = arith.index_cast %dim : index to i32
+ %5 = llvm.insertelement %4, %1[%2 : i32] : vector<[16]xi32>
+ %6 = llvm.shufflevector %5, %1 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<[16]xi32>
+ %7 = arith.cmpi slt, %3, %6 : vector<[16]xi32>
+ %8 = llvm.insertelement %arg1, %0[%2 : i32] : vector<[16]xf32>
+ %9 = llvm.shufflevector %8, %0 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<[16]xf32>
+ %vscale = vector.vscale
+ %c16_vscale = arith.muli %vscale, %c16 : index
+ scf.for %arg2 = %c0 to %c16_vscale step %c1 {
+ %10 = vector.extract %7[%arg2] : i1 from vector<[16]xi1>
+ scf.if %10 {
+ %11 = vector.extract %9[%arg2] : f32 from vector<[16]xf32>
+ memref.store %11, %arg0[%arg2] : memref<?xf32, strided<[?], offset: ?>>
+ } else {
+ }
+ }
+ return
+ }
+}
+
+// -----
+module {
+ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.print punctuation <open>
+ scf.for %arg1 = %c0 to %c1 step %c1 {
+ %0 = vector.extract %arg0[] : f32 from vector<f32>
+ vector.print %0 : f32 punctuation <no_punctuation>
+ %1 = arith.cmpi ult, %arg1, %c0 : index
+ scf.if %1 {
+ vector.print punctuation <comma>
+ }
+ }
+ vector.print punctuation <close>
+ vector.print
+ return
+ }
+}
+
+// -----
+module {
+ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %0 = vector.shape_cast %arg0 : vector<2x2xf32> to vector<4xf32>
+ vector.print punctuation <open>
+ scf.for %arg1 = %c0 to %c2 step %c1 {
+ vector.print punctuation <open>
+ scf.for %arg2 = %c0 to %c2 step %c1 {
+ %2 = arith.muli %arg1, %c2 : index
+ %3 = arith.addi %arg2, %2 : index
+ %4 = vector.extract %0[%3] : f32 from vector<4xf32>
+ vector.print %4 : f32 punctuation <no_punctuation>
+ %5 = arith.cmpi ult, %arg2, %c1 : index
+ scf.if %5 {
+ vector.print punctuation <comma>
+ }
+ }
+ vector.print punctuation <close>
+ %1 = arith.cmpi ult, %arg1, %c1 : index
+ scf.if %1 {
+ vector.print punctuation <comma>
+ }
+ }
+ vector.print punctuation <close>
+ vector.print
+ return
+ }
+}
+
+// -----
+module {
+ func.func @vector_print_scalable_vector(%arg0: vector<[4]xi32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ %vscale = vector.vscale
+ %c4_vscale = arith.muli %vscale, %c4 : index
+ %0 = arith.subi %c4_vscale, %c1 : index
+ vector.print punctuation <open>
+ scf.for %arg1 = %c0 to %c4_vscale step %c1 {
+ %1 = vector.extract %arg0[%arg1] : i32 from vector<[4]xi32>
+ vector.print %1 : i32 punctuation <no_punctuation>
+ %2 = arith.cmpi ult, %arg1, %0 : index
+ scf.if %2 {
+ vector.print punctuation <comma>
+ }
+ }
+ vector.print punctuation <close>
+ vector.print
+ return
+ }
+}
+
+// -----
+module {
+ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[4]xf32> {
+ %c3 = arith.constant 3 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<3x[4]xf32>>
+ %alloca_0 = memref.alloca() : memref<vector<3x[4]xi1>>
+ %dim = memref.dim %arg0, %c1 : memref<3x?xf32>
+ %0 = vector.create_mask %c1, %dim : vector<3x[4]xi1>
+ memref.store %0, %alloca_0[] : memref<vector<3x[4]xi1>>
+ %1 = vector.type_cast %alloca : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
+ %2 = vector.type_cast %alloca_0 : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
+ scf.for %arg1 = %c0 to %c3 step %c1 {
+ %4 = memref.load %2[%arg1] : memref<3xvector<[4]xi1>>
+ %5 = vector.transfer_read %arg0[%arg1, %c0], %cst, %4 {in_bounds = [true]} : memref<3x?xf32>, vector<[4]xf32>
+ memref.store %5, %1[%arg1] : memref<3xvector<[4]xf32>>
+ }
+ %3 = memref.load %alloca[] : memref<vector<3x[4]xf32>>
+ return %3 : vector<3x[4]xf32>
+ }
+}
+
+// -----
+module {
+ func.func @transfer_write_array_of_scalable(%arg0: vector<3x[4]xf32>, %arg1: memref<3x?xf32>) {
+ %c3 = arith.constant 3 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<3x[4]xf32>>
+ %alloca_0 = memref.alloca() : memref<vector<3x[4]xi1>>
+ %dim = memref.dim %arg1, %c1 : memref<3x?xf32>
+ %0 = vector.create_mask %c1, %dim : vector<3x[4]xi1>
+ memref.store %0, %alloca_0[] : memref<vector<3x[4]xi1>>
+ memref.store %arg0, %alloca[] : memref<vector<3x[4]xf32>>
+ %1 = vector.type_cast %alloca : memref<vector<3x[4]xf32>> to memref<3xvector<[4]xf32>>
+ %2 = vector.type_cast %alloca_0 : memref<vector<3x[4]xi1>> to memref<3xvector<[4]xi1>>
+ scf.for %arg2 = %c0 to %c3 step %c1 {
+ %3 = memref.load %1[%arg2] : memref<3xvector<[4]xf32>>
+ %4 = memref.load %2[%arg2] : memref<3xvector<[4]xi1>>
+ vector.transfer_write %3, %arg1[%arg2, %c0], %4 {in_bounds = [true]} : vector<[4]xf32>, memref<3x?xf32>
+ }
+ return
+ }
+}
+
+// -----
+module {
+ func.func @cannot_lower_transfer_write_with_leading_scalable(%arg0: vector<[4]x4xf32>, %arg1: memref<?x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %dim = memref.dim %arg1, %c0 : memref<?x4xf32>
+ %0 = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
+ vector.transfer_write %arg0, %arg1[%c0, %c0], %0 {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x4xf32>
+ return
+ }
+}
+
+// -----
+module {
+ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf32>) -> vector<[4]x4xf32> {
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dim = memref.dim %arg0, %c0 : memref<?x4xf32>
+ %0 = vector.create_mask %dim, %c4 : vector<[4]x4xi1>
+ %1 = vector.transfer_read %arg0[%c0, %c0], %cst, %0 {in_bounds = [true, true]} : memref<?x4xf32>, vector<[4]x4xf32>
+ return %1 : vector<[4]x4xf32>
+ }
+ func.func @does_not_crash_on_unpack_one_dim(%arg0: memref<1x1x1x1xi32>, %arg1: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+ %c1 = arith.constant 1 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<1x1x1x1xi32>>
+ %alloca_0 = memref.alloca() : memref<vector<1x1xi1>>
+ memref.store %arg1, %alloca_0[] : memref<vector<1x1xi1>>
+ %0 = vector.type_cast %alloca : memref<vector<1x1x1x1xi32>> to memref<1xvector<1x1x1xi32>>
+ %1 = vector.type_cast %alloca_0 : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+ scf.for %arg2 = %c0 to %c1 step %c1 {
+ %3 = vector.type_cast %0 : memref<1xvector<1x1x1xi32>> to memref<1x1xvector<1x1xi32>>
+ scf.for %arg3 = %c0 to %c1 step %c1 {
+ %4 = vector.type_cast %3 : memref<1x1xvector<1x1xi32>> to memref<1x1x1xvector<1xi32>>
+ scf.for %arg4 = %c0 to %c1 step %c1 {
+ %5 = memref.load %1[%arg2] : memref<1xvector<1xi1>>
+ %6 = vector.transfer_read %arg0[%arg2, %c0, %c0, %c0], %c0_i32, %5 {in_bounds = [true]} : memref<1x1x1x1xi32>, vector<1xi32>
+ memref.store %6, %4[%arg2, %arg3, %arg4] : memref<1x1x1xvector<1xi32>>
+ }
+ }
+ }
+ %2 = memref.load %alloca[] : memref<vector<1x1x1x1xi32>>
+ return %2 : vector<1x1x1x1xi32>
+ }
+ func.func @add_arrays_of_scalable_vectors(%arg0: memref<1x2x?xf32>, %arg1: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> {
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %c2 = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<1x2x[4]xf32>>
+ %alloca_0 = memref.alloca() : memref<vector<1x2x[4]xi1>>
+ %dim = memref.dim %arg0, %c2 : memref<1x2x?xf32>
+ %0 = vector.create_mask %c2, %c2, %dim : vector<1x2x[4]xi1>
+ memref.store %0, %alloca_0[] : memref<vector<1x2x[4]xi1>>
+ %1 = vector.type_cast %alloca : memref<vector<1x2x[4]xf32>> to memref<1xvector<2x[4]xf32>>
+ %2 = vector.type_cast %alloca_0 : memref<vector<1x2x[4]xi1>> to memref<1xvector<2x[4]xi1>>
+ scf.for %arg2 = %c0 to %c1 step %c1 {
+ %4 = vector.type_cast %1 : memref<1xvector<2x[4]xf32>> to memref<1x2xvector<[4]xf32>>
+ %5 = vector.type_cast %2 : memref<1xvector<2x[4]xi1>> to memref<1x2xvector<[4]xi1>>
+ scf.for %arg3 = %c0 to %c2 step %c1 {
+ %6 = memref.load %5[%arg2, %arg3] : memref<1x2xvector<[4]xi1>>
+ %7 = vector.transfer_read %arg0[%arg2, %arg3, %c0], %cst, %6 {in_bounds = [true]} : memref<1x2x?xf32>, vector<[4]xf32>
+ memref.store %7, %4[%arg2, %arg3] : memref<1x2xvector<[4]xf32>>
+ }
+ }
+ %3 = memref.load %alloca[] : memref<vector<1x2x[4]xf32>>
+ return %3 : vector<1x2x[4]xf32>
+ }
+ func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%arg0: vector<[4]x[4]xf32>, %arg1: memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg0, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ return
+ }
+ func.func @unroll_transfer_write_target_rank_zero(%arg0: vector<2xi32>) {
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() : memref<4xi32>
+ vector.transfer_write %arg0, %alloc[%c0] {in_bounds = [true]} : vector<2xi32>, memref<4xi32>
+ return
+ }
+}
+
+// -----
+module {
+ func.func @scalable_transpose_store_unmasked(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ vector.transfer_write %0, %arg1[%arg2, %arg3] {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+ return
+ }
+}
+
+// -----
+module {
+ func.func @scalable_transpose_store_dynamic_mask(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5: index) {
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ %1 = vector.create_mask %arg4, %arg5 : vector<[4]x4xi1>
+ vector.transfer_write %0, %arg1[%arg2, %arg3], %1 {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+ return
+ }
+}
+
+// -----
+module {
+ func.func @scalable_transpose_store_constant_mask(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ %1 = vector.constant_mask [4, 3] : vector<[4]x4xi1>
+ vector.transfer_write %0, %arg1[%arg2, %arg3], %1 {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+ return
+ }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+ func.func @negative_scalable_transpose_store_0(%arg0: vector<[4]x4xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<4x[4]xf32>>
+ %0 = vector.transpose %arg0, [1, 0] : vector<[4]x4xf32> to vector<4x[4]xf32>
+ memref.store %0, %alloca[] : memref<vector<4x[4]xf32>>
+ %1 = vector.type_cast %alloca : memref<vector<4x[4]xf32>> to memref<4xvector<[4]xf32>>
+ scf.for %arg4 = %c0 to %c4 step %c1 {
+ %2 = affine.apply #map(%arg4)[%arg2]
+ %3 = memref.load %1[%arg4] : memref<4xvector<[4]xf32>>
+ vector.transfer_write %3, %arg1[%2, %arg3] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+ }
+ return
+ }
+}
+
+// -----
+#map = affine_map<(d0)[s0] -> (d0 + s0)>
+module {
+ func.func @negative_scalable_transpose_store_1(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %alloca = memref.alloca() : memref<vector<4x[4]xf32>>
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ %1 = vector.transpose %0, [1, 0] : vector<[4]x4xf32> to vector<4x[4]xf32>
+ memref.store %1, %alloca[] : memref<vector<4x[4]xf32>>
+ %2 = vector.type_cast %alloca : memref<vector<4x[4]xf32>> to memref<4xvector<[4]xf32>>
+ scf.for %arg4 = %c0 to %c4 step %c1 {
+ %3 = affine.apply #map(%arg4)[%arg2]
+ %4 = memref.load %2[%arg4] : memref<4xvector<[4]xf32>>
+ vector.transfer_write %4, %arg1[%3, %arg3] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+ }
+ return
+ }
+}
+
+// -----
+module {
+ func.func @negative_scalable_transpose_store_2(%arg0: vector<4x[4]xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+ %0 = vector.transpose %arg0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+ vector.transfer_write %0, %arg1[%arg2, %arg3] {in_bounds = [false, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+ return
+ }
+}
+
+// -----
+module {
+ func.func @negative_scalable_transpose_store_3(%arg0: vector<[4]x4xf32>, %arg1: memref<?x?xf32>, %arg2: index, %arg3: index) {
+ vector.transfer_write %arg0, %arg1[%arg2, %arg3] {in_bounds = [true, true]} : vector<[4]x4xf32>, memref<?x?xf32>
+ return
+ }
+}
+
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index 33177736eb5fe..1ed82954398f0 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -558,10 +558,9 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
-// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
+// CHECK: %[[EL:.*]] = vector.extract %[[VEC]][] : f32 from vector<f32>
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
// CHECK: scf.if %[[IS_NOT_LAST]] {
More information about the Mlir-commits
mailing list