[Mlir-commits] [mlir] 1e61b37 - [mlir][vector] Tighten the semantics of vector.gather (#135749)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 16 04:16:08 PDT 2025


Author: Andrzej Warzyński
Date: 2025-04-16T13:16:04+02:00
New Revision: 1e61b374ba3ba2891dc1abda732b0b9263216785

URL: https://github.com/llvm/llvm-project/commit/1e61b374ba3ba2891dc1abda732b0b9263216785
DIFF: https://github.com/llvm/llvm-project/commit/1e61b374ba3ba2891dc1abda732b0b9263216785.diff

LOG: [mlir][vector] Tighten the semantics of vector.gather (#135749)

This patch restricts `vector.gather` to only accept tensors and memrefs
as valid sources. Currently, the source is typed as `AnyShaped`, which
also includes vectors—allowing the following (invalid) construct to pass
verification:

```mlir
  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
       : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
(Note: the source %base here is a vector, which is incorrect.)

In contrast, `vector.scatter` currently only accepts memrefs, so some
asymmetry remains between the two ops. This PR is a step toward aligning
their semantics.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/IR/CommonTypeConstraints.td
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7fc56b1aa4e7e..d7518943229ea 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1972,7 +1972,7 @@ def Vector_GatherOp :
     DeclareOpInterfaceMethods<MaskableOpInterface>,
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
   ]>,
-    Arguments<(ins Arg<AnyShaped, "", [MemRead]>:$base,
+    Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemRead]>:$base,
                Variadic<Index>:$indices,
                VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec,
                VectorOfNonZeroRankOf<[I1]>:$mask,

diff  --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index e6f17ded4628b..45ec1846580f2 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">;
 // Whether a type is a MemRefType.
 def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">;
 
+// Whether a type is a TensorType or a MemRefType.
+def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>;
+
 // Whether a type is an UnrankedMemRefType
 def IsUnrankedMemRefTypePred
         : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">;
@@ -426,7 +429,9 @@ class ValueSemanticsContainerOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, HasValueSemanticsPred,
   "container with value semantics">;
 
+//===----------------------------------------------------------------------===//
 // Vector types.
+//===----------------------------------------------------------------------===//
 
 class VectorOfNonZeroRankOf<list<Type> allowedTypes> :
   ShapedContainerType<allowedTypes, IsVectorOfNonZeroRankTypePred, "vector",
@@ -755,7 +760,7 @@ class StaticShapeTensorOf<list<Type> allowedTypes>
 def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
 
 //===----------------------------------------------------------------------===//
-// Memref type.
+// Memref types.
 //===----------------------------------------------------------------------===//
 
 // Any unranked memref whose element type is from the given `allowedTypes` list.
@@ -878,6 +883,14 @@ class NestedTupleOf<list<Type> allowedTypes> :
                        "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))",
                        "nested tuple">;
 
+//===----------------------------------------------------------------------===//
+// Mixed types
+//===----------------------------------------------------------------------===//
+
+class TensorOrMemRef<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsTensorOrMemRefTypePred, "Tensor or MemRef",
+                      "::mlir::ShapedType">;
+
 //===----------------------------------------------------------------------===//
 // Common type constraints
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index dbf829e014b8d..3a8320971bac4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
 
 // -----
 
+func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+                                %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
+  %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
+    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
 func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
@@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector
 
 // -----
 
+func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
+                             %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +2 {{custom op 'vector.scatter' invalid kind of type specified}}
+  vector.scatter %base[%c0][%indices], %mask, %pass_thru
+    : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+}
+
+// -----
+
+
 func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                  %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list