[Mlir-commits] [mlir] [mlir][vector] ND vectors linearization pass (PR #81159)

Ivan Butygin llvmlistbot at llvm.org
Fri Feb 9 10:01:44 PST 2024


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/81159

>From a25509263063f973013159d7e4b25937f65f6ed7 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Wed, 7 Feb 2024 23:31:50 +0100
Subject: [PATCH 1/4] [mlir][vector] ND vectors linearization pass

Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors as `array<array<... vector>>` and SPIR-V doesn't handle them as all at the moment.
Sometime it's prefferable to tread multidim vectors as linearized. Add pass to do this.
Only constants and simple elementwise ops are supported for now.

Also, move generic op return type utility to common place and add ConversionPattern operating on traits.
---
 .../mlir/Dialect/Vector/Transforms/Passes.td  |   9 ++
 .../Vector/Transforms/VectorRewritePatterns.h |   6 +
 .../mlir/Transforms/DialectConversion.h       |  23 ++++
 .../Dialect/Math/Transforms/LegalizeToF32.cpp |  20 +--
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 +
 .../Vector/Transforms/VectorLinearize.cpp     | 122 ++++++++++++++++++
 .../Transforms/Utils/DialectConversion.cpp    |  21 +++
 mlir/test/Dialect/Vector/linearize.mlir       |  15 +++
 8 files changed, 204 insertions(+), 13 deletions(-)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
 create mode 100644 mlir/test/Dialect/Vector/linearize.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..71f412507457c2 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,13 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
 }
 
+def VectorLinearize : Pass<"vector-linearize"> {
+  let summary = "Linearize ND vectors into !d";
+  let description = [{
+    Linearizes ND vectors for N >= 2 into 1D vectors.
+  }];
+  let dependentDialects = ["vector::VectorDialect"];
+ }
+
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..45f54fc70e3261 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -20,7 +20,9 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
 
 namespace mlir {
+class ConversionTarget;
 class RewritePatternSet;
+class TypeConverter;
 
 namespace arith {
 class AndIOp;
@@ -375,6 +377,10 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+void populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..5081b4c06a617e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
   using ConversionPattern::matchAndRewrite;
 };
 
+/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
+/// for matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitConversionPattern : public ConversionPattern {
+public:
+  OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+  OpTraitConversionPattern(const TypeConverter &typeConverter,
+                           MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+};
+
+/// Generic utility to convert op result types according to type converter
+/// without knowing exact op type.
+/// Clones existing op with new result types and returns it.
+FailureOr<Operation *>
+convertOpResultTypes(Operation *op, ValueRange operands,
+                     const TypeConverter &converter,
+                     ConversionPatternRewriter &rewriter);
+
 /// Add a pattern to the given pattern list to convert the signature of a
 /// FunctionOpInterface op with the given type converter. This only supports
 /// ops which use FunctionType to represent their type.
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152..5998133b7eab8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   Location loc = op->getLoc();
   const TypeConverter *converter = getTypeConverter();
-  if (converter->isLegal(op))
-    return rewriter.notifyMatchFailure(loc, "op already legal");
-  OperationState newOp(loc, op->getName());
-  newOp.addOperands(operands);
+  FailureOr<Operation *> legalized =
+      convertOpResultTypes(op, operands, *converter, rewriter);
+  if (failed(legalized))
+    return failure();
 
-  SmallVector<Type> newResultTypes;
-  if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
-    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
-  newOp.addTypes(newResultTypes);
-  newOp.addAttributes(op->getAttrs());
-  Operation *legalized = rewriter.create(newOp);
-  SmallVector<Value> results = legalized->getResults();
-  for (auto [result, newType, origType] :
-       llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+  SmallVector<Value> results = (*legalized)->getResults();
+  for (auto [result, newType, origType] : llvm::zip_equal(
+           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
     if (newType != origType)
       result = rewriter.create<arith::TruncFOp>(loc, origType, result);
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..adf961ff935ffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorEmulateMaskedLoadStore.cpp
   VectorEmulateNarrowType.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
+  VectorLinearize.cpp
   VectorTransferOpTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
new file mode 100644
index 00000000000000..7602e8c1976a9a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -0,0 +1,122 @@
+//===- VectorLinearize.cpp - vector linearization transforms --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and pass for linearizing ND vectors into 1D.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::vector {
+#define GEN_PASS_DEF_VECTORLINEARIZE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace mlir::vector
+
+using namespace mlir;
+
+namespace {
+struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = constOp.getLoc();
+    auto resType =
+        getTypeConverter()->convertType<VectorType>(constOp.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!dstElementsAttr)
+      return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+
+    dstElementsAttr = dstElementsAttr.reshape(resType);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
+                                                   dstElementsAttr);
+    return success();
+  }
+};
+
+struct LinearizeVectorizable final
+    : OpTraitConversionPattern<OpTrait::Vectorizable> {
+  using OpTraitConversionPattern::OpTraitConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    FailureOr<Operation *> newOp =
+        convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(op, (*newOp)->getResults());
+    return success();
+  }
+};
+
+struct VectorLinearizePass final
+    : mlir::vector::impl::VectorLinearizeBase<VectorLinearizePass> {
+  using VectorLinearizeBase::VectorLinearizeBase;
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+
+    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+                                                              patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target) {
+  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
+    // Ignore scalable vectors for now.
+    if (type.getRank() <= 1 || type.isScalable())
+      return type;
+
+    return VectorType::get(type.getNumElements(), type.getElementType());
+  });
+
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
+    if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
+        !isa<VectorType>(type))
+      return nullptr;
+
+    return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+  };
+  typeConverter.addArgumentMaterialization(materializeCast);
+  typeConverter.addSourceMaterialization(materializeCast);
+  typeConverter.addTargetMaterialization(materializeCast);
+
+  target.markUnknownOpDynamicallyLegal(
+      [&](Operation *op) -> std::optional<bool> {
+        if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+          return typeConverter.isLegal(op);
+
+        return std::nullopt;
+      });
+
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
+                                                         patterns.getContext());
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..bfccef7cfe574b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3131,6 +3131,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
 };
 } // namespace
 
+FailureOr<Operation *>
+mlir::convertOpResultTypes(Operation *op, ValueRange operands,
+                           const TypeConverter &converter,
+                           ConversionPatternRewriter &rewriter) {
+  assert(op && "Invalid op");
+  Location loc = op->getLoc();
+  if (converter.isLegal(op))
+    return rewriter.notifyMatchFailure(loc, "op already legal");
+
+  OperationState newOp(loc, op->getName());
+  newOp.addOperands(operands);
+
+  SmallVector<Type> newResultTypes;
+  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+
+  newOp.addTypes(newResultTypes);
+  newOp.addAttributes(op->getAttrs());
+  return rewriter.create(newOp);
+}
+
 void mlir::populateFunctionOpInterfaceTypeConversionPattern(
     StringRef functionLikeOpName, RewritePatternSet &patterns,
     const TypeConverter &converter) {
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
new file mode 100644
index 00000000000000..e0fac81199bc8d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -split-input-file -vector-linearize | FileCheck %s
+
+// CHECK-LABEL: test_linearize
+//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+//       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+// Arith and math ops are handled in generic way, check some of them
+//       CHECK: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+  %1 = math.sin %arg0 : vector<2x2xf32>
+//       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+  %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}

>From dea16dcb590df6100a2333ac5af1352322c964cc Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 8 Feb 2024 17:42:36 +0100
Subject: [PATCH 2/4] typo

---
 mlir/include/mlir/Dialect/Vector/Transforms/Passes.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 71f412507457c2..32b4363be00949 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -22,7 +22,7 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
 }
 
 def VectorLinearize : Pass<"vector-linearize"> {
-  let summary = "Linearize ND vectors into !d";
+  let summary = "Linearize ND vectors into 1D";
   let description = [{
     Linearizes ND vectors for N >= 2 into 1D vectors.
   }];

>From bb14ff7721e014d2422f7bbf4485c54d769db0ae Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 8 Feb 2024 17:48:21 +0100
Subject: [PATCH 3/4] desc

---
 .../mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h       | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 45f54fc70e3261..9a98b103d9934b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -377,6 +377,7 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Linearizes ND vectors (N >= 2) into 1D
 void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target);

>From 8f57b09e8fc3ff78d8edc46fd2a5e042bbdbdec6 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Fri, 9 Feb 2024 18:59:02 +0100
Subject: [PATCH 4/4] switch to test pass

---
 .../mlir/Dialect/Vector/Transforms/Passes.td  |  9 ------
 .../Vector/Transforms/VectorRewritePatterns.h |  2 +-
 .../Vector/Transforms/VectorLinearize.cpp     | 25 ----------------
 mlir/test/Dialect/Vector/linearize.mlir       |  2 +-
 .../Dialect/Vector/TestVectorTransforms.cpp   | 29 +++++++++++++++++++
 5 files changed, 31 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 32b4363be00949..4911a61ab3c25d 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,13 +21,4 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
 }
 
-def VectorLinearize : Pass<"vector-linearize"> {
-  let summary = "Linearize ND vectors into 1D";
-  let description = [{
-    Linearizes ND vectors for N >= 2 into 1D vectors.
-  }];
-  let dependentDialects = ["vector::VectorDialect"];
- }
-
-
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 9a98b103d9934b..31b4eec35ec864 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -377,7 +377,7 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
-/// Linearizes ND vectors (N >= 2) into 1D
+/// Linearizes ND vectors (N >= 2) into 1D.
 void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
     ConversionTarget &target);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 7602e8c1976a9a..c5352043955579 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -12,17 +12,11 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/Passes.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Transforms/DialectConversion.h"
 
-namespace mlir::vector {
-#define GEN_PASS_DEF_VECTORLINEARIZE
-#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
-} // namespace mlir::vector
-
 using namespace mlir;
 
 namespace {
@@ -65,25 +59,6 @@ struct LinearizeVectorizable final
     return success();
   }
 };
-
-struct VectorLinearizePass final
-    : mlir::vector::impl::VectorLinearizeBase<VectorLinearizePass> {
-  using VectorLinearizeBase::VectorLinearizeBase;
-
-  void runOnOperation() override {
-    auto *context = &getContext();
-
-    TypeConverter typeConverter;
-    RewritePatternSet patterns(context);
-    ConversionTarget target(*context);
-
-    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
-                                                              patterns, target);
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      return signalPassFailure();
-  }
-};
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index e0fac81199bc8d..824e4b5515d43f 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -vector-linearize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
 
 // CHECK-LABEL: test_linearize
 //  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 126d65b1b8487f..acd38980514a56 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -823,6 +823,33 @@ struct TestVectorEmulateMaskedLoadStore final
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
+
+struct TestVectorLinearize final
+    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+  StringRef getArgument() const override { return "test-vector-linearize"; }
+  StringRef getDescription() const override {
+    return "Linearizes ND vectors for N >= 2 into 1D vectors";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+
+    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+                                                              patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -867,6 +894,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();
+
+  PassRegistration<TestVectorLinearize>();
 }
 } // namespace test
 } // namespace mlir



More information about the Mlir-commits mailing list