[Mlir-commits] [mlir] [mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern (PR #84848)

Kojo Acquah llvmlistbot at llvm.org
Mon Mar 18 14:18:33 PDT 2024


================
@@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
       return failure();
     }
 
-    // Collapse to 1D vectors required by smmla intrinsic
-    auto collapsedInputType = VectorType::get(
-        {16}, extsiLhs.getType().cast<ShapedType>().getElementType());
-    auto collapsedOutputType =
-        VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
-    auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
-        extsiLhs.getLoc(), collapsedInputType, extsiLhs);
-    auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
-        extsiRhs.getLoc(), collapsedInputType, extsiRhs);
-    auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
-        res.getLoc(), collapsedOutputType, res);
-
-    // Replace the contract with a neon op
-    auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
-        op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
-        collapsedRhs);
-
-    // Reshape output back to 2D
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
-                                                     smmlaOp);
+    // Initial accumulator for the final result. This is the un-tiled result if
+    // tiling is done.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+
+    SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
+    SmallVector<int64_t> smmlaShape{2, 2, 8};
+    SmallVector<int64_t> loopOrder{0, 1, 2};
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
+
+      // Helper to compute the new shape of each operand and extract the slice.
+      auto extractOperand = [&](Value operand, AffineMap permutationMap,
+                                ArrayRef<int64_t> operandOffsets) {
+        SmallVector<int64_t> operandShape =
+            applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
+        SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
+        return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand, operandOffsets, operandShape, operandStrides);
+      };
+
+      // Extract tiled lhs, rhs, and acc
+      AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
+      SmallVector<int64_t> lhsOffsets =
+          applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+      auto tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
----------------
KoolJBlack wrote:

added a some, but I think the rest are pretty clear? 

https://github.com/llvm/llvm-project/pull/84848


More information about the Mlir-commits mailing list