[Mlir-commits] [mlir] 5ae2fe7 - [mlir][vector] Reject alignment attribute on tensor-level gather/scatter (#188924)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 28 01:06:24 PDT 2026
Author: Jorn Tuyls
Date: 2026-03-28T09:06:19+01:00
New Revision: 5ae2fe75c3898cbf78f170d3cd686e02182f36fc
URL: https://github.com/llvm/llvm-project/commit/5ae2fe75c3898cbf78f170d3cd686e02182f36fc
DIFF: https://github.com/llvm/llvm-project/commit/5ae2fe75c3898cbf78f170d3cd686e02182f36fc.diff
LOG: [mlir][vector] Reject alignment attribute on tensor-level gather/scatter (#188924)
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index c1536d6e062cd..bd419f2ba93ee 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6192,6 +6192,10 @@ LogicalResult GatherOp::verify() {
return emitOpError("expected result dim to match mask dim");
if (resVType != getPassThruVectorType())
return emitOpError("expected pass_thru of same type as result type");
+ if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
+ return emitOpError(
+ "alignment is only supported for memref bases, not tensor bases");
+ }
return success();
}
@@ -6300,6 +6304,10 @@ LogicalResult ScatterOp::verify() {
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore dim to match mask dim");
+ if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
+ return emitOpError(
+ "alignment is only supported for memref bases, not tensor bases");
+ }
return success();
}
namespace {
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 8f8429e5844d6..f90312c915334 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1545,6 +1545,15 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
// -----
+func.func @gather_tensor_alignment(%base: tensor<16xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %c0 : index) {
+ // expected-error at +1 {{'vector.gather' op alignment is only supported for memref bases, not tensor bases}}
+ %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+ { alignment = 8 : i64 } : tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1624,6 +1633,15 @@ func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vect
// -----
+func.func @scatter_tensor_alignment(%base: tensor<?xf32>, %indices: vector<16xi32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>, %c0: index) {
+ // expected-error at +1 {{'vector.scatter' op alignment is only supported for memref bases, not tensor bases}}
+ vector.scatter %base[%c0][%indices], %mask, %value { alignment = 8 : i64 }
+ : tensor<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<?xf32>
+}
+
+// -----
+
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error at +1 {{'vector.expandload' op base element type ('f64') does not match result element type ('f32')}}
More information about the Mlir-commits
mailing list