[Mlir-commits] [mlir] [mlir][linalg] Generate `vector.transfer_read` for contiguous `tensor.extract` loads (PR #76436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 27 03:35:34 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Prathamesh Tagore (meshtag)

<details>
<summary>Changes</summary>

This PR intends to cover more cases where `vector.transfer_read` can be generated from `tensor.extract` instead of `vector.gather`. 

I have replaced generation of `vector.shapecast` and `vector.extractelement` ops with `vector.extract` op in cases involving scalar broadcast and contiguous access. This was done with the motive of making the vectorization algorithm more powerful by enabling it to vectorize better for n non-unit dimensional target shapes. `vector.extract` on 1 dimensional loads (or the ones that can be made one dimensional through simple casting) should be simply folded away by a simple rewrite pattern (if it's not happening already). 

---

Patch is 35.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76436.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+55-96) 
- (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+208-42) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c21d007c931b9b..0d0b1ef0d085df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
@@ -788,9 +789,6 @@ enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
 
   auto targetShape = linalgOp.getStaticLoopRanges();
-  assert(((llvm::count_if(targetShape,
-                          [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
-         "n-D vectors are not yet supported");
   assert(targetShape.back() != 1 &&
          "1-D vectors with the trailing dim eqaual 1 are not yet supported");
 
@@ -806,12 +804,20 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
   Operation *defOp = val.getDefiningOp();
   assert(defOp && "This is neither a block argument nor an operation result");
 
-  // IndexOp is loop invariant as long as its result remains constant across
-  // iterations. Given the assumptions on the loop ranges above, only the
-  // trailing loop dim ever changes.
-  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
-  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
-    return (indexOp.getDim() != trailingLoopDim);
+  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
+    // If target shape is of the form 1x1x1x..xn and val is obtained from a
+    // linalg.index op, it will be loop invariant only if index op's dim is not
+    // the trailing dimension.
+    if (llvm::count_if(targetShape,
+                       [](int64_t dimSize) { return dimSize > 1; }) == 1 &&
+        targetShape.back() != 1) {
+      auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+      return indexOp.getDim() != trailingLoopDim;
+    }
+    // val will be loop variant in some of the other cases.
+    // TODO: Relax this condition
+    return false;
+  }
 
   auto *ancestor = block->findAncestorOpInBlock(*defOp);
 
@@ -830,50 +836,35 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
   return result;
 }
 
-/// Check whether \p val could be used for calculating the trailing index for a
-/// contiguous load operation.
-///
-/// There are currently 3 types of values that are allowed here:
-///   1. loop-invariant values,
-///   2. values that increment by 1 with every loop iteration,
-///   3. results of basic arithmetic operations (linear and continuous)
-///      involving 1., 2. and 3.
-/// This method returns True if indeed only such values are used in calculating
-/// \p val.
-///
-/// Additionally, the trailing index for a contiguous load operation should
-/// increment by 1 with every loop iteration, i.e. be based on:
-///   * `linalg.index <dim>` ,
-/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
-/// updated to `true` when such an op is found.
-static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
-                                bool &foundIndexOp) {
-
+// Determine if the val is obtained from a linalg.index op for the dimension at
+// which it is used to extract a value from the tensor and if it could be used
+// for contigous memory access.
+static bool isProperLinalgIdx(LinalgOp &linalgOp, Value &val,
+                              uint64_t valuePosInExtract) {
   auto targetShape = linalgOp.getStaticLoopRanges();
-  assert(((llvm::count_if(targetShape,
-                          [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
-         "n-D vectors are not yet supported");
   assert(targetShape.back() != 1 &&
          "1-D vectors with the trailing dim 1 are not yet supported");
 
-  // Blocks outside _this_ linalg.generic are effectively loop invariant.
-  // However, analysing block arguments for _this_ linalg.generic Op is a bit
-  // tricky. Just bail out in the latter case.
-  // TODO: We could try analysing the corresponding affine map here.
+  // val can't be a result of linalg.index for this linalg.generic if it is a
+  // block argument.
   auto *block = linalgOp.getBlock();
   if (isa<BlockArgument>(val))
-    return llvm::all_of(block->getArguments(),
-                        [&val](Value v) { return (v != val); });
+    return false;
 
   Operation *defOp = val.getDefiningOp();
-  assert(defOp && "This is neither a block argument nor an operation result");
+  assert(defOp && "This is not an operation result");
 
-  // Given the assumption on the loop ranges above, only the trailing loop
-  // index is not constant.
-  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
-    foundIndexOp = (indexOp.getDim() == trailingLoopDim);
-    return true;
+    // If target shape is of the form 1x1x1x..xn and val is obtained from a
+    // linalg.index op, it will be used for contiguous access only when it is
+    // obtained for the trailing dimension.
+    if (llvm::count_if(targetShape,
+                       [](int64_t dimSize) { return dimSize > 1; }) == 1 &&
+        targetShape.back() != 1) {
+      auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+      return indexOp.getDim() == trailingLoopDim;
+    }
+    return indexOp.getDim() == valuePosInExtract;
   }
 
   auto *ancestor = block->findAncestorOpInBlock(*defOp);
@@ -882,14 +873,14 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
     return false;
 
   // Conservatively reject Ops that could lead to indices with stride other
-  // than 1.
+  // than 1 after processing the result of linalg.index.
   if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
           ancestor))
     return false;
 
   bool result = false;
   for (auto op : ancestor->getOperands())
-    result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
+    result |= isProperLinalgIdx(linalgOp, op, valuePosInExtract);
 
   return result;
 }
@@ -915,14 +906,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   if (linalgOp.hasDynamicShape())
     return VectorMemoryAccessKind::Gather;
 
-  // 1. Assume that it's a gather load when reading _into_:
-  //    * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
-  //    * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
-  // TODO: Relax these conditions.
-  // FIXME: This condition assumes non-dynamic sizes.
-  if ((llvm::count_if(targetShape,
-                      [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
-      targetShape.back() == 1)
+  if (targetShape.back() == 1)
     return VectorMemoryAccessKind::Gather;
 
   // 2. Assume that it's a gather load when reading _from_ a tensor for which
@@ -931,51 +915,29 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   if (inputShape.getShape().back() == 1)
     return VectorMemoryAccessKind::Gather;
 
-  bool leadingIdxsLoopInvariant = true;
+  bool isLoopInvariantLoad = true;
+  bool isProperLinalgIdxLoad = true;
 
-  // 3. Analyze the leading indices of `extractOp`.
-  // Look at the way each index is calculated and decide whether it is suitable
-  // for a contiguous load, i.e. whether it's loop invariant.
   auto indices = extractOp.getIndices();
-  auto leadIndices = indices.drop_back(1);
-
-  for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
+  for (auto [i, indexVal] : llvm::enumerate(indices)) {
     if (inputShape.getShape()[i] == 1)
       continue;
 
-    leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
-  }
+    isLoopInvariantLoad &= isLoopInvariantIdx(linalgOp, indexVal);
+    isProperLinalgIdxLoad &= !isLoopInvariantLoad
+                                 ? isProperLinalgIdx(linalgOp, indexVal, i)
+                                 : isProperLinalgIdxLoad;
 
-  if (!leadingIdxsLoopInvariant) {
-    LDBG("Found gather load: " << extractOp);
-    return VectorMemoryAccessKind::Gather;
+    if (!isLoopInvariantLoad && !isProperLinalgIdxLoad) {
+      LDBG("Found gather load: " << extractOp);
+      return VectorMemoryAccessKind::Gather;
+    }
   }
 
-  // 4. Analyze the trailing index for `extractOp`.
-  // At this point we know that the leading indices are loop invariant. This
-  // means that is potentially a scalar or a contiguous load. We can decide
-  // based on the trailing idx.
-  auto extractOpTrailingIdx = indices.back();
-
-  // 4a. Scalar broadcast load
-  // If the trailing index is loop invariant then this is a scalar load.
-  if (leadingIdxsLoopInvariant &&
-      isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
+  if (isLoopInvariantLoad) {
     LDBG("Found scalar broadcast load: " << extractOp);
-
     return VectorMemoryAccessKind::ScalarBroadcast;
-  }
-
-  // 4b. Contiguous loads
-  // The trailing `extractOp` index should increment with every loop iteration.
-  // This effectively means that it must be based on the trailing loop index.
-  // This is what the following bool captures.
-  bool foundIndexOp = false;
-  bool isContiguousLoad =
-      isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
-  isContiguousLoad &= foundIndexOp;
-
-  if (isContiguousLoad) {
+  } else if (!isLoopInvariantLoad && isProperLinalgIdxLoad) {
     LDBG("Found contigous load: " << extractOp);
     return VectorMemoryAccessKind::Contiguous;
   }
@@ -1048,9 +1010,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
   SmallVector<Value> transferReadIdxs;
-  auto resTrailingDim = resultType.getShape().back();
-  auto zero = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
     auto idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
@@ -1058,11 +1017,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       continue;
     }
 
-    auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
-        loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
-        bvm.lookup(extractOp.getIndices()[i]));
-    transferReadIdxs.push_back(
-        rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+    auto idxShapedType = dyn_cast<ShapedType>(idx.getType());
+    SmallVector<int64_t> extractIndicesVec(idxShapedType.getRank(), 0);
+
+    transferReadIdxs.push_back(rewriter.create<vector::ExtractOp>(
+        loc, idx, ArrayRef<int64_t>(extractIndicesVec)));
   }
 
   // `tensor.extract_element` is always in-bounds, hence the following holds.
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 3fd4fcd536624c..0ac67ca6af6ca7 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -92,17 +92,19 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf
 
 // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
-// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
-// CHECK:   %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex>
-// CHECK:   %[[C0_i32:.*]] = arith.constant 0 : i32
-// CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:   %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK:   %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK:   %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK:   %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
-// CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+//      CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+//      CHECK: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+//      CHECK: %[[CST_1:.*]] = arith.constant 0.000000e+00 : f32
+//      CHECK: %[[C0:.*]] = arith.constant 0 : index
+//      CHECK: %[[E0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+//      CHECK: %[[E1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+//      CHECK: %[[E2:.*]] = vector.extract %[[CST_0]][0] : index from vector<3xindex>
+//      CHECK: %[[R1:.*]] = vector.transfer_read %[[ARG0]][%[[E0]], %[[E1]], %[[E2]]], %[[CST_1]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+//      CHECK: %[[RES:.*]] = vector.transfer_write %[[R1]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
+//      CHECK: return %[[RES]] : tensor<1x1x3xf32>
+//      CHECK: }
+
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -113,7 +115,7 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
- // -----
+// -----
 
 func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> {
   %c79 = arith.constant 79 : index
@@ -134,26 +136,21 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
   return %25 : tensor<1x4xf32>
 }
 
-
-// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_complex(
+/// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_complex(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<45x80x16xf32>,
 // CHECK-SAME:      %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
 // CHECK-SAME:      %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
 // CHECK:           %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : i32
 // CHECK:           %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_10:.*]] = arith.constant 79 : index
 // CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
-// CHECK:           %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex>
 // CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
 // CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
 // CHECK:           %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
 // CHECK:           %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
-// CHECK:           %[[VAL_17:.*]] = vector.shape_cast %[[VAL_12]] : vector<1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_18:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_19:.*]] = vector.extractelement %[[VAL_16]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_18]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_18:.*]] = vector.extract %[[VAL_16]][0] : index from vector<4xindex>
+// CHECK:           %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_11]], %[[VAL_10]], %[[VAL_18]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
 // CHECK:           %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
 // CHECK:           return %[[VAL_21]] : tensor<1x4xf32>
 // CHECK:         }
@@ -239,19 +236,21 @@ func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32
 // CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_contiguous_and_gather(
 // CHECK-SAME:                    %[[VAL_0:.*]]: tensor<6xf32>
 // CHECK-SAME:                    %[[VAL_1:.*]]: tensor<5xi32>
-// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3, 4]> : vector<5xindex>
 // CHECK:           %[[VAL_3:.*]] = arith.constant 0 : i32
 // CHECK:           %[[VAL_4:.*]] = arith.constant dense<0> : vector<5xindex>
 // CHECK:           %[[VAL_5:.*]] = arith.constant dense<5> : vector<5xindex>
 // CHECK:           %[[VAL_6:.*]] = arith.constant dense<true> : vector<5xi1>
 // CHECK:           %[[VAL_7:.*]] = arith.constant dense<0.000000e+00> : vector<5xf32>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_8:.*]] = tensor.empty() : tensor<5xf32>
-// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_2]]], %[[VAL_3]] {in_bounds = [true]} : tensor<5xi32>, vector<5xi32>
+// CHECK:           %[[E0:.*]] = vector.extract %[[CST]][0] : index from vector<5xindex>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[E0]]], %[[VAL_3]] {in_bounds = [true]} : tensor<5xi32>, vector<5xi32>
 // CHECK:           %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : vector<5xi32> to vector<5xindex>
 // CHECK:           %[[VAL_11:.*]] = arith.maxsi %[[VAL_10]], %[[VAL_4]] : vector<5xindex>
 // CHECK:           %[[VAL_12:.*]] = arith.minsi %[[VAL_11]], %[[VAL_5]] : vector<5xindex>
-// CHECK:           %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_2]]] {{\[}}%[[VAL_12]]], %[[VAL_6]], %[[VAL_7]] : tensor<6xf32>, vector<5xindex>, vector<5xi1>, vector<5xf32> into vector<5xf32>
-// CHECK:           %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_8]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<5xf32>, tensor<5xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[C0]]] {{\[}}%[[VAL_12]]], %[[VAL_6]], %[[VAL_7]] : tensor<6xf32>, vector<5xindex>, vector<5xi1>, vector<5xf32> into vector<5xf32>
+// CHECK:           %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_8]]{{\[}}%[[C0]]] {in_bounds = [true]} : vector<5xf32>, tensor<5xf32>
 // CHECK:           return %[[VAL_14]] : tensor<5xf32>
 
 module attributes {transform.with_named_sequence} {
@@ -286,13 +285,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
 // CHECK-SAME:                                                                        %[[VAL_1:.*]]: index,
 // CHECK-SAME:                                                                        %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
 // CHECK:           %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i32
 // CHECK:           %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[VAL_6:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_7:.*]] = arith.constant 79 : index
 // CHECK:           %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
 // CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK:           %[[VAL_10:.*]] = vector.extractelement %[[VAL_9]]{{\[}}%[[VAL_4]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex>
 // CHECK:           %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
 // CHECK:           %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76436


More information about the Mlir-commits mailing list