[Mlir-commits] [mlir] 5718460 - [mlir][vector] Relax constraints on shape_cast (#136587)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 1 10:18:36 PDT 2025
Author: James Newling
Date: 2025-05-01T10:18:33-07:00
New Revision: 5718460b22400e71e1832be489c9090f2c7d3ebb
URL: https://github.com/llvm/llvm-project/commit/5718460b22400e71e1832be489c9090f2c7d3ebb
DIFF: https://github.com/llvm/llvm-project/commit/5718460b22400e71e1832be489c9090f2c7d3ebb.diff
LOG: [mlir][vector] Relax constraints on shape_cast (#136587)
`vector.shape_cast` was initially designed to be the union of
collapse_shape and expand_shape. There was an inconsistency in the
verifier that allowed any shape casts when the rank did not change, which
led to a strange middle ground where you could cast from shape (4,3) to
(3,4) but not from (4,3) to (2,3,2). That issue was fixed (verifier made stricter)
in https://github.com/llvm/llvm-project/pull/135855, but further feedback
there (and polling) suggests that vector.shape_cast should rather allow all
shape casts (so more like tensor.reshape than
tensor.collapse_shape/tensor.expand_shape). This PR makes this simplification
by relaxing the verifier.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index d7518943229ea..4d49e52b21563 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2244,18 +2244,8 @@ def Vector_ShapeCastOp :
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
- The shape_cast operation casts between an n-D source vector shape and
- a k-D result vector shape (the element type remains the same).
-
- If reducing rank (n > k), result dimension sizes must be a product
- of contiguous source dimension sizes.
- If expanding rank (n < k), source dimensions must factor into a
- contiguous sequence of destination dimension sizes.
- Each source dim is expanded (or contiguous sequence of source dims combined)
- in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
- sequence of result dims (or a single result dim), in result dimension list
- order (i.e. 0 <= j < k). The product of all source dimension sizes and all
- result dimension sizes must match.
+ Casts to a vector with the same number of elements, element type, and
+ number of scalable dimensions.
It is currently assumed that this operation does not require moving data,
and that it will be folded away before lowering vector operations.
@@ -2265,15 +2255,13 @@ def Vector_ShapeCastOp :
2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM
is supported in that particular case, for now.
- Example:
+ Examples:
```mlir
- // Example casting to a lower vector rank.
- %1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32>
-
- // Example casting to a higher vector rank.
- %3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32>
+ %1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32>
+ // with 2 scalable dimensions (number of which must be preserved).
+ %3 = vector.shape_cast %2 : vector<[2]x3x[4]xi8> to vector<3x[1]x[8]xi8>
```
}];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 67e3aa564a184..f47e356d6fe14 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5546,124 +5546,56 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
-/// Returns true if each element of 'a' is equal to the product of a contiguous
-/// sequence of the elements of 'b'. Returns false otherwise.
-static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
- unsigned rankA = a.size();
- unsigned rankB = b.size();
- assert(rankA < rankB);
-
- auto isOne = [](int64_t v) { return v == 1; };
-
- // Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
- // casted to a 0-d vector.
- if (rankA == 0 && llvm::all_of(b, isOne))
- return true;
+LogicalResult ShapeCastOp::verify() {
- unsigned i = 0;
- unsigned j = 0;
- while (i < rankA && j < rankB) {
- int64_t dimA = a[i];
- int64_t dimB = 1;
- while (dimB < dimA && j < rankB)
- dimB *= b[j++];
- if (dimA != dimB)
- break;
- ++i;
+ VectorType sourceType = getSourceVectorType();
+ VectorType resultType = getResultVectorType();
- // Handle the case when trailing dimensions are of size 1.
- // Include them into the contiguous sequence.
- if (i < rankA && llvm::all_of(a.slice(i), isOne))
- i = rankA;
- if (j < rankB && llvm::all_of(b.slice(j), isOne))
- j = rankB;
- }
+ // Check that element type is preserved
+ if (sourceType.getElementType() != resultType.getElementType())
+ return emitOpError("has
diff erent source and result element types");
- return i == rankA && j == rankB;
-}
-
-static LogicalResult verifyVectorShapeCast(Operation *op,
- VectorType sourceVectorType,
- VectorType resultVectorType) {
- // Check that element type is the same.
- if (sourceVectorType.getElementType() != resultVectorType.getElementType())
- return op->emitOpError("source/result vectors must have same element type");
- auto sourceShape = sourceVectorType.getShape();
- auto resultShape = resultVectorType.getShape();
-
- // Check that product of source dim sizes matches product of result dim sizes.
- int64_t sourceDimProduct = std::accumulate(
- sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
- int64_t resultDimProduct = std::accumulate(
- resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
- if (sourceDimProduct != resultDimProduct)
- return op->emitOpError("source/result number of elements must match");
-
- // Check that expanding/contracting rank cases.
- unsigned sourceRank = sourceVectorType.getRank();
- unsigned resultRank = resultVectorType.getRank();
- if (sourceRank < resultRank) {
- if (!isValidShapeCast(sourceShape, resultShape))
- return op->emitOpError("invalid shape cast");
- } else if (sourceRank > resultRank) {
- if (!isValidShapeCast(resultShape, sourceShape))
- return op->emitOpError("invalid shape cast");
+ // Check that number of elements is preserved
+ int64_t sourceNElms = sourceType.getNumElements();
+ int64_t resultNElms = resultType.getNumElements();
+ if (sourceNElms != resultNElms) {
+ return emitOpError() << "has
diff erent number of elements at source ("
+ << sourceNElms << ") and result (" << resultNElms
+ << ")";
}
// Check that (non-)scalability is preserved
- int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
- int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
+ int64_t sourceNScalableDims = sourceType.getNumScalableDims();
+ int64_t resultNScalableDims = resultType.getNumScalableDims();
if (sourceNScalableDims != resultNScalableDims)
- return op->emitOpError("
diff erent number of scalable dims at source (")
- << sourceNScalableDims << ") and result (" << resultNScalableDims
- << ")";
- sourceVectorType.getNumDynamicDims();
-
- return success();
-}
-
-LogicalResult ShapeCastOp::verify() {
- auto sourceVectorType =
- llvm::dyn_cast_or_null<VectorType>(getSource().getType());
- auto resultVectorType =
- llvm::dyn_cast_or_null<VectorType>(getResult().getType());
-
- // Check if source/result are of vector type.
- if (sourceVectorType && resultVectorType)
- return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
+ return emitOpError() << "has
diff erent number of scalable dims at source ("
+ << sourceNScalableDims << ") and result ("
+ << resultNScalableDims << ")";
return success();
}
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
+ VectorType resultType = getType();
+
// No-op shape cast.
- if (getSource().getType() == getType())
+ if (getSource().getType() == resultType)
return getSource();
- VectorType resultType = getType();
-
- // Canceling shape casts.
+ // Y = shape_cast(shape_cast(X)))
+ // -> X, if X and Y have same type
+ // -> shape_cast(X) otherwise.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
-
- // Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();
if (resultType == srcType)
return otherOp.getSource();
- if (srcType.getRank() < resultType.getRank()) {
- if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
- return {};
- } else if (srcType.getRank() > resultType.getRank()) {
- if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
- return {};
- } else {
- return {};
- }
setOperand(otherOp.getSource());
return getResult();
}
- // Cancelling broadcast and shape cast ops.
+ // Y = shape_cast(broadcast(X))
+ // -> X, if X and Y have same type
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
if (bcastOp.getSourceType() == resultType)
return bcastOp.getSource();
diff --git a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
index ae2b5393ca449..60ad54bf5c370 100644
--- a/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/CPU/X86/vector-transpose-lowering.mlir
@@ -26,8 +26,7 @@ func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
// CHECK-NEXT: vector.insert {{.*}}[1]
// CHECK-NEXT: vector.insert {{.*}}[2]
// CHECK-NEXT: vector.insert {{.*}}[3]
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
+ // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<8x4xf32>
%0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
return %0 : vector<8x4xf32>
}
@@ -54,8 +53,7 @@ func.func @transpose021_1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32>
// CHECK-NEXT: vector.insert {{.*}}[1]
// CHECK-NEXT: vector.insert {{.*}}[2]
// CHECK-NEXT: vector.insert {{.*}}[3]
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
- // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32>
+ // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<1x8x4xf32>
%0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32>
return %0 : vector<1x8x4xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 29a11f47481c8..e0ec9c66d3a48 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -977,10 +977,9 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
// -----
-// CHECK-LABEL: dont_fold_expand_collapse
-// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
-// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
-// CHECK: return %[[B]] : vector<8x8xf32>
+// CHECK-LABEL: fold_expand_collapse
+// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<8x8xf32>
+// CHECK: return %[[A]] : vector<8x8xf32>
func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
%0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
%1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 349a58d4eb4e4..be65d4c2eef58 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1165,34 +1165,21 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// -----
+
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{op source/result vectors must have same element type}}
+ // expected-error at +1 {{'vector.shape_cast' op has
diff erent source and result element types}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
}
// -----
func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{op source/result number of elements must match}}
+ // expected-error at +1 {{'vector.shape_cast' op has
diff erent number of elements at source (30) and result (20)}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
}
// -----
-func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
- // expected-error at +1 {{invalid shape cast}}
- %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
-}
-
-// -----
-
-func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
- // expected-error at +1 {{invalid shape cast}}
- %0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
-}
-
-// -----
-
func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
// expected-error at +1 {{
diff erent number of scalable dims at source (1) and result (0)}}
%0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 8ae1e9f9d0c64..f3220aed4360c 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -564,6 +564,17 @@ func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
return %0, %1, %2, %3 : vector<15x2xf32>, vector<8xf32>, vector<16xf32>, vector<16x1xf32>
}
+// A vector.shape_cast can cast between any 2 shapes as long as the
+// number of elements is preserved. For those familiar with the tensor
+// dialect: this behaviour is like the tensor.reshape operation, i.e.
+// less restrictive than tensor.collapse_shape and tensor.expand_shape
+// CHECK-LABEL: @shape_cast_general_reshape
+func.func @shape_cast_general_reshape(%arg0 : vector<2x3xf32>) -> (vector<3x1x2xf32>) {
+ // CHECK: vector.shape_cast %{{.*}} : vector<2x3xf32> to vector<3x1x2xf32>
+ %0 = vector.shape_cast %arg0 : vector<2x3xf32> to vector<3x1x2xf32>
+ return %0 : vector<3x1x2xf32>
+}
+
// CHECK-LABEL: @shape_cast_0d
func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
More information about the Mlir-commits
mailing list