[Mlir-commits] [mlir] [mlir][vector] Propagate alignment when emulating masked{load, stores}. (PR #155648)

Erick Ochoa Lopez llvmlistbot at llvm.org
Thu Aug 28 05:22:14 PDT 2025


https://github.com/amd-eochoalo updated https://github.com/llvm/llvm-project/pull/155648

>From 52d6a7ba4f00854490b15cbaf58aa5eb73973ddc Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Wed, 27 Aug 2025 09:20:10 -0700
Subject: [PATCH 1/2] [mlir][vector] Propagate alignment when emulating
 masked{load,stores}.

---
 .../VectorEmulateMaskedLoadStore.cpp          | 13 ++++++--
 .../vector-emulate-masked-load-store.mlir     | 30 +++++++++++++++++++
 2 files changed, 40 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index cb3e8dc67a1ae..a00a1352ce40d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -64,6 +64,9 @@ struct VectorMaskedLoadOpConverter final
     Value mask = maskedLoadOp.getMask();
     Value base = maskedLoadOp.getBase();
     Value iValue = maskedLoadOp.getPassThru();
+    bool nontemporal = false;
+    auto alignment = maskedLoadOp.getAlignment();
+    uint64_t align = alignment.has_value() ? alignment.value() : 0;
     auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
@@ -73,8 +76,8 @@ struct VectorMaskedLoadOpConverter final
       auto ifOp = scf::IfOp::create(
           rewriter, loc, maskBit,
           [&](OpBuilder &builder, Location loc) {
-            auto loadedValue =
-                memref::LoadOp::create(builder, loc, base, indices);
+            auto loadedValue = memref::LoadOp::create(
+                builder, loc, base, indices, nontemporal, align);
             auto combinedValue =
                 vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
             scf::YieldOp::create(builder, loc, combinedValue.getResult());
@@ -132,6 +135,9 @@ struct VectorMaskedStoreOpConverter final
     Value mask = maskedStoreOp.getMask();
     Value base = maskedStoreOp.getBase();
     Value value = maskedStoreOp.getValueToStore();
+    bool nontemporal = false;
+    auto alignment = maskedStoreOp.getAlignment();
+    uint64_t align = alignment.has_value() ? alignment.value() : 0;
     auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
@@ -141,7 +147,8 @@ struct VectorMaskedStoreOpConverter final
       auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
       auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
-      memref::StoreOp::create(rewriter, loc, extractedValue, base, indices);
+      memref::StoreOp::create(rewriter, loc, extractedValue, base, indices,
+                              nontemporal, align);
 
       rewriter.setInsertionPointAfter(ifOp);
       indices.back() =
diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
index 3867f075af8e4..e74eb08339684 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir
@@ -54,6 +54,22 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
   return %0: vector<4xf32>
 }
 
+// CHECK-LABEL:  @vector_maskedload_with_alignment
+//       CHECK:       memref.load
+//       CHECK-SAME:  {alignment = 8 : i64}
+//       CHECK:       memref.load
+//       CHECK-SAME:  {alignment = 8 : i64}
+func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  %s = arith.constant 0.0 : f32
+  %pass_thru = vector.splat %s : vector<4xf32>
+  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
 // CHECK-LABEL:  @vector_maskedstore
 //  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
 //   CHECK-DAG:  %[[C7:.*]] = arith.constant 7 : index
@@ -93,3 +109,17 @@ func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
   vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
   return
 }
+
+// CHECK-LABEL:  @vector_maskedstore_with_alignment
+//       CHECK:       memref.store
+//       CHECK-SAME:  {alignment = 8 : i64}
+//       CHECK:       memref.store
+//       CHECK-SAME:  {alignment = 8 : i64}
+func.func @vector_maskedstore_with_alignment(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
+  %idx_0 = arith.constant 0 : index
+  %idx_1 = arith.constant 1 : index
+  %idx_4 = arith.constant 4 : index
+  %mask = vector.create_mask %idx_1 : vector<4xi1>
+  vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 { alignment = 8 } : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
+  return
+}

>From 6432c874ddf8d4c5b1eef131ae4c365fb1407066 Mon Sep 17 00:00:00 2001
From: Erick Ochoa <erick.ochoalopez at amd.com>
Date: Thu, 28 Aug 2025 05:21:52 -0700
Subject: [PATCH 2/2] Use explicit type instead of auto

---
 .../Vector/Transforms/VectorEmulateMaskedLoadStore.cpp | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index a00a1352ce40d..bdf5b205b20af 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -65,8 +65,7 @@ struct VectorMaskedLoadOpConverter final
     Value base = maskedLoadOp.getBase();
     Value iValue = maskedLoadOp.getPassThru();
     bool nontemporal = false;
-    auto alignment = maskedLoadOp.getAlignment();
-    uint64_t align = alignment.has_value() ? alignment.value() : 0;
+    std::optional<uint64_t> alignment = maskedLoadOp.getAlignment();
     auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
@@ -77,7 +76,7 @@ struct VectorMaskedLoadOpConverter final
           rewriter, loc, maskBit,
           [&](OpBuilder &builder, Location loc) {
             auto loadedValue = memref::LoadOp::create(
-                builder, loc, base, indices, nontemporal, align);
+                builder, loc, base, indices, nontemporal, alignment.value_or(0));
             auto combinedValue =
                 vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
             scf::YieldOp::create(builder, loc, combinedValue.getResult());
@@ -136,8 +135,7 @@ struct VectorMaskedStoreOpConverter final
     Value base = maskedStoreOp.getBase();
     Value value = maskedStoreOp.getValueToStore();
     bool nontemporal = false;
-    auto alignment = maskedStoreOp.getAlignment();
-    uint64_t align = alignment.has_value() ? alignment.value() : 0;
+    std::optional<uint64_t> alignment = maskedStoreOp.getAlignment();
     auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
     Value one = arith::ConstantOp::create(rewriter, loc, indexType,
                                           IntegerAttr::get(indexType, 1));
@@ -148,7 +146,7 @@ struct VectorMaskedStoreOpConverter final
       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
       auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
       memref::StoreOp::create(rewriter, loc, extractedValue, base, indices,
-                              nontemporal, align);
+                              nontemporal, alignment.value_or(0));
 
       rewriter.setInsertionPointAfter(ifOp);
       indices.back() =



More information about the Mlir-commits mailing list