[Mlir-commits] [mlir] 878d359 - [mlir][vector] Avoid setting padding by default to `0` in `vector.transfer_read` prefer `ub.poison` (#146088)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 30 12:20:46 PDT 2025
Author: Fabian Mora
Date: 2025-06-30T15:20:42-04:00
New Revision: 878d3594ed74cd1153bc5cba2cb7a449702cdf34
URL: https://github.com/llvm/llvm-project/commit/878d3594ed74cd1153bc5cba2cb7a449702cdf34
DIFF: https://github.com/llvm/llvm-project/commit/878d3594ed74cd1153bc5cba2cb7a449702cdf34.diff
LOG: [mlir][vector] Avoid setting padding by default to `0` in `vector.transfer_read` prefer `ub.poison` (#146088)
Context:
`vector.transfer_read` always requires a padding value. Most of its
builders take no `padding` value and assume the safe value of `0`.
However, this should be a conscious choice by the API user, as it makes
it easy to introduce bugs.
For example, I found several occasions while making this patch that the
padding value was not getting propagated (`vector.transfer_read` was
transformed into another `vector.transfer_read`). These bugs, were
always caused because of constructors that don't require specifying
padding.
Additionally, using `ub.poison` as a possible default value is better,
as it indicates the user "doesn't care" about the actual padding value,
forcing users to specify the actual padding semantics they want.
With that in mind, this patch changes the builders in
`vector.transfer_read` to always having a `std::optional<Value> padding`
argument. This argument is never optional, but for convenience users can
pass `std::nullopt`, padding the transfer read with `ub.poison`.
---------
Signed-off-by: Fabian Mora <fabian.mora-cordero at amd.com>
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/Arith.h
mlir/include/mlir/Dialect/Vector/IR/Vector.td
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
index 0bee876ac9bfa..c2de3c9021b0b 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h
+++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h
@@ -154,6 +154,11 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs);
arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);
+
+/// Creates an `arith.constant` operation with a zero value of type `type`. This
+/// method asserts if `type` is invalid for representing zero with
+/// `arith.constant`.
+Value getZeroConstant(OpBuilder &builder, Location loc, Type type);
} // namespace arith
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
index 1922cc63ef353..5125ae7c13717 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td
@@ -21,7 +21,10 @@ def Vector_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
- let dependentDialects = ["arith::ArithDialect"];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "ub::UBDialect"
+ ];
}
// Base class for Vector dialect ops.
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e6b85de5a522a..dfb2756e57bea 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1543,30 +1543,29 @@ def Vector_TransferReadOp :
}];
let builders = [
- /// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
+ /// 1. Builder that sets padding to `padding` or poison if not provided and
+ /// an empty mask (variant with attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
+ "std::optional<Value>":$padding,
"AffineMapAttr":$permutationMapAttr,
"ArrayAttr":$inBoundsAttr)>,
- /// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
+ /// 2. Builder that sets padding to `padding` or poison if not provided and
+ /// an empty mask (variant without attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
+ "std::optional<Value>":$padding,
"AffineMap":$permutationMap,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
- /// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
+ /// 3. Builder that sets padding to `padding` or poison if not provided and
+ /// permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
- "Value":$padding,
- CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
- /// 4. Builder that sets padding to zero and permutation map to
- /// 'getMinorIdentityMap'.
- OpBuilder<(ins "VectorType":$vectorType,
- "Value":$source,
- "ValueRange":$indices,
- CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
+ "std::optional<Value>":$padding,
+ CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index f6f192a6d964a..7fae260767e0a 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -1257,7 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = state.builder.create<vector::TransferReadOp>(
- loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
+ loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
+ /*padding=*/std::nullopt, permutationMap);
// Register replacement for future uses in the scope.
state.registerOpVectorReplacement(loadOp, transfer);
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5194f2b58669a..ec96a35ae82e7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -292,6 +292,16 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
return false;
}
+Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
+ Type type) {
+ // TODO: Incorporate this check to `FloatAttr::get*`.
+ assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
+ "type doesn't have a zero representation");
+ TypedAttr zeroAttr = builder.getZeroAttr(type);
+ assert(zeroAttr && "unsupported type for zero attribute");
+ return builder.create<arith::ConstantOp>(loc, zeroAttr);
+}
+
//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d52ff4d4257c7..3dbb93b8a0669 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -426,6 +426,7 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
// Create the new `transfer_read`.
auto newReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), collapsedVT, collapsedMem, indices,
+ readOp.getPadding(),
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
// Cast back to the original vector type.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 830ae5414c6bd..b467114c72f7d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1183,6 +1183,10 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto srcRank = extractOp.getTensor().getType().getRank();
SmallVector<bool> inBounds(dstRank, true);
+ // Get the value to pad transfer reads with 0.
+ Value padding =
+ arith::getZeroConstant(rewriter, loc, resultType.getElementType());
+
// 2a. Handle scalar broadcast access.
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
MLIRContext *ctx = rewriter.getContext();
@@ -1190,7 +1194,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, resultType, extractOp.getTensor(), transferReadIdxs,
+ loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
permutationMap, inBounds);
// Mask this broadcasting xfer_read here rather than relying on the generic
@@ -1227,8 +1231,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
}
auto transferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
- inBounds);
+ loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
+ permutationMap, inBounds);
LDBG("Vectorised as contiguous load: " << extractOp);
return VectorizationHookResult{VectorizationHookStatus::NewOp,
@@ -1384,7 +1388,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
/// performed to the maximal common vector size implied by the `linalgOp`
/// iteration space. This eager broadcasting is introduced in the
/// permutation_map of the vector.transfer_read operations. The eager
-/// broadcasting makes it trivial to detrmine where broadcast, transposes and
+/// broadcasting makes it trivial to determine where broadcast, transposes and
/// reductions should occur, without any bookkeeping. The tradeoff is that, in
/// the absence of good canonicalizations, the amount of work increases.
/// This is not deemed a problem as we expect canonicalizations and foldings to
@@ -1439,7 +1443,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
Operation *read = rewriter.create<vector::TransferReadOp>(
- loc, readType, opOperand->get(), indices, readMap);
+ loc, readType, opOperand->get(), indices,
+ /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), readMap);
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
Value readValue = read->getResult(0);
@@ -2641,6 +2646,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
Value readValue = rewriter.create<vector::TransferReadOp>(
loc, readType, copyOp.getSource(), indices,
+ /*padding=*/arith::getZeroConstant(rewriter, loc, srcElementType),
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue =
@@ -3487,15 +3493,18 @@ struct Conv1DGenerator
SmallVector<Value> resPadding(resShape.size(), zero);
// Read the whole lhs, rhs and res in one shot (with zero padding).
- Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
- lhsPadding);
+ Value lhs = rewriter.create<vector::TransferReadOp>(
+ loc, lhsType, lhsShaped, lhsPadding,
+ /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
// This is needed only for Conv.
Value rhs = nullptr;
if (oper == ConvOperationKind::Conv)
- rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
- rhsPadding);
- Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
- resPadding);
+ rhs = rewriter.create<vector::TransferReadOp>(
+ loc, rhsType, rhsShaped, rhsPadding,
+ /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
+ Value res = rewriter.create<vector::TransferReadOp>(
+ loc, resType, resShaped, resPadding,
+ /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
// The base vectorization case for channeled convolution is input:
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
@@ -3742,19 +3751,22 @@ struct Conv1DGenerator
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
- loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
+ /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
auto maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());
// Read rhs slice of size {kw, c} @ [0, 0].
- Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
- ValueRange{zero, zero});
+ Value rhs = rewriter.create<vector::TransferReadOp>(
+ loc, rhsType, rhsShaped, ValueRange{zero, zero},
+ /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
auto maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
- loc, resType, resShaped, ValueRange{zero, zero, zero});
+ loc, resType, resShaped, ValueRange{zero, zero, zero},
+ /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
auto maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a11dbe2589205..fc7ed7e479b49 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4261,33 +4261,39 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
- ValueRange indices, AffineMapAttr permutationMapAttr,
+ ValueRange indices, std::optional<Value> padding,
+ AffineMapAttr permutationMapAttr,
/*optional*/ ArrayAttr inBoundsAttr) {
+
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
- Value padding = builder.create<arith::ConstantOp>(
- result.location, elemType, builder.getZeroAttr(elemType));
+ if (!padding)
+ padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
- padding, /*mask=*/Value(), inBoundsAttr);
+ *padding, /*mask=*/Value(), inBoundsAttr);
}
/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
- ValueRange indices, AffineMap permutationMap,
+ ValueRange indices, std::optional<Value> padding,
+ AffineMap permutationMap,
std::optional<ArrayRef<bool>> inBounds) {
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
- build(builder, result, vectorType, source, indices, permutationMapAttr,
- inBoundsAttr);
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+ if (!padding)
+ padding = builder.create<ub::PoisonOp>(result.location, elemType);
+ build(builder, result, vectorType, source, indices, *padding,
+ permutationMapAttr, inBoundsAttr);
}
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
- ValueRange indices, Value padding,
+ ValueRange indices, std::optional<Value> padding,
std::optional<ArrayRef<bool>> inBounds) {
AffineMap permutationMap = getTransferMinorIdentityMap(
llvm::cast<ShapedType>(source.getType()), vectorType);
@@ -4296,23 +4302,14 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
+ Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
+ if (!padding)
+ padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
- padding,
+ *padding,
/*mask=*/Value(), inBoundsAttr);
}
-/// 4. Builder that sets padding to zero and permutation map to
-/// 'getMinorIdentityMap'.
-void TransferReadOp::build(OpBuilder &builder, OperationState &result,
- VectorType vectorType, Value source,
- ValueRange indices,
- std::optional<ArrayRef<bool>> inBounds) {
- Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
- Value padding = builder.create<arith::ConstantOp>(
- result.location, elemType, builder.getZeroAttr(elemType));
- build(builder, result, vectorType, source, indices, padding, inBounds);
-}
-
template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index af90ed8f5deaf..fb99e22c77ea0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper {
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
- loc, cast<VectorType>(type), buffer, indices,
+ loc, cast<VectorType>(type), buffer, indices, /*padding=*/std::nullopt,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 785a8aaf3f0a9..efdae93e730bd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -660,7 +660,8 @@ class FlattenContiguousRowMajorTransferReadPattern
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
- loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
+ loc, flatVectorType, collapsedSource, collapsedIndices,
+ transferReadOp.getPadding(), collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
// 4. Replace the old transfer_read with the new one reading from the
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
index 81b04ccceaf27..72ced5b53879b 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
@@ -21,7 +21,7 @@ func.func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for {{.*}} step 128
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
-// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector
%a0 = affine.load %A[%c0, %c0] : memref<?x?xf32>
@@ -47,7 +47,7 @@ func.func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
%P = memref.dim %B, %c2 : memref<?x?x?xf32>
// CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i3 = 0 to %M { // vectorized
%a3 = affine.load %A[%c0, %i3] : memref<?x?xf32>
@@ -76,7 +76,7 @@ func.func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK-NEXT: for [[IV9:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %[[APP9_0:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[APP9_1:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%[[APP9_0]], %[[APP9_1]]], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i8 = 0 to %M { // vectorized
affine.for %i9 = 0 to %N {
@@ -280,7 +280,7 @@ func.func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK:for [[IV4:%[0-9a-zA-Z_]+]] = 0 to [[ARG_M]] step 128 {
// CHECK-NEXT: for [[IV5:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
-// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{[a-zA-Z0-9_]*}} : memref<?x?xf32>, vector<128xf32>
affine.for %i4 = 0 to %M { // vectorized
affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1
@@ -424,7 +424,7 @@ func.func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
-// CHECK: %{{.*}} = arith.constant 0.0{{.*}}: f32
+// CHECK: %{{.*}} = ub.poison : f32
// CHECK: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %{{.*}} in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
@@ -458,7 +458,7 @@ func.func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
-// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
+// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
index 15a7133cf0f65..7d4d111c09799 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir
@@ -11,7 +11,7 @@ func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf3
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]])
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
@@ -42,7 +42,7 @@ func.func @vec_affine_apply_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48x
// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 {
// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 {
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID2]](%[[ARG4]])
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
@@ -140,7 +140,7 @@ func.func @affine_map_with_expr_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24
// CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID3]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]])
// CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]])
-// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32>
// CHECK-NEXT: vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32>
// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
index 5f923cdafb956..49bd2eddbdedd 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
@@ -11,8 +11,8 @@
// CHECK-LABEL: @base_case
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
-// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8>
@@ -36,8 +36,8 @@ func.func @base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<
// CHECK-LABEL: @with_3d_vector
// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]:
-// CHECK: %[[PAD:.+]] = arith.constant 0 : i8
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8
+// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<?x?x2x8xi8> into memref<?x?xi8>
// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %[[PAD]] {in_bounds = [true]}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 0f04d3b79b535..d18edd0ac5563 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -85,8 +85,8 @@ func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
-// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}>
@@ -116,8 +116,8 @@ func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
// CHECK-SAME: -> vector<1x1x4x3x2xi8>
-// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
@@ -149,8 +149,8 @@ func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
-// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
@@ -182,8 +182,8 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
// CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32>
-// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
-// CHECK: %[[C_0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C_0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1x43x24xi32>
@@ -241,8 +241,8 @@ func.func @transfer_read_leading_dynamic_dims(
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
-// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
@@ -304,8 +304,8 @@ func.func @transfer_read_dynamic_dim_to_flatten(
// CHECK-SAME: %[[IDX_1:arg0]]
// CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]]
-// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?x24xi32>
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 7cfbcdf101d11..1161dbd4b2166 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -1132,8 +1132,8 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %
// CHECK-SCF-IF: gpu.barrier
// CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]]
- // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
- // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
+ // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32>
+ // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32>
// CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32>
return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32>
}
More information about the Mlir-commits
mailing list