[Mlir-commits] [mlir] [mlir][vector] Add support for unrolling vector.bitcast ops. (PR #94064)

Han-Chung Wang llvmlistbot at llvm.org
Mon Jun 3 11:35:07 PDT 2024


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/94064

>From c08faa4dcd4c120c442798b33f74480bdb9b4f8f Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 31 May 2024 15:23:00 -0700
Subject: [PATCH 1/3] [mlir][vector] Add support for unrolling vector.bitcast
 ops.

The revision unrolls vector.bitcast like:

```mlir
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
```

to

```mlir
%cst = arith.constant dense<0> : vector<2x2xi64>
%0 = vector.extract %arg0[0] : vector<4xi32> from vector<2x4xi32>
%1 = vector.bitcast %0 : vector<4xi32> to vector<2xi64>
%2 = vector.insert %1, %cst [0] : vector<2xi64> into vector<2x2xi64>
%3 = vector.extract %arg0[1] : vector<4xi32> from vector<2x4xi32>
%4 = vector.bitcast %3 : vector<4xi32> to vector<2xi64>
%5 = vector.insert %4, %2 [1] : vector<2xi64> into vector<2x2xi64>
```

The scalable vector is not supported because of the limitation of
`vector::createUnrollIterator`. The targetRank could mismatch the final
rank during unrolling; there is no direct way to query what the final
rank is from the object.
---
 .../Vector/TransformOps/VectorTransformOps.td | 14 +++
 .../Vector/Transforms/LoweringPatterns.h      |  9 ++
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  1 +
 .../TransformOps/VectorTransformOps.cpp       |  5 +
 .../Dialect/Vector/Transforms/CMakeLists.txt  |  1 +
 .../Vector/Transforms/LowerVectorBitCast.cpp  | 94 +++++++++++++++++++
 .../VectorToLLVM/vector-to-llvm.mlir          | 10 ++
 .../vector-bitcast-lowering-transforms.mlir   | 35 +++++++
 8 files changed, 169 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
 create mode 100644 mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bc3c16d40520e..c91e8fbbae90f 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -89,6 +89,20 @@ def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyLowerBitCastPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.lower_bitcast",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector bitcast operations should be lowered to
+    finer-grained vector primitives.
+
+    This is usally a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyLowerBroadcastPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_broadcast",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 8fd9904fabc0e..1976b8399c7f9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -276,6 +276,15 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
 void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
                                                PatternBenefit benefit = 1);
 
+/// Populates the pattern set with the following patterns:
+///
+/// [UnrollBitCastOp]
+/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
+/// BitCastOp (of `targetRank`) + InsertOp.
+void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
+                                           int64_t targetRank = 1,
+                                           PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a436c4a9400..55143d5939ba2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
+    populateVectorBitCastLoweringPatterns(patterns);
     populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
     populateVectorMaskOpLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 61fd6bd972e3a..23960269095e5 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -79,6 +79,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
 }
 
+void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorBitCastLoweringPatterns(patterns);
+}
+
 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorBroadcastLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4dbefdd376a8b..723b2f62d65d4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorTransforms
   BufferizableOpInterfaceImpl.cpp
+  LowerVectorBitCast.cpp
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
   LowerVectorGather.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
new file mode 100644
index 0000000000000..581ee54fb2935
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -0,0 +1,94 @@
+//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.bitcast' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-bitcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// A one-shot unrolling of vector.bitcast to the `targetRank`.
+///
+/// Example:
+///
+///   vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
+///
+/// Would be unrolled to:
+///
+/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
+/// %0 = vector.extract %a[0, 0, 0]                 ─┐
+///        : vector<4xi64> from vector<1x2x3x4xi64>  |
+/// %1 = vector.bitcast %0                           | - Repeated 6x for
+///        : vector<4xi64> to vector<8xi32>          |   all leading positions
+/// %2 = vector.insert %1, %result [0, 0, 0]         |
+///        : vector<8xi64> into vector<1x2x3x8xi32> ─┘
+///
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
+public:
+  UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
+                  PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), targetRank(targetRank){};
+
+  LogicalResult matchAndRewrite(vector::BitCastOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+    if (!unrollIterator)
+      return failure();
+
+    // TODO: Support the scalable vector cases. It is not supported because
+    // the final rank could be values other than `targetRank`. It makes creating
+    // the result type of new vector.bitcast ops much harder.
+    if (resultType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "unrolling vector.bitcast on scalable vectors is NIY");
+
+    SmallVector<int64_t> shape(resultType.getShape().take_back(targetRank));
+    auto bitcastResType = VectorType::get(shape, resultType.getElementType());
+
+    Location loc = op.getLoc();
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    for (auto position : *unrollIterator) {
+      Value extract =
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+      Value bitcast =
+          rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
+      result =
+          rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  int64_t targetRank = 1;
+};
+
+} // namespace
+
+void mlir::vector::populateVectorBitCastLoweringPatterns(
+    RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
+  patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 245edb6789d30..12121ea0dd70e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2564,3 +2564,13 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
     %0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
     return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @vector_bitcast_2d
+// CHECK:         llvm.bitcast
+// CHECK-NOT:     vector.bitcast
+func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
+  %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
+  return %0 : vector<2x2xi64>
+}
diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
new file mode 100644
index 0000000000000..e8c529dcacc75
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
+  %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
+  return %0 : vector<2x2xi64>
+}
+// CHECK-LABEL: func.func @vector_bitcast_2d
+// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
+// CHECK:         %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
+// CHECK:         %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
+// CHECK:         %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
+// CHECK:         %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
+// CHECK:         %[[V2:.+]] = vector.extract %[[IN]][1] : vector<4xi32> from vector<2x4xi32>
+// CHECK:         %[[B2:.+]] = vector.bitcast %[[V2]] : vector<4xi32> to vector<2xi64>
+// CHECK:         %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1]
+// CHECK:         return %[[R2]]
+
+func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> vector<1x2x[3]x8xi32> {
+  %0 = vector.bitcast %arg0 : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+  return %0 : vector<1x2x[3]x8xi32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
+// CHECK:         vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_bitcast
+    } : !transform.any_op
+    transform.yield
+  }
+}

>From 2dd32ce1d16fc44e5a913edd1cef0d02aadd99af Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 31 May 2024 16:49:06 -0700
Subject: [PATCH 2/3] clang format

---
 mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
index 581ee54fb2935..969da47e12af9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -47,7 +47,7 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
 public:
   UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
                   PatternBenefit benefit = 1)
-      : OpRewritePattern(context, benefit), targetRank(targetRank){};
+      : OpRewritePattern(context, benefit), targetRank(targetRank) {};
 
   LogicalResult matchAndRewrite(vector::BitCastOp op,
                                 PatternRewriter &rewriter) const override {

>From b9fa8fcf51cc363f9cd7db5e834d467d3af331bc Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 3 Jun 2024 11:34:45 -0700
Subject: [PATCH 3/3] address comments

---
 .../Vector/Transforms/LowerVectorBitCast.cpp   | 10 ++++++----
 .../vector-bitcast-lowering-transforms.mlir    | 18 ++++++++++++++++++
 2 files changed, 24 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
index 969da47e12af9..092ec927c92ae 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -59,11 +59,13 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
     // TODO: Support the scalable vector cases. It is not supported because
     // the final rank could be values other than `targetRank`. It makes creating
     // the result type of new vector.bitcast ops much harder.
-    if (resultType.isScalable())
-      return rewriter.notifyMatchFailure(
-          op, "unrolling vector.bitcast on scalable vectors is NIY");
+    if (resultType.isScalable()) {
+      return rewriter.notifyMatchFailure(op,
+                                         "unrolling vector.bitcast on scalable "
+                                         "vectors is not yet implemented");
+    }
 
-    SmallVector<int64_t> shape(resultType.getShape().take_back(targetRank));
+    ArrayRef<int64_t> shape = resultType.getShape().take_back(targetRank);
     auto bitcastResType = VectorType::get(shape, resultType.getElementType());
 
     Location loc = op.getLoc();
diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
index e8c529dcacc75..23fece208c561 100644
--- a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -1,5 +1,23 @@
 // RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
+func.func @vector_bitcast_0d(%arg0: vector<i32>) -> vector<f32> {
+  %0 = vector.bitcast %arg0 : vector<i32> to vector<f32>
+  return %0 : vector<f32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_0d
+// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
+// CHECK:         %[[RES:.+]] = vector.bitcast %[[IN]] : vector<i32> to vector<f32>
+// CHECK:         return %[[RES]]
+
+func.func @vector_bitcast_1d(%arg0: vector<10xi64>) -> vector<20xi32> {
+  %0 = vector.bitcast %arg0 : vector<10xi64> to vector<20xi32>
+  return %0 : vector<20xi32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_1d
+// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
+// CHECK:         %[[RES:.+]] = vector.bitcast %[[IN]] : vector<10xi64> to vector<20xi32>
+// CHECK:         return %[[RES]]
+
 func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
   %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
   return %0 : vector<2x2xi64>



More information about the Mlir-commits mailing list