[Mlir-commits] [mlir] [mlir][Vector] Remove uses of vector.extractelement/vector.insertelement (PR #113827)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 27 11:17:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Kunwar Grover (Groverkss)

<details>
<summary>Changes</summary>

This patch removes usages of vector.extractelement/vector.insertelement. These operations can be fully represented by vector.extract/vector.insert. See https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116 for more information.

---

Patch is 73.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113827.diff


20 Files Affected:

- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+3-6) 
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+9-10) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+5-5) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+15-8) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+2-3) 
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+2-18) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+5-7) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp (+4-14) 
- (modified) mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir (+12-12) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+1-2) 
- (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+7-7) 
- (modified) mlir/test/Dialect/Linalg/vectorization-scalable.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+4-6) 
- (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir (+4-8) 
- (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+28-27) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+15-3) 
- (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+31-58) 
- (modified) mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir (+2-4) 
- (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+2-2) 


``````````diff
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 6b27ec9947cb0b..6b9cbaf57676c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -313,8 +313,7 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
     auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
     Value asF16s =
         rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
-    Value result = rewriter.create<vector::ExtractElementOp>(
-        loc, asF16s, rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0));
+    Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
     return rewriter.replaceOp(op, result);
   }
   VectorType outType = cast<VectorType>(op.getOut().getType());
@@ -334,13 +333,11 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
   for (int64_t i = 0; i < numElements; i += 2) {
     int64_t elemsThisOp = std::min(numElements, i + 2) - i;
     Value thisResult = nullptr;
-    Value elemA = rewriter.create<vector::ExtractElementOp>(
-        loc, in, rewriter.create<arith::ConstantIndexOp>(loc, i));
+    Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
     Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
 
     if (elemsThisOp == 2) {
-      elemB = rewriter.create<vector::ExtractElementOp>(
-          loc, in, rewriter.createOrFold<arith::ConstantIndexOp>(loc, i + 1));
+      elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
     }
 
     thisResult =
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 3a4dc806efe976..ddbc4d2c4a4f3d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -157,7 +157,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
     return Value();
 
   Location loc = xferOp.getLoc();
-  return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
+  return b.create<vector::ExtractOp>(loc, xferOp.getMask(), iv);
 }
 
 /// Helper function TransferOpConversion and TransferOp1dConversion.
@@ -686,7 +686,7 @@ struct PrepareTransferWriteConversion
 /// %lastIndex = arith.subi %length, %c1 : index
 /// vector.print punctuation <open>
 /// scf.for %i = %c0 to %length step %c1 {
-///   %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+///   %el = vector.extract %v[%i : index] : vector<[4]xi32>
 ///   vector.print %el : i32 punctuation <no_punctuation>
 ///   %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
 ///   scf.if %notLastIndex {
@@ -756,7 +756,8 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     if (vectorType.getRank() != 1) {
       // Flatten n-D vectors to 1D. This is done to allow indexing with a
       // non-constant value (which can currently only be done via
-      // vector.extractelement for 1D vectors).
+      // vector.extract for 1D vectors).
+      // TODO: vector.extract supports N-D non-constant indices now.
       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
                                         std::multiplies<int64_t>());
       auto flatVectorType =
@@ -819,8 +820,7 @@ struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
     }
 
     // Print the scalar elements in the inner most loop.
-    auto element =
-        rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
+    auto element = rewriter.create<vector::ExtractOp>(loc, value, flatIndex);
     rewriter.create<vector::PrintOp>(loc, element,
                                      vector::PrintPunctuation::NoPunctuation);
 
@@ -1563,7 +1563,7 @@ struct Strategy1d<TransferReadOp> {
         [&](OpBuilder &b, Location loc) {
           Value val =
               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
-          return b.create<vector::InsertElementOp>(loc, val, vec, iv);
+          return b.create<vector::InsertOp>(loc, val, vec, iv);
         },
         /*outOfBoundsCase=*/
         [&](OpBuilder & /*b*/, Location loc) { return vec; });
@@ -1591,8 +1591,7 @@ struct Strategy1d<TransferWriteOp> {
     generateInBoundsCheck(
         b, xferOp, iv, dim,
         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
-          auto val =
-              b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
+          auto val = b.create<vector::ExtractOp>(loc, xferOp.getVector(), iv);
           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
         });
     b.create<scf::YieldOp>(loc);
@@ -1614,7 +1613,7 @@ struct Strategy1d<TransferWriteOp> {
 /// This pattern generates IR as follows:
 ///
 /// 1. Generate a for loop iterating over each vector element.
-/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
+/// 2. Inside the loop, generate a InsertOp or ExtractOp,
 ///    depending on OpTy.
 ///
 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
@@ -1630,7 +1629,7 @@ struct Strategy1d<TransferWriteOp> {
 /// Is rewritten to approximately the following pseudo-IR:
 /// ```
 /// for i = 0 to 9 {
-///   %t = vector.extractelement %vec[i] : vector<9xf32>
+///   %t = vector.extract %vec[i] : vector<9xf32>
 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
 /// }
 /// ```
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a2457176a1d47..e38bbad1637d45 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1134,8 +1134,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
   //    (0th) element and use that.
   SmallVector<Value> transferReadIdxs;
-  auto zero = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
   for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
     Value idx = bvm.lookup(extractOp.getIndices()[i]);
     if (idx.getType().isIndex()) {
@@ -1149,7 +1147,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
                         resultType.getScalableDims().back()),
         idx);
     transferReadIdxs.push_back(
-        rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+        rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0));
   }
 
   // `tensor.extract_element` is always in-bounds, hence the following holds.
@@ -1415,7 +1413,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
     // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
     // TODO: remove this.
     if (readType.getRank() == 0)
-      readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+      readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                     SmallVector<int64_t>{});
 
     LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
                                  << "\n");
@@ -2268,7 +2267,8 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
       loc, readType, copyOp.getSource(), indices,
       rewriter.getMultiDimIdentityMap(srcType.getRank()));
   if (cast<VectorType>(readValue.getType()).getRank() == 0) {
-    readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
+    readValue = rewriter.create<vector::ExtractOp>(loc, readValue,
+                                                   SmallVector<int64_t>{});
     readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
   }
   Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d71a236f62f454..af5b3637bf5b10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -697,8 +697,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
     Value result;
     if (vectorType.getRank() == 0) {
       if (mask)
-        mask = rewriter.create<ExtractElementOp>(loc, mask);
-      result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
+        mask = rewriter.create<ExtractOp>(loc, mask, SmallVector<int64_t>{});
+      result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(),
+                                          SmallVector<int64_t>{});
     } else {
       if (mask)
         mask = rewriter.create<ExtractOp>(loc, mask, 0);
@@ -1983,12 +1984,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
     if (extractResultRank < broadcastSrcRank)
       return failure();
 
-    // Special case if broadcast src is a 0D vector.
+    // If extractResultRank is 0, broadcastSrcRank has to be zero, since
+    // broadcastSrcRank >= extractResultRank for this pattern. If so, the input
+    // to the broadcast will be a vector<f32> or f32, but the result will be a
+    // f32, because of vector.extract 0-d semantics. Therefore, we instead
+    // just replace the broadcast with a vector.extract.
     if (extractResultRank == 0) {
       assert(broadcastSrcRank == 0 && llvm::isa<VectorType>(source.getType()));
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, source,
+                                                     SmallVector<int64_t>{});
       return success();
     }
+
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();
@@ -2951,11 +2958,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
               InsertOpConstantFolder>(context);
 }
 
-// Eliminates insert operations that produce values identical to their source
-// value. This happens when the source and destination vectors have identical
-// sizes.
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
-  if (getNumIndices() == 0)
+  // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
+  // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
+  // (type mismatch).
+  if (getNumIndices() == 0 && getSourceType() == getResult().getType())
     return getSource();
   return {};
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 6c36bbaee85237..6d82d753eeed80 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -65,7 +65,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     if (srcRank <= 1 && dstRank == 1) {
       Value ext;
       if (srcRank == 0)
-        ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
+        ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(),
+                                                 SmallVector<int64_t>{});
       else
         ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
       rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 716da55ba09aec..72bf329daaa76e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -391,9 +391,8 @@ struct TwoDimMultiReductionToReduction
         reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
       }
 
-      result = rewriter.create<vector::InsertElementOp>(
-          loc, reductionOp->getResult(0), result,
-          rewriter.create<arith::ConstantIndexOp>(loc, i));
+      result = rewriter.create<vector::InsertOp>(loc, reductionOp->getResult(0),
+                                                 result, i);
     }
 
     rewriter.replaceOp(rootOp, result);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..343178c8156d25 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -177,24 +177,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
       }
 
       Value extract;
-      if (srcRank == 0) {
-        // 0-D vector special case
-        assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
-        extract = rewriter.create<vector::ExtractElementOp>(
-            loc, op.getSourceVectorType().getElementType(), op.getSource());
-      } else {
-        extract =
-            rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
-      }
-
-      if (resRank == 0) {
-        // 0-D vector special case
-        assert(resIdx.empty() && "Unexpected indices for 0-D vector");
-        result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
-      } else {
-        result =
-            rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
-      }
+      extract = rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 2289fd1ff1364e..4ea6bcf3181dae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1238,7 +1238,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
     if (extractOp.getNumIndices() == 0)
       return failure();
 
-    // Rewrite vector.extract with 1d source to vector.extractelement.
+    // Rewrite vector.extract with 1d source to vector.extract.
     if (extractSrcType.getRank() == 1) {
       if (extractOp.hasDynamicPosition())
         // TODO: Dinamic position not supported yet.
@@ -1247,9 +1247,8 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(extractOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = extractOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(extractOp);
-      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
-          extractOp, extractOp.getVector(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+          extractOp, extractOp.getVector(), pos);
       return success();
     }
 
@@ -1519,9 +1518,8 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
       assert(insertOp.getNumIndices() == 1 && "expected 1 index");
       int64_t pos = insertOp.getStaticPosition()[0];
       rewriter.setInsertionPoint(insertOp);
-      rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
-          insertOp, insertOp.getSource(), insertOp.getDest(),
-          rewriter.create<arith::ConstantIndexOp>(loc, pos));
+      rewriter.replaceOpWithNewOp<vector::InsertOp>(
+          insertOp, insertOp.getSource(), insertOp.getDest(), pos);
       return success();
     }
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..a5d5dc00b33cd3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -21,23 +21,13 @@ using namespace mlir::vector;
 // Helper that picks the proper sequence for inserting.
 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
                        Value into, int64_t offset) {
-  auto vectorType = cast<VectorType>(into.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<InsertOp>(loc, from, into, offset);
-  return rewriter.create<vector::InsertElementOp>(
-      loc, vectorType, from, into,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<InsertOp>(loc, from, into, offset);
 }
 
 // Helper that picks the proper sequence for extracting.
 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
                         int64_t offset) {
-  auto vectorType = cast<VectorType>(vector.getType());
-  if (vectorType.getRank() > 1)
-    return rewriter.create<ExtractOp>(loc, vector, offset);
-  return rewriter.create<vector::ExtractElementOp>(
-      loc, vectorType.getElementType(), vector,
-      rewriter.create<arith::ConstantIndexOp>(loc, offset));
+  return rewriter.create<ExtractOp>(loc, vector, offset);
 }
 
 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
@@ -277,8 +267,8 @@ class Convert1DExtractStridedSliceIntoExtractInsertChain final
 };
 
 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
-/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
-/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+/// For such cases, we can rewrite it to ExtractOp + lower rank
+/// ExtractStridedSliceOp + InsertOp for the n-D case.
 class DecomposeNDExtractStridedSlice
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
diff --git a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
index 121cae26748a82..8991506dee1dfb 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/16-bit-floats.mlir
@@ -5,7 +5,7 @@
 func.func @scalar_trunc(%v: f32) -> f16{
   // CHECK: %[[poison:.*]] = llvm.mlir.poison : f32
   // CHECK: %[[trunc:.*]] = rocdl.cvt.pkrtz %[[value]], %[[poison]] : vector<2xf16>
-  // CHECK: %[[extract:.*]] = vector.extractelement %[[trunc]][%c0 : index] : vector<2xf16>
+  // CHECK: %[[extract:.*]] = vector.extract %[[trunc]][0] : f16 from vector<2xf16>
   // CHECK: return %[[extract]] : f16
   %w = arith.truncf %v : f32 to f16
   return %w : f16
@@ -14,8 +14,8 @@ func.func @scalar_trunc(%v: f32) -> f16{
 // CHECK-LABEL: @vector_trunc
 // CHECK-SAME: (%[[value:.*]]: vector<2xf32>)
 func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]]
   // CHECK: %[[ret:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: return %[[ret]]
   %w = arith.truncf %v : vector<2xf32> to vector<2xf16>
@@ -25,23 +25,23 @@ func.func @vector_trunc_short(%v: vector<2xf32>) -> vector<2xf16> {
 // CHECK-LABEL:  @vector_trunc_long
 // CHECK-SAME: (%[[value:.*]]: vector<9xf32>)
 func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf16> {
-  // CHECK: %[[elem0:.*]] = vector.extractelement %[[value]][%c0 : index]
-  // CHECK: %[[elem1:.*]] = vector.extractelement %[[value]][%c1 : index]
+  // CHECK: %[[elem0:.*]] = vector.extract %[[value]][0]
+  // CHECK: %[[elem1:.*]] = vector.extract %[[value]][1]
   // CHECK: %[[packed0:.*]] = rocdl.cvt.pkrtz %[[elem0]], %[[elem1]] : vector<2xf16>
   // CHECK: %[[out0:.*]] = vector.insert_strided_slice %[[packed0]], {{.*}} {offsets = [0], strides = [1]} : vector<2xf16> into vector<9xf16>
-  // CHECK: %[[elem2:.*]] = vector.extractelement %[[value]][%c2 : index]
-  // CHECK: %[[elem3:.*]] = vector.extractelement %[[value]][%c3 : index]
+  // CHECK: %[[elem2:.*]] = vector.extract %[[value]][2]
+  // CHECK: %[[elem3:.*]] = vector.extract %[[value]][3]
   // CHECK: %[[packed1:.*]] = rocdl.cvt.pkrtz %[[elem2]], %[[elem3]] : vector<2xf16>
   // CHECK: %[[out1:.*]] = vector.insert_strided_slice %[[packed1]], %[[out0]] {offsets = [2], strides = [1]} : vector<2xf16> into...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/113827


More information about the Mlir-commits mailing list