[Mlir-commits] [mlir] 98c73d5 - [mlir][vector] Restrict vector.shape_cast (scalable vectors) (#100331)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 25 01:50:28 PDT 2024


Author: Andrzej WarzyƄski
Date: 2024-07-25T09:50:25+01:00
New Revision: 98c73d5df7ff0b5d9c10bc9d44a584d631def1e6

URL: https://github.com/llvm/llvm-project/commit/98c73d5df7ff0b5d9c10bc9d44a584d631def1e6
DIFF: https://github.com/llvm/llvm-project/commit/98c73d5df7ff0b5d9c10bc9d44a584d631def1e6.diff

LOG: [mlir][vector] Restrict vector.shape_cast (scalable vectors) (#100331)

Updates the verifier for `vector.shape_cast` so that incorrect cases
where "scalability" is dropped are immediately rejected. For example:
```mlir
  vector.shape_cast %vec : vector<1x1x[4]xindex> to vector<4xindex>
```

Also, as a separate PR, I've prepared a fix for the Linalg vectorizer to
avoid generating such shape casts (*):
* https://github.com/llvm/llvm-project/pull/100325

(*) Note, that's just one specific case that I've identified so far.

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 0b3532dcc7d4f..176a167a3ca31 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1168,6 +1168,11 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
       return !llvm::is_contained(getScalableDims(), false);
     }
 
+    /// Get the number of scalable dimensions.
+    int64_t getNumScalableDims() const {
+      return llvm::count(getScalableDims(), true);
+    }
+
     /// Get or create a new VectorType with the same shape as `this` and an
     /// element type of bitwidth scaled by `scale`.
     /// Return null if the scaled element type cannot be represented.

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index df3a59ed80ad4..d297c40760cd8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5238,6 +5238,16 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
     if (!isValidShapeCast(resultShape, sourceShape))
       return op->emitOpError("invalid shape cast");
   }
+
+  // Check that (non-)scalability is preserved
+  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
+  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
+  if (sourceNScalableDims != resultNScalableDims)
+    return op->emitOpError("
diff erent number of scalable dims at source (")
+           << sourceNScalableDims << ") and result (" << resultNScalableDims
+           << ")";
+  sourceVectorType.getNumDynamicDims();
+
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 208982a3e0e7b..00914c1d1baf6 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1182,6 +1182,20 @@ func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
 
 // -----
 
+func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
+  // expected-error at +1 {{
diff erent number of scalable dims at source (1) and result (0)}}
+  %0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
+}
+
+// -----
+
+func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<2x[15]x[2]xf32>) {
+  // expected-error at +1 {{
diff erent number of scalable dims at source (2) and result (1)}}
+  %0 = vector.shape_cast %arg0 : vector<2x[15]x[2]xf32> to vector<30x[2]xf32>
+}
+
+// -----
+
 func.func @bitcast_not_vector(%arg0 : vector<5x1x3x2xf32>) {
   // expected-error at +1 {{'vector.bitcast' invalid kind of type specified}}
   %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to f32


        


More information about the Mlir-commits mailing list