[Mlir-commits] [mlir] 8dffb71 - [mlir][VectorOps] Add lowering for vector.shape_cast of scalable vectors

Benjamin Maxwell llvmlistbot at llvm.org
Thu Sep 7 09:00:15 PDT 2023


Author: Benjamin Maxwell
Date: 2023-09-07T15:58:44Z
New Revision: 8dffb71cbada73c5f9c1e9df7891566be0ff762d

URL: https://github.com/llvm/llvm-project/commit/8dffb71cbada73c5f9c1e9df7891566be0ff762d
DIFF: https://github.com/llvm/llvm-project/commit/8dffb71cbada73c5f9c1e9df7891566be0ff762d.diff

LOG: [mlir][VectorOps] Add lowering for vector.shape_cast of scalable vectors

This adds a lowering similar to the general shape_cast lowering, but
instead moves elements a (scalable) subvector at a time via
vector.scalable.extract/insert. It is restricted to the case where both
the source and result vector types have a single trailing scalable
dimension (due to limitations of the insert/extract ops).

The current lowerings are now disabled for scalable vectors, as they
produce incorrect results at runtime (due to assuming a fixed number
of elements).

Examples of casts that now work:

  // Flattening:
  %v = vector.shape_cast %arg0 : vector<4x[8]xi8> to vector<[32]xi8>

  // Un-flattening:
  %v = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>

Reviewed By: awarzynski, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D159217

Added: 
    mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir

Modified: 
    mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index f2b28cad76745fd..d6ea9931095e1c9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -54,6 +54,10 @@ class ShapeCastOp2DDownCastRewritePattern
                                 PatternRewriter &rewriter) const override {
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
+
+    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+      return failure();
+
     if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
       return failure();
 
@@ -87,6 +91,10 @@ class ShapeCastOp2DUpCastRewritePattern
                                 PatternRewriter &rewriter) const override {
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
+
+    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+      return failure();
+
     if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
       return failure();
 
@@ -106,6 +114,20 @@ class ShapeCastOp2DUpCastRewritePattern
   }
 };
 
+static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
+                   int dimIdx, int initialStep = 1) {
+  int step = initialStep;
+  for (int d = dimIdx; d >= 0; d--) {
+    idx[d] += step;
+    if (idx[d] >= tp.getDimSize(d)) {
+      idx[d] = 0;
+      step = 1;
+    } else {
+      break;
+    }
+  }
+}
+
 // We typically should not lower general shape cast operations into data
 // movement instructions, since the assumption is that these casts are
 // optimized away during progressive lowering. For completeness, however,
@@ -121,6 +143,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
 
+    if (sourceVectorType.isScalable() || resultVectorType.isScalable())
+      return failure();
+
     // Special case 2D / 1D lowerings with better implementations.
     // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
     int64_t srcRank = sourceVectorType.getRank();
@@ -175,21 +200,161 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
     rewriter.replaceOp(op, result);
     return success();
   }
+};
+
+/// A shape_cast lowering for scalable vectors with a single trailing scalable
+/// dimension. This is similar to the general shape_cast lowering but makes use
+/// of vector.scalable.insert and vector.scalable.extract to move elements a
+/// subvector at a time.
+///
+/// E.g.:
+/// ```
+/// // Flatten scalable vector
+/// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
+/// ```
+/// is rewritten to:
+/// ```
+/// // Flatten scalable vector
+/// %c = arith.constant dense<0> : vector<[8]xi32>
+/// %0 = vector.extract %arg0[0, 0] : vector<2x1x[4]xi32>
+/// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
+/// %2 = vector.extract %arg0[1, 0] : vector<2x1x[4]xi32>
+/// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
+/// ```
+/// or:
+/// ```
+/// // Un-flatten scalable vector
+/// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
+/// ```
+/// is rewritten to:
+/// ```
+/// // Un-flatten scalable vector
+/// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
+/// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
+/// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
+/// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
+/// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
+/// ```
+class ScalableShapeCastOpRewritePattern
+    : public OpRewritePattern<vector::ShapeCastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    auto sourceVectorType = op.getSourceVectorType();
+    auto resultVectorType = op.getResultVectorType();
+    auto srcRank = sourceVectorType.getRank();
+    auto resRank = resultVectorType.getRank();
+
+    // This can only lower shape_casts where both the source and result types
+    // have a single trailing scalable dimension. This is because there are no
+    // legal representation of other scalable types in LLVM (and likely won't be
+    // soon). There are also (currently) no operations that can index or extract
+    // from >= 2D scalable vectors or scalable vectors of fixed vectors.
+    if (!isTrailingDimScalable(sourceVectorType) ||
+        !isTrailingDimScalable(resultVectorType)) {
+      return failure();
+    }
+
+    // The sizes of the trailing dimension of the source and result vectors, the
+    // size of subvector to move, and the number of elements in the vectors.
+    // These are "min" sizes as they are the size when vscale == 1.
+    auto minSourceTrailingSize = sourceVectorType.getShape().back();
+    auto minResultTrailingSize = resultVectorType.getShape().back();
+    auto minExtractionSize =
+        std::min(minSourceTrailingSize, minResultTrailingSize);
+    int64_t minNumElts = 1;
+    for (auto size : sourceVectorType.getShape())
+      minNumElts *= size;
+
+    // The subvector type to move from the source to the result. Note that this
+    // is a scalable vector. This rewrite will generate code in terms of the
+    // "min" size (vscale == 1 case), that scales to any vscale.
+    auto extractionVectorType = VectorType::get(
+        {minExtractionSize}, sourceVectorType.getElementType(), {true});
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
+
+    SmallVector<int64_t> srcIdx(srcRank);
+    SmallVector<int64_t> resIdx(resRank);
+
+    // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
+    // once D150000 lands.
+    Value currentResultScalableVector;
+    Value currentSourceScalableVector;
+    for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
+      // 1. Extract a scalable subvector from the source vector.
+      if (!currentSourceScalableVector) {
+        if (srcRank != 1) {
+          currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
+              loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
+        } else {
+          currentSourceScalableVector = op.getSource();
+        }
+      }
+      Value sourceSubVector = currentSourceScalableVector;
+      if (minExtractionSize < minSourceTrailingSize) {
+        sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
+            loc, extractionVectorType, sourceSubVector, srcIdx.back());
+      }
 
-private:
-  static void incIdx(SmallVector<int64_t> &idx, VectorType tp, int64_t r) {
-    assert(0 <= r && r < tp.getRank());
-    if (++idx[r] == tp.getDimSize(r)) {
-      idx[r] = 0;
-      incIdx(idx, tp, r - 1);
+      // 2. Insert the scalable subvector into the result vector.
+      if (!currentResultScalableVector) {
+        if (minExtractionSize == minResultTrailingSize) {
+          currentResultScalableVector = sourceSubVector;
+        } else if (resRank != 1) {
+          currentResultScalableVector = rewriter.create<vector::ExtractOp>(
+              loc, result, llvm::ArrayRef(resIdx).drop_back());
+        } else {
+          currentResultScalableVector = result;
+        }
+      }
+      if (minExtractionSize < minResultTrailingSize) {
+        currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
+            loc, sourceSubVector, currentResultScalableVector, resIdx.back());
+      }
+
+      // 3. Update the source and result scalable vectors if needed.
+      if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
+          currentResultScalableVector != result) {
+        // Finished row of result. Insert complete scalable vector into result
+        // (n-D) vector.
+        result = rewriter.create<vector::InsertOp>(
+            loc, currentResultScalableVector, result,
+            llvm::ArrayRef(resIdx).drop_back());
+        currentResultScalableVector = {};
+      }
+      if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
+        // Finished row of source.
+        currentSourceScalableVector = {};
+      }
+
+      // 4. Increment the insert/extract indices, stepping by minExtractionSize
+      // for the trailing dimensions.
+      incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
+      incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
     }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+  static bool isTrailingDimScalable(VectorType type) {
+    return type.getRank() >= 1 && type.getScalableDims().back() &&
+           !llvm::is_contained(type.getScalableDims().drop_back(), true);
   }
 };
+
 } // namespace
 
 void mlir::vector::populateVectorShapeCastLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<ShapeCastOp2DDownCastRewritePattern,
-               ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
-      patterns.getContext(), benefit);
+               ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
+               ScalableShapeCastOpRewritePattern>(patterns.getContext(),
+                                                  benefit);
 }

diff  --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
new file mode 100644
index 000000000000000..cdae5f963b28343
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir
@@ -0,0 +1,214 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s
+
+/// This tests that shape casts of scalable vectors (with one trailing scalable dim)
+/// can be correctly lowered to vector.scalable.insert/extract.
+
+// CHECK-LABEL: i32_3d_to_1d_last_dim_scalable
+// CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32>
+func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32>
+{
+  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32>
+  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<2x1x[4]xi32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[4]xi32> into vector<[8]xi32>
+  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1, 0] : vector<2x1x[4]xi32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][4] : vector<[4]xi32> into vector<[8]xi32>
+  %flat = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
+  // CHECK-NEXT: return %[[res1]] : vector<[8]xi32>
+  return %flat : vector<[8]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: i32_1d_to_3d_last_dim_scalable
+// CHECK-SAME: %[[arg0:.*]]: vector<[8]xi32>
+func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x[4]xi32> {
+  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<2x1x[4]xi32>
+  // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[4]xi32> from vector<[8]xi32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
+  // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][4] : vector<[4]xi32> from vector<[8]xi32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
+  %unflat = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
+  // CHECK-NEXT: return %[[res1]] : vector<2x1x[4]xi32>
+  return %unflat : vector<2x1x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: i8_2d_to_1d_last_dim_scalable
+// CHECK-SAME: %[[arg0:.*]]: vector<4x[8]xi8>
+func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]xi8> {
+  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[32]xi8>
+  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<4x[8]xi8>
+  // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[8]xi8> into vector<[32]xi8>
+  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<4x[8]xi8>
+  // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][8] : vector<[8]xi8> into vector<[32]xi8>
+  // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][2] : vector<4x[8]xi8>
+  // CHECK-NEXT: %[[res2:.*]] = vector.scalable.insert %[[subvec2]], %[[res1]][16] : vector<[8]xi8> into vector<[32]xi8>
+  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][3] : vector<4x[8]xi8>
+  // CHECK-NEXT: %[[res3:.*]] = vector.scalable.insert %[[subvec3]], %[[res2]][24] : vector<[8]xi8> into vector<[32]xi8>
+  %flat = vector.shape_cast %arg0 : vector<4x[8]xi8> to vector<[32]xi8>
+  // CHECK-NEXT: return %[[res3]] : vector<[32]xi8>
+  return %flat : vector<[32]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: i8_1d_to_2d_last_dim_scalable
+// CHECK-SAME: %[[arg0:.*]]: vector<[32]xi8>
+func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]xi8> {
+  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<4x[8]xi8>
+  // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[8]xi8> from vector<[32]xi8>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[8]xi8> into vector<4x[8]xi8>
+  // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][8] : vector<[8]xi8> from vector<[32]xi8>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[8]xi8> into vector<4x[8]xi8>
+  // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[arg0]][16] : vector<[8]xi8> from vector<[32]xi8>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[8]xi8> into vector<4x[8]xi8>
+  // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[arg0]][24] : vector<[8]xi8> from vector<[32]xi8>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[8]xi8> into vector<4x[8]xi8>
+  %unflat = vector.shape_cast %arg0 : vector<[32]xi8> to vector<4x[8]xi8>
+  // CHECK-NEXT: return %[[res3]] : vector<4x[8]xi8>
+  return %unflat : vector<4x[8]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: f32_permute_leading_non_scalable_dims
+// CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
+func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
+  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<2x3x[4]xf32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<2x3x[4]xf32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<2x3x[4]xf32>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<2x3x[4]xf32>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<2x3x[4]xf32>
+  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  // CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<2x3x[4]xf32>
+  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
+  %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32>
+  // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32>
+  return %res : vector<3x2x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: f64_flatten_leading_non_scalable_dims
+// CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf64>
+func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) -> vector<4x[2]xf64>
+{
+  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4x[2]xf64>
+  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<2x2x[2]xf64>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf64> into vector<4x[2]xf64>
+  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<2x2x[2]xf64>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf64> into vector<4x[2]xf64>
+  // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][1, 0] : vector<2x2x[2]xf64>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf64> into vector<4x[2]xf64>
+  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 1] : vector<2x2x[2]xf64>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf64> into vector<4x[2]xf64>
+  %res = vector.shape_cast %arg0: vector<2x2x[2]xf64> to vector<4x[2]xf64>
+  // CHECK-NEXT: return %7 : vector<4x[2]xf64>
+  return %res : vector<4x[2]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: f32_reduce_trailing_scalable_dim
+// CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
+func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
+{
+  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<6x[2]xf32>
+  // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<3x[4]xf32>
+  // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<3x[4]xf32>
+  // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32>
+  // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32>
+  // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<3x[4]xf32>
+  // CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32>
+  // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32>
+  // CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32>
+  // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32>
+  %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32>
+  // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32>
+  return %res: vector<6x[2]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: f32_increase_trailing_scalable_dim
+// CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32>
+func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32>
+{
+  // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<2x[4]xf32>
+  // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<4x[2]xf32>
+  // CHECK-NEXT: %[[resvec0:.*]] = vector.extract %[[cst]][0] : vector<2x[4]xf32>
+  // CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[resvec0]][0] : vector<[2]xf32> into vector<[4]xf32>
+  // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<4x[2]xf32>
+  // CHECK-NEXT: %[[resvec2:.*]] = vector.scalable.insert %[[subvec1]], %[[resvec1]][2] : vector<[2]xf32> into vector<[4]xf32>
+  // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[resvec2]], %[[cst]] [0] : vector<[4]xf32> into vector<2x[4]xf32>
+  // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][2] : vector<4x[2]xf32>
+  // CHECK-NEXT: %[[resvec3:.*]] = vector.extract %[[cst]][1] : vector<2x[4]xf32>
+  // CHECK-NEXT: %[[resvec4:.*]] = vector.scalable.insert %[[subvec3]], %[[resvec3]][0] : vector<[2]xf32> into vector<[4]xf32>
+  // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][3] : vector<4x[2]xf32>
+  // CHECK-NEXT: %[[resvec5:.*]] = vector.scalable.insert %[[subvec4]], %[[resvec4]][2] : vector<[2]xf32> into vector<[4]xf32>
+  // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[resvec5]], %[[res0]] [1] : vector<[4]xf32> into vector<2x[4]xf32>
+  %res = vector.shape_cast %arg0: vector<4x[2]xf32> to vector<2x[4]xf32>
+  // CHECK-NEXT: return %[[res1]] : vector<2x[4]xf32>
+  return %res: vector<2x[4]xf32>
+}
+
+// -----
+
+/// The following shape_casts are not supported as the types cannot be
+/// represented in LLVM (and likely won't be supported soon), and currently
+/// there's no ops that could do the extracts/inserts required.
+
+// -----
+
+// CHECK-LABEL: cannot_cast_to_non_trailing_scalable_dim
+// CHECK-SAME: %[[arg0:.*]]: vector<[4]xf32>
+func.func @cannot_cast_to_non_trailing_scalable_dim(%arg0: vector<[4]xf32>) -> vector<[2]x2xf32> {
+  // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[4]xf32> to vector<[2]x2xf32>
+  %res = vector.shape_cast %arg0 : vector<[4]xf32> to vector<[2]x2xf32>
+  // CHECK-NEXT: return %[[res]] : vector<[2]x2xf32>
+  return %res: vector<[2]x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: cannot_shape_cast_from_non_trailing_scalable_dim
+// CHECK-SAME: %[[arg0:.*]]: vector<[2]x2xf32>
+func.func @cannot_shape_cast_from_non_trailing_scalable_dim(%arg0: vector<[2]x2xf32>) -> vector<[4]xf32> {
+  // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[2]x2xf32> to vector<[4]xf32>
+  %res = vector.shape_cast %arg0 : vector<[2]x2xf32> to vector<[4]xf32>
+  // CHECK-NEXT: return %[[res]] : vector<[4]xf32>
+  return %res: vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: cannot_shape_cast_more_than_one_scalable_dim
+// CHECK-SAME: %[[arg0:.*]]: vector<[4]x[4]xf32>
+func.func @cannot_shape_cast_more_than_one_scalable_dim(%arg0: vector<[4]x[4]xf32>) -> vector<2x[2]x[4]xf32>  {
+  // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[4]x[4]xf32> to vector<2x[2]x[4]xf32>
+  %res = vector.shape_cast %arg0 : vector<[4]x[4]xf32> to vector<2x[2]x[4]xf32>
+  // CHECK-NEXT: return %[[res]] : vector<2x[2]x[4]xf32>
+  return %res: vector<2x[2]x[4]xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%module_op: !transform.any_op):
+  %f = transform.structured.match ops{["func.func"]} in %module_op
+    : (!transform.any_op) -> !transform.any_op
+
+  transform.apply_patterns to %f {
+    transform.apply_patterns.vector.lower_shape_cast
+  } : !transform.any_op
+}


        


More information about the Mlir-commits mailing list