[Mlir-commits] [mlir] [mlir][vector] Avoid setting padding by default to `0` in `vector.transfer_read` prefer `ub.poisson` (PR #146088)
Fabian Mora
llvmlistbot at llvm.org
Fri Jun 27 07:32:24 PDT 2025
https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/146088
Context:
`vector.transfer_read` always requires a padding value. Most of its builders take no `padding` value and assume the safe value of `0`. However, this should be a conscious choice by the API user, as it makes it easy to introduce bugs.
For example, I found several occasions while making this patch that the padding value was not getting propagated (`vector.transfer_read` was transformed into another `vector.transfer_read`). These bugs, were always caused because of constructors that don't require specifying padding.
Additionally, IMO using poisson as a possible value is better, as it indicates the user "doesn't care" about the actual padding value.
With that in mind, this patch changes the builders in `vector.transfer_read` to always having a `std::optional<Value> padding` argument. This argument is never optional, but for convenience one can pass `std::nullopt`, padding the transfer read with `ub.poisson`.
>From ca3e2b2187e6c8be0ebee5e174d4227e509dbff9 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fabian.mora-cordero at amd.com>
Date: Fri, 27 Jun 2025 14:16:45 +0000
Subject: [PATCH] [mlir][vector] Avoid setting padding by default in vector
transfer read, prefer ub.poisson
Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
---
mlir/include/mlir/Dialect/Arith/IR/Arith.h | 3 ++
mlir/include/mlir/Dialect/Vector/IR/Vector.td | 5 ++-
.../mlir/Dialect/Vector/IR/VectorOps.td | 18 ++++-----
.../Affine/Transforms/SuperVectorize.cpp | 3 +-
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 5 +++
.../Transforms/LegalizeVectorStorage.cpp | 1 +
.../Linalg/Transforms/Vectorization.cpp | 38 +++++++++++-------
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 39 +++++++++----------
.../Vector/Transforms/VectorDistribute.cpp | 2 +-
.../Transforms/VectorTransferOpTransforms.cpp | 3 +-
.../Affine/SuperVectorize/vectorize_1d.mlir | 12 +++---
.../vectorize_affine_apply.mlir | 6 +--
.../ArmSVE/legalize-transfer-read.mlir | 8 ++--
.../Vector/vector-transfer-flatten.mlir | 24 ++++++------
.../Vector/vector-warp-distribute.mlir | 4 +-
15 files changed, 95 insertions(+), 76 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 0bee876ac9bfa..84d1a2535e863 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -154,6 +154,9 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs);
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
+
+/// Creates an `arith.constant` operation with a zero value of type `type`.
+Value getZeroConstant(OpBuilder &builder, Location loc, Type type);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
index 1922cc63ef353..5125ae7c13717 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
@@ -21,7 +21,10 @@ def Vector_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
- let dependentDialects = ["arith::ArithDialect"];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "ub::UBDialect"
+ ];
}
// Base class for Vector dialect ops.
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e6b85de5a522a..c1fcc5299416e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1543,30 +1543,28 @@ def Vector_TransferReadOp :
}];
let builders = [
- /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+ /// 1. Builder that sets padding to `padding` or poisson if not provided and
+ /// an empty mask (variant with attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
+ "std::optional<Value>":$padding,
"AffineMapAttr":$permutationMapAttr,
"ArrayAttr":$inBoundsAttr)>,
- /// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
+ /// 2. Builder that sets padding to `padding` or poisson if not provided and
+ /// an empty mask (variant without attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
+ "std::optional<Value>":$padding,
"AffineMap":$permutationMap,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
- "Value":$padding,
- CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
- /// 4. Builder that sets padding to zero and permutation map to
- /// 'getMinorIdentityMap'.
- OpBuilder<(ins "VectorType":$vectorType,
- "Value":$source,
- "ValueRange":$indices,
- CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
+ "std::optional<Value>":$padding,
+ CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index f6f192a6d964a..6e8f7126df325 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -1257,7 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = state.builder.create<vector::TransferReadOp>(
- loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
+ loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, std::nullopt,
+ permutationMap);
// Register replacement for future uses in the scope.
state.registerOpVectorReplacement(loadOp, transfer);
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5194f2b58669a..c9fe579a0b8a9 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -292,6 +292,11 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
return false;
}
+Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
+ Type type) {
+ return builder.create<arith::ConstantOp>(loc, builder.getZeroAttr(type));
+}
+
//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d52ff4d4257c7..3dbb93b8a0669 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -426,6 +426,7 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
// Create the new `transfer_read`.
auto newReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), collapsedVT, collapsedMem, indices,
+ readOp.getPadding(),
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
// Cast back to the original vector type.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 830ae5414c6bd..444396aaeccfc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1191,6 +1191,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs,
+ arith::getZeroConstant(rewriter, loc, resultType.getElementType()),
permutationMap, inBounds);
// Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1227,8 +1228,9 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
}
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
- inBounds);
+ loc, resultType, extractOp.getTensor(), transferReadIdxs,
+ arith::getZeroConstant(rewriter, loc, resultType.getElementType()),
+ permutationMap, inBounds);
LDBG("Vectorised as contiguous load: " << extractOp);
return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1384,7 +1386,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
/// performed to the maximal common vector size implied by the `linalgOp`
/// iteration space. This eager broadcasting is introduced in the
/// permutation_map of the vector.transfer_read operations. The eager
-/// broadcasting makes it trivial to detrmine where broadcast, transposes and
+/// broadcasting makes it trivial to determine where broadcast, transposes and
/// reductions should occur, without any bookkeeping. The tradeoff is that, in
/// the absence of good canonicalizations, the amount of work increases.
/// This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1439,7 +1441,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
Operation *read = rewriter.create<vector::TransferReadOp>(
- loc, readType, opOperand->get(), indices, readMap);
+ loc, readType, opOperand->get(), indices,
+ arith::getZeroConstant(rewriter, loc, elemType), readMap);
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
Value readValue = read->getResult(0);
@@ -2641,6 +2644,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
Value readValue = rewriter.create<vector::TransferReadOp>(
loc, readType, copyOp.getSource(), indices,
+ arith::getZeroConstant(rewriter, loc, srcElementType),
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue =
@@ -3487,15 +3491,18 @@ struct Conv1DGenerator
SmallVector<Value> resPadding(resShape.size(), zero);
// Read the whole lhs, rhs and res in one shot (with zero padding).
- Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
- lhsPadding);
+ Value lhs = rewriter.create<vector::TransferReadOp>(
+ loc, lhsType, lhsShaped, lhsPadding,
+ arith::getZeroConstant(rewriter, loc, lhsEltType));
// This is needed only for Conv.
Value rhs = nullptr;
if (oper == ConvOperationKind::Conv)
- rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
- rhsPadding);
- Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
- resPadding);
+ rhs = rewriter.create<vector::TransferReadOp>(
+ loc, rhsType, rhsShaped, rhsPadding,
+ arith::getZeroConstant(rewriter, loc, rhsEltType));
+ Value res = rewriter.create<vector::TransferReadOp>(
+ loc, resType, resShaped, resPadding,
+ arith::getZeroConstant(rewriter, loc, resEltType));
// The base vectorization case for channeled convolution is input:
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3742,19 +3749,22 @@ struct Conv1DGenerator
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
- loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
+ arith::getZeroConstant(rewriter, loc, lhsEltType));
auto maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
// Read rhs slice of size {kw, c} @ [0, 0].
- Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
- ValueRange{zero, zero});
+ Value rhs = rewriter.create<vector::TransferReadOp>(
+ loc, rhsType, rhsShaped, ValueRange{zero, zero},
+ arith::getZeroConstant(rewriter, loc, rhsEltType));
auto maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
- loc, resType, resShaped, ValueRange{zero, zero, zero});
+ loc, resType, resShaped, ValueRange{zero, zero, zero},
+ arith::getZeroConstant(rewriter, loc, resEltType));
auto maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a11dbe2589205..fc7ed7e479b49 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4261,33 +4261,39 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
- ValueRange indices, AffineMapAttr permutationMapAttr,
+ ValueRange indices, std::optional<Value> padding,
+ AffineMapAttr permutationMapAttr,
/*optional*/ ArrayAttr inBoundsAttr) {
+
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
- Value padding = builder.create<arith::ConstantOp>(
- result.location, elemType, builder.getZeroAttr(elemType));
+ if (!padding)
+ padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
- padding, /*mask=*/Value(), inBoundsAttr);
+ *padding, /*mask=*/Value(), inBoundsAttr);
}
/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
- ValueRange indices, AffineMap permutationMap,
+ ValueRange indices, std::optional<Value> padding,
+ AffineMap permutationMap,
std::optional<ArrayRef<bool>> inBounds) {
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
- build(builder, result, vectorType, source, indices, permutationMapAttr,
- inBoundsAttr);
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+ if (!padding)
+ padding = builder.create<ub::PoisonOp>(result.location, elemType);
+ build(builder, result, vectorType, source, indices, *padding,
+ permutationMapAttr, inBoundsAttr);
}
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
- ValueRange indices, Value padding,
+ ValueRange indices, std::optional<Value> padding,
std::optional<ArrayRef<bool>> inBounds) {
AffineMap permutationMap = getTransferMinorIdentityMap(
llvm::cast<ShapedType>(source.getType()), vectorType);
@@ -4296,23 +4302,14 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+ if (!padding)
+ padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
- padding,
+ *padding,
/*mask=*/Value(), inBoundsAttr);
}
-/// 4. Builder that sets padding to zero and permutation map to
-/// 'getMinorIdentityMap'.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value source,
- ValueRange indices,
- std::optional<ArrayRef<bool>> inBounds) {
- Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
- Value padding = builder.create<arith::ConstantOp>(
- result.location, elemType, builder.getZeroAttr(elemType));
- build(builder, result, vectorType, source, indices, padding, inBounds);
-}
-
template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index af90ed8f5deaf..ba9f39c6393ce 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper {
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
- loc, cast<VectorType>(type), buffer, indices,
+ loc, cast<VectorType>(type), buffer, indices, std::nullopt,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 785a8aaf3f0a9..efdae93e730bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -660,7 +660,8 @@ class FlattenContiguousRowMajorTransferReadPattern
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
- loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
+ loc, flatVectorType, collapsedSource, collapsedIndices,
+ transferReadOp.getPadding(), collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
// 4. Replace the old transfer_read with the new one reading from the
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
index 81b04ccceaf27..72ced5b53879b 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
@@ -21,7 +21,7 @@ func.func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for {{.*}} step 128
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
-// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector
%a0 = affine.load %A[%c0, %c0] : memref<?x?xf32>
@@ -47,7 +47,7 @@ func.func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
%P = memref.dim %B, %c2 : memref<?x?x?xf32>
// CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i3 = 0 to %M { // vectorized
%a3 = affine.load %A[%c0, %i3] : memref<?x?xf32>
@@ -76,7 +76,7 @@ func.func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK-NEXT: for [[IV9:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %[[APP9_0:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[APP9_1:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%[[APP9_0]], %[[APP9_1]]], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i8 = 0 to %M { // vectorized
affine.for %i9 = 0 to %N {
@@ -280,7 +280,7 @@ func.func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK:for [[IV4:%[0-9a-zA-Z_]+]] = 0 to [[ARG_M]] step 128 {
// CHECK-NEXT: for [[IV5:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
-// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{[a-zA-Z0-9_]*}} : memref<?x?xf32>, vector<128xf32>
affine.for %i4 = 0 to %M { // vectorized
affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1
@@ -424,7 +424,7 @@ func.func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
-// CHECK: %{{.*}} = arith.constant 0.0{{.*}}: f32
+// CHECK: %{{.*}} = ub.poison : f32
// CHECK: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %{{.*}} in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
@@ -458,7 +458,7 @@ func.func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
-// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
index 15a7133cf0f65..7d4d111c09799 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
@@ -11,7 +11,7 @@ func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf3
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
@@ -42,7 +42,7 @@ func.func @vec_affine_apply_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48x
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID2]](%[[ARG4]])
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
@@ -140,7 +140,7 @@ func.func @affine_map_with_expr_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID3]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]])
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
index 5f923cdafb956..49bd2eddbdedd 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
@@ -11,8 +11,8 @@
// CHECK-LABEL: @base_case
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
-// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8>
@@ -36,8 +36,8 @@ func.func @base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<
// CHECK-LABEL: @with_3d_vector
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
-// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
+// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<?x?x2x8xi8> into memref<?x?xi8>
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %[[PAD]] {in_bounds = [true]}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 0f04d3b79b535..d18edd0ac5563 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -85,8 +85,8 @@ func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
-// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}>
@@ -116,8 +116,8 @@ func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
// CHECK-SAME: -> vector<1x1x4x3x2xi8>
-// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
@@ -149,8 +149,8 @@ func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
-// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
@@ -182,8 +182,8 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
// CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32>
-// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
-// CHECK: %[[C_0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C_0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1x43x24xi32>
@@ -241,8 +241,8 @@ func.func @transfer_read_leading_dynamic_dims(
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
-// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
@@ -304,8 +304,8 @@ func.func @transfer_read_dynamic_dim_to_flatten(
// CHECK-SAME: %[[IDX_1:arg0]]
// CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]]
-// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?x24xi32>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 7cfbcdf101d11..1161dbd4b2166 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1132,8 +1132,8 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
// CHECK-SCF-IF: gpu.barrier
// CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]]
- // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
- // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
+ // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
+ // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
// CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32>
return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32>
}
More information about the Mlir-commits
mailing list