[Mlir-commits] [mlir] [mlir][ArmSME] Fold extracts from 3D create_masks of SME-like masks (PR #80148)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 31 07:07:07 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir
Author: Benjamin Maxwell (MacDue)
<details>
<summary>Changes</summary>
When unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks.
i.e.:
```mlir
%mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
: vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
```
ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below.
```mlir
%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/80148.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+80-3)
- (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+36)
``````````diff
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 85ec53c2618aa..14b9d8e34da65 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -7,13 +7,12 @@
//===----------------------------------------------------------------------===//
//
// This pass legalizes vector operations so they can be lowered to ArmSME.
-// Currently, this only implements the decomposition of vector operations that
-// use vector sizes larger than an SME tile, into multiple SME-sized operations.
//
// Note: In the context of this pass 'tile' always refers to an SME tile.
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
namespace {
+//===----------------------------------------------------------------------===//
+// Decomposition of vector operations larger than an SME tile
+//===----------------------------------------------------------------------===//
+
// Common match failure reasons.
static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
"op vector size is not multiple of SME tiles");
@@ -338,13 +341,86 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
+/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
+/// necessary for the mask to be lowered to ArmSME.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
+/// %subMask = vector.extract %mask[2]
+/// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
+/// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
+/// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
+/// ```
+struct FoldExtractFromVectorOfSMELikeCreateMasks
+ : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto loc = extractOp.getLoc();
+ auto createMaskOp =
+ extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ extractOp, "extract not from vector.create_mask op");
+
+ VectorType extractedMaskType =
+ llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+ if (!extractedMaskType)
+ return rewriter.notifyMatchFailure(extractOp,
+ "extracted type is not a vector type");
+
+ auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
+ if (numScalable != 2)
+ return rewriter.notifyMatchFailure(
+ extractOp, "expected extracted type to be an SME-like mask");
+
+ // TODO: Support multiple extraction indices.
+ if (extractOp.getStaticPosition().size() != 1)
+ return rewriter.notifyMatchFailure(
+ extractOp, "only a single extraction index is supported");
+
+ auto frontMaskDim = createMaskOp.getOperand(0);
+ if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
+ return rewriter.notifyMatchFailure(
+ extractOp,
+ "constant vector.create_masks dims should be folded elsewhere");
+
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto extractionIndex = getValueOrCreateConstantIndexOp(
+ rewriter, loc, extractOp.getMixedPosition()[0]);
+ auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
+ loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
+ frontMaskDim);
+ auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
+ loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
+
+ rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+ extractOp, extractedMaskType,
+ ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
OneToNTypeConverter converter;
RewritePatternSet patterns(context);
-
converter.addConversion([](Type type) { return type; });
converter.addConversion(
[](VectorType vectorType,
@@ -358,6 +434,7 @@ struct VectorLegalizationPass
return success();
});
+ patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks>(context);
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a20abeefedcfd..a2526db9b4831 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -266,3 +266,39 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @extract_from_vector_create_mask_non_constant_dim(
+// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index)
+func.func @extract_from_vector_create_mask_non_constant_dim(%dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi sgt, %[[DIM0]], %[[C2]] : index
+ // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
+ // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
+ // CHECK-NEXT: return %[[EXTRACT]]
+ %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
+ %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+ return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @non_constant_extract_from_vector_create_mask_non_constant(
+// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index)
+func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: index, %dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi slt, %[[INDEX]], %[[DIM0]] : index
+ // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
+ // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
+ // CHECK-NEXT: return %[[EXTRACT]]
+ %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
+ %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+ return %extract : vector<[4]x[4]xi1>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/80148
More information about the Mlir-commits
mailing list