[Mlir-commits] [mlir] 469b9cb - [mlir][VectorOps] Don't fold extract chains that include dynamic indices (#68333)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 6 06:32:42 PDT 2023


Author: Benjamin Maxwell
Date: 2023-10-06T14:32:37+01:00
New Revision: 469b9cbe5a410afc836a22fa75870d37dbf9cecc

URL: https://github.com/llvm/llvm-project/commit/469b9cbe5a410afc836a22fa75870d37dbf9cecc
DIFF: https://github.com/llvm/llvm-project/commit/469b9cbe5a410afc836a22fa75870d37dbf9cecc.diff

LOG: [mlir][VectorOps] Don't fold extract chains that include dynamic indices (#68333)

This is not yet supported and previously led to a confusing crash where
an extract op with a kDynamic marker, but no dynamic positions was
created. The verifier has also been updated to check for this, and hint
at where the problem is likely to be.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 027ef3605aeba46..044b6cc07d3d629 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1244,6 +1244,14 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 }
 
 LogicalResult vector::ExtractOp::verify() {
+  // Note: This check must come before getMixedPosition() to prevent a crash.
+  auto dynamicMarkersCount =
+      llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
+  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
+    return emitOpError(
+        "mismatch between dynamic and static positions (kDynamic marker but no "
+        "corresponding dynamic position) -- this can only happen due to an "
+        "incorrect fold/rewrite");
   auto position = getMixedPosition();
   if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
     return emitOpError(
@@ -1285,6 +1293,9 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
   globalPosition.append(extrPos.rbegin(), extrPos.rend());
   while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
     currentOp = nextOp;
+    // TODO: Canonicalization for dynamic position not implemented yet.
+    if (currentOp.hasDynamicPosition())
+      return failure();
     ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
     globalPosition.append(extrPos.rbegin(), extrPos.rend());
   }

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 05615b96ae6d69f..924886c50030967 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1693,6 +1693,18 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, %
 
 // -----
 
+// CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts
+//  CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index)
+//       CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32>
+//       CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32>
+func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 {
+  %0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32>
+  %1 = vector.extract %0[1] : f32 from vector<4xf32>
+  return %1 : f32
+}
+
+// -----
+
 // CHECK-LABEL: extract_extract_strided2
 //  CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
 //       CHECK: %[[V:.*]] = vector.extract %[[A]][1] : vector<4xf32> from vector<2x4xf32>


        


More information about the Mlir-commits mailing list