[Mlir-commits] [mlir] [mlir][vector] Canonicalize gathers/scatters with trivial offsets (PR #117939)
Ivan Butygin
llvmlistbot at llvm.org
Fri Dec 27 13:01:25 PST 2024
https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/117939
>From 5877c6e1797feb0fee310387927b5e3d942aef60 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 28 Nov 2024 00:27:25 +0100
Subject: [PATCH 1/4] [mlir][vector] Canonicalize gathers/scatters with trivial
offsets
Cononicalize gathers/scatters with contiguous (i.e. [0, 1, 2, ...]) offsets into vector masked load/store ops.
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 46 +++++++++++++++++++++-
mlir/test/Dialect/Vector/canonicalize.mlir | 31 +++++++++++++++
2 files changed, 75 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 491b5f44b722b1..b2a15a5a7a0bf7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5176,6 +5176,19 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
+static LogicalResult isContiguousIndices(Value val) {
+ auto vecType = dyn_cast<VectorType>(val.getType());
+ if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
+ return failure();
+
+ DenseIntElementsAttr elements;
+ if (!matchPattern(val, m_Constant(&elements)))
+ return failure();
+
+ return success(
+ llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
+}
+
namespace {
class GatherFolder final : public OpRewritePattern<GatherOp> {
public:
@@ -5194,11 +5207,26 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
}
};
+
+class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(GatherOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(isContiguousIndices(op.getIndexVec())))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
+ op.getIndices(), op.getMask(),
+ op.getPassThru());
+ return success();
+ }
+};
} // namespace
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<GatherFolder>(context);
+ results.add<GatherFolder, GatherTrivialIndices>(context);
}
//===----------------------------------------------------------------------===//
@@ -5240,11 +5268,25 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
}
};
+
+class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(ScatterOp op,
+ PatternRewriter &rewriter) const override {
+ if (failed(isContiguousIndices(op.getIndexVec())))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<MaskedStoreOp>(
+ op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+ return success();
+ }
+};
} // namespace
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ScatterFolder>(context);
+ results.add<ScatterFolder, ScatterTrivialIndices>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 89af0f7332f5c4..fb567761476a02 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2838,3 +2838,34 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
return %1 : vector<1x1x2x1x1x1xi32>
}
+
+// -----
+
+// CHECK-LABEL: @contiguous_gather
+// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK: return %[[R]]
+func.func @contiguous_gather(%base: memref<?xf32>,
+ %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+ %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
+ memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %1 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_scatter
+// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+func.func @contiguous_scatter(%base: memref<?xf32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>){
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+ vector.scatter %base[%c0][%indices], %mask, %value :
+ memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+ return
+}
>From 021eda4e2d8b6b9dcb4e79e3dceaae0c9788373a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 29 Nov 2024 02:22:36 +0100
Subject: [PATCH 2/4] add const mask tests
---
mlir/test/Dialect/Vector/canonicalize.mlir | 35 +++++++++++++++++++++-
1 file changed, 34 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index fb567761476a02..4c9b8dd83dc28d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2857,14 +2857,47 @@ func.func @contiguous_gather(%base: memref<?xf32>,
// -----
+// CHECK-LABEL: @contiguous_gather_const_mask
+// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[PASSTHRU:.*]]: vector<16xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[R:.*]] = vector.load %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
+// CHECK: return %[[R]]
+func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
+ %passthru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+ %mask = arith.constant dense<true> : vector<16xi1>
+ %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
+ memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %1 : vector<16xf32>
+}
+
+// -----
+
// CHECK-LABEL: @contiguous_scatter
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
func.func @contiguous_scatter(%base: memref<?xf32>,
- %mask: vector<16xi1>, %value: vector<16xf32>){
+ %mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+ vector.scatter %base[%c0][%indices], %mask, %value :
+ memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_scatter_const_mask
+// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[VALUE:.*]]: vector<16xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: vector.store %[[VALUE]], %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
+func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
+ %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+ %mask = vector.constant_mask [16] : vector<16xi1>
vector.scatter %base[%c0][%indices], %mask, %value :
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
>From facf4733b063fa44af38a1ec8f004a2c26534425 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 29 Nov 2024 02:30:21 +0100
Subject: [PATCH 3/4] nits
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b2a15a5a7a0bf7..4a32b0817c5223 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5176,13 +5176,14 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
return llvm::to_vector<4>(getVectorType().getShape());
}
-static LogicalResult isContiguousIndices(Value val) {
- auto vecType = dyn_cast<VectorType>(val.getType());
+/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
+static LogicalResult isContiguousIndices(Value indexVec) {
+ auto vecType = dyn_cast<VectorType>(indexVec.getType());
if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
return failure();
DenseIntElementsAttr elements;
- if (!matchPattern(val, m_Constant(&elements)))
+ if (!matchPattern(indexVec, m_Constant(&elements)))
return failure();
return success(
@@ -5208,6 +5209,8 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
}
};
+/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
+/// maskedload. Only 1D non-scalable vectors are supported for now.
class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -5269,6 +5272,8 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
}
};
+/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
+/// maskedstore. Only 1D non-scalable vectors are supported for now.
class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
public:
using OpRewritePattern::OpRewritePattern;
>From 8b1f69cf99201e0dcd1f1ea91f0f9aeadf9190c3 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 29 Nov 2024 03:39:35 +0100
Subject: [PATCH 4/4] vector.step support
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 3 +++
mlir/test/Dialect/Vector/canonicalize.mlir | 31 ++++++++++++++++++++++
2 files changed, 34 insertions(+)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4a32b0817c5223..980c7f1ac1ed64 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5182,6 +5182,9 @@ static LogicalResult isContiguousIndices(Value indexVec) {
if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
return failure();
+ if (indexVec.getDefiningOp<StepOp>())
+ return success();
+
DenseIntElementsAttr elements;
if (!matchPattern(indexVec, m_Constant(&elements)))
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 4c9b8dd83dc28d..c9e0d32d8db0d0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2874,6 +2874,22 @@ func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
// -----
+// CHECK-LABEL: @contiguous_gather_step
+// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+// CHECK: return %[[R]]
+func.func @contiguous_gather_step(%base: memref<?xf32>,
+ %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+ %c0 = arith.constant 0 : index
+ %indices = vector.step : vector<16xindex>
+ %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
+ memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+ return %1 : vector<16xf32>
+}
+
+// -----
+
// CHECK-LABEL: @contiguous_scatter
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
@@ -2902,3 +2918,18 @@ func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @contiguous_scatter_step
+// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+func.func @contiguous_scatter_step(%base: memref<?xf32>,
+ %mask: vector<16xi1>, %value: vector<16xf32>) {
+ %c0 = arith.constant 0 : index
+ %indices = vector.step : vector<16xindex>
+ vector.scatter %base[%c0][%indices], %mask, %value :
+ memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+ return
+}
More information about the Mlir-commits
mailing list