[Mlir-commits] [mlir] f80a976 - [mlir][vector] Add gather lowering patterns

Jakub Kuderski llvmlistbot at llvm.org
Tue Mar 14 08:00:12 PDT 2023


Author: Jakub Kuderski
Date: 2023-03-14T10:59:30-04:00
New Revision: f80a976acd85611acd795225999a92bba57c76e6

URL: https://github.com/llvm/llvm-project/commit/f80a976acd85611acd795225999a92bba57c76e6
DIFF: https://github.com/llvm/llvm-project/commit/f80a976acd85611acd795225999a92bba57c76e6.diff

LOG: [mlir][vector] Add gather lowering patterns

This is for targets that do not support gather-like ops, e.g., SPIR-V.

Gather is expanded into lower-level vector ops with memory accesses
guarded with `scf.if`.

I also considered generating `vector.maskedload`s, but decided against
it to keep the `memref` and `tensor` codepath closer together. There's a
good chance that if a target doesn't support gather it does not support
masked loads either.

Issue: https://github.com/llvm/llvm-project/issues/60905

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D145942

Added: 
    mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 1d572435cc2cd..af68de7e0051e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -334,6 +334,14 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
                                   const UnrollVectorOptions &options,
                                   PatternBenefit benefit = 1);
 
+/// Expands `vector.gather` ops into a series of conditional scalar loads
+/// (`vector.load` for memrefs or `tensor.extract` for tensors). These loads are
+/// conditional to avoid out-of-bounds memory accesses and guarded with `scf.if`
+/// ops. This lowering path is intended for targets that do not feature
+/// dedicated gather ops.
+void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+
 //===----------------------------------------------------------------------===//
 // Finer-grained patterns exposed for more control over individual lowerings.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 9e5d787856b1f..fe59143ebd55f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 
+#include <cstdint>
 #include <functional>
 #include <optional>
 #include <type_traits>
@@ -22,6 +23,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -30,8 +32,10 @@
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/VectorInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
 
@@ -3153,6 +3157,132 @@ struct CanonicalizeContractMatmulToMMT final
   FilterConstraintType filter;
 };
 
+/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
+///        ... into vector<2x3xf32>
+///
+/// ==>
+///
+/// %0   = arith.constant dense<0.0> : vector<2x3xf32>
+/// %g0  = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
+/// %1   = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %g1  = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
+/// %g   = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d gather ops.
+struct FlattenGather : OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultTy = op.getType();
+    if (resultTy.getRank() < 2)
+      return rewriter.notifyMatchFailure(op, "already flat");
+
+    Location loc = op.getLoc();
+    Value indexVec = op.getIndexVec();
+    Value maskVec = op.getMask();
+    Value passThruVec = op.getPassThru();
+
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultTy, rewriter.getZeroAttr(resultTy));
+
+    Type subTy = VectorType::get(resultTy.getShape().drop_front(),
+                                 resultTy.getElementType());
+
+    for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+      int64_t thisIdx[1] = {i};
+
+      Value indexSubVec =
+          rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
+      Value maskSubVec =
+          rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
+      Value passThruSubVec =
+          rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
+      Value subGather = rewriter.create<vector::GatherOp>(
+          loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
+          passThruSubVec);
+      result =
+          rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
+/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
+/// loads/extracts are made conditional using `scf.if` ops.
+struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultTy = op.getType();
+    if (resultTy.getRank() != 1)
+      return rewriter.notifyMatchFailure(op, "unsupported rank");
+
+    Location loc = op.getLoc();
+    Type elemTy = resultTy.getElementType();
+    // Vector type with a single element. Used to generate `vector.loads`.
+    VectorType elemVecTy = VectorType::get({1}, elemTy);
+
+    Value condMask = op.getMask();
+    Value base = op.getBase();
+    Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
+        loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
+        op.getIndexVec());
+    auto baseOffsets = llvm::to_vector(op.getIndices());
+    Value lastBaseOffset = baseOffsets.back();
+
+    Value result = op.getPassThru();
+
+    // Emit a conditional access for each vector element.
+    for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
+      int64_t thisIdx[1] = {i};
+      Value condition =
+          rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
+      Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
+      baseOffsets.back() =
+          rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
+
+      auto loadBuilder = [&](OpBuilder &b, Location loc) {
+        Value extracted;
+        if (isa<MemRefType>(base.getType())) {
+          // `vector.load` does not support scalar result; emit a vector load
+          // and extract the single result instead.
+          Value load =
+              b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
+          int64_t zeroIdx[1] = {0};
+          extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
+        } else {
+          extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
+        }
+
+        Value newResult =
+            b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
+        b.create<scf::YieldOp>(loc, newResult);
+      };
+      auto passThruBuilder = [result](OpBuilder &b, Location loc) {
+        b.create<scf::YieldOp>(loc, result);
+      };
+
+      result =
+          rewriter
+              .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
+                                 /*elseBuilder=*/passThruBuilder)
+              .getResult(0);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorMaskMaterializationPatterns(
@@ -3249,6 +3379,12 @@ void mlir::vector::populateVectorScanLoweringPatterns(
   patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
 }
 
+void mlir::vector::populateVectorGatherLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
+                                                          benefit);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
new file mode 100644
index 0000000000000..5afd2fc73a7cf
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -0,0 +1,127 @@
+// RUN: mlir-opt %s --test-vector-gather-lowering | FileCheck %s
+
+// CHECK-LABEL: @gather_memref_1d
+// CHECK-SAME:    ([[BASE:%.+]]: memref<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
+// CHECK-DAG:     [[M0:%.+]]    = vector.extract [[MASK]][0] : vector<2xi1>
+// CHECK-DAG:     %[[IDX0:.+]]  = vector.extract [[IDXVEC]][0] : vector<2xindex>
+// CHECK-NEXT:    [[RES0:%.+]]  = scf.if [[M0]] -> (vector<2xf32>)
+// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[IDX0]]] : memref<?xf32>, vector<1xf32>
+// CHECK-NEXT:      [[ELEM0:%.+]] = vector.extract [[LD0]][0] : vector<1xf32>
+// CHECK-NEXT:      [[INS0:%.+]]  = vector.insert [[ELEM0]], [[PASS]] [0] : f32 into vector<2xf32>
+// CHECK-NEXT:      scf.yield [[INS0]] : vector<2xf32>
+// CHECK-NEXT:    else
+// CHECK-NEXT:      scf.yield [[PASS]] : vector<2xf32>
+// CHECK-DAG:     [[M1:%.+]]    = vector.extract [[MASK]][1] : vector<2xi1>
+// CHECK-DAG:     %[[IDX1:.+]]  = vector.extract [[IDXVEC]][1] : vector<2xindex>
+// CHECK-NEXT:    [[RES1:%.+]]  = scf.if [[M1]] -> (vector<2xf32>)
+// CHECK-NEXT:      [[LD1:%.+]]   = vector.load [[BASE]][%[[IDX1]]] : memref<?xf32>, vector<1xf32>
+// CHECK-NEXT:      [[ELEM1:%.+]] = vector.extract [[LD1]][0] : vector<1xf32>
+// CHECK-NEXT:      [[INS1:%.+]]  = vector.insert [[ELEM1]], [[RES0]] [1] : f32 into vector<2xf32>
+// CHECK-NEXT:      scf.yield [[INS1]] : vector<2xf32>
+// CHECK-NEXT:    else
+// CHECK-NEXT:      scf.yield [[RES0]] : vector<2xf32>
+// CHECK:         return [[RES1]] : vector<2xf32>
+func.func @gather_memref_1d(%base: memref<?xf32>, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @gather_memref_1d_i32_index
+// CHECK-SAME:    ([[BASE:%.+]]: memref<?xf32>, [[IDXVEC:%.+]]: vector<2xi32>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
+// CHECK-DAG:     [[C42:%.+]]   = arith.constant 42 : index
+// CHECK-DAG:     [[IDXS:%.+]]  = arith.index_cast [[IDXVEC]] : vector<2xi32> to vector<2xindex>
+// CHECK-DAG:     [[IDX0:%.+]]  = vector.extract [[IDXS]][0] : vector<2xindex>
+// CHECK-NEXT:    %[[OFF0:.+]]  = arith.addi [[IDX0]], [[C42]] : index
+// CHECK-NEXT:    [[RES0:%.+]]  = scf.if
+// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[OFF0]]] : memref<?xf32>, vector<1xf32>
+// CHECK:         else
+// CHECK:         [[IDX1:%.+]]  = vector.extract [[IDXS]][1] : vector<2xindex>
+// CHECK:         %[[OFF1:.+]]  = arith.addi [[IDX1]], [[C42]] : index
+// CHECK:         [[RES1:%.+]]  = scf.if
+// CHECK-NEXT:      [[LD1:%.+]]   = vector.load [[BASE]][%[[OFF1]]] : memref<?xf32>, vector<1xf32>
+// CHECK:         else
+// CHECK:         return [[RES1]] : vector<2xf32>
+func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
+  %c0 = arith.constant 42 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : memref<?xf32>, vector<2xi32>, vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @gather_memref_2d
+// CHECK-SAME:    ([[BASE:%.+]]: memref<?x?xf32>, [[IDXVEC:%.+]]: vector<2x3xindex>, [[MASK:%.+]]: vector<2x3xi1>, [[PASS:%.+]]: vector<2x3xf32>)
+// CHECK-DAG:     %[[C0:.+]]    = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.+]]    = arith.constant 1 : index
+// CHECK-DAG:     [[PTV0:%.+]]  = vector.extract [[PASS]][0] : vector<2x3xf32>
+// CHECK-DAG:     [[M0:%.+]]    = vector.extract [[MASK]][0, 0] : vector<2x3xi1>
+// CHECK-DAG:     [[IDX0:%.+]]  = vector.extract [[IDXVEC]][0, 0] : vector<2x3xindex>
+// CHECK-NEXT:    %[[OFF0:.+]]  = arith.addi [[IDX0]], %[[C1]] : index
+// CHECK-NEXT:    [[RES0:%.+]]  = scf.if [[M0]] -> (vector<3xf32>)
+// CHECK-NEXT:      [[LD0:%.+]]   = vector.load [[BASE]][%[[C0]], %[[OFF0]]] : memref<?x?xf32>, vector<1xf32>
+// CHECK-NEXT:      [[ELEM0:%.+]] = vector.extract [[LD0]][0] : vector<1xf32>
+// CHECK-NEXT:      [[INS0:%.+]]  = vector.insert [[ELEM0]], [[PTV0]] [0] : f32 into vector<3xf32>
+// CHECK-NEXT:      scf.yield [[INS0]] : vector<3xf32>
+// CHECK-NEXT:    else
+// CHECK-NEXT:      scf.yield [[PTV0]] : vector<3xf32>
+// CHECK-COUNT-5: scf.if
+// CHECK:         [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : vector<3xf32> into vector<2x3xf32>
+// CHECK-NEXT:    return [[FINAL]] : vector<2x3xf32>
+ func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+ }
+
+// CHECK-LABEL: @gather_tensor_1d
+// CHECK-SAME:    ([[BASE:%.+]]: tensor<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
+// CHECK-DAG:     [[M0:%.+]]    = vector.extract [[MASK]][0] : vector<2xi1>
+// CHECK-DAG:     %[[IDX0:.+]]  = vector.extract [[IDXVEC]][0] : vector<2xindex>
+// CHECK-NEXT:    [[RES0:%.+]]  = scf.if [[M0]] -> (vector<2xf32>)
+// CHECK-NEXT:      [[ELEM0:%.+]] = tensor.extract [[BASE]][%[[IDX0]]] : tensor<?xf32>
+// CHECK-NEXT:      [[INS0:%.+]]  = vector.insert [[ELEM0]], [[PASS]] [0] : f32 into vector<2xf32>
+// CHECK-NEXT:      scf.yield [[INS0]] : vector<2xf32>
+// CHECK-NEXT:    else
+// CHECK-NEXT:      scf.yield [[PASS]] : vector<2xf32>
+// CHECK-DAG:     [[M1:%.+]]    = vector.extract [[MASK]][1] : vector<2xi1>
+// CHECK-DAG:     %[[IDX1:.+]]  = vector.extract [[IDXVEC]][1] : vector<2xindex>
+// CHECK-NEXT:    [[RES1:%.+]]  = scf.if [[M1]] -> (vector<2xf32>)
+// CHECK-NEXT:      [[ELEM1:%.+]] = tensor.extract [[BASE]][%[[IDX1]]] : tensor<?xf32>
+// CHECK-NEXT:      [[INS1:%.+]]  = vector.insert [[ELEM1]], [[RES0]] [1] : f32 into vector<2xf32>
+// CHECK-NEXT:      scf.yield [[INS1]] : vector<2xf32>
+// CHECK-NEXT:    else
+// CHECK-NEXT:      scf.yield [[RES0]] : vector<2xf32>
+// CHECK:         return [[RES1]] : vector<2xf32>
+func.func @gather_tensor_1d(%base: tensor<?xf32>, %v: vector<2xindex>, %mask: vector<2xi1>, %pass_thru: vector<2xf32>) -> vector<2xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @gather_tensor_2d
+// CHECK:  scf.if
+// CHECK:    tensor.extract
+// CHECK:  else
+// CHECK:  scf.if
+// CHECK:    tensor.extract
+// CHECK:  else
+// CHECK:  scf.if
+// CHECK:    tensor.extract
+// CHECK:  else
+// CHECK:  scf.if
+// CHECK:    tensor.extract
+// CHECK:  else
+// CHECK:  scf.if
+// CHECK:    tensor.extract
+// CHECK:  else
+// CHECK:  scf.if
+// CHECK:    tensor.extract
+// CHECK:  else
+// CHECK:       [[FINAL:%.+]] = vector.insert %{{.+}}, %{{.+}} [1] : vector<3xf32> into vector<2x3xf32>
+// CHECK-NEXT:  return [[FINAL]] : vector<2x3xf32>
+ func.func @gather_tensor_2d(%base: tensor<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : tensor<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
+  return %0 : vector<2x3xf32>
+ }

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 93736dade444e..5a21bff0b39c3 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
@@ -911,6 +912,29 @@ struct TestCreateVectorBroadcast
   }
 };
 
+struct TestVectorGatherLowering
+    : public PassWrapper<TestVectorGatherLowering,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorGatherLowering)
+
+  StringRef getArgument() const final { return "test-vector-gather-lowering"; }
+  StringRef getDescription() const final {
+    return "Test patterns that lower the gather op in the vector conditional "
+           "loads";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect, func::FuncDialect,
+                    memref::MemRefDialect, scf::SCFDialect,
+                    tensor::TensorDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorGatherLoweringPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -953,6 +977,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestVectorExtractStridedSliceLowering>();
 
   PassRegistration<TestCreateVectorBroadcast>();
+
+  PassRegistration<TestVectorGatherLowering>();
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list