[Mlir-commits] [mlir] 35ef399 - [mlir][vector] ND vectors linearization pass (#81159)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 13 04:31:02 PST 2024


Author: Ivan Butygin
Date: 2024-02-13T15:30:58+03:00
New Revision: 35ef3994bf738318b59ce640910fb1ccd3bb7dcb

URL: https://github.com/llvm/llvm-project/commit/35ef3994bf738318b59ce640910fb1ccd3bb7dcb
DIFF: https://github.com/llvm/llvm-project/commit/35ef3994bf738318b59ce640910fb1ccd3bb7dcb.diff

LOG: [mlir][vector] ND vectors linearization pass (#81159)

Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion
handles ND vectors (N >= 2) as `array<array<... vector>>` and SPIR-V
conversion doesn't handle them at all at the moment. Sometimes it's
preferable to treat multidim vectors as linearized 1D. Add pass to do
this. Only constants and simple elementwise ops are supported for now.

@krzysz00 I've extracted yours result type conversion code from
LegalizeToF32 and moved it to common place.

Also, add ConversionPattern class operating on traits.

Added: 
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/test/Dialect/Vector/linearize.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
    mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f5941d32e683fc..7c943f07066c70 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -20,7 +20,9 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc"
 
 namespace mlir {
+class ConversionTarget;
 class RewritePatternSet;
+class TypeConverter;
 
 namespace arith {
 class AndIOp;
@@ -375,6 +377,13 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Populates patterns for ND vectors (N >= 2) linearization and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+void populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target);
+
 } // namespace vector
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b1ec1fe4ecd51a..091131651bbf56 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -604,6 +604,29 @@ class OpInterfaceConversionPattern : public ConversionPattern {
   using ConversionPattern::matchAndRewrite;
 };
 
+/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
+/// for matching and rewriting against instances of an operation that possess a
+/// given trait.
+template <template <typename> class TraitType>
+class OpTraitConversionPattern : public ConversionPattern {
+public:
+  OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+  OpTraitConversionPattern(const TypeConverter &typeConverter,
+                           MLIRContext *context, PatternBenefit benefit = 1)
+      : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
+                          TypeID::get<TraitType>(), benefit, context) {}
+};
+
+/// Generic utility to convert op result types according to type converter
+/// without knowing exact op type.
+/// Clones existing op with new result types and returns it.
+FailureOr<Operation *>
+convertOpResultTypes(Operation *op, ValueRange operands,
+                     const TypeConverter &converter,
+                     ConversionPatternRewriter &rewriter);
+
 /// Add a pattern to the given pattern list to convert the signature of a
 /// FunctionOpInterface op with the given type converter. This only supports
 /// ops which use FunctionType to represent their type.

diff  --git a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
index d281790e877152..5998133b7eab8b 100644
--- a/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp
@@ -76,20 +76,14 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   Location loc = op->getLoc();
   const TypeConverter *converter = getTypeConverter();
-  if (converter->isLegal(op))
-    return rewriter.notifyMatchFailure(loc, "op already legal");
-  OperationState newOp(loc, op->getName());
-  newOp.addOperands(operands);
+  FailureOr<Operation *> legalized =
+      convertOpResultTypes(op, operands, *converter, rewriter);
+  if (failed(legalized))
+    return failure();
 
-  SmallVector<Type> newResultTypes;
-  if (failed(converter->convertTypes(op->getResultTypes(), newResultTypes)))
-    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
-  newOp.addTypes(newResultTypes);
-  newOp.addAttributes(op->getAttrs());
-  Operation *legalized = rewriter.create(newOp);
-  SmallVector<Value> results = legalized->getResults();
-  for (auto [result, newType, origType] :
-       llvm::zip_equal(results, newResultTypes, op->getResultTypes())) {
+  SmallVector<Value> results = (*legalized)->getResults();
+  for (auto [result, newType, origType] : llvm::zip_equal(
+           results, (*legalized)->getResultTypes(), op->getResultTypes())) {
     if (newType != origType)
       result = rewriter.create<arith::TruncFOp>(loc, origType, result);
   }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index daf28882976ef6..adf961ff935ffb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   VectorEmulateMaskedLoadStore.cpp
   VectorEmulateNarrowType.cpp
   VectorInsertExtractStridedSliceRewritePatterns.cpp
+  VectorLinearize.cpp
   VectorTransferOpTransforms.cpp
   VectorTransferSplitRewritePatterns.cpp
   VectorTransforms.cpp

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
new file mode 100644
index 00000000000000..c5352043955579
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -0,0 +1,97 @@
+//===- VectorLinearize.cpp - vector linearization transforms --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns and pass for linearizing ND vectors into 1D.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = constOp.getLoc();
+    auto resType =
+        getTypeConverter()->convertType<VectorType>(constOp.getType());
+    if (!resType)
+      return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
+    auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!dstElementsAttr)
+      return rewriter.notifyMatchFailure(loc, "unsupported attr type");
+
+    dstElementsAttr = dstElementsAttr.reshape(resType);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, resType,
+                                                   dstElementsAttr);
+    return success();
+  }
+};
+
+struct LinearizeVectorizable final
+    : OpTraitConversionPattern<OpTrait::Vectorizable> {
+  using OpTraitConversionPattern::OpTraitConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    FailureOr<Operation *> newOp =
+        convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(op, (*newOp)->getResults());
+    return success();
+  }
+};
+} // namespace
+
+void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target) {
+  typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
+    // Ignore scalable vectors for now.
+    if (type.getRank() <= 1 || type.isScalable())
+      return type;
+
+    return VectorType::get(type.getNumElements(), type.getElementType());
+  });
+
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
+    if (inputs.size() != 1 || !isa<VectorType>(inputs.front().getType()) ||
+        !isa<VectorType>(type))
+      return nullptr;
+
+    return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
+  };
+  typeConverter.addArgumentMaterialization(materializeCast);
+  typeConverter.addSourceMaterialization(materializeCast);
+  typeConverter.addTargetMaterialization(materializeCast);
+
+  target.markUnknownOpDynamicallyLegal(
+      [&](Operation *op) -> std::optional<bool> {
+        if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
+          return typeConverter.isLegal(op);
+
+        return std::nullopt;
+      });
+
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
+                                                         patterns.getContext());
+}

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e90447084d68bd..a5a77e00fbfb5f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3130,6 +3130,27 @@ struct AnyFunctionOpInterfaceSignatureConversion
 };
 } // namespace
 
+FailureOr<Operation *>
+mlir::convertOpResultTypes(Operation *op, ValueRange operands,
+                           const TypeConverter &converter,
+                           ConversionPatternRewriter &rewriter) {
+  assert(op && "Invalid op");
+  Location loc = op->getLoc();
+  if (converter.isLegal(op))
+    return rewriter.notifyMatchFailure(loc, "op already legal");
+
+  OperationState newOp(loc, op->getName());
+  newOp.addOperands(operands);
+
+  SmallVector<Type> newResultTypes;
+  if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+    return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
+
+  newOp.addTypes(newResultTypes);
+  newOp.addAttributes(op->getAttrs());
+  return rewriter.create(newOp);
+}
+
 void mlir::populateFunctionOpInterfaceTypeConversionPattern(
     StringRef functionLikeOpName, RewritePatternSet &patterns,
     const TypeConverter &converter) {

diff  --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
new file mode 100644
index 00000000000000..85e23103eaedb7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
+
+// CHECK-LABEL: test_linearize
+//  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
+//       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
+func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
+//       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
+//       CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
+
+// Arith and math ops are handled in generic way, check some of them
+//       CHECK: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+  %1 = math.sin %arg0 : vector<2x2xf32>
+//       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+  %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
+
+//       CHECK: return %[[RES]] : vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 126d65b1b8487f..acd38980514a56 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -823,6 +823,33 @@ struct TestVectorEmulateMaskedLoadStore final
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 };
+
+struct TestVectorLinearize final
+    : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+  StringRef getArgument() const override { return "test-vector-linearize"; }
+  StringRef getDescription() const override {
+    return "Linearizes ND vectors for N >= 2 into 1D vectors";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    auto *context = &getContext();
+
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+
+    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+                                                              patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      return signalPassFailure();
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -867,6 +894,8 @@ void registerTestVectorLowerings() {
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();
+
+  PassRegistration<TestVectorLinearize>();
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list