[Mlir-commits] [mlir] [mlir][VectorOps] Don't fold extract chains that include dynamic indices (PR #68333)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 5 09:33:29 PDT 2023

llvmbot wrote:




This is not yet supported and previously led to a confusing crash where an extract op with a kDynamic marker, but no dynamic positions was created. The verifier has also been updated to check for this, and hint at where the problem is likely to be.

Full diff: https://github.com/llvm/llvm-project/pull/68333.diff

2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+11) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12) 

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 027ef3605aeba46..f84a574c4634fc3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1244,6 +1244,14 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 LogicalResult vector::ExtractOp::verify() {
+  // Note: This check must come before getMixedPosition() to prevent a crash.
+  auto dynamicMarkersCount =
+      llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
+  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
+    return emitOpError(
+        "mismatch between dynamic and static positions (kDynamic marker but no "
+        "corresponding dynamic position) -- this can only happen due to an "
+        "incorrect/fold rewrite");
   auto position = getMixedPosition();
   if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
     return emitOpError(
@@ -1285,6 +1293,9 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
   globalPosition.append(extrPos.rbegin(), extrPos.rend());
   while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
     currentOp = nextOp;
+    // TODO: Canonicalization for dynamic position not implemented yet.
+    if (currentOp.hasDynamicPosition())
+      return failure();
     ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
     globalPosition.append(extrPos.rbegin(), extrPos.rend());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 05615b96ae6d69f..924886c50030967 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1693,6 +1693,18 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %
 // -----
+// CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts
+//  CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index)
+//       CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32>
+//       CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
+func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 {
+  %0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+// -----
 // CHECK-LABEL: extract_extract_strided2
 //  CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
 //       CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>




More information about the Mlir-commits mailing list