[Mlir-commits] [mlir] [mlir][Vector] Support efficient shape cast lowering for n-D vectors (PR #123497)

Diego Caballero llvmlistbot at llvm.org
Sat Jan 18 21:38:38 PST 2025


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/123497

This PR implements a generalization of the existing more efficient lowering of shape casts from 2-D to 1D and 1-D to 2-D vectors. This significantly reduces code size and generates more performant code for n-D shape casts that make their way to LLVM/SPIR-V.

>From 7eca1d855af22fd34fa04636421732bb8eef377c Mon Sep 17 00:00:00 2001
From: Diego Caballero <dieg0ca6aller0 at gmail.com>
Date: Sat, 18 Jan 2025 21:35:22 -0800
Subject: [PATCH] [mlir][Vector] Support efficient shape cast lowering for n-D
 vectors

This PR implements a generalization of the existing efficient lowering
of shape casts from 2-D to 1D and 1-D to 2-D vectors. This significantly
reduces code size and generates more performant code for n-D shape casts
that make their way to LLVM/SPIR-V.
---
 .../Transforms/LowerVectorShapeCast.cpp       | 154 ++++++++++--------
 ...vector-shape-cast-lowering-transforms.mlir |  45 ++---
 2 files changed, 102 insertions(+), 97 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..edbf798e1c673b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -11,40 +11,41 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
 
 #define DEBUG_TYPE "vector-shape-cast-lowering"
 
 using namespace mlir;
 using namespace mlir::vector;
 
+/// Increments n-D `indices` by `step` starting from the innermost dimension.
+static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+                   int step = 1) {
+  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
+    indices[dim] += step;
+    if (indices[dim] < vecType.getDimSize(dim))
+      break;
+
+    indices[dim] = 0;
+    step = 1;
+  }
+}
+
 namespace {
-/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
-/// vectors progressively on the way to target llvm.matrix intrinsics.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOp2DDownCastRewritePattern
+/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
+/// vectors progressively. This iterates over the n-1 major dimensions of the
+/// n-D vector and performs rewrites into:
+///   vector.extract from n-D + vector.insert_strided_slice offset into 1-D
+class ShapeCastOpNDDownCastRewritePattern
     : public OpRewritePattern<vector::ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -53,35 +54,52 @@ class ShapeCastOp2DDownCastRewritePattern
                                 PatternRewriter &rewriter) const override {
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
-
     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
       return failure();
 
-    if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
+    int64_t srcRank = sourceVectorType.getRank();
+    int64_t resRank = resultVectorType.getRank();
+    if (srcRank < 2 || resRank != 1)
       return failure();
 
+    // Compute the number of 1-D vector elements involved in the reshape.
+    int64_t numElts = 1;
+    for (int64_t dim = 0; dim < srcRank - 1; ++dim)
+      numElts *= sourceVectorType.getDimSize(dim);
+
     auto loc = op.getLoc();
-    Value desc = rewriter.create<arith::ConstantOp>(
+    SmallVector<int64_t> srcIdx(srcRank - 1);
+    SmallVector<int64_t> resIdx(resRank);
+    int64_t extractSize = sourceVectorType.getShape().back();
+    Value result = rewriter.create<arith::ConstantOp>(
         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-    unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
-    for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
-      Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
-      desc = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, vec, desc,
-          /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
+
+    // Compute the indices of each 1-D vector element of the source extraction
+    // and destination slice insertion and generate such instructions.
+    for (int64_t i = 0; i < numElts; ++i) {
+      if (i != 0) {
+        incIdx(srcIdx, sourceVectorType, /*step=*/1);
+        incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+      }
+
+      Value extract =
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, extract, result,
+          /*offsets=*/resIdx, /*strides=*/1);
     }
-    rewriter.replaceOp(op, desc);
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
 
-/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
-/// vectors progressively.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
+/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
+/// vectors progressively. This iterates over the n-1 major dimension of the n-D
+/// vector and performs rewrites into:
+///   vector.extract_strided_slice from 1-D + vector.insert into n-D
 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOp2DUpCastRewritePattern
+class ShapeCastOpNDUpCastRewritePattern
     : public OpRewritePattern<vector::ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -90,43 +108,43 @@ class ShapeCastOp2DUpCastRewritePattern
                                 PatternRewriter &rewriter) const override {
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
-
     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
       return failure();
 
-    if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
+    int64_t srcRank = sourceVectorType.getRank();
+    int64_t resRank = resultVectorType.getRank();
+    if (srcRank != 1 || resRank < 2)
       return failure();
 
+    // Compute the number of 1-D vector elements involved in the reshape.
+    int64_t numElts = 1;
+    for (int64_t dim = 0; dim < resRank - 1; ++dim)
+      numElts *= resultVectorType.getDimSize(dim);
+
+    // Compute the indices of each 1-D vector element of the source slice
+    // extraction and destination insertion and generate such instructions.
     auto loc = op.getLoc();
-    Value desc = rewriter.create<arith::ConstantOp>(
+    SmallVector<int64_t> srcIdx(srcRank);
+    SmallVector<int64_t> resIdx(resRank - 1);
+    int64_t extractSize = resultVectorType.getShape().back();
+    Value result = rewriter.create<arith::ConstantOp>(
         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-    unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
-    for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
-      Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
-          /*sizes=*/mostMinorVectorSize,
+    for (int64_t i = 0; i < numElts; ++i) {
+      if (i != 0) {
+        incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
+        incIdx(resIdx, resultVectorType, /*step=*/1);
+      }
+
+      Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
           /*strides=*/1);
-      desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
+      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
     }
-    rewriter.replaceOp(op, desc);
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
 
-static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
-                   int dimIdx, int initialStep = 1) {
-  int step = initialStep;
-  for (int d = dimIdx; d >= 0; d--) {
-    idx[d] += step;
-    if (idx[d] >= tp.getDimSize(d)) {
-      idx[d] = 0;
-      step = 1;
-    } else {
-      break;
-    }
-  }
-}
-
 // We typically should not lower general shape cast operations into data
 // movement instructions, since the assumption is that these casts are
 // optimized away during progressive lowering. For completeness, however,
@@ -145,18 +163,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
       return failure();
 
-    // Special case 2D / 1D lowerings with better implementations.
-    // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
+    // Special case for n-D / 1-D lowerings with better implementations.
     int64_t srcRank = sourceVectorType.getRank();
     int64_t resRank = resultVectorType.getRank();
-    if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
+    if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
       return failure();
 
     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
     // extract/insert chains.
-    // TODO: consider evolving the semantics to only allow 1D source or dest and
-    // drop this potentially very expensive lowering.
-    // Compute number of elements involved in the reshape.
     int64_t numElts = 1;
     for (int64_t r = 0; r < srcRank; r++)
       numElts *= sourceVectorType.getDimSize(r);
@@ -172,8 +186,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
     for (int64_t i = 0; i < numElts; i++) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, srcRank - 1);
-        incIdx(resIdx, resultVectorType, resRank - 1);
+        incIdx(srcIdx, sourceVectorType);
+        incIdx(resIdx, resultVectorType);
       }
 
       Value extract;
@@ -252,7 +266,7 @@ class ScalableShapeCastOpRewritePattern
     // have a single trailing scalable dimension. This is because there are no
     // legal representation of other scalable types in LLVM (and likely won't be
     // soon). There are also (currently) no operations that can index or extract
-    // from >= 2D scalable vectors or scalable vectors of fixed vectors.
+    // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
     if (!isTrailingDimScalable(sourceVectorType) ||
         !isTrailingDimScalable(resultVectorType)) {
       return failure();
@@ -334,8 +348,8 @@ class ScalableShapeCastOpRewritePattern
 
       // 4. Increment the insert/extract indices, stepping by minExtractionSize
       // for the trailing dimensions.
-      incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
-      incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
+      incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
+      incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
     }
 
     rewriter.replaceOp(op, result);
@@ -352,8 +366,8 @@ class ScalableShapeCastOpRewritePattern
 
 void mlir::vector::populateVectorShapeCastLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<ShapeCastOp2DDownCastRewritePattern,
-               ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
+  patterns.add<ShapeCastOpNDDownCastRewritePattern,
+               ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
                ScalableShapeCastOpRewritePattern>(patterns.getContext(),
                                                   benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index f2f1211fd70eed..b4c52d5533116c 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
 // CHECK-LABEL: func @nop_shape_cast
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>
@@ -82,19 +82,16 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
 // CHECK-LABEL: func @shape_cast_3d1d
 // CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
 // CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
-// CHECK: return %[[T11]] : vector<6xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]]
+// CHECK-SAME:           {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
+// CHECK-SAME:           {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
+// CHECK-SAME:           {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: return %[[T5]] : vector<6xf32>
 
 func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
   %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
@@ -104,19 +101,13 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
 // CHECK-LABEL: func @shape_cast_1d3d
 // CHECK-SAME: %[[A:.*]]: vector<6xf32>
 // CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<6xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : f32 from vector<6xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : f32 from vector<6xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : f32 from vector<6xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : f32 from vector<6xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : f32 from vector<6xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
-// CHECK: return %[[T11]] : vector<2x1x3xf32>
+// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME:           {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
+// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK:                {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32>
+// CHECK: return %[[T3]] : vector<2x1x3xf32>
 
 func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
   %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>



More information about the Mlir-commits mailing list