[Mlir-commits] [mlir] [mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern (PR #84848)

Kojo Acquah llvmlistbot at llvm.org
Tue Mar 19 10:00:07 PDT 2024


https://github.com/KoolJBlack updated https://github.com/llvm/llvm-project/pull/84848

>From ae74220769fc9be5ea8229a3093956d67f02a8f2 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Mon, 11 Mar 2024 19:26:39 +0000
Subject: [PATCH 1/2] implement vector unroll for i8mm

---
 .../LowerContractionToSMMLAPattern.cpp        | 112 +++++++++++++-----
 .../Dialect/ArmNeon/lower-to-arm-neon.mlir    |  50 ++++++++
 2 files changed, 130 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 47c84708f3c38b..acb03927b5d23e 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -16,7 +16,9 @@
 #include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -36,8 +38,10 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
-/// Lowering from a single vector::contractOp directly to the arm neon smmla
-/// intrinsic. The shapes of the contract and intrinsic must match.
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This up to an
+/// 8x8x8 vector contract that is tiled (up to 16) smmla instructions with
+/// unrolling. If no unrolling is necessary, a single smmla instruction is
+/// emitted.
 class LowerContractionToSMMLAPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
@@ -45,10 +49,6 @@ class LowerContractionToSMMLAPattern
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    Value lhs = op.getLhs();
-    Value rhs = op.getRhs();
-    Value res = op.getAcc();
-
     // Check index maps that represent M N K in contract.
     auto indexingMaps = op.getIndexingMapsArray();
     if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
@@ -57,7 +57,6 @@ class LowerContractionToSMMLAPattern
         })) {
       return failure();
     }
-
     // Check iterator types for contract.
     auto iteratorTypes = op.getIteratorTypesArray();
     if (iteratorTypes.size() != 3 ||
@@ -66,22 +65,24 @@ class LowerContractionToSMMLAPattern
         iteratorTypes[2] != vector::IteratorType::reduction) {
       return failure();
     }
-
-    // Check the tile size by mapping the dimensions of the contract.
+    // Infer tile sizes from operands; Note: RHS is not transposed.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
     auto dimM = lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = lhsType.getDimSize(1);
-    if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+
+    // Unrolling patterns can handle [(2|4|8), (2|4|8), 8] shaped inputs for
+    // tiling.
+    if (dimM % 2 != 0 || dimM > 8 || dimN % 2 != 0 || dimN > 8 || dimK != 8) {
       return failure();
     }
 
     // Check two extsi inputs Rhs Lhs for contract.
     arith::ExtSIOp origLhsExtOp =
-        dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
+        dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
     arith::ExtSIOp origRhsExtOp =
-        dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp());
+        dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
     if (!origLhsExtOp || !origRhsExtOp) {
       return failure();
     }
@@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
       return failure();
     }
 
-    // Collapse to 1D vectors required by smmla intrinsic
-    auto collapsedInputType = VectorType::get(
-        {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
-    auto collapsedOutputType =
-        VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
-    auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
-        extsiLhs.getLoc(), collapsedInputType, extsiLhs);
-    auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
-        extsiRhs.getLoc(), collapsedInputType, extsiRhs);
-    auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
-        res.getLoc(), collapsedOutputType, res);
-
-    // Replace the contract with a neon op
-    auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
-        op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
-        collapsedRhs);
-
-    // Reshape output back to 2D
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
-                                                     smmlaOp);
+    // Initial accumulator for the final result. This is the un-tiled result if
+    // tiling is done.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+
+    SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
+    SmallVector<int64_t> smmlaShape{2, 2, 8};
+    SmallVector<int64_t> loopOrder{0, 1, 2};
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+
+      // Helper to compute the new shape of each operand and extract the slice.
+      auto extractOperand = [&](Value operand, AffineMap permutationMap,
+                                ArrayRef<int64_t> operandOffsets) {
+        SmallVector<int64_t> operandShape =
+            applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+        SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
+        return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand, operandOffsets, operandShape, operandStrides);
+      };
+
+      // Extract tiled lhs, rhs, and acc
+      AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
+      SmallVector<int64_t> lhsOffsets =
+          applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
+      AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+      SmallVector<int64_t> rhsOffsets =
+          applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
+      AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+      SmallVector<int64_t> accOffsets =
+          applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledAcc =
+          extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+
+      // Collapse tiled operands to 1D vectors required by smmla intrinsic
+      auto collapsedInputType = VectorType::get(
+          tiledLhs.getType().cast<ShapedType>().getNumElements(),
+          tiledLhs.getType().cast<ShapedType>().getElementType());
+      auto collapsedOutputType = VectorType::get(
+          {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
+      auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledLhs.getLoc(), collapsedInputType, tiledLhs);
+      auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+      auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+
+      // Insert contract op
+      auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
+          op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+          collapsedRhs);
+
+      // Reshape output back to 2D
+      Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
+          smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+
+      // Insert the tiled result back into the non tiled result of the
+      // contract op.
+      SmallVector<int64_t> strides(
+          tiledRes.getType().cast<ShapedType>().getRank(), 1);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, tiledRes, result, accOffsets, strides);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index cba7b00ba77a82..a4b873144b8b83 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -40,3 +40,53 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
   return %res : vector<2x2xi32>
 }
+
+// -----
+
+// CHECK-LABEL: test_lower_vector_arm_neon_unroll
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
+// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  return %[[VAL_39]] : vector<4x4xi32>
+// CHECK:  }
+func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}

>From f6cb2ee58e59460c015b4de5379f66f790e30003 Mon Sep 17 00:00:00 2001
From: Kojo Acquah <kooljblack at google.com>
Date: Mon, 18 Mar 2024 21:17:32 +0000
Subject: [PATCH 2/2] diego comments

---
 .../LowerContractionToSMMLAPattern.cpp        |  18 +--
 .../Dialect/ArmNeon/lower-to-arm-neon.mlir    | 122 ++++++++++++------
 2 files changed, 92 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index acb03927b5d23e..1f48d27aa27b17 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -38,10 +38,10 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
-/// Lowering from a vector::contractOp arm neon smmla intrinsic. This up to an
-/// 8x8x8 vector contract that is tiled (up to 16) smmla instructions with
-/// unrolling. If no unrolling is necessary, a single smmla instruction is
-/// emitted.
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
+/// any vector.contract into multiple smmla instructions with unrolling so long
+/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
+/// smmla instruction is emitted.
 class LowerContractionToSMMLAPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
@@ -72,9 +72,9 @@ class LowerContractionToSMMLAPattern
     auto dimN = rhsType.getDimSize(0);
     auto dimK = lhsType.getDimSize(1);
 
-    // Unrolling patterns can handle [(2|4|8), (2|4|8), 8] shaped inputs for
+    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
     // tiling.
-    if (dimM % 2 != 0 || dimM > 8 || dimN % 2 != 0 || dimN > 8 || dimK != 8) {
+    if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
       return failure();
     }
 
@@ -139,15 +139,15 @@ class LowerContractionToSMMLAPattern
       AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
       SmallVector<int64_t> lhsOffsets =
           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
-      auto tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
+      Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
       AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
       SmallVector<int64_t> rhsOffsets =
           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
-      auto tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
+      Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
       AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
       SmallVector<int64_t> accOffsets =
           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
-      auto tiledAcc =
+      Value tiledAcc =
           extractOperand(op.getAcc(), accPermutationMap, accOffsets);
 
       // Collapse tiled operands to 1D vectors required by smmla intrinsic
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index a4b873144b8b83..e2be87453bf6f2 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -45,48 +45,92 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
 
 // CHECK-LABEL: test_lower_vector_arm_neon_unroll
 // CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
-// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
-// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
-// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
-// CHECK:  %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
-// CHECK:  %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
-// CHECK:  %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
-// CHECK:  %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
-// CHECK:  %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
-// CHECK:  %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
-// CHECK:  %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
-// CHECK:  %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
-// CHECK:  %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:  %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:  %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
-// CHECK:  %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
-// CHECK:  %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
-// CHECK:  %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
-// CHECK:  %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
-// CHECK:  %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
-// CHECK:  %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
-// CHECK:  %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:  %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:  %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
-// CHECK:  %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
-// CHECK:  %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
-// CHECK:  %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
-// CHECK:  %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
-// CHECK:  %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
-// CHECK:  %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
-// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:  %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
-// CHECK:  %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
-// CHECK:  %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
-// CHECK:  %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
-// CHECK:  %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
-// CHECK:  return %[[VAL_39]] : vector<4x4xi32>
-// CHECK:  }
+// CHECK-DAG:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
+// CHECK-DAG:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK-DAG:  %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK-DAG:  %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK-DAG:  %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK-DAG:  return %[[VAL_39]] : vector<4x4xi32>
+// CHECK-DAG:  }
 func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
   %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
   %rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
   return %res : vector<4x4xi32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_mixed_unroll(
+// CHECK-SAME:                                                       %[[VAL_0:.*]]: vector<4x8xi8>,
+// CHECK-SAME:                                                       %[[VAL_1:.*]]: vector<2x8xi4>,
+// CHECK-SAME:                                                       %[[VAL_2:.*]]: vector<4x2xi32>) -> vector<4x2xi32> {
+// CHECK-DAG:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x2xi32>
+// CHECK-DAG:  %[[VAL_4:.*]] = arith.extsi %[[VAL_1]] : vector<2x8xi4> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
+// CHECK-DAG:  %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK-DAG:  %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_15:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_16:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK-DAG:  %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x2xi32> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_18:.*]] = arm_neon.intr.smmla %[[VAL_17]], %[[VAL_15]], %[[VAL_16]] : vector<16xi8> to vector<4xi32>
+// CHECK-DAG:  %[[VAL_19:.*]] = vector.shape_cast %[[VAL_18]] : vector<4xi32> to vector<2x2xi32>
+// CHECK-DAG:  %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_19]], %[[VAL_12]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
+// CHECK-DAG:  return %[[VAL_20]] : vector<4x2xi32>
+// CHECK-DAG:  }
+func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<2x8xi32> into vector<4x2xi32>
+  return %res : vector<4x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(
+// CHECK-DAG:  %[[result:.*]] = vector.contract
+func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<4x12xi8> to vector<4x12xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<4x12xi8> to vector<4x12xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}



More information about the Mlir-commits mailing list