[Mlir-commits] [mlir] [mlir][vector] Restrict vector.shape_cast (scalable vectors) (PR #100331)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 24 02:05:02 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Updates the verifier for `vector.shape_cast` so that the following
incorrect cases are immediately rejected:
```mlir
  vector.shape_cast %vec : vector<1x1x[4]xindex> to vector<4xindex>
```

Seperately, here's a fix for the Linalg vectorizer to prevent the
vectorizer from generating such shape casts (*):
* https://github.com/llvm/llvm-project/pull/100325

(*) Note, that's just one specific case that I've identified so far.


---
Full diff: https://github.com/llvm/llvm-project/pull/100331.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+5) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+7) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index df3a59ed80ad4..f145dc9f8817d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5238,6 +5238,11 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
     if (!isValidShapeCast(resultShape, sourceShape))
       return op->emitOpError("invalid shape cast");
   }
+
+  // Check that (non-)scalability is preserved
+  if (sourceVectorType.isScalable() != resultVectorType.isScalable())
+    return op->emitOpError("non-matching scalability flags");
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 208982a3e0e7b..3fad61198b474 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1182,6 +1182,13 @@ func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
 
 // -----
 
+func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
+  // expected-error at +1 {{non-matching scalability flags}}
+  %0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
+}
+
+// -----
+
 func.func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
   // expected-error at +1 {{'vector.bitcast' invalid kind of type specified}}
   %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32

``````````

</details>


https://github.com/llvm/llvm-project/pull/100331


More information about the Mlir-commits mailing list