[Mlir-commits] [mlir] [mlir][vector] Relax constraints on shape_cast (PR #136587)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 22 16:53:01 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

<details>
<summary>Changes</summary>

`vector.shape_cast` was initially designed to be the union of collapse_shape and expand_shape. There was a inconsistency in the verifier that allowed any shape casts if the rank did not change, which led to a strange middle ground where you could cast from shape (4,3) to (3,4) but not from (4,3) to (2,3,2). That issue was fixed in https://github.com/llvm/llvm-project/pull/135855, but further feedback there and polling suggests that vector.shape_cast should rather just allow all shape casts (so more like tensor.reshape than tensor.collapse_shape/tensor.expand_shape). This PR makes it that by relaxing the verifier. 

This PR also adds a simple canonicalizer, so that 

```
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32>
%1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<2x2xf32>
```
becomes 
```
%0 = vector.shape_cast %arg0 : vector<4xf32> to  vector<2x2xf32>
```

---
Full diff: https://github.com/llvm/llvm-project/pull/136587.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+6-18) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+34-94) 
- (modified) mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir (+2-4) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+15-4) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+3-16) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (+14) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d7518943229ea..4d49e52b21563 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,8 @@ def Vector_ShapeCastOp :
     Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "shape_cast casts between vector shapes";
   let description = [{
-    The shape_cast operation casts between an n-D source vector shape and
-    a k-D result vector shape (the element type remains the same).
-
-    If reducing rank (n > k), result dimension sizes must be a product
-    of contiguous source dimension sizes.
-    If expanding rank (n < k), source dimensions must factor into a
-    contiguous sequence of destination dimension sizes.
-    Each source dim is expanded (or contiguous sequence of source dims combined)
-    in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
-    sequence of result dims (or a single result dim), in result dimension list
-    order (i.e. 0 <= j < k). The product of all source dimension sizes and all
-    result dimension sizes must match.
+    Casts to a vector with the same number of elements, element type, and
+    number of scalable dimensions.
 
     It is currently assumed that this operation does not require moving data,
     and that it will be folded away before lowering vector operations.
@@ -2265,15 +2255,13 @@ def Vector_ShapeCastOp :
     2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM
     is supported in that particular case, for now.
 
-    Example:
+    Examples:
 
     ```mlir
-    // Example casting to a lower vector rank.
-    %1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32>
-
-    // Example casting to a higher vector rank.
-    %3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32>
+    %1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32>
 
+    // with 2 scalable dimensions (number of which must be preserved).
+    %3 = vector.shape_cast %2 : vector<[2]x3x[4]xi8> to vector<3x[1]x[8]xi8>
     ```
   }];
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 368259b38b153..732a5d21a4b87 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5505,127 +5505,67 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
-  unsigned rankA = a.size();
-  unsigned rankB = b.size();
-  assert(rankA < rankB);
-
-  auto isOne = [](int64_t v) { return v == 1; };
-
-  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
-  // casted to a 0-d vector.
-  if (rankA == 0 && llvm::all_of(b, isOne))
-    return true;
-
-  unsigned i = 0;
-  unsigned j = 0;
-  while (i < rankA && j < rankB) {
-    int64_t dimA = a[i];
-    int64_t dimB = 1;
-    while (dimB < dimA && j < rankB)
-      dimB *= b[j++];
-    if (dimA != dimB)
-      break;
-    ++i;
+LogicalResult ShapeCastOp::verify() {
 
-    // Handle the case when trailing dimensions are of size 1.
-    // Include them into the contiguous sequence.
-    if (i < rankA && llvm::all_of(a.slice(i), isOne))
-      i = rankA;
-    if (j < rankB && llvm::all_of(b.slice(j), isOne))
-      j = rankB;
-  }
+  VectorType sourceType = getSourceVectorType();
+  VectorType resultType = getResultVectorType();
 
-  return i == rankA && j == rankB;
-}
+  // Check that element type is preserved
+  if (sourceType.getElementType() != resultType.getElementType())
+    return emitOpError("has different source and result element types");
 
-static LogicalResult verifyVectorShapeCast(Operation *op,
-                                           VectorType sourceVectorType,
-                                           VectorType resultVectorType) {
-  // Check that element type is the same.
-  if (sourceVectorType.getElementType() != resultVectorType.getElementType())
-    return op->emitOpError("source/result vectors must have same element type");
-  auto sourceShape = sourceVectorType.getShape();
-  auto resultShape = resultVectorType.getShape();
-
-  // Check that product of source dim sizes matches product of result dim sizes.
-  int64_t sourceDimProduct = std::accumulate(
-      sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
-  int64_t resultDimProduct = std::accumulate(
-      resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
-  if (sourceDimProduct != resultDimProduct)
-    return op->emitOpError("source/result number of elements must match");
-
-  // Check that expanding/contracting rank cases.
-  unsigned sourceRank = sourceVectorType.getRank();
-  unsigned resultRank = resultVectorType.getRank();
-  if (sourceRank < resultRank) {
-    if (!isValidShapeCast(sourceShape, resultShape))
-      return op->emitOpError("invalid shape cast");
-  } else if (sourceRank > resultRank) {
-    if (!isValidShapeCast(resultShape, sourceShape))
-      return op->emitOpError("invalid shape cast");
+  // Check that number of elements is preserved
+  int64_t sourceNElms = sourceType.getNumElements();
+  int64_t resultNElms = resultType.getNumElements();
+  if (sourceNElms != resultNElms) {
+    return emitOpError() << "has different number of elements at source ("
+                         << sourceNElms << ") and result (" << resultNElms
+                         << ")";
   }
 
   // Check that (non-)scalability is preserved
-  int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
-  int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
+  int64_t sourceNScalableDims = sourceType.getNumScalableDims();
+  int64_t resultNScalableDims = resultType.getNumScalableDims();
   if (sourceNScalableDims != resultNScalableDims)
-    return op->emitOpError("different number of scalable dims at source (")
-           << sourceNScalableDims << ") and result (" << resultNScalableDims
-           << ")";
-  sourceVectorType.getNumDynamicDims();
-
-  return success();
-}
-
-LogicalResult ShapeCastOp::verify() {
-  auto sourceVectorType =
-      llvm::dyn_cast_or_null<VectorType>(getSource().getType());
-  auto resultVectorType =
-      llvm::dyn_cast_or_null<VectorType>(getResult().getType());
-
-  // Check if source/result are of vector type.
-  if (sourceVectorType && resultVectorType)
-    return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
+    return emitOpError() << "has different number of scalable dims at source ("
+                         << sourceNScalableDims << ") and result ("
+                         << resultNScalableDims << ")";
 
   return success();
 }
 
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
+  VectorType resultType = getType();
+
   // No-op shape cast.
-  if (getSource().getType() == getType())
+  if (getSource().getType() == resultType)
     return getSource();
 
-  VectorType resultType = getType();
-
-  // Canceling shape casts.
+  // Y = shape_cast(shape_cast(X)))
+  //      -> X, if X and Y have same type
+  //      -> shape_cast(X) otherwise.
   if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
-    // Only allows valid transitive folding (expand/collapse dimensions).
     VectorType srcType = otherOp.getSource().getType();
     if (resultType == srcType)
       return otherOp.getSource();
-    if (srcType.getRank() < resultType.getRank()) {
-      if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
-        return {};
-    } else if (srcType.getRank() > resultType.getRank()) {
-      if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
-        return {};
-    } else {
-      return {};
-    }
     setOperand(otherOp.getSource());
     return getResult();
   }
 
-  // Cancelling broadcast and shape cast ops.
+  // Y = shape_cast(broadcast(X))
+  //      -> X, if X and Y have same type, else
+  //      -> shape_cast(X) if X is a vector and the broadcast preserves
+  //         number of elements.
   if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
     if (bcastOp.getSourceType() == resultType)
       return bcastOp.getSource();
+    if (auto bcastSrcType = dyn_cast<VectorType>(bcastOp.getSourceType())) {
+      if (bcastSrcType.getNumElements() == resultType.getNumElements()) {
+        setOperand(bcastOp.getSource());
+        return getResult();
+      }
+    }
   }
 
   // shape_cast(constant) -> constant
diff --git a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
index ae2b5393ca449..60ad54bf5c370 100644
--- a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
@@ -26,8 +26,7 @@ func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
   // CHECK-NEXT: vector.insert {{.*}}[1]
   // CHECK-NEXT: vector.insert {{.*}}[2]
   // CHECK-NEXT: vector.insert {{.*}}[3]
-  // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
-  // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
+  // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<8x4xf32>
   %0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
   return %0 : vector<8x4xf32>
 }
@@ -54,8 +53,7 @@ func.func @transpose021_1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32>
   // CHECK-NEXT: vector.insert {{.*}}[1]
   // CHECK-NEXT: vector.insert {{.*}}[2]
   // CHECK-NEXT: vector.insert {{.*}}[3]
-  // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
-  // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32>
+  // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<1x8x4xf32>
   %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32>
   return %0 : vector<1x8x4xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2d365ac2b4287..04d8e613d4156 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -950,10 +950,9 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
 
 // -----
 
-// CHECK-LABEL: dont_fold_expand_collapse
-//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-//       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-//       CHECK:   return %[[B]] : vector<8x8xf32>
+// CHECK-LABEL: fold_expand_collapse
+//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<8x8xf32>
+//       CHECK:   return %[[A]] : vector<8x8xf32>
 func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
     %0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
     %1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
@@ -973,6 +972,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @fold_count_preserving_broadcast_shapecast
+//  CHECK-SAME: (%[[V:.+]]: vector<4xf32>)
+//       CHECK:   %[[SHAPECAST:.*]] = vector.shape_cast %[[V]] : vector<4xf32> to vector<2x2xf32>
+//       CHECK:   return %[[SHAPECAST]] : vector<2x2xf32>
+func.func @fold_count_preserving_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<2x2xf32> {
+    %0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32>
+    %1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<2x2xf32>
+    return %1 : vector<2x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
 //       CHECK:   vector.broadcast
 //   CHECK-NOT:   vector.shape_cast
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3a8320971bac4..fa4837126accb 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1131,34 +1131,21 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
 
 // -----
 
+
 func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{op source/result vectors must have same element type}}
+  // expected-error at +1 {{'vector.shape_cast' op has different source and result element types}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
 }
 
 // -----
 
 func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{op source/result number of elements must match}}
+  // expected-error at +1 {{'vector.shape_cast' op has different number of elements at source (30) and result (20)}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
 }
 
 // -----
 
-func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{invalid shape cast}}
-  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
-}
-
-// -----
-
-func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
-  // expected-error at +1 {{invalid shape cast}}
-  %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
-}
-
-// -----
-
 func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
   // expected-error at +1 {{different number of scalable dims at source (1) and result (0)}}
   %0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..36f7db8c39d4d 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -543,6 +543,20 @@ func.func @vector_print_on_scalar(%arg0: i64) {
   return
 }
 
+// CHECK-LABEL: @shape_cast_valid_rank_reduction
+func.func @shape_cast_valid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
+  // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<2x15xf32>
+  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
+  return
+}
+
+// CHECK-LABEL: @shape_cast_valid_rank_expansion
+func.func @shape_cast_valid_rank_expansion(%arg0 : vector<15x2xf32>) {
+  // CHECK: vector.shape_cast %{{.*}} : vector<15x2xf32> to vector<5x2x3x1xf32>
+  %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
+  return
+}
+
 // CHECK-LABEL: @shape_cast
 func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xf32>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/136587


More information about the Mlir-commits mailing list