[Mlir-commits] [mlir] [mlir][vector] Reject alignment attribute on tensor-level gather/scatter (PR #188924)
Jorn Tuyls
llvmlistbot at llvm.org
Fri Mar 27 02:55:03 PDT 2026
https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/188924
>From 47de5b386e6846f4b61e8be8b8eefc411f90ce30 Mon Sep 17 00:00:00 2001
From: Jorn <jorn.tuyls at gmail.com>
Date: Fri, 27 Mar 2026 02:33:13 -0700
Subject: [PATCH] [mlir][vector] Reject alignment attribute on tensor-level
gather/scatter
Alignment is a memory concept that only applies to memref bases. Add
verifier checks to vector.gather and vector.scatter to reject the
alignment attribute when the base is a not a memref. If a future use case
requires alignment on tensor-level ops, the verifier will surface it.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 ++++++++
mlir/test/Dialect/Vector/invalid.mlir | 18 ++++++++++++++++++
2 files changed, 26 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 73632875ca9e2..05dd3f597f096 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6191,6 +6191,10 @@ LogicalResult GatherOp::verify() {
return emitOpError("expected result dim to match mask dim");
if (resVType != getPassThruVectorType())
return emitOpError("expected pass_thru of same type as result type");
+ if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
+ return emitOpError(
+ "alignment is only supported for memref bases, not tensor bases");
+ }
return success();
}
@@ -6299,6 +6303,10 @@ LogicalResult ScatterOp::verify() {
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
+ if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
+ return emitOpError(
+ "alignment is only supported for memref bases, not tensor bases");
+ }
return success();
}
namespace {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 8f8429e5844d6..f90312c915334 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1545,6 +1545,15 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
// -----
+func.func @gather_tensor_alignment(%base: tensor<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
+ // expected-error at +1 {{'vector.gather' op alignment is only supported for memref bases, not tensor bases}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ { alignment = 8 : i64 } : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1624,6 +1633,15 @@ func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vect
// -----
+func.func @scatter_tensor_alignment(%base: tensor<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error at +1 {{'vector.scatter' op alignment is only supported for memref bases, not tensor bases}}
+ vector.scatter %base[%c0][%indices], %mask, %value { alignment = 8 : i64 }
+ : tensor<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<?xf32>
+}
+
+// -----
+
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error at +1 {{'vector.expandload' op base element type ('f64') does not match result element type ('f32')}}
More information about the Mlir-commits
mailing list