[Mlir-commits] [mlir] [mlir][ArmSME] Fold extracts from 3D create_masks of SME-like masks (PR #80148)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Jan 31 07:06:37 PST 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/80148

When unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks.

i.e.:

```mlir
%mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
        : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
```

ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below.

```mlir
%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
```

>From 00948ff5b06c2544d6d23cf4b4516da02a198e7c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 31 Jan 2024 13:48:06 +0000
Subject: [PATCH] [mlir][ArmSME] Fold extracts from 3D create_masks of SME-like
 masks

When unrolling the reduction dimension of something like a matmul for
SME, it is possible to get 3D masks, which are vectors of SME-like
masks. The 2D masks for individual operations are then extracted from
the 3D masks.

i.e.:

```mlir
%mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
        : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
```

ArmSME only supports lowering 2D create_masks, so we must fold the
extract into the create_mask. This can be done by checking if the
extraction index is within the true region, then using that select the
first dimension of the 2D mask. This is shown below.

```mlir
%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
```
---
 .../ArmSME/Transforms/VectorLegalization.cpp  | 83 ++++++++++++++++++-
 .../Dialect/ArmSME/vector-legalization.mlir   | 36 ++++++++
 2 files changed, 116 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 85ec53c2618aa..14b9d8e34da65 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -7,13 +7,12 @@
 //===----------------------------------------------------------------------===//
 //
 // This pass legalizes vector operations so they can be lowered to ArmSME.
-// Currently, this only implements the decomposition of vector operations that
-// use vector sizes larger than an SME tile, into multiple SME-sized operations.
 //
 // Note: In the context of this pass 'tile' always refers to an SME tile.
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
 
 namespace {
 
+//===----------------------------------------------------------------------===//
+// Decomposition of vector operations larger than an SME tile
+//===----------------------------------------------------------------------===//
+
 // Common match failure reasons.
 static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
     "op vector size is not multiple of SME tiles");
@@ -338,13 +341,86 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
+/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
+/// necessary for the mask to be lowered to ArmSME.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
+///  %subMask = vector.extract %mask[2]
+///          : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
+///  %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
+///  %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
+///  ```
+struct FoldExtractFromVectorOfSMELikeCreateMasks
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = extractOp.getLoc();
+    auto createMaskOp =
+        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          extractOp, "extract not from vector.create_mask op");
+
+    VectorType extractedMaskType =
+        llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+    if (!extractedMaskType)
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "extracted type is not a vector type");
+
+    auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
+    if (numScalable != 2)
+      return rewriter.notifyMatchFailure(
+          extractOp, "expected extracted type to be an SME-like mask");
+
+    // TODO: Support multiple extraction indices.
+    if (extractOp.getStaticPosition().size() != 1)
+      return rewriter.notifyMatchFailure(
+          extractOp, "only a single extraction index is supported");
+
+    auto frontMaskDim = createMaskOp.getOperand(0);
+    if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
+      return rewriter.notifyMatchFailure(
+          extractOp,
+          "constant vector.create_masks dims should be folded elsewhere");
+
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto extractionIndex = getValueOrCreateConstantIndexOp(
+        rewriter, loc, extractOp.getMixedPosition()[0]);
+    auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
+        loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
+        frontMaskDim);
+    auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
+        loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
+
+    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+        extractOp, extractedMaskType,
+        ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
     OneToNTypeConverter converter;
     RewritePatternSet patterns(context);
-
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
         [](VectorType vectorType,
@@ -358,6 +434,7 @@ struct VectorLegalizationPass
           return success();
         });
 
+    patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a20abeefedcfd..a2526db9b4831 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -266,3 +266,39 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
   vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @extract_from_vector_create_mask_non_constant_dim(
+// CHECK-SAME:                                                    %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                    %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                    %[[DIM2:[a-z0-9]+]]: index)
+func.func @extract_from_vector_create_mask_non_constant_dim(%dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi sgt, %[[DIM0]], %[[C2]] : index
+  // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
+  // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
+  // CHECK-NEXT: return %[[EXTRACT]]
+  %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
+  %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+  return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @non_constant_extract_from_vector_create_mask_non_constant(
+// CHECK-SAME:                                                             %[[INDEX:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                             %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                             %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                             %[[DIM2:[a-z0-9]+]]: index)
+func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: index, %dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi slt, %[[INDEX]], %[[DIM0]] : index
+  // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
+  // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
+  // CHECK-NEXT: return %[[EXTRACT]]
+  %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
+  %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+  return %extract : vector<[4]x[4]xi1>
+}



More information about the Mlir-commits mailing list