[Mlir-commits] [mlir] [mlir][vector] Canonicalize gathers/scatters with trivial offsets (PR #117939)

Ivan Butygin llvmlistbot at llvm.org
Thu Nov 28 18:39:55 PST 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/117939

>From f4c618ea9a2a902bdaac0cf9c632acfa7219bd08 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 28 Nov 2024 00:27:25 +0100
Subject: [PATCH 1/4] [mlir][vector] Canonicalize gathers/scatters with trivial
 offsets

Cononicalize gathers/scatters with contiguous (i.e. [0, 1, 2, ...]) offsets into vector masked load/store ops.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 46 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 31 +++++++++++++++
 2 files changed, 75 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0c0a7bc98d8b5e..21e62085be5a49 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5181,6 +5181,19 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
+static LogicalResult isContiguousIndices(Value val) {
+  auto vecType = dyn_cast<VectorType>(val.getType());
+  if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
+    return failure();
+
+  DenseIntElementsAttr elements;
+  if (!matchPattern(val, m_Constant(&elements)))
+    return failure();
+
+  return success(
+      llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
+}
+
 namespace {
 class GatherFolder final : public OpRewritePattern<GatherOp> {
 public:
@@ -5199,11 +5212,26 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
     llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
   }
 };
+
+class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(isContiguousIndices(op.getIndexVec())))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
+                                              op.getIndices(), op.getMask(),
+                                              op.getPassThru());
+    return success();
+  }
+};
 } // namespace
 
 void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<GatherFolder>(context);
+  results.add<GatherFolder, GatherTrivialIndices>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -5245,11 +5273,25 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
     llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
   }
 };
+
+class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(ScatterOp op,
+                                PatternRewriter &rewriter) const override {
+    if (failed(isContiguousIndices(op.getIndexVec())))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<MaskedStoreOp>(
+        op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
+    return success();
+  }
+};
 } // namespace
 
 void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ScatterFolder>(context);
+  results.add<ScatterFolder, ScatterTrivialIndices>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5ae769090dac66..b4f9d98e729771 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2826,3 +2826,34 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s
   %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
   return %1 : vector<1x1x2x1x1x1xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @contiguous_gather
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+//       CHECK:   return %[[R]]
+func.func @contiguous_gather(%base: memref<?xf32>,
+                             %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+  %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
+    memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %1 : vector<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_scatter
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+func.func @contiguous_scatter(%base: memref<?xf32>,
+                              %mask: vector<16xi1>, %value: vector<16xf32>){
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+  vector.scatter %base[%c0][%indices], %mask, %value :
+    memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+  return
+}

>From ff5a39b9daa00f3f883d10d83830b19ff061115a Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 29 Nov 2024 02:22:36 +0100
Subject: [PATCH 2/4] add const mask tests

---
 mlir/test/Dialect/Vector/canonicalize.mlir | 35 +++++++++++++++++++++-
 1 file changed, 34 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b4f9d98e729771..b9ae28112d8a0f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2845,14 +2845,47 @@ func.func @contiguous_gather(%base: memref<?xf32>,
 
 // -----
 
+// CHECK-LABEL: @contiguous_gather_const_mask
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[PASSTHRU:.*]]: vector<16xf32>)
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[R:.*]] = vector.load %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
+//       CHECK:   return %[[R]]
+func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
+                                        %passthru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+  %mask = arith.constant dense<true> : vector<16xi1>
+  %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
+    memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %1 : vector<16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @contiguous_scatter
 //  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
 //       CHECK:   %[[C0:.*]] = arith.constant 0 : index
 //       CHECK:   vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
 func.func @contiguous_scatter(%base: memref<?xf32>,
-                              %mask: vector<16xi1>, %value: vector<16xf32>){
+                              %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+  vector.scatter %base[%c0][%indices], %mask, %value :
+    memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_scatter_const_mask
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[VALUE:.*]]: vector<16xf32>)
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   vector.store %[[VALUE]], %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
+func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
+                                         %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
   %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
+  %mask = vector.constant_mask [16] : vector<16xi1>
   vector.scatter %base[%c0][%indices], %mask, %value :
     memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
   return

>From 185c98cb431345902e02a440b402652b90fe3dda Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 29 Nov 2024 02:30:21 +0100
Subject: [PATCH 3/4] nits

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 21e62085be5a49..788e8f555ada45 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5181,13 +5181,14 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
-static LogicalResult isContiguousIndices(Value val) {
-  auto vecType = dyn_cast<VectorType>(val.getType());
+/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
+static LogicalResult isContiguousIndices(Value indexVec) {
+  auto vecType = dyn_cast<VectorType>(indexVec.getType());
   if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
     return failure();
 
   DenseIntElementsAttr elements;
-  if (!matchPattern(val, m_Constant(&elements)))
+  if (!matchPattern(indexVec, m_Constant(&elements)))
     return failure();
 
   return success(
@@ -5213,6 +5214,8 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
   }
 };
 
+/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
+/// maskedload. Only 1D non-scalable vectors are supported for now.
 class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -5274,6 +5277,8 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
   }
 };
 
+/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
+/// maskedstore. Only 1D non-scalable vectors are supported for now.
 class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
 public:
   using OpRewritePattern::OpRewritePattern;

>From 06e4f95c0b73d1c93207b0b2fca1b8312801d610 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 29 Nov 2024 03:39:35 +0100
Subject: [PATCH 4/4] vector.step support

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   |  3 +++
 mlir/test/Dialect/Vector/canonicalize.mlir | 31 ++++++++++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 788e8f555ada45..7053aaafdafdea 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5187,6 +5187,9 @@ static LogicalResult isContiguousIndices(Value indexVec) {
   if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
     return failure();
 
+  if (indexVec.getDefiningOp<StepOp>())
+    return success();
+
   DenseIntElementsAttr elements;
   if (!matchPattern(indexVec, m_Constant(&elements)))
     return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b9ae28112d8a0f..058e3bd35eb255 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2862,6 +2862,22 @@ func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
 
 // -----
 
+// CHECK-LABEL: @contiguous_gather_step
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+//       CHECK:   return %[[R]]
+func.func @contiguous_gather_step(%base: memref<?xf32>,
+                                  %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = arith.constant 0 : index
+  %indices = vector.step : vector<16xindex>
+  %1 = vector.gather %base[%c0][%indices], %mask, %passthru :
+    memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %1 : vector<16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @contiguous_scatter
 //  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
 //       CHECK:   %[[C0:.*]] = arith.constant 0 : index
@@ -2890,3 +2906,18 @@ func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
     memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @contiguous_scatter_step
+//  CHECK-SAME:   (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+func.func @contiguous_scatter_step(%base: memref<?xf32>,
+                                   %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c0 = arith.constant 0 : index
+  %indices = vector.step : vector<16xindex>
+  vector.scatter %base[%c0][%indices], %mask, %value :
+    memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
+  return
+}



More information about the Mlir-commits mailing list