[Mlir-commits] [mlir] [mlir][vector] Remove `vector.reshape` operation (PR #101645)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 2 03:05:04 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
This operation was added five years ago and has no lowerings or uses within upstream MLIR (and no reported uses downstream). There’s only a handful of round-trip tests.
See related RFC:
https://discourse.llvm.org/t/rfc-should-vector-reshape-be-removed/80478/3
---
Full diff: https://github.com/llvm/llvm-project/pull/101645.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (-116)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (-62)
- (modified) mlir/test/Dialect/Vector/invalid.mlir (-60)
- (modified) mlir/test/Dialect/Vector/ops.mlir (-17)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..cd19d356a6739 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1178,122 +1178,6 @@ def Vector_OuterProductOp :
let hasVerifier = 1;
}
-// TODO: Add transformation which decomposes ReshapeOp into an optimized
-// sequence of vector rotate/shuffle/select operations.
-def Vector_ReshapeOp :
- Vector_Op<"reshape", [AttrSizedOperandSegments, Pure]>,
- Arguments<(ins AnyVector:$vector, Variadic<Index>:$input_shape,
- Variadic<Index>:$output_shape,
- I64ArrayAttr:$fixed_vector_sizes)>,
- Results<(outs AnyVector:$result)> {
- let summary = "vector reshape operation";
- let description = [{
- Reshapes its vector operand from 'input_shape' to 'output_shape' maintaining
- fixed vector dimension 'fixed_vector_sizes' on the innermost vector
- dimensions.
-
- The parameters 'input_shape' and 'output_shape' represent valid data shapes
- across fixed vector shapes. For example, if a vector has a valid data
- shape [6] with fixed vector size [8], then the valid data elements are
- assumed to be stored at the beginning of the vector with the remaining
- vector elements undefined.
-
- In the examples below, valid data elements are represented by an alphabetic
- character, and undefined data elements are represented by '-'.
-
- Example
-
- vector<1x8xf32> with valid data shape [6], fixed vector sizes [8]
-
- input: [a, b, c, d, e, f]
-
- layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
-
- vector layout: [a, b, c, d, e, f, -, -]
-
- Example
-
- vector<2x8xf32> with valid data shape [10], fixed vector sizes [8]
-
- input: [a, b, c, d, e, f, g, h, i, j]
-
- layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
-
- vector layout: [[a, b, c, d, e, f, g, h],
- [i, j, -, -, -, -, -, -]]
-
- Example
-
- vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes
- [2, 3]
-
- input: [[a, b, c, d, e],
- [f, g, h, i, j],
- [k, l, m, n, o]]
-
- layout map: (d0, d1) -> (d0 floordiv 3, d1 floordiv 5,
- d0 mod 3, d1 mod 5)
-
- vector layout: [[[[a, b, c],
- [f, g, h]]
- [[d, e, -],
- [i, j, -]]],
- [[[k, l, m],
- [-, -, -]]
- [[n, o, -],
- [-, -, -]]]]
-
- Example
-
- %1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4]
- : vector<3x2x4xf32> to vector<2x3x4xf32>
-
- input: [[a, b, c, d, e, f],
- [g, h, i, j, k, l],
- [m, n, o, p, q, r]]
-
- layout map: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)
-
-
- Input vector: [[[a, b, c, d],
- [e, f, -, -]],
- [[g, h, i, j],
- [k, l, -, -]],
- [[m, n, o, p],
- [q, r, -, -]]]
-
- Output vector: [[[a, b, c, d],
- [e, f, g, h],
- [i, -, -, -]],
- [[j, k, l, m],
- [n, o, p, q],
- [r, -, -, -]]]
- }];
-
- let extraClassDeclaration = [{
- VectorType getInputVectorType() {
- return ::llvm::cast<VectorType>(getVector().getType());
- }
- VectorType getOutputVectorType() {
- return ::llvm::cast<VectorType>(getResult().getType());
- }
-
- /// Returns as integer value the number of input shape operands.
- int64_t getNumInputShapeSizes() { return getInputShape().size(); }
-
- /// Returns as integer value the number of output shape operands.
- int64_t getNumOutputShapeSizes() { return getOutputShape().size(); }
-
- void getFixedVectorSizes(SmallVectorImpl<int64_t> &results);
- }];
-
- let assemblyFormat = [{
- $vector `,` `[` $input_shape `]` `,` `[` $output_shape `]` `,`
- $fixed_vector_sizes attr-dict `:` type($vector) `to` type($result)
- }];
- let hasVerifier = 1;
-}
-
def Vector_ExtractStridedSliceOp :
Vector_Op<"extract_strided_slice", [Pure,
PredOpTrait<"operand and result have same element type",
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..e65f4cbbff3e1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3327,68 +3327,6 @@ Type OuterProductOp::getExpectedMaskType() {
vecType.getScalableDims());
}
-//===----------------------------------------------------------------------===//
-// ReshapeOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult ReshapeOp::verify() {
- // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
- auto inputVectorType = getInputVectorType();
- auto outputVectorType = getOutputVectorType();
- int64_t inputShapeRank = getNumInputShapeSizes();
- int64_t outputShapeRank = getNumOutputShapeSizes();
- SmallVector<int64_t, 4> fixedVectorSizes;
- getFixedVectorSizes(fixedVectorSizes);
- int64_t numFixedVectorSizes = fixedVectorSizes.size();
-
- if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
- return emitError("invalid input shape for vector type ") << inputVectorType;
-
- if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
- return emitError("invalid output shape for vector type ")
- << outputVectorType;
-
- // Verify that the 'fixedVectorSizes' match an input/output vector shape
- // suffix.
- unsigned inputVectorRank = inputVectorType.getRank();
- for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
- unsigned index = inputVectorRank - numFixedVectorSizes - i;
- if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
- return emitError("fixed vector size must match input vector for dim ")
- << i;
- }
-
- unsigned outputVectorRank = outputVectorType.getRank();
- for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
- unsigned index = outputVectorRank - numFixedVectorSizes - i;
- if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
- return emitError("fixed vector size must match output vector for dim ")
- << i;
- }
-
- // If all shape operands are produced by constant ops, verify that product
- // of dimensions for input/output shape match.
- auto isDefByConstant = [](Value operand) {
- return getConstantIntValue(operand).has_value();
- };
- if (llvm::all_of(getInputShape(), isDefByConstant) &&
- llvm::all_of(getOutputShape(), isDefByConstant)) {
- int64_t numInputElements = 1;
- for (auto operand : getInputShape())
- numInputElements *= getConstantIntValue(operand).value();
- int64_t numOutputElements = 1;
- for (auto operand : getOutputShape())
- numOutputElements *= getConstantIntValue(operand).value();
- if (numInputElements != numOutputElements)
- return emitError("product of input and output shape sizes must match");
- }
- return success();
-}
-
-void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
- populateFromInt64AttrArray(getFixedVectorSizes(), results);
-}
-
//===----------------------------------------------------------------------===//
// ExtractStridedSliceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 00914c1d1baf6..10ba895a1b3a4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1094,66 +1094,6 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// -----
-func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
- %c6 = arith.constant 6 : index
- %c9 = arith.constant 9 : index
- // expected-error at +1 {{invalid input shape for vector type}}
- %1 = vector.reshape %arg0, [%c3, %c6, %c3], [%c2, %c9], [4]
- : vector<3x2x4xf32> to vector<2x3x4xf32>
-}
-
-// -----
-
-func.func @reshape_bad_output_shape(%arg0 : vector<3x2x4xf32>) {
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
- %c6 = arith.constant 6 : index
- %c9 = arith.constant 9 : index
- // expected-error at +1 {{invalid output shape for vector type}}
- %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9, %c3], [4]
- : vector<3x2x4xf32> to vector<2x3x4xf32>
-}
-
-// -----
-
-func.func @reshape_bad_input_output_shape_product(%arg0 : vector<3x2x4xf32>) {
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
- %c6 = arith.constant 6 : index
- %c9 = arith.constant 9 : index
- // expected-error at +1 {{product of input and output shape sizes must match}}
- %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c6], [4]
- : vector<3x2x4xf32> to vector<2x3x4xf32>
-}
-
-// -----
-
-func.func @reshape_bad_input_fixed_size(%arg0 : vector<3x2x5xf32>) {
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
- %c6 = arith.constant 6 : index
- %c9 = arith.constant 9 : index
- // expected-error at +1 {{fixed vector size must match input vector for dim 0}}
- %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
- : vector<3x2x5xf32> to vector<2x3x4xf32>
-}
-
-// -----
-
-func.func @reshape_bad_output_fixed_size(%arg0 : vector<3x2x4xf32>) {
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
- %c6 = arith.constant 6 : index
- %c9 = arith.constant 9 : index
- // expected-error at +1 {{fixed vector size must match output vector for dim 0}}
- %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
- : vector<3x2x4xf32> to vector<2x3x5xf32>
-}
-
-// -----
-
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
// expected-error at +1 {{op source/result vectors must have same element type}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 7e578452b82cc..4759fcc9511fb 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -522,23 +522,6 @@ func.func @vector_print_on_scalar(%arg0: i64) {
return
}
-// CHECK-LABEL: @reshape
-func.func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
- %c2 = arith.constant 2 : index
- // CHECK: %[[C3:.*]] = arith.constant 3 : index
- %c3 = arith.constant 3 : index
- // CHECK: %[[C6:.*]] = arith.constant 6 : index
- %c6 = arith.constant 6 : index
- // CHECK: %[[C9:.*]] = arith.constant 9 : index
- %c9 = arith.constant 9 : index
- // CHECK: vector.reshape %{{.*}}, [%[[C3]], %[[C6]]], [%[[C2]], %[[C9]]], [4] : vector<3x2x4xf32> to vector<2x3x4xf32>
- %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
- : vector<3x2x4xf32> to vector<2x3x4xf32>
-
- return %1 : vector<2x3x4xf32>
-}
-
// CHECK-LABEL: @shape_cast
func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xf32>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/101645
More information about the Mlir-commits
mailing list