[Mlir-commits] [mlir] [mlir][vector] Preserve alignment attribute during gather/scatter bufferization (PR #188924)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 27 01:19:37 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Jorn Tuyls (jtuyls)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/188924.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+4-3) 
- (modified) mlir/test/Dialect/Vector/bufferize.mlir (+5-5) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 352f477a8746e..3301dd921b3b6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -167,9 +167,10 @@ struct ScatterOpInterface
     if (failed(buffer))
       return failure();
     vector::ScatterOp::create(rewriter, scatterOp.getLoc(),
-                              /*resultType=*/nullptr, *buffer,
+                              /*resultType=*/Type{}, *buffer,
                               scatterOp.getOffsets(), scatterOp.getIndices(),
-                              scatterOp.getMask(), scatterOp.getValueToStore());
+                              scatterOp.getMask(), scatterOp.getValueToStore(),
+                              scatterOp.getAlignmentAttr());
     replaceOpWithBufferizedValues(rewriter, op, *buffer);
     return success();
   }
@@ -212,7 +213,7 @@ struct GatherOpInterface
     replaceOpWithNewBufferizedOp<vector::GatherOp>(
         rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
         gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
-        gatherOp.getPassThru());
+        gatherOp.getPassThru(), gatherOp.getAlignmentAttr());
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index 70adefd0dc4ec..3162e48fb36db 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -39,13 +39,13 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
 //       CHECK:   %[[c0:.*]] = arith.constant 0 : index
 //       CHECK:   %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
 //       CHECK:   memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
-//       CHECK:   vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+//       CHECK:   vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] {alignment = 8 : i64} : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 //       CHECK:   %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
 //       CHECK:   return %[[tensor]] : tensor<16x16xf32>
-func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>, 
+func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
                   %mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
   %c0 = arith.constant 0 : index
-  %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
+  %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value {alignment = 8 : i64}
       : tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
   return %0 : tensor<16x16xf32>
 }
@@ -57,10 +57,10 @@ func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
 //  CHECK-SAME:     %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
 //       CHECK:   %[[m:.*]] = bufferization.to_buffer %[[base]] : tensor<?x?xf32> to memref<?x?xf32>
 //       CHECK:   %[[c0:.*]] = arith.constant 0 : index
-//       CHECK:   %[[out:.*]] = vector.gather %[[m]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[pass_thru]] : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+//       CHECK:   %[[out:.*]] = vector.gather %[[m]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[pass_thru]] {alignment = 8 : i64} : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 func.func @gather(%base: tensor<?x?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0 : index
-  %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : tensor<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru {alignment = 8 : i64} : tensor<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %0 : vector<16xf32>
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/188924


More information about the Mlir-commits mailing list