[Mlir-commits] [mlir] [mlir][vector] Add verification for incorrect vector.extract (PR #115824)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Nov 21 06:59:51 PST 2024


================
@@ -1339,6 +1339,83 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
   return l == r;
 }
 
+// Common verification rules for `InsertOp` and `ExtractOp` involving indices
+// and shapes. `indexedType` is the vector type being indexed by the operation,
+// i.e., the destination type in `InsertOp` and the source type in `ExtractOp`.
+// `nonIndexedType` is the inserted or extracted type by an `InsertOp` or and
+// `ExtractOp`, respectively.
+static LogicalResult verifyInsertExtractIndicesAndShapes(Operation *op,
+                                                         VectorType indexedType,
+                                                         int64_t numIndices,
+                                                         Type nonIndexedType) {
+  assert((isa<InsertOp>(op) || isa<ExtractOp>(op)) &&
+         "Expected InsertOp or ExtractOp");
+
+  std::string nonIndexedStr = isa<InsertOp>(op) ? "inserted" : "extracted";
+  std::string indexedStr = isa<InsertOp>(op) ? "destination" : "source";
+  int64_t indexedRank = indexedType.getRank();
+  if (numIndices > indexedRank) {
+    return op->emitOpError()
+           << "expected a number of indices no greater than the " << indexedStr
+           << " vector rank";
+  }
+
+  if (auto nonIndexedVecType = dyn_cast<VectorType>(nonIndexedType)) {
+    // Vector case, including meaningful cases such as:
+    //  * 0-D vector:
+    //    * vector.extract %src[2]: vector<f32> from vector<8xf32)
+    //    * vector.insert %src, %dst[3]: vector<f32> into vector<8xf32>
+    //  * One-element vector:
+    //    * vector.extract %src[4]: vector<1xf32> from vector<8xf32>
+    //    * vector.insert %src, %dst[1]: vector<1xf32> into vector<8xf32>
+    //    * vector.extract %src[7]: vector<1xf32> from vector<8x1xf32>
+    //    * vector.insert %src, %dst[5]: vector<1xf32> into vector<8x1xf32>
+    int64_t nonIndexedRank = nonIndexedVecType.getRank();
+    bool isSingleElem1DNonIndexedVec =
+        (nonIndexedRank == 1 && nonIndexedVecType.getDimSize(0) == 1);
+    bool isSingleElem1DIndexedVec =
+        (indexedRank == 1 && indexedType.getDimSize(0) == 1);
+    // Verify 0-D -> single-element 1-D supported cases.
+    if ((indexedRank == 0 && isSingleElem1DNonIndexedVec) ||
+        (nonIndexedRank == 0 && isSingleElem1DIndexedVec)) {
+      return op->emitOpError("expected source and destination vectors with "
+                             "different number of elements");
+    }
+
+    // Verify indices for all the cases.
+    int64_t indexedRankMinusIndices = indexedRank - numIndices;
+    if (indexedRankMinusIndices != nonIndexedRank &&
+        (!isSingleElem1DNonIndexedVec || indexedRankMinusIndices != 0)) {
+      return op->emitOpError()
+             << "expected " << indexedStr
+             << " vector rank minus number of indices to match the rank of the "
+             << nonIndexedStr << " vector";
+    }
+    // Check that if we are inserting or extracting a sub-vector, the
+    // corresponding source and destination shapes match.
+    if (indexedRankMinusIndices > 0) {
+      auto indexedShape = indexedType.getShape();
+      if (indexedShape.drop_front(numIndices) != nonIndexedVecType.getShape()) {
+        return op->emitOpError() << "expected " << nonIndexedStr
+                                 << " vector shape to match the sub-vector "
+                                    "shape of the "
+                                 << indexedStr << " vector";
+      }
+    }
+
+    return success();
+  }
+
+  // Scalar case.
+  if (indexedRank != numIndices) {
+    return op->emitOpError()
+           << "expected " << indexedStr
+           << " vector rank to match the number of indices for scalar cases";
----------------
banach-space wrote:

This error message is a bit confusing:
* at the level of the verifier (i.e. in this method) it's somewhat clear what "scalar case" mean, but
* at the op level it is not.

How about sth like:
```
'vector.extract' op expected source vector rank to match the number of indices when extracting/inserting a scalar
```
? Extra bonus points for selecting "extracting" for `vector.extract` and "inserting" for `vector.insert` :)

https://github.com/llvm/llvm-project/pull/115824


More information about the Mlir-commits mailing list