[Mlir-commits] [mlir] [mlir][linalg] Support scalable vectorization of linalg.index operations (PR #96778)

Cullen Rhodes llvmlistbot at llvm.org
Wed Jun 26 07:52:30 PDT 2024


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/96778

The vectorization of linalg.index operations doesn't support scalable
vectors when computing the index vector. This patch fixes this with the
vector.step operation.

>From 1ba850c1801b94bcb903bbd6c4d715a0d4b6c959 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 19 Jun 2024 15:29:25 +0000
Subject: [PATCH 1/3] [mlir][vector] Add vector.step operation

This patch adds a new vector.step operation to the Vector dialect. It
produces a linear sequence of index values from 0 to N, where N is the
number of elements in the result vector, and can be used to create
vectors of indices.

It supports both fixed-width and scalable vectors. For fixed the
canonical representation is `arith.constant dense<[0, .., N]>`. A
scalable step cannot be represented as a constant and is lowered to the
`llvm.experimental.stepvector` intrinsic [1].

[1] https://llvm.org/docs/LangRef.html#llvm-experimental-stepvector-intrinsic
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 29 +++++++++++++++++++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 17 +++++++++--
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 14 +++++++++
 .../VectorToLLVM/vector-to-llvm.mlir          | 11 +++++++
 mlir/test/Dialect/Vector/canonicalize.mlir    | 10 +++++++
 mlir/test/Dialect/Vector/invalid.mlir         | 16 ++++++++++
 mlir/test/Dialect/Vector/ops.mlir             |  9 +++++-
 7 files changed, 103 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 097e5e6fb0d61..94cba7d7882cd 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3017,6 +3017,35 @@ def Vector_ScanOp :
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// VectorStepOp
+//===----------------------------------------------------------------------===//
+
+def Vector_StepOp : Vector_Op<"step", [Pure]> {
+  let summary = "A linear sequence of values from 0 to N";
+  let description = [{
+    A `step` operation produces an index vector, i.e. a 1-D vector of values of
+    index type that represents a linear sequence from 0 to N, where N is the
+    number of elements in the `result` vector.
+
+    Supports fixed-width and scalable vectors. For fixed the canonical
+    representation is `arith.constant dense<[0, .., N]>`. A scalable step
+    cannot be represented as a constant and is lowered to the
+    [llvm.experimental.stepvector](https://llvm.org/docs/LangRef.html#llvm-experimental-stepvector-intrinsic)
+    intrinsic.
+
+    Examples:
+
+    ```mlir
+    %0 = vector.step : vector<4xindex> ; [0, 1, 2, 3]
+    %1 = vector.step : vector<[4]xindex> ; [0, 1, .., <vscale * 4 - 1>]
+    ```
+  }];
+  let hasFolder = 1;
+  let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
+  let assemblyFormat = "attr-dict `:` type($result)";
+}
+
 def Vector_YieldOp : Vector_Op<"yield", [
     Pure, ReturnLike, Terminator]> {
   let summary = "Terminates and yields values from vector regions.";
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 0eac55255b133..6a8a9d818aad2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1860,6 +1860,19 @@ struct VectorFromElementsLowering
   }
 };
 
+/// Conversion pattern for vector.step.
+struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type llvmType = typeConverter->convertType(stepOp.getType());
+    rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1885,8 +1898,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorSplatOpLowering, VectorSplatNdOpLowering,
                VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
-               VectorDeinterleaveOpLowering, VectorFromElementsLowering>(
-      converter);
+               VectorDeinterleaveOpLowering, VectorFromElementsLowering,
+               VectorStepOpLowering>(converter);
   // Transfer ops with rank > 1 are handled by VectorToSCF.
   populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6734c80f2760d..8efafcab5529e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6316,6 +6316,20 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
   return SplatElementsAttr::get(getType(), {constOperand});
 }
 
+//===----------------------------------------------------------------------===//
+// StepOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
+  auto resultType = cast<VectorType>(getType());
+  if (resultType.isScalable())
+    return nullptr;
+  SmallVector<APInt> indices;
+  for (unsigned i = 0; i < resultType.getNumElements(); i++)
+    indices.push_back(APInt(/*width=*/64, i));
+  return DenseElementsAttr::get(resultType, indices);
+}
+
 //===----------------------------------------------------------------------===//
 // WarpExecuteOnLane0Op
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09b79708a9ab2..897ff7ad6b43a 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2621,3 +2621,14 @@ func.func @vector_from_elements_0d(%a: f32) -> vector<f32> {
   %0 = vector.from_elements %a : vector<f32>
   return %0 : vector<f32>
 }
+
+// -----
+
+// CHECK-LABEL: @vector_step
+// CHECK: %[[STEPVECTOR:.*]] = llvm.intr.experimental.stepvector : vector<[4]xi64>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[STEPVECTOR]] : vector<[4]xi64> to vector<[4]xindex>
+// CHECK: return %[[CAST]] : vector<[4]xindex>
+func.func @vector_step() -> vector<[4]xindex> {
+  %0 = vector.step : vector<[4]xindex>
+  return %0 : vector<[4]xindex>
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8181f1a8c5d13..9c3bbb907cfb4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2711,3 +2711,13 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
   // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
   return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
 }
+
+// -----
+
+// CHECK-LABEL: @fold_vector_step_to_constant
+// CHECK: %[[CONSTANT:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: return %[[CONSTANT]] : vector<4xindex>
+func.func @fold_vector_step_to_constant() -> vector<4xindex> {
+  %0 = vector.step : vector<4xindex>
+  return %0 : vector<4xindex>
+}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index d0eaed8f98cc5..db169a6c1f8ae 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1871,3 +1871,19 @@ func.func @invalid_from_elements(%a: f32, %b: i32) {
   vector.from_elements %a, %b : vector<2xf32>
   return
 }
+
+// -----
+
+func.func @invalid_step_0d() {
+  // expected-error @+1 {{vector.step' op result #0 must be vector of index values of ranks 1, but got 'vector<f32>'}}
+  vector.step : vector<f32>
+  return
+}
+
+// -----
+
+func.func @invalid_step_2d() {
+  // expected-error @+1 {{vector.step' op result #0 must be vector of index values of ranks 1, but got 'vector<2x4xf32>'}}
+  vector.step : vector<2x4xf32>
+  return
+}
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4da09584db88b..7908e61abc704 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -1171,4 +1171,11 @@ func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vecto
   // CHECK: vector.from_elements %[[b]], %[[b]], %[[a]], %[[a]] : vector<2x2xf32>
   %3 = vector.from_elements %b, %b, %a, %a : vector<2x2xf32>
   return %0, %1, %2, %3 : vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: @step
+func.func @step() {
+  %0 = vector.step : vector<2xindex>
+  %1 = vector.step : vector<[4]xindex>
+  return
+}

>From 01efec6d34e1f1ff701aaf17ce62776ce88decdd Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 20 Jun 2024 13:47:19 +0000
Subject: [PATCH 2/3] [mlir][linalg] Add scalable vectorization of tensor
 extract test

---
 .../Linalg/vectorize-tensor-extract.mlir      | 50 +++++++++++++++++++
 1 file changed, 50 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 85e1c56dd45a0..4b0df6a01c8fc 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -113,6 +113,56 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_scalable_nd_tensor_extract_transfer_read_basic(%arg0: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %1 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } outs(%arg2 : tensor<?x?x?xf32>) {
+  ^bb0(%arg4: f32):
+    %2 = linalg.index 0 : index
+    %3 = linalg.index 1 : index
+    %4 = linalg.index 2 : index
+    %5 = tensor.extract %arg0[%2, %3, %4] : tensor<?x?x?xf32>
+    linalg.yield %5 : f32
+  } -> tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+
+// CHECK-LABEL: @vectorize_scalable_nd_tensor_extract_transfer_read_basic
+// CHECK-SAME: %[[BASE:.*]]: tensor<?x?x?xf32>, %[[DEST:.*]]: tensor<?x?x?xf32>
+// CHECK:           %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
+// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
+// CHECK:           %[[INDEX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[DEST_DIM0:.*]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x?x?xf32>
+// CHECK:           %[[DEST_DIM1:.*]] = tensor.dim %[[DEST]], %[[C1]] : tensor<?x?x?xf32>
+// CHECK:           %[[DEST_DIM2:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<?x?x?xf32>
+// CHECK:           %[[DEST_MASK:.*]] = vector.create_mask %[[DEST_DIM0]], %[[DEST_DIM1]], %[[DEST_DIM2]] : vector<1x1x[4]xi1>
+// CHECK:           %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<4xindex> to vector<1x1x[4]xindex>
+// CHECK:           %[[GATHER:.*]] = vector.mask %[[DEST_MASK]] { vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<?x?x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
+// CHECK:           %[[OUT:.*]] = vector.mask %[[DEST_MASK]] { vector.transfer_write %[[GATHER]], %[[DEST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<?x?x?xf32> } : vector<1x1x[4]xi1> -> tensor<?x?x?xf32>
+// CHECK:           return %[[OUT]] : tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [1, 1, [4]] {vectorize_nd_extract} : !transform.any_op
+
+    %func = transform.structured.match ops{["func.func"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %func {
+      transform.apply_patterns.canonicalization
+      transform.apply_patterns.linalg.tiling_canonicalization
+    } : !transform.any_op
+    transform.yield
+  }
+}
+
  // -----
 
 func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> {

>From 228281e6cbd9f8889188318e70ecec01e8630415 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 26 Jun 2024 14:27:26 +0000
Subject: [PATCH 3/3] [mlir][linalg] Support scalable vectorization of
 linalg.index operations

The vectorization of linalg.index operations doesn't support scalable
vectors when computing the index vector. This patch fixes this with the
vector.step operation.
---
 .../Linalg/Transforms/Vectorization.cpp       | 21 ++++++++++++-------
 .../vectorize-tensor-extract-masked.mlir      |  4 ++--
 .../Linalg/vectorize-tensor-extract.mlir      |  4 ++--
 3 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 511835a226e7a..20b151f76df00 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -195,6 +195,10 @@ struct VectorizationState {
   /// Returns the canonical vector shape used to vectorize the iteration space.
   ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
 
+  /// Returns the vector dimensions that are scalable in the canonical vector
+  /// shape.
+  ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
+
   /// Returns a vector type of the provided `elementType` with the canonical
   /// vector shape and the corresponding fixed/scalable dimensions bit. If
   /// `dimPermutation` is provided, the canonical vector dimensions are permuted
@@ -694,23 +698,24 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
   auto loc = indexOp.getLoc();
   // Compute the static loop sizes of the index op.
-  auto targetShape = state.getCanonicalVecShape();
+  ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
+  auto dim = indexOp.getDim();
   // Compute a one-dimensional index vector for the index op dimension.
-  auto constantSeq =
-      llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
-  auto indexSteps = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getIndexVectorAttr(constantSeq));
+  auto indexVectorType =
+      VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
+                      state.getScalableVecDims()[dim]);
+  auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
   // Return the one-dimensional index vector if it lives in the trailing
   // dimension of the iteration space since the vectorization algorithm in this
   // case can handle the broadcast.
-  if (indexOp.getDim() == targetShape.size() - 1)
+  if (dim == targetShape.size() - 1)
     return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
   // Otherwise permute the targetShape to move the index dimension last,
   // broadcast the one-dimensional index vector to the permuted shape, and
   // finally transpose the broadcasted index vector to undo the permutation.
   auto permPattern =
       llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
-  std::swap(permPattern[indexOp.getDim()], permPattern.back());
+  std::swap(permPattern[dim], permPattern.back());
   auto permMap =
       AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
 
@@ -719,7 +724,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
       indexSteps);
   SmallVector<int64_t> transposition =
       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
-  std::swap(transposition.back(), transposition[indexOp.getDim()]);
+  std::swap(transposition.back(), transposition[dim]);
   auto transposeOp =
       rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index e68d297dc41f2..f042753780013 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -63,7 +63,7 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
 // CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_12:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.step : vector<4xindex>
 // CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
 // CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
 // CHECK-DAG:       %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
@@ -160,7 +160,7 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_gather(%
 // CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
 // CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_12:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.step : vector<4xindex>
 // CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
 // CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
 // CHECK:           %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 4b0df6a01c8fc..8ec1cdc609742 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -135,7 +135,6 @@ func.func @vectorize_scalable_nd_tensor_extract_transfer_read_basic(%arg0: tenso
 // CHECK-SAME: %[[BASE:.*]]: tensor<?x?x?xf32>, %[[DEST:.*]]: tensor<?x?x?xf32>
 // CHECK:           %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
 // CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
-// CHECK:           %[[INDEX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
 // CHECK:           %[[C2:.*]] = arith.constant 2 : index
 // CHECK:           %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
@@ -143,7 +142,8 @@ func.func @vectorize_scalable_nd_tensor_extract_transfer_read_basic(%arg0: tenso
 // CHECK:           %[[DEST_DIM1:.*]] = tensor.dim %[[DEST]], %[[C1]] : tensor<?x?x?xf32>
 // CHECK:           %[[DEST_DIM2:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<?x?x?xf32>
 // CHECK:           %[[DEST_MASK:.*]] = vector.create_mask %[[DEST_DIM0]], %[[DEST_DIM1]], %[[DEST_DIM2]] : vector<1x1x[4]xi1>
-// CHECK:           %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<4xindex> to vector<1x1x[4]xindex>
+// CHECK:           %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
+// CHECK:           %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
 // CHECK:           %[[GATHER:.*]] = vector.mask %[[DEST_MASK]] { vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<?x?x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
 // CHECK:           %[[OUT:.*]] = vector.mask %[[DEST_MASK]] { vector.transfer_write %[[GATHER]], %[[DEST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<?x?x?xf32> } : vector<1x1x[4]xi1> -> tensor<?x?x?xf32>
 // CHECK:           return %[[OUT]] : tensor<?x?x?xf32>



More information about the Mlir-commits mailing list