[Mlir-commits] [mlir] c5e6d56 - [mlir][vector] Propagate alignment when emulating masked{load, stores}. (#155648)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 5 06:48:44 PDT 2025
Author: Erick Ochoa Lopez
Date: 2025-09-05T13:48:40Z
New Revision: c5e6d56f01732661b0b12159543356035d251d3e
URL: https://github.com/llvm/llvm-project/commit/c5e6d56f01732661b0b12159543356035d251d3e
DIFF: https://github.com/llvm/llvm-project/commit/c5e6d56f01732661b0b12159543356035d251d3e.diff
LOG: [mlir][vector] Propagate alignment when emulating masked{load,stores}. (#155648)
Propagate alignment from `vector.maskedload` and `vector.maskedstore` to
`memref.load` and `memref.store` during `VectorEmulateMaskedLoadStore`
pass.
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index cb3e8dc67a1ae..78f74eef7bee3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -64,6 +64,7 @@ struct VectorMaskedLoadOpConverter final
Value mask = maskedLoadOp.getMask();
Value base = maskedLoadOp.getBase();
Value iValue = maskedLoadOp.getPassThru();
+ std::optional<uint64_t> alignment = maskedLoadOp.getAlignment();
auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
Value one = arith::ConstantOp::create(rewriter, loc, indexType,
IntegerAttr::get(indexType, 1));
@@ -73,8 +74,9 @@ struct VectorMaskedLoadOpConverter final
auto ifOp = scf::IfOp::create(
rewriter, loc, maskBit,
[&](OpBuilder &builder, Location loc) {
- auto loadedValue =
- memref::LoadOp::create(builder, loc, base, indices);
+ auto loadedValue = memref::LoadOp::create(
+ builder, loc, base, indices, /*nontemporal=*/false,
+ alignment.value_or(0));
auto combinedValue =
vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
scf::YieldOp::create(builder, loc, combinedValue.getResult());
@@ -132,6 +134,8 @@ struct VectorMaskedStoreOpConverter final
Value mask = maskedStoreOp.getMask();
Value base = maskedStoreOp.getBase();
Value value = maskedStoreOp.getValueToStore();
+ bool nontemporal = false;
+ std::optional<uint64_t> alignment = maskedStoreOp.getAlignment();
auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
Value one = arith::ConstantOp::create(rewriter, loc, indexType,
IntegerAttr::get(indexType, 1));
@@ -141,7 +145,8 @@ struct VectorMaskedStoreOpConverter final
auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
- memref::StoreOp::create(rewriter, loc, extractedValue, base, indices);
+ memref::StoreOp::create(rewriter, loc, extractedValue, base, indices,
+ nontemporal, alignment.value_or(0));
rewriter.setInsertionPointAfter(ifOp);
indices.back() =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index 3867f075af8e4..e74eb08339684 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -54,6 +54,22 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
return %0: vector<4xf32>
}
+// CHECK-LABEL: @vector_maskedload_with_alignment
+// CHECK: memref.load
+// CHECK-SAME: {alignment = 8 : i64}
+// CHECK: memref.load
+// CHECK-SAME: {alignment = 8 : i64}
+func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %idx_4 = arith.constant 4 : index
+ %mask = vector.create_mask %idx_1 : vector<4xi1>
+ %s = arith.constant 0.0 : f32
+ %pass_thru = vector.splat %s : vector<4xf32>
+ %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ return %0: vector<4xf32>
+}
+
// CHECK-LABEL: @vector_maskedstore
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
@@ -93,3 +109,17 @@ func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
return
}
+
+// CHECK-LABEL: @vector_maskedstore_with_alignment
+// CHECK: memref.store
+// CHECK-SAME: {alignment = 8 : i64}
+// CHECK: memref.store
+// CHECK-SAME: {alignment = 8 : i64}
+func.func @vector_maskedstore_with_alignment(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
+ %idx_0 = arith.constant 0 : index
+ %idx_1 = arith.constant 1 : index
+ %idx_4 = arith.constant 4 : index
+ %mask = vector.create_mask %idx_1 : vector<4xi1>
+ vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 { alignment = 8 } : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
+ return
+}
More information about the Mlir-commits
mailing list