[Mlir-commits] [mlir] 2b04291 - [mlir][Vector] Fix scalable InsertSlice/ExtractSlice lowering (#124861)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 31 14:21:39 PST 2025


Author: Diego Caballero
Date: 2025-01-31T14:21:35-08:00
New Revision: 2b04291830a2a34b681ae711dabfa1032f6c84f7

URL: https://github.com/llvm/llvm-project/commit/2b04291830a2a34b681ae711dabfa1032f6c84f7
DIFF: https://github.com/llvm/llvm-project/commit/2b04291830a2a34b681ae711dabfa1032f6c84f7.diff

LOG: [mlir][Vector] Fix scalable InsertSlice/ExtractSlice lowering (#124861)

It looks like scalable `vector.insertslice/extractslice` ops made their way
through lowering patterns that generate `vector.shuffle` ops. I'm not
sure why this wasn't caught by the verifier, probably because the
shuffle op was folded into something else as part of the same rewrite
and the IR wasn't verified.

This PR fixes the issue by preventing scalable vector.insertslice/extractslice
ops to be lowered to vector shuffles. Instead, they are now lowered to a
sequence of insertslice/extractelement ops using an existing patter.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72405fcfef00f13..82a985c9e582481 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -96,11 +96,15 @@ class ConvertSameRankInsertStridedSliceIntoShuffle
                                 PatternRewriter &rewriter) const override {
     auto srcType = op.getSourceVectorType();
     auto dstType = op.getDestVectorType();
+    int64_t srcRank = srcType.getRank();
+
+    // Scalable vectors are not supported by vector shuffle.
+    if ((srcType.isScalable() || dstType.isScalable()) && srcRank == 1)
+      return failure();
 
     if (op.getOffsets().getValue().empty())
       return failure();
 
-    int64_t srcRank = srcType.getRank();
     int64_t dstRank = dstType.getRank();
     assert(dstRank >= srcRank);
     if (dstRank != srcRank)
@@ -184,6 +188,11 @@ class Convert1DExtractStridedSliceIntoShuffle
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
     auto dstType = op.getType();
+    auto srcType = op.getSourceVectorType();
+
+    // Scalable vectors are not supported by vector shuffle.
+    if (dstType.isScalable() || srcType.isScalable())
+      return failure();
 
     assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
 
@@ -309,6 +318,8 @@ class DecomposeNDExtractStridedSlice
   }
 };
 
+// TODO: Make sure these `populate*` patterns are tested in isolation.
+
 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DecomposeDifferentRankInsertStridedSlice,
@@ -331,4 +342,14 @@ void vector::populateVectorInsertExtractStridedSliceTransforms(
   patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
                Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
                                                         benefit);
+  // Generate chains of extract/insert ops for scalable vectors only as they
+  // can't be lowered to vector shuffles.
+  populateVectorExtractStridedSliceToExtractInsertChainPatterns(
+      patterns,
+      /*controlFn=*/
+      [](ExtractStridedSliceOp op) {
+        return op.getType().isScalable() ||
+               op.getSourceVectorType().isScalable();
+      },
+      benefit);
 }

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 62649b83d887d1a..7df6defc0f202f1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2026,13 +2026,14 @@ func.func @extract_strided_slice_f32_1d_from_2d_scalable(%arg0: vector<4x[8]xf32
 // CHECK-LABEL:   func.func @extract_strided_slice_f32_1d_from_2d_scalable(
 //  CHECK-SAME:    %[[ARG:.*]]: vector<4x[8]xf32>)
 //       CHECK:    %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-//       CHECK:    %[[T0:.*]] = llvm.mlir.undef : !llvm.array<2 x vector<[8]xf32>>
-//       CHECK:    %[[T1:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
-//       CHECK:    %[[T2:.*]] = llvm.insertvalue %[[T1]], %[[T0]][0] : !llvm.array<2 x vector<[8]xf32>>
-//       CHECK:    %[[T3:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
-//       CHECK:    %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[T2]][1] : !llvm.array<2 x vector<[8]xf32>>
-//       CHECK:    %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
-//       CHECK:    return %[[T5]]
+//       CHECK:    %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
+//       CHECK:    %[[DST:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[E0:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[E1:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<[8]xf32>>
+//       CHECK:    %[[I0:.*]] = llvm.insertvalue %[[E0]], %[[DST]][0] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[I1:.*]] = llvm.insertvalue %[[E1]], %[[I0]][1] : !llvm.array<2 x vector<[8]xf32>>
+//       CHECK:    %[[RES:.*]] = builtin.unrealized_conversion_cast %[[I1]] : !llvm.array<2 x vector<[8]xf32>> to vector<2x[8]xf32>
+//       CHECK:    return %[[RES]]
 
 // -----
 


        


More information about the Mlir-commits mailing list