[Mlir-commits] [mlir] [mlir][vector] Consistently handle rank-preserving shape_cast (PR #135855)

James Newling llvmlistbot at llvm.org
Tue Apr 15 14:08:59 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/135855

Before this PR, the following operation
```
%1 = vector.shape_cast %0 : vector<3x2xf32> to vector<2x3xf32>
```
was not illegal. There were checks that n-d to k-d shape casts where strictly expanding (n < k) or collapsing (n > k) but the case of n == k was always considered legal w.r.t. shape. 

With this PR, rank-preserving shape_casts are only legal of they insert/remove dimensions of size 1. For example `<1x4xf32> -> <4x1xf32>` is legal.  This is consistent with the n < k and n > k cases. 

This PR also improves the error messages generated with `emitOpError`.

>From c1f4264a71d6d80350056d0d9ca86a0ac2c1e04f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 15 Apr 2025 13:09:43 -0700
Subject: [PATCH] fix edge case where n=k (rank-preserving shape_cast)

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 17 ++---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 63 ++++++++++---------
 mlir/test/Dialect/Vector/invalid.mlir         | 15 +++--
 mlir/test/Dialect/Vector/ops.mlir             |  8 +++
 4 files changed, 61 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 7fc56b1aa4e7e..a9e25f23ef90f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,19 @@ def Vector_ShapeCastOp :
     Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "shape_cast casts between vector shapes";
   let description = [{
-    The shape_cast operation casts between an n-D source vector shape and
-    a k-D result vector shape (the element type remains the same).
+    The shape_cast operation casts from an n-D source vector to a k-D result
+    vector. The element type remains the same, as does the number of elements
+    (product of dimensions).
+
+    If reducing or preserving rank (n >= k), all result dimension sizes must be
+    products of contiguous source dimension sizes. If expanding rank (n < k),
+    source dimensions must all factor into contiguous sequences of destination
+    dimension sizes.
 
-    If reducing rank (n > k), result dimension sizes must be a product
-    of contiguous source dimension sizes.
-    If expanding rank (n < k), source dimensions must factor into a
-    contiguous sequence of destination dimension sizes.
     Each source dim is expanded (or contiguous sequence of source dims combined)
     in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
     sequence of result dims (or a single result dim), in result dimension list
-    order (i.e. 0 <= j < k). The product of all source dimension sizes and all
-    result dimension sizes must match.
+    order (i.e. 0 <= j < k).
 
     It is currently assumed that this operation does not require moving data,
     and that it will be folded away before lowering vector operations.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bee5c1fd6ed58..554dbba081898 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5534,10 +5534,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
 
 /// Returns true if each element of 'a' is equal to the product of a contiguous
 /// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+static bool isValidExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
   unsigned rankA = a.size();
   unsigned rankB = b.size();
-  assert(rankA < rankB);
+  assert(rankA <= rankB);
 
   auto isOne = [](int64_t v) { return v == 1; };
 
@@ -5573,34 +5573,36 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
                                            VectorType resultVectorType) {
   // Check that element type is the same.
   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
-    return op->emitOpError("source/result vectors must have same element type");
-  auto sourceShape = sourceVectorType.getShape();
-  auto resultShape = resultVectorType.getShape();
+    return op->emitOpError("has different source and result element types");
+  ArrayRef<int64_t> lowRankShape = sourceVectorType.getShape();
+  ArrayRef<int64_t> highRankShape = resultVectorType.getShape();
+  if (lowRankShape.size() > highRankShape.size())
+    std::swap(lowRankShape, highRankShape);
 
   // Check that product of source dim sizes matches product of result dim sizes.
-  int64_t sourceDimProduct = std::accumulate(
-      sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
-  int64_t resultDimProduct = std::accumulate(
-      resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
-  if (sourceDimProduct != resultDimProduct)
-    return op->emitOpError("source/result number of elements must match");
-
-  // Check that expanding/contracting rank cases.
-  unsigned sourceRank = sourceVectorType.getRank();
-  unsigned resultRank = resultVectorType.getRank();
-  if (sourceRank < resultRank) {
-    if (!isValidShapeCast(sourceShape, resultShape))
-      return op->emitOpError("invalid shape cast");
-  } else if (sourceRank > resultRank) {
-    if (!isValidShapeCast(resultShape, sourceShape))
-      return op->emitOpError("invalid shape cast");
+  int64_t nLowRankElms =
+      std::accumulate(lowRankShape.begin(), lowRankShape.end(), 1LL,
+                      std::multiplies<int64_t>{});
+  int64_t nHighRankElms =
+      std::accumulate(highRankShape.begin(), highRankShape.end(), 1LL,
+                      std::multiplies<int64_t>{});
+
+  if (nLowRankElms != nHighRankElms) {
+    return op->emitOpError(
+        "has a different number of source and result elements");
+  }
+
+  if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) {
+    return op->emitOpError(
+        "is invalid (does not uniformly collapse or expand)");
   }
 
   // Check that (non-)scalability is preserved
   int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
   int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
   if (sourceNScalableDims != resultNScalableDims)
-    return op->emitOpError("different number of scalable dims at source (")
+    return op->emitOpError(
+               "has a different number of scalable dims at source (")
            << sourceNScalableDims << ") and result (" << resultNScalableDims
            << ")";
   sourceVectorType.getNumDynamicDims();
@@ -5634,17 +5636,18 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
     // Only allows valid transitive folding (expand/collapse dimensions).
     VectorType srcType = otherOp.getSource().getType();
+
     if (resultType == srcType)
       return otherOp.getSource();
-    if (srcType.getRank() < resultType.getRank()) {
-      if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
-        return {};
-    } else if (srcType.getRank() > resultType.getRank()) {
-      if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
-        return {};
-    } else {
+
+    ArrayRef<int64_t> lowRankShape = srcType.getShape();
+    ArrayRef<int64_t> highRankShape = resultType.getShape();
+    if (lowRankShape.size() > highRankShape.size())
+      std::swap(lowRankShape, highRankShape);
+
+    if (!isValidExpandingShapeCast(lowRankShape, highRankShape))
       return {};
-    }
+
     setOperand(otherOp.getSource());
     return getResult();
   }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index dbf829e014b8d..9f94fb0574504 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
 // -----
 
 func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{op source/result vectors must have same element type}}
+  // expected-error at +1 {{op has different source and result element types}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
 }
 
 // -----
 
 func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{op source/result number of elements must match}}
+  // expected-error at +1 {{op has a different number of source and result elements}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
 }
 
 // -----
 
+func.func @shape_cast_invalid_rank_preservating(%arg0 : vector<3x2xf32>) {
+  // expected-error at +1 {{op is invalid (does not uniformly collapse or expand)}}
+  %0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32>
+}
+
+// -----
+
 func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{invalid shape cast}}
+  // expected-error at +1 {{op is invalid (does not uniformly collapse or expand)}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
 }
 
 // -----
 
 func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
-  // expected-error at +1 {{invalid shape cast}}
+  // expected-error at +1 {{op is invalid (does not uniformly collapse or expand)}}
   %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
 }
 
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..527bccf8383ca 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -576,6 +576,14 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
   return %1 : vector<1x1x1x1xf32>
 }
 
+// CHECK-LABEL: @shape_cast_rank_preserving
+func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {
+
+  // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
+  %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
+  return %0 : vector<4x1xf32>
+}
+
 // CHECK-LABEL: @bitcast
 func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xi32>,



More information about the Mlir-commits mailing list