[Mlir-commits] [mlir] [mlir][vector] Remove `vector.reshape` operation (PR #101645)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 2 03:05:04 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This operation was added five years ago and has no lowerings or uses within upstream MLIR (and no reported uses downstream). There’s only a handful of round-trip tests.

See related RFC:
https://discourse.llvm.org/t/rfc-should-vector-reshape-be-removed/80478/3

---
Full diff: https://github.com/llvm/llvm-project/pull/101645.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (-116) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (-62) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (-60) 
- (modified) mlir/test/Dialect/Vector/ops.mlir (-17) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..cd19d356a6739 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1178,122 +1178,6 @@ def Vector_OuterProductOp :
   let hasVerifier = 1;
 }
 
-// TODO: Add transformation which decomposes ReshapeOp into an optimized
-// sequence of vector rotate/shuffle/select operations.
-def Vector_ReshapeOp :
-  Vector_Op<"reshape", [AttrSizedOperandSegments, Pure]>,
-    Arguments<(ins AnyVector:$vector, Variadic<Index>:$input_shape,
-               Variadic<Index>:$output_shape,
-               I64ArrayAttr:$fixed_vector_sizes)>,
-    Results<(outs AnyVector:$result)> {
-  let summary = "vector reshape operation";
-  let description = [{
-    Reshapes its vector operand from 'input_shape' to 'output_shape' maintaining
-    fixed vector dimension 'fixed_vector_sizes' on the innermost vector
-    dimensions.
-
-    The parameters 'input_shape' and 'output_shape' represent valid data shapes
-    across fixed vector shapes. For example, if a vector has a valid data
-    shape [6] with fixed vector size [8], then the valid data elements are
-    assumed to be stored at the beginning of the vector with the remaining
-    vector elements undefined.
-
-    In the examples below, valid data elements are represented by an alphabetic
-    character, and undefined data elements are represented by '-'.
-
-    Example
-
-      vector<1x8xf32> with valid data shape [6], fixed vector sizes [8]
-
-                input: [a, b, c, d, e, f]
-
-           layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
-
-        vector layout: [a, b, c, d, e, f, -, -]
-
-    Example
-
-      vector<2x8xf32> with valid data shape [10], fixed vector sizes [8]
-
-                input: [a, b, c, d, e, f, g, h, i, j]
-
-           layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
-
-        vector layout: [[a, b, c, d, e, f, g, h],
-                        [i, j, -, -, -, -, -, -]]
-
-    Example
-
-      vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes
-      [2, 3]
-
-                input: [[a, b, c, d, e],
-                        [f, g, h, i, j],
-                        [k, l, m, n, o]]
-
-           layout map: (d0, d1) -> (d0 floordiv 3, d1 floordiv 5,
-                                    d0 mod 3, d1 mod 5)
-
-        vector layout: [[[[a, b, c],
-                          [f, g, h]]
-                         [[d, e, -],
-                          [i, j, -]]],
-                        [[[k, l, m],
-                          [-, -, -]]
-                         [[n, o, -],
-                          [-, -, -]]]]
-
-    Example
-
-      %1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4]
-        : vector<3x2x4xf32> to vector<2x3x4xf32>
-
-             input: [[a, b, c, d, e, f],
-                     [g, h, i, j, k, l],
-                     [m, n, o, p, q, r]]
-
-        layout map: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)
-
-
-      Input vector:  [[[a, b, c, d],
-                       [e, f, -, -]],
-                      [[g, h, i, j],
-                       [k, l, -, -]],
-                      [[m, n, o, p],
-                       [q, r, -, -]]]
-
-      Output vector:  [[[a, b, c, d],
-                        [e, f, g, h],
-                        [i, -, -, -]],
-                       [[j, k, l, m],
-                        [n, o, p, q],
-                        [r, -, -, -]]]
-  }];
-
-  let extraClassDeclaration = [{
-    VectorType getInputVectorType() {
-      return ::llvm::cast<VectorType>(getVector().getType());
-    }
-    VectorType getOutputVectorType() {
-      return ::llvm::cast<VectorType>(getResult().getType());
-    }
-
-    /// Returns as integer value the number of input shape operands.
-    int64_t getNumInputShapeSizes() { return getInputShape().size(); }
-
-    /// Returns as integer value the number of output shape operands.
-    int64_t getNumOutputShapeSizes() { return getOutputShape().size(); }
-
-    void getFixedVectorSizes(SmallVectorImpl<int64_t> &results);
-  }];
-
-  let assemblyFormat = [{
-    $vector `,` `[` $input_shape `]` `,` `[` $output_shape `]` `,`
-    $fixed_vector_sizes attr-dict `:` type($vector) `to` type($result)
-  }];
-  let hasVerifier = 1;
-}
-
 def Vector_ExtractStridedSliceOp :
   Vector_Op<"extract_strided_slice", [Pure,
     PredOpTrait<"operand and result have same element type",
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..e65f4cbbff3e1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3327,68 +3327,6 @@ Type OuterProductOp::getExpectedMaskType() {
                          vecType.getScalableDims());
 }
 
-//===----------------------------------------------------------------------===//
-// ReshapeOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult ReshapeOp::verify() {
-  // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
-  auto inputVectorType = getInputVectorType();
-  auto outputVectorType = getOutputVectorType();
-  int64_t inputShapeRank = getNumInputShapeSizes();
-  int64_t outputShapeRank = getNumOutputShapeSizes();
-  SmallVector<int64_t, 4> fixedVectorSizes;
-  getFixedVectorSizes(fixedVectorSizes);
-  int64_t numFixedVectorSizes = fixedVectorSizes.size();
-
-  if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
-    return emitError("invalid input shape for vector type ") << inputVectorType;
-
-  if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
-    return emitError("invalid output shape for vector type ")
-           << outputVectorType;
-
-  // Verify that the 'fixedVectorSizes' match an input/output vector shape
-  // suffix.
-  unsigned inputVectorRank = inputVectorType.getRank();
-  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
-    unsigned index = inputVectorRank - numFixedVectorSizes - i;
-    if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
-      return emitError("fixed vector size must match input vector for dim ")
-             << i;
-  }
-
-  unsigned outputVectorRank = outputVectorType.getRank();
-  for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
-    unsigned index = outputVectorRank - numFixedVectorSizes - i;
-    if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
-      return emitError("fixed vector size must match output vector for dim ")
-             << i;
-  }
-
-  // If all shape operands are produced by constant ops, verify that product
-  // of dimensions for input/output shape match.
-  auto isDefByConstant = [](Value operand) {
-    return getConstantIntValue(operand).has_value();
-  };
-  if (llvm::all_of(getInputShape(), isDefByConstant) &&
-      llvm::all_of(getOutputShape(), isDefByConstant)) {
-    int64_t numInputElements = 1;
-    for (auto operand : getInputShape())
-      numInputElements *= getConstantIntValue(operand).value();
-    int64_t numOutputElements = 1;
-    for (auto operand : getOutputShape())
-      numOutputElements *= getConstantIntValue(operand).value();
-    if (numInputElements != numOutputElements)
-      return emitError("product of input and output shape sizes must match");
-  }
-  return success();
-}
-
-void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
-  populateFromInt64AttrArray(getFixedVectorSizes(), results);
-}
-
 //===----------------------------------------------------------------------===//
 // ExtractStridedSliceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 00914c1d1baf6..10ba895a1b3a4 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1094,66 +1094,6 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
 
 // -----
 
-func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
-  %c6 = arith.constant 6 : index
-  %c9 = arith.constant 9 : index
-  // expected-error at +1 {{invalid input shape for vector type}}
-  %1 = vector.reshape %arg0, [%c3, %c6, %c3], [%c2, %c9], [4]
-    : vector<3x2x4xf32> to vector<2x3x4xf32>
-}
-
-// -----
-
-func.func @reshape_bad_output_shape(%arg0 : vector<3x2x4xf32>) {
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
-  %c6 = arith.constant 6 : index
-  %c9 = arith.constant 9 : index
-  // expected-error at +1 {{invalid output shape for vector type}}
-  %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9, %c3], [4]
-    : vector<3x2x4xf32> to vector<2x3x4xf32>
-}
-
-// -----
-
-func.func @reshape_bad_input_output_shape_product(%arg0 : vector<3x2x4xf32>) {
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
-  %c6 = arith.constant 6 : index
-  %c9 = arith.constant 9 : index
-  // expected-error at +1 {{product of input and output shape sizes must match}}
-  %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c6], [4]
-    : vector<3x2x4xf32> to vector<2x3x4xf32>
-}
-
-// -----
-
-func.func @reshape_bad_input_fixed_size(%arg0 : vector<3x2x5xf32>) {
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
-  %c6 = arith.constant 6 : index
-  %c9 = arith.constant 9 : index
-  // expected-error at +1 {{fixed vector size must match input vector for dim 0}}
-  %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
-    : vector<3x2x5xf32> to vector<2x3x4xf32>
-}
-
-// -----
-
-func.func @reshape_bad_output_fixed_size(%arg0 : vector<3x2x4xf32>) {
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
-  %c6 = arith.constant 6 : index
-  %c9 = arith.constant 9 : index
-  // expected-error at +1 {{fixed vector size must match output vector for dim 0}}
-  %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
-    : vector<3x2x4xf32> to vector<2x3x5xf32>
-}
-
-// -----
-
 func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
   // expected-error at +1 {{op source/result vectors must have same element type}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 7e578452b82cc..4759fcc9511fb 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -522,23 +522,6 @@ func.func @vector_print_on_scalar(%arg0: i64) {
   return
 }
 
-// CHECK-LABEL: @reshape
-func.func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
-  // CHECK:      %[[C2:.*]] = arith.constant 2 : index
-  %c2 = arith.constant 2 : index
-  // CHECK:      %[[C3:.*]] = arith.constant 3 : index
-  %c3 = arith.constant 3 : index
-  // CHECK:      %[[C6:.*]] = arith.constant 6 : index
-  %c6 = arith.constant 6 : index
-  // CHECK:      %[[C9:.*]] = arith.constant 9 : index
-  %c9 = arith.constant 9 : index
-  // CHECK: vector.reshape %{{.*}}, [%[[C3]], %[[C6]]], [%[[C2]], %[[C9]]], [4] : vector<3x2x4xf32> to vector<2x3x4xf32>
-  %1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
-    : vector<3x2x4xf32> to vector<2x3x4xf32>
-
-  return %1 : vector<2x3x4xf32>
-}
-
 // CHECK-LABEL: @shape_cast
 func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
                  %arg1 : vector<8x1xf32>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/101645


More information about the Mlir-commits mailing list