[Mlir-commits] [mlir] [mlir][vector] Preserve alignment attribute during gather/scatter bufferization (PR #188924)
Jorn Tuyls
llvmlistbot at llvm.org
Fri Mar 27 01:19:06 PDT 2026
https://github.com/jtuyls created https://github.com/llvm/llvm-project/pull/188924
None
>From 4b0eb5d744dbf8a8b3805cdca15805afb92f985e Mon Sep 17 00:00:00 2001
From: Jorn <jorn.tuyls at gmail.com>
Date: Fri, 27 Mar 2026 00:48:00 -0700
Subject: [PATCH] [mlir][vector] Preserve alignment attribute during
gather/scatter bufferization
---
.../Vector/Transforms/BufferizableOpInterfaceImpl.cpp | 7 ++++---
mlir/test/Dialect/Vector/bufferize.mlir | 10 +++++-----
2 files changed, 9 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 352f477a8746e..3301dd921b3b6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -167,9 +167,10 @@ struct ScatterOpInterface
if (failed(buffer))
return failure();
vector::ScatterOp::create(rewriter, scatterOp.getLoc(),
- /*resultType=*/nullptr, *buffer,
+ /*resultType=*/Type{}, *buffer,
scatterOp.getOffsets(), scatterOp.getIndices(),
- scatterOp.getMask(), scatterOp.getValueToStore());
+ scatterOp.getMask(), scatterOp.getValueToStore(),
+ scatterOp.getAlignmentAttr());
replaceOpWithBufferizedValues(rewriter, op, *buffer);
return success();
}
@@ -212,7 +213,7 @@ struct GatherOpInterface
replaceOpWithNewBufferizedOp<vector::GatherOp>(
rewriter, gatherOp, gatherOp.getVectorType(), *buffer,
gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(),
- gatherOp.getPassThru());
+ gatherOp.getPassThru(), gatherOp.getAlignmentAttr());
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/bufferize.mlir b/mlir/test/Dialect/Vector/bufferize.mlir
index 70adefd0dc4ec..3162e48fb36db 100644
--- a/mlir/test/Dialect/Vector/bufferize.mlir
+++ b/mlir/test/Dialect/Vector/bufferize.mlir
@@ -39,13 +39,13 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
-// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] {alignment = 8 : i64} : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
// CHECK: return %[[tensor]] : tensor<16x16xf32>
-func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
+func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
%c0 = arith.constant 0 : index
- %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
+ %0 = vector.scatter %base[%c0, %c0][%v], %mask, %value {alignment = 8 : i64}
: tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
return %0 : tensor<16x16xf32>
}
@@ -57,10 +57,10 @@ func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)
// CHECK: %[[m:.*]] = bufferization.to_buffer %[[base]] : tensor<?x?xf32> to memref<?x?xf32>
// CHECK: %[[c0:.*]] = arith.constant 0 : index
-// CHECK: %[[out:.*]] = vector.gather %[[m]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[pass_thru]] : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK: %[[out:.*]] = vector.gather %[[m]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[pass_thru]] {alignment = 8 : i64} : memref<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
func.func @gather(%base: tensor<?x?xf32>, %v: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0 : index
- %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : tensor<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ %0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru {alignment = 8 : i64} : tensor<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %0 : vector<16xf32>
}
More information about the Mlir-commits
mailing list