[Mlir-commits] [mlir] [mlir][spirv] Fix function signature legalization for n-D vectors (PR #99872)

Angel Zhang llvmlistbot at llvm.org
Mon Jul 22 06:24:08 PDT 2024


https://github.com/angelz913 created https://github.com/llvm/llvm-project/pull/99872

None

>From b14e3fe50ef08979ad248382a34b10025649bffd Mon Sep 17 00:00:00 2001
From: Angel Zhang <angel.zhang at amd.com>
Date: Sun, 21 Jul 2024 15:26:32 +0000
Subject: [PATCH] [mlir][spirv] Fix function signature legalization for n-D
 vectors

---
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 10 ++++-
 .../func-signature-vector-unroll.mlir         | 44 +++++++++++++++++++
 2 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index bf5044437fd09..c146589612b5e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1098,13 +1098,19 @@ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
       // the original operand of illegal type.
       auto originalShape =
           llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
-      SmallVector<int64_t> strides(targetShape->size(), 1);
+      SmallVector<int64_t> strides(originalShape.size(), 1);
+      SmallVector<int64_t> extractShape(originalShape.size(), 1);
+      extractShape.back() = targetShape->back();
       SmallVector<Type> newTypes;
       Value returnValue = returnOp.getOperand(origResultNo);
       for (SmallVector<int64_t> offsets :
            StaticTileOffsetRange(originalShape, *targetShape)) {
         Value result = rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, returnValue, offsets, *targetShape, strides);
+            loc, returnValue, offsets, extractShape, strides);
+        SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
+        if (originalShape.size() > 1)
+          result =
+              rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
         newOperands.push_back(result);
         newTypes.push_back(unrolledType);
       }
diff --git a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
index 347d282f9ee0c..c018ccb924983 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/func-signature-vector-unroll.mlir
@@ -66,6 +66,28 @@ func.func @simple_vector_8(%arg0 : vector<8xi32>) -> vector<8xi32> {
 
 // -----
 
+// CHECK-LABEL: @simple_vector_2d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
+func.func @simple_vector_2d(%arg0 : vector<4x4xi32>) -> vector<4x4xi32> {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4x4xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [1, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [2, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [3, 0], strides = [1]} : vector<4xi32> into vector<4x4xi32>
+  // CHECK: %[[EXTRACT0:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT0_1:.*]] = vector.extract %[[EXTRACT0]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: %[[EXTRACT1:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT1_1:.*]] = vector.extract %[[EXTRACT1]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: %[[EXTRACT2:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [2, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT2_1:.*]] = vector.extract %[[EXTRACT2]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: %[[EXTRACT3:.*]] = vector.extract_strided_slice %[[INSERT3]] {offsets = [3, 0], sizes = [1, 4], strides = [1, 1]} : vector<4x4xi32> to vector<1x4xi32>
+  // CHECK: %[[EXTRACT3_1:.*]] = vector.extract %[[EXTRACT3]][0] : vector<4xi32> from vector<1x4xi32>
+  // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]] : vector<4xi32>, vector<4xi32>, vector<4xi32>, vector<4xi32>
+  return %arg0 : vector<4x4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @vector_6and8
 // CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>)
 func.func @vector_6and8(%arg0 : vector<6xi32>, %arg1 : vector<8xi32>) -> (vector<6xi32>, vector<8xi32>) {
@@ -113,6 +135,28 @@ func.func @scalar_vector(%arg0 : vector<8xi32>, %arg1 : vector<3xi32>, %arg2 : i
 
 // -----
 
+// CHECK-LABEL: @vector_2dand1d
+// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<3xi32>, %[[ARG2:.+]]: vector<3xi32>, %[[ARG3:.+]]: vector<3xi32>, %[[ARG4:.+]]: vector<4xi32>)
+func.func @vector_2dand1d(%arg0 : vector<2x6xi32>, %arg1 : vector<4xi32>) -> (vector<2x6xi32>, vector<4xi32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi32>
+  // CHECK: %[[INSERT0:.*]] = vector.insert_strided_slice %[[ARG0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[INSERT1:.*]] = vector.insert_strided_slice %[[ARG1]], %[[INSERT0]] {offsets = [0, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[INSERT2:.*]] = vector.insert_strided_slice %[[ARG2]], %[[INSERT1]] {offsets = [1, 0], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[INSERT3:.*]] = vector.insert_strided_slice %[[ARG3]], %[[INSERT2]] {offsets = [1, 3], strides = [1]} : vector<3xi32> into vector<2x6xi32>
+  // CHECK: %[[EXTRACT0:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT0_1:.*]]  = vector.extract %[[EXTRACT0]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: %[[EXTRACT1:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [0, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT1_1:.*]]  = vector.extract %[[EXTRACT1]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: %[[EXTRACT2:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT2_1:.*]]  = vector.extract %[[EXTRACT2]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: %[[EXTRACT3:.*]]  = vector.extract_strided_slice %[[INSERT3]] {offsets = [1, 3], sizes = [1, 3], strides = [1, 1]} : vector<2x6xi32> to vector<1x3xi32>
+  // CHECK: %[[EXTRACT3_1:.*]]  = vector.extract %[[EXTRACT3]][0] : vector<3xi32> from vector<1x3xi32>
+  // CHECK: return %[[EXTRACT0_1]], %[[EXTRACT1_1]], %[[EXTRACT2_1]], %[[EXTRACT3_1]], %[[ARG4]] : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<4xi32>
+  return %arg0, %arg1 : vector<2x6xi32>, vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @reduction
 // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>, %[[ARG2:.+]]: vector<4xi32>, %[[ARG3:.+]]: vector<4xi32>, %[[ARG4:.+]]: i32)
 func.func @reduction(%arg0 : vector<8xi32>, %arg1 : vector<8xi32>, %arg2 : i32) -> (i32) {



More information about the Mlir-commits mailing list