[Mlir-commits] [mlir] [mlir][vector] Add verification for incorrect vector.extract (PR #115824)

Kunwar Grover llvmlistbot at llvm.org
Tue Nov 26 07:23:01 PST 2024


================
@@ -1339,6 +1339,83 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   return l == r;
 }
 
+// Common verification rules for `InsertOp` and `ExtractOp` involving indices
+// and shapes. `indexedType` is the vector type being indexed by the operation,
+// i.e., the destination type in `InsertOp` and the source type in `ExtractOp`.
+// `nonIndexedType` is the inserted or extracted type by an `InsertOp` or and
+// `ExtractOp`, respectively.
+static LogicalResult verifyInsertExtractIndicesAndShapes(Operation *op,
+                                                         VectorType indexedType,
+                                                         int64_t numIndices,
+                                                         Type nonIndexedType) {
+  assert((isa<InsertOp>(op) || isa<ExtractOp>(op)) &&
+         "Expected InsertOp or ExtractOp");
+
+  std::string nonIndexedStr = isa<InsertOp>(op) ? "inserted" : "extracted";
+  std::string indexedStr = isa<InsertOp>(op) ? "destination" : "source";
+  int64_t indexedRank = indexedType.getRank();
+  if (numIndices > indexedRank) {
+    return op->emitOpError()
+           << "expected a number of indices no greater than the " << indexedStr
+           << " vector rank";
+  }
+
+  if (auto nonIndexedVecType = dyn_cast<VectorType>(nonIndexedType)) {
+    // Vector case, including meaningful cases such as:
+    //  * 0-D vector:
+    //    * vector.extract %src[2]: vector<f32> from vector<8xf32)
+    //    * vector.insert %src, %dst[3]: vector<f32> into vector<8xf32>
+    //  * One-element vector:
+    //    * vector.extract %src[4]: vector<1xf32> from vector<8xf32>
+    //    * vector.insert %src, %dst[1]: vector<1xf32> into vector<8xf32>
+    //    * vector.extract %src[7]: vector<1xf32> from vector<8x1xf32>
+    //    * vector.insert %src, %dst[5]: vector<1xf32> into vector<8x1xf32>
----------------
Groverkss wrote:

Can we make the verifier stricter? I don't see what we gain from having these return types:

```
1. vector.extract %src[4]: vector<f32> from vector<8xf32>
2. vector.extract %src[4]: vector<1xf32> from vector<8xf32>
3. vector.extract %src[4]: f32 from vector<8xf32>
```

The second example especially really confusing. Can we instead make the verifier choose one representation for a given input type and indices? We can make a choice between the above example 1 and 3 as the canonical choice:

Possibility 1 (possible 0-d vectors as result):

```
Rule 1: Result is always a vector type of sourceVectorRank - numIndices
```

Possibility 2 (no 0-d vectors as a result):

```
Rule 1: Result is scalar when sourceVectorRank == numIndices
Rule 2: Result is a vector of rank sourceVectorRank - numIndices when sourceVectorRank < numIndices
```

I would guess Possibility 2 would be the best representation for now, since it matches what llvm and spirv lowering can do.

This would also eliminate the need for doing any bitcast / reshapes for vector.extract.

What do you think?

https://github.com/llvm/llvm-project/pull/115824


More information about the Mlir-commits mailing list