[Mlir-commits] [mlir] [mlir][vector] Restrict vector.shape_cast (scalable vectors) (PR #100331)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Jul 25 01:12:45 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/100331

>From 858c6a7e4b41b1b05da77dec177de33dfdf82c6d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 24 Jul 2024 09:54:11 +0100
Subject: [PATCH 1/5] [mlir][vector] Restrict vector.shape_cast (scalable
 vectors)

Updates the verifier for `vector.shape_cast` so that the following
incorrect cases are immediately rejected:
```mlir
  vector.shape_cast %vec : vector<1x1x[4]xindex> to vector<4xindex>
```

Seperately, here's a fix for the Linalg vectorizer to prevent the
vectorizer from generating such shape casts (*):
* https://github.com/llvm/llvm-project/pull/100325

(*) Note, that's just one specific case that I've identified so far.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +++++
 mlir/test/Dialect/Vector/invalid.mlir    | 7 +++++++
 2 files changed, 12 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index df3a59ed80ad4..f145dc9f8817d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5238,6 +5238,11 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
     if (!isValidShapeCast(resultShape, sourceShape))
       return op->emitOpError("invalid shape cast");
   }
+
+  // Check that (non-)scalability is preserved
+  if (sourceVectorType.isScalable() != resultVectorType.isScalable())
+    return op->emitOpError("non-matching scalability flags");
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 208982a3e0e7b..3fad61198b474 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1182,6 +1182,13 @@ func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
 
 // -----
 
+func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
+  // expected-error at +1 {{non-matching scalability flags}}
+  %0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
+}
+
+// -----
+
 func.func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
   // expected-error at +1 {{'vector.bitcast' invalid kind of type specified}}
   %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32

>From f67115a0805144c907a04d7e4fd1dd981b81833c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 24 Jul 2024 17:21:20 +0100
Subject: [PATCH 2/5] fixup! [mlir][vector] Restrict vector.shape_cast
 (scalable vectors)

Make the verifier even stricter
---
 mlir/include/mlir/IR/BuiltinTypes.td     | 6 ++++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 6 ++++--
 mlir/test/Dialect/Vector/invalid.mlir    | 9 ++++++++-
 3 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 0b3532dcc7d4f..079d1b4921645 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1168,6 +1168,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
       return !llvm::is_contained(getScalableDims(), false);
     }
 
+    /// Get the number of scalable dimension.
+    int64_t getNumScalableDims() const {
+      return llvm::count(getScalableDims(), true);
+    }
+
+
     /// Get or create a new VectorType with the same shape as `this` and an
     /// element type of bitwidth scaled by `scale`.
     /// Return null if the scaled element type cannot be represented.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f145dc9f8817d..d52411a54c36e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5240,8 +5240,10 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
   }
 
   // Check that (non-)scalability is preserved
-  if (sourceVectorType.isScalable() != resultVectorType.isScalable())
-    return op->emitOpError("non-matching scalability flags");
+  if (sourceVectorType.getNumScalableDims() !=
+      resultVectorType.getNumScalableDims())
+    return op->emitOpError("non-matching scalable dims");
+  sourceVectorType.getNumDynamicDims();
 
   return success();
 }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3fad61198b474..79fbdc55ed426 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1183,12 +1183,19 @@ func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
 // -----
 
 func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
-  // expected-error at +1 {{non-matching scalability flags}}
+  // expected-error at +1 {{non-matching scalable dims}}
   %0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
 }
 
 // -----
 
+func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<2x[15]x[2]xf32>) {
+  // expected-error at +1 {{non-matching scalable dims}}
+  %0 = vector.shape_cast %arg0 : vector<2x[15]x[2]xf32> to vector<30x[2]f32>
+}
+
+// -----
+
 func.func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
   // expected-error at +1 {{'vector.bitcast' invalid kind of type specified}}
   %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32

>From 946ff5b7d58c68004e1ff28e949e8205d087580c Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 24 Jul 2024 18:23:15 +0100
Subject: [PATCH 3/5] fixup! fixup! [mlir][vector] Restrict vector.shape_cast
 (scalable vectors)

Fix typos
---
 mlir/include/mlir/IR/BuiltinTypes.td  | 2 +-
 mlir/test/Dialect/Vector/invalid.mlir | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 079d1b4921645..38419dfa50106 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1168,7 +1168,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
       return !llvm::is_contained(getScalableDims(), false);
     }
 
-    /// Get the number of scalable dimension.
+    /// Get the number of scalable dimensions.
     int64_t getNumScalableDims() const {
       return llvm::count(getScalableDims(), true);
     }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 79fbdc55ed426..4fb9f7b7d9854 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1191,7 +1191,7 @@ func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
 
 func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<2x[15]x[2]xf32>) {
   // expected-error at +1 {{non-matching scalable dims}}
-  %0 = vector.shape_cast %arg0 : vector<2x[15]x[2]xf32> to vector<30x[2]f32>
+  %0 = vector.shape_cast %arg0 : vector<2x[15]x[2]xf32> to vector<30x[2]xf32>
 }
 
 // -----

>From 4071c99a718d14a2e6f9b08d9a999fd6b25d1435 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Wed, 24 Jul 2024 21:28:55 +0100
Subject: [PATCH 4/5] fixup! fixup! fixup! [mlir][vector] Restrict
 vector.shape_cast (scalable vectors)

Improve diagnostic
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 ++++++---
 mlir/test/Dialect/Vector/invalid.mlir    | 4 ++--
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d52411a54c36e..3a4c5e2a60010 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5240,9 +5240,12 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
   }
 
   // Check that (non-)scalability is preserved
-  if (sourceVectorType.getNumScalableDims() !=
-      resultVectorType.getNumScalableDims())
-    return op->emitOpError("non-matching scalable dims");
+  auto sourceNScalableDims = sourceVectorType.getNumScalableDims();
+  auto resultNScalableDims = resultVectorType.getNumScalableDims();
+  if (sourceNScalableDims != resultNScalableDims)
+    return op->emitOpError("different number of scalable dims at source (")
+           << sourceNScalableDims << ") and result (" << resultNScalableDims
+           << ")";
   sourceVectorType.getNumDynamicDims();
 
   return success();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 4fb9f7b7d9854..00914c1d1baf6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1183,14 +1183,14 @@ func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
 // -----
 
 func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
-  // expected-error at +1 {{non-matching scalable dims}}
+  // expected-error at +1 {{different number of scalable dims at source (1) and result (0)}}
   %0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
 }
 
 // -----
 
 func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<2x[15]x[2]xf32>) {
-  // expected-error at +1 {{non-matching scalable dims}}
+  // expected-error at +1 {{different number of scalable dims at source (2) and result (1)}}
   %0 = vector.shape_cast %arg0 : vector<2x[15]x[2]xf32> to vector<30x[2]xf32>
 }
 

>From 54320e31cf94e9cb13299fb4766dc21549221036 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 25 Jul 2024 09:11:59 +0100
Subject: [PATCH 5/5] fixup! fixup! fixup! fixup! [mlir][vector] Restrict
 vector.shape_cast (scalable vectors)

Spell out auto
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3a4c5e2a60010..d297c40760cd8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5240,8 +5240,8 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
   }
 
   // Check that (non-)scalability is preserved
-  auto sourceNScalableDims = sourceVectorType.getNumScalableDims();
-  auto resultNScalableDims = resultVectorType.getNumScalableDims();
+  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
+  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
   if (sourceNScalableDims != resultNScalableDims)
     return op->emitOpError("different number of scalable dims at source (")
            << sourceNScalableDims << ") and result (" << resultNScalableDims



More information about the Mlir-commits mailing list