[Mlir-commits] [mlir] [mlir][vector] Reject alignment attribute on tensor-level gather/scatter (PR #188924)

Jorn Tuyls llvmlistbot at llvm.org
Fri Mar 27 02:55:03 PDT 2026


https://github.com/jtuyls updated https://github.com/llvm/llvm-project/pull/188924

>From 47de5b386e6846f4b61e8be8b8eefc411f90ce30 Mon Sep 17 00:00:00 2001
From: Jorn <jorn.tuyls at gmail.com>
Date: Fri, 27 Mar 2026 02:33:13 -0700
Subject: [PATCH] [mlir][vector] Reject alignment attribute on tensor-level
 gather/scatter

Alignment is a memory concept that only applies to memref bases. Add
verifier checks to vector.gather and vector.scatter to reject the
alignment attribute when the base is a not a memref. If a future use case
requires alignment on tensor-level ops, the verifier will surface it.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp |  8 ++++++++
 mlir/test/Dialect/Vector/invalid.mlir    | 18 ++++++++++++++++++
 2 files changed, 26 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 73632875ca9e2..05dd3f597f096 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6191,6 +6191,10 @@ LogicalResult GatherOp::verify() {
     return emitOpError("expected result dim to match mask dim");
   if (resVType != getPassThruVectorType())
     return emitOpError("expected pass_thru of same type as result type");
+  if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
+    return emitOpError(
+        "alignment is only supported for memref bases, not tensor bases");
+  }
   return success();
 }
 
@@ -6299,6 +6303,10 @@ LogicalResult ScatterOp::verify() {
     return emitOpError("expected valueToStore dim to match indices dim");
   if (valueVType.getShape() != maskVType.getShape())
     return emitOpError("expected valueToStore dim to match mask dim");
+  if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
+    return emitOpError(
+        "alignment is only supported for memref bases, not tensor bases");
+  }
   return success();
 }
 namespace {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 8f8429e5844d6..f90312c915334 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1545,6 +1545,15 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
 
 // -----
 
+func.func @gather_tensor_alignment(%base: tensor<16xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
+  // expected-error at +1 {{'vector.gather' op alignment is only supported for memref bases, not tensor bases}}
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    { alignment = 8 : i64 } : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
 func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
                              %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
@@ -1624,6 +1633,15 @@ func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vect
 
 // -----
 
+func.func @scatter_tensor_alignment(%base: tensor<?xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+  // expected-error at +1 {{'vector.scatter' op alignment is only supported for memref bases, not tensor bases}}
+  vector.scatter %base[%c0][%indices], %mask, %value { alignment = 8 : i64 }
+    : tensor<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<?xf32>
+}
+
+// -----
+
 func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{'vector.expandload' op base element type ('f64') does not match result element type ('f32')}}



More information about the Mlir-commits mailing list