[Mlir-commits] [mlir] [mlir][vector] ND vectors linearization pass (PR #81159)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 8 08:50:11 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as `array<array<... vector>>` and SPIR-V conversion doesn't handle them at all at the moment. Sometimes it's preferable to treat multidim vectors as linearized 1D. Add pass to do this. Only constants and simple elementwise ops are supported for now.

@<!-- -->krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place.

Also, add ConversionPattern class operating on traits.

---
Full diff: https://github.com/llvm/llvm-project/pull/81159.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.td (+9) 
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+6) 
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+23) 
- (modified) mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp (+7-13) 
- (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+122) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+21) 
- (added) mlir/test/Dialect/Vector/linearize.mlir (+15) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 4911a61ab3c25d..32b4363be00949 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -21,4 +21,13 @@ def LowerVectorMaskPass : Pass<"lower-vector-mask", "func::FuncOp"> {
   let constructor = "mlir::vector::createLowerVectorMaskPass()";
 }
 
+def VectorLinearize : Pass<"vector-linearize"> {
+  let summary = "Linearize ND vectors into 1D";
+  let description = [{
+    Linearizes ND vectors for N >= 2 into 1D vectors.
+  }];
+  let dependentDialects = ["vector::VectorDialect"];
+ }
+
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..45f54fc70e3261 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -20,7 +20,9 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
 
 namespace mlir {
+class ConversionTarget;
 class RewritePatternSet;
+class TypeConverter;
 
 namespace arith {
 class AndIOp;
@@ -375,6 +377,10 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+void populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f4..5081b4c06a617e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
   using ConversionPattern::matchAndRewrite;
 };
 
+/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
+/// for matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitConversionPattern : public ConversionPattern {
+public:
+  OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+  OpTraitConversionPattern(const TypeConverter &typeConverter,
+                           MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+};
+
+/// Generic utility to convert op result types according to type converter
+/// without knowing exact op type.
+/// Clones existing op with new result types and returns it.
+FailureOr<Operation *>
+convertOpResultTypes(Operation *op, ValueRange operands,
+                     const TypeConverter &converter,
+                     ConversionPatternRewriter &rewriter);
+
 /// Add a pattern to the given pattern list to convert the signature of a
 /// FunctionOpInterface op with the given type converter. This only supports
 /// ops which use FunctionType to represent their type.
diff --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152..5998133b7eab8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   Location loc = op->getLoc();
   const TypeConverter *converter = getTypeConverter();
-  if (converter->isLegal(op))
-    return rewriter.notifyMatchFailure(loc, "op already legal");
-  OperationState newOp(loc, op->getName());
-  newOp.addOperands(operands);
+  FailureOr<Operation *> legalized =
+      convertOpResultTypes(op, operands, *converter, rewriter);
+  if (failed(legalized))
+    return failure();
 
-  SmallVector<Type> newResultTypes;
-  if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
-    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
-  newOp.addTypes(newResultTypes);
-  newOp.addAttributes(op->getAttrs());
-  Operation *legalized = rewriter.create(newOp);
-  SmallVector<Value> results = legalized->getResults();
-  for (auto [result, newType, origType] :
-       llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+  SmallVector<Value> results = (*legalized)->getResults();
+  for (auto [result, newType, origType] : llvm::zip_equal(
+           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
     if (newType != origType)
       result = rewriter.create<arith::TruncFOp>(loc, origType, result);
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..adf961ff935ffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorEmulateMaskedLoadStore.cpp
   VectorEmulateNarrowType.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
+  VectorLinearize.cpp
   VectorTransferOpTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
new file mode 100644
index 00000000000000..7602e8c1976a9a
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -0,0 +1,122 @@
+//===- VectorLinearize.cpp - vector linearization transforms --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and pass for linearizing ND vectors into 1D.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::vector {
+#define GEN_PASS_DEF_VECTORLINEARIZE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+} // namespace mlir::vector
+
+using namespace mlir;
+
+namespace {
+struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = constOp.getLoc();
+    auto resType =
+        getTypeConverter()->convertType<VectorType>(constOp.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!dstElementsAttr)
+      return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+
+    dstElementsAttr = dstElementsAttr.reshape(resType);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
+                                                   dstElementsAttr);
+    return success();
+  }
+};
+
+struct LinearizeVectorizable final
+    : OpTraitConversionPattern<OpTrait::Vectorizable> {
+  using OpTraitConversionPattern::OpTraitConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    FailureOr<Operation *> newOp =
+        convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(op, (*newOp)->getResults());
+    return success();
+  }
+};
+
+struct VectorLinearizePass final
+    : mlir::vector::impl::VectorLinearizeBase<VectorLinearizePass> {
+  using VectorLinearizeBase::VectorLinearizeBase;
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+
+    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+                                                              patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target) {
+  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
+    // Ignore scalable vectors for now.
+    if (type.getRank() <= 1 || type.isScalable())
+      return type;
+
+    return VectorType::get(type.getNumElements(), type.getElementType());
+  });
+
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
+    if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
+        !isa<VectorType>(type))
+      return nullptr;
+
+    return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+  };
+  typeConverter.addArgumentMaterialization(materializeCast);
+  typeConverter.addSourceMaterialization(materializeCast);
+  typeConverter.addTargetMaterialization(materializeCast);
+
+  target.markUnknownOpDynamicallyLegal(
+      [&](Operation *op) -> std::optional<bool> {
+        if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+          return typeConverter.isLegal(op);
+
+        return std::nullopt;
+      });
+
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
+                                                         patterns.getContext());
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb447227..bfccef7cfe574b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3131,6 +3131,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
 };
 } // namespace
 
+FailureOr<Operation *>
+mlir::convertOpResultTypes(Operation *op, ValueRange operands,
+                           const TypeConverter &converter,
+                           ConversionPatternRewriter &rewriter) {
+  assert(op && "Invalid op");
+  Location loc = op->getLoc();
+  if (converter.isLegal(op))
+    return rewriter.notifyMatchFailure(loc, "op already legal");
+
+  OperationState newOp(loc, op->getName());
+  newOp.addOperands(operands);
+
+  SmallVector<Type> newResultTypes;
+  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+
+  newOp.addTypes(newResultTypes);
+  newOp.addAttributes(op->getAttrs());
+  return rewriter.create(newOp);
+}
+
 void mlir::populateFunctionOpInterfaceTypeConversionPattern(
     StringRef functionLikeOpName, RewritePatternSet &patterns,
     const TypeConverter &converter) {
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
new file mode 100644
index 00000000000000..e0fac81199bc8d
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -split-input-file -vector-linearize | FileCheck %s
+
+// CHECK-LABEL: test_linearize
+//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+//       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+// Arith and math ops are handled in generic way, check some of them
+//       CHECK: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+  %1 = math.sin %arg0 : vector<2x2xf32>
+//       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+  %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/81159


More information about the Mlir-commits mailing list