[Mlir-commits] [mlir] [mlir][ArmNeon] Implements unrolling patterns for LowerVectorToArmNeon LowerContractionToSMMLAPattern (PR #84848)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 12 08:40:23 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kojo Acquah (KoolJBlack)

<details>
<summary>Changes</summary>

This patch updates `LowerVectorToArmNeonPattern` to unroll larger vector contracts into multiple smmla instructions. 

Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2. `vector.extract_strided_slice`/`vector.insert_strided_slice` divides the contract into tiles to be processed in a row. 



---
Full diff: https://github.com/llvm/llvm-project/pull/84848.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+80-32) 
- (modified) mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir (+50) 


``````````diff
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 47c84708f3c38b..acb03927b5d23e 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -16,7 +16,9 @@
 #include "mlir/Dialect/ArmNeon/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -36,8 +38,10 @@ static Type matchContainerType(Type element, Type container) {
   return element;
 }
 
-/// Lowering from a single vector::contractOp directly to the arm neon smmla
-/// intrinsic. The shapes of the contract and intrinsic must match.
+/// Lowering from a vector::contractOp arm neon smmla intrinsic. This up to an
+/// 8x8x8 vector contract that is tiled (up to 16) smmla instructions with
+/// unrolling. If no unrolling is necessary, a single smmla instruction is
+/// emitted.
 class LowerContractionToSMMLAPattern
     : public OpRewritePattern<vector::ContractionOp> {
 public:
@@ -45,10 +49,6 @@ class LowerContractionToSMMLAPattern
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    Value lhs = op.getLhs();
-    Value rhs = op.getRhs();
-    Value res = op.getAcc();
-
     // Check index maps that represent M N K in contract.
     auto indexingMaps = op.getIndexingMapsArray();
     if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
@@ -57,7 +57,6 @@ class LowerContractionToSMMLAPattern
         })) {
       return failure();
     }
-
     // Check iterator types for contract.
     auto iteratorTypes = op.getIteratorTypesArray();
     if (iteratorTypes.size() != 3 ||
@@ -66,22 +65,24 @@ class LowerContractionToSMMLAPattern
         iteratorTypes[2] != vector::IteratorType::reduction) {
       return failure();
     }
-
-    // Check the tile size by mapping the dimensions of the contract.
+    // Infer tile sizes from operands; Note: RHS is not transposed.
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
     auto dimM = lhsType.getDimSize(0);
     auto dimN = rhsType.getDimSize(0);
     auto dimK = lhsType.getDimSize(1);
-    if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
+
+    // Unrolling patterns can handle [(2|4|8), (2|4|8), 8] shaped inputs for
+    // tiling.
+    if (dimM % 2 != 0 || dimM > 8 || dimN % 2 != 0 || dimN > 8 || dimK != 8) {
       return failure();
     }
 
     // Check two extsi inputs Rhs Lhs for contract.
     arith::ExtSIOp origLhsExtOp =
-        dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
+        dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
     arith::ExtSIOp origRhsExtOp =
-        dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp());
+        dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
     if (!origLhsExtOp || !origRhsExtOp) {
       return failure();
     }
@@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
       return failure();
     }
 
-    // Collapse to 1D vectors required by smmla intrinsic
-    auto collapsedInputType = VectorType::get(
-        {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
-    auto collapsedOutputType =
-        VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
-    auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
-        extsiLhs.getLoc(), collapsedInputType, extsiLhs);
-    auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
-        extsiRhs.getLoc(), collapsedInputType, extsiRhs);
-    auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
-        res.getLoc(), collapsedOutputType, res);
-
-    // Replace the contract with a neon op
-    auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
-        op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
-        collapsedRhs);
-
-    // Reshape output back to 2D
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
-                                                     smmlaOp);
+    // Initial accumulator for the final result. This is the un-tiled result if
+    // tiling is done.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+
+    SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
+    SmallVector<int64_t> smmlaShape{2, 2, 8};
+    SmallVector<int64_t> loopOrder{0, 1, 2};
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+
+      // Helper to compute the new shape of each operand and extract the slice.
+      auto extractOperand = [&](Value operand, AffineMap permutationMap,
+                                ArrayRef<int64_t> operandOffsets) {
+        SmallVector<int64_t> operandShape =
+            applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+        SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
+        return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand, operandOffsets, operandShape, operandStrides);
+      };
+
+      // Extract tiled lhs, rhs, and acc
+      AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
+      SmallVector<int64_t> lhsOffsets =
+          applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
+      AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
+      SmallVector<int64_t> rhsOffsets =
+          applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
+      AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
+      SmallVector<int64_t> accOffsets =
+          applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledAcc =
+          extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+
+      // Collapse tiled operands to 1D vectors required by smmla intrinsic
+      auto collapsedInputType = VectorType::get(
+          tiledLhs.getType().cast<ShapedType>().getNumElements(),
+          tiledLhs.getType().cast<ShapedType>().getElementType());
+      auto collapsedOutputType = VectorType::get(
+          {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
+      auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledLhs.getLoc(), collapsedInputType, tiledLhs);
+      auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+      auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+
+      // Insert contract op
+      auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
+          op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
+          collapsedRhs);
+
+      // Reshape output back to 2D
+      Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
+          smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+
+      // Insert the tiled result back into the non tiled result of the
+      // contract op.
+      SmallVector<int64_t> strides(
+          tiledRes.getType().cast<ShapedType>().getRank(), 1);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, tiledRes, result, accOffsets, strides);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
index cba7b00ba77a82..a4b873144b8b83 100644
--- a/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
+++ b/mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir
@@ -40,3 +40,53 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
   %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
   return %res : vector<2x2xi32>
 }
+
+// -----
+
+// CHECK-LABEL: test_lower_vector_arm_neon_unroll
+// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
+// CHECK:  %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
+// CHECK:  %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
+// CHECK:  %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
+// CHECK:  %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
+// CHECK:  %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
+// CHECK:  %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
+// CHECK:  %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
+// CHECK:  return %[[VAL_39]] : vector<4x4xi32>
+// CHECK:  }
+func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
+  %lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
+  %res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
+  return %res : vector<4x4xi32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/84848


More information about the Mlir-commits mailing list