[Mlir-commits] [mlir] Relax checks on vector.shape_cast (PR #136587)

James Newling llvmlistbot at llvm.org
Mon Apr 21 10:51:00 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/136587

Alternative to https://github.com/llvm/llvm-project/pull/135855

>From 84dff32df484a5d001db6334f034ff343900cac2 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 21 Apr 2025 10:48:22 -0700
Subject: [PATCH] remove checks for collapse / expand

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 21 +----
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 83 +++----------------
 mlir/test/Dialect/Vector/canonicalize.mlir    |  7 +-
 mlir/test/Dialect/Vector/invalid.mlir         | 13 ---
 mlir/test/Dialect/Vector/ops.mlir             | 16 ++++
 5 files changed, 32 insertions(+), 108 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d7518943229ea..64b5c58cfec24 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,8 @@ def Vector_ShapeCastOp :
     Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "shape_cast casts between vector shapes";
   let description = [{
-    The shape_cast operation casts between an n-D source vector shape and
-    a k-D result vector shape (the element type remains the same).
-
-    If reducing rank (n > k), result dimension sizes must be a product
-    of contiguous source dimension sizes.
-    If expanding rank (n < k), source dimensions must factor into a
-    contiguous sequence of destination dimension sizes.
-    Each source dim is expanded (or contiguous sequence of source dims combined)
-    in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
-    sequence of result dims (or a single result dim), in result dimension list
-    order (i.e. 0 <= j < k). The product of all source dimension sizes and all
-    result dimension sizes must match.
+    The shape_cast operation casts from a source vector to a target vector,
+    retaining the element type and number of elements.
 
     It is currently assumed that this operation does not require moving data,
     and that it will be folded away before lowering vector operations.
@@ -2268,12 +2258,7 @@ def Vector_ShapeCastOp :
     Example:
 
     ```mlir
-    // Example casting to a lower vector rank.
-    %1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32>
-
-    // Example casting to a higher vector rank.
-    %3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32>
-
+    %1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32>
     ```
   }];
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 368259b38b153..2fcfab4770d4d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5505,48 +5505,18 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), argRanges.front());
 }
 
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
-  unsigned rankA = a.size();
-  unsigned rankB = b.size();
-  assert(rankA < rankB);
-
-  auto isOne = [](int64_t v) { return v == 1; };
-
-  // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
-  // casted to a 0-d vector.
-  if (rankA == 0 && llvm::all_of(b, isOne))
-    return true;
-
-  unsigned i = 0;
-  unsigned j = 0;
-  while (i < rankA && j < rankB) {
-    int64_t dimA = a[i];
-    int64_t dimB = 1;
-    while (dimB < dimA && j < rankB)
-      dimB *= b[j++];
-    if (dimA != dimB)
-      break;
-    ++i;
-
-    // Handle the case when trailing dimensions are of size 1.
-    // Include them into the contiguous sequence.
-    if (i < rankA && llvm::all_of(a.slice(i), isOne))
-      i = rankA;
-    if (j < rankB && llvm::all_of(b.slice(j), isOne))
-      j = rankB;
-  }
+LogicalResult ShapeCastOp::verify() {
+  auto sourceVectorType =
+      llvm::dyn_cast_or_null<VectorType>(getSource().getType());
+  auto resultVectorType =
+      llvm::dyn_cast_or_null<VectorType>(getResult().getType());
 
-  return i == rankA && j == rankB;
-}
+  if (!sourceVectorType) return failure();
+  if (!resultVectorType) return failure();
 
-static LogicalResult verifyVectorShapeCast(Operation *op,
-                                           VectorType sourceVectorType,
-                                           VectorType resultVectorType) {
   // Check that element type is the same.
   if (sourceVectorType.getElementType() != resultVectorType.getElementType())
-    return op->emitOpError("source/result vectors must have same element type");
+    return emitOpError("source/result vectors must have same element type");
   auto sourceShape = sourceVectorType.getShape();
   auto resultShape = resultVectorType.getShape();
 
@@ -5556,24 +5526,13 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
   int64_t resultDimProduct = std::accumulate(
       resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
   if (sourceDimProduct != resultDimProduct)
-    return op->emitOpError("source/result number of elements must match");
-
-  // Check that expanding/contracting rank cases.
-  unsigned sourceRank = sourceVectorType.getRank();
-  unsigned resultRank = resultVectorType.getRank();
-  if (sourceRank < resultRank) {
-    if (!isValidShapeCast(sourceShape, resultShape))
-      return op->emitOpError("invalid shape cast");
-  } else if (sourceRank > resultRank) {
-    if (!isValidShapeCast(resultShape, sourceShape))
-      return op->emitOpError("invalid shape cast");
-  }
+    return emitOpError("source/result number of elements must match");
 
   // Check that (non-)scalability is preserved
   int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
   int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
   if (sourceNScalableDims != resultNScalableDims)
-    return op->emitOpError("different number of scalable dims at source (")
+    return emitOpError("different number of scalable dims at source (")
            << sourceNScalableDims << ") and result (" << resultNScalableDims
            << ")";
   sourceVectorType.getNumDynamicDims();
@@ -5581,19 +5540,6 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
   return success();
 }
 
-LogicalResult ShapeCastOp::verify() {
-  auto sourceVectorType =
-      llvm::dyn_cast_or_null<VectorType>(getSource().getType());
-  auto resultVectorType =
-      llvm::dyn_cast_or_null<VectorType>(getResult().getType());
-
-  // Check if source/result are of vector type.
-  if (sourceVectorType && resultVectorType)
-    return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
-
-  return success();
-}
-
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // No-op shape cast.
@@ -5609,15 +5555,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
     VectorType srcType = otherOp.getSource().getType();
     if (resultType == srcType)
       return otherOp.getSource();
-    if (srcType.getRank() < resultType.getRank()) {
-      if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
-        return {};
-    } else if (srcType.getRank() > resultType.getRank()) {
-      if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
-        return {};
-    } else {
-      return {};
-    }
     setOperand(otherOp.getSource());
     return getResult();
   }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 2d365ac2b4287..5bf8b9338c498 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -950,10 +950,9 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
 
 // -----
 
-// CHECK-LABEL: dont_fold_expand_collapse
-//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-//       CHECK:   %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-//       CHECK:   return %[[B]] : vector<8x8xf32>
+// CHECK-LABEL: fold_expand_collapse
+//       CHECK:   %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<8x8xf32>
+//       CHECK:   return %[[A]] : vector<8x8xf32>
 func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
     %0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
     %1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3a8320971bac4..14399d5a19394 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1145,19 +1145,6 @@ func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
 
 // -----
 
-func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{invalid shape cast}}
-  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
-}
-
-// -----
-
-func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
-  // expected-error at +1 {{invalid shape cast}}
-  %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
-}
-
-// -----
 
 func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
   // expected-error at +1 {{different number of scalable dims at source (1) and result (0)}}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..bbd8f51445549 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -543,6 +543,22 @@ func.func @vector_print_on_scalar(%arg0: i64) {
   return
 }
 
+// CHECK-LABEL: @shape_cast_valid_rank_reduction
+func.func @shape_cast_valid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
+  // CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<2x15xf32>
+  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
+  return
+}
+
+
+// CHECK-LABEL: @shape_cast_valid_rank_expansion
+func.func @shape_cast_valid_rank_expansion(%arg0 : vector<15x2xf32>) {
+  // CHECK: vector.shape_cast %{{.*}} : vector<15x2xf32> to vector<5x2x3x1xf32>
+  %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
+  return
+}
+
+
 // CHECK-LABEL: @shape_cast
 func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xf32>,



More information about the Mlir-commits mailing list