[Mlir-commits] [mlir] [mlir][vector] Add support for unrolling vector.bitcast ops. (PR #94064)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 31 15:28:09 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

<details>
<summary>Changes</summary>

The revision unrolls vector.bitcast like:

```mlir
%0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
```

to

```mlir
%cst = arith.constant dense<0> : vector<2x2xi64>
%0 = vector.extract %arg0[0] : vector<4xi32> from vector<2x4xi32>
%1 = vector.bitcast %0 : vector<4xi32> to vector<2xi64>
%2 = vector.insert %1, %cst [0] : vector<2xi64> into vector<2x2xi64>
%3 = vector.extract %arg0[1] : vector<4xi32> from vector<2x4xi32>
%4 = vector.bitcast %3 : vector<4xi32> to vector<2xi64>
%5 = vector.insert %4, %2 [1] : vector<2xi64> into vector<2x2xi64>
```

The scalable vector is not supported because of the limitation of `vector::createUnrollIterator`. The targetRank could mismatch the final rank during unrolling; there is no direct way to query what the final rank is from the object.

---
Full diff: https://github.com/llvm/llvm-project/pull/94064.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+14) 
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+9) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1) 
- (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5) 
- (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp (+94) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+10) 
- (added) mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir (+35) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index bc3c16d40520e..c91e8fbbae90f 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -89,6 +89,20 @@ def ApplyTransferPermutationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyLowerBitCastPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.lower_bitcast",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector bitcast operations should be lowered to
+    finer-grained vector primitives.
+
+    This is usally a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyLowerBroadcastPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_broadcast",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 8fd9904fabc0e..1976b8399c7f9 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -276,6 +276,15 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
 void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
                                                PatternBenefit benefit = 1);
 
+/// Populates the pattern set with the following patterns:
+///
+/// [UnrollBitCastOp]
+/// A one-shot unrolling of BitCastOp to (one or more) ExtractOp +
+/// BitCastOp (of `targetRank`) + InsertOp.
+void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
+                                           int64_t targetRank = 1,
+                                           PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index e3a436c4a9400..55143d5939ba2 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -64,6 +64,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
+    populateVectorBitCastLoweringPatterns(patterns);
     populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
     populateVectorMaskOpLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 61fd6bd972e3a..23960269095e5 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -79,6 +79,11 @@ void transform::ApplyTransferPermutationPatternsOp::populatePatterns(
   vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
 }
 
+void transform::ApplyLowerBitCastPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorBitCastLoweringPatterns(patterns);
+}
+
 void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorBroadcastLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4dbefdd376a8b..723b2f62d65d4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorTransforms
   BufferizableOpInterfaceImpl.cpp
+  LowerVectorBitCast.cpp
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp
   LowerVectorGather.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
new file mode 100644
index 0000000000000..581ee54fb2935
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp
@@ -0,0 +1,94 @@
+//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.bitcast' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-bitcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// A one-shot unrolling of vector.bitcast to the `targetRank`.
+///
+/// Example:
+///
+///   vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
+///
+/// Would be unrolled to:
+///
+/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
+/// %0 = vector.extract %a[0, 0, 0]                 ─┐
+///        : vector<4xi64> from vector<1x2x3x4xi64>  |
+/// %1 = vector.bitcast %0                           | - Repeated 6x for
+///        : vector<4xi64> to vector<8xi32>          |   all leading positions
+/// %2 = vector.insert %1, %result [0, 0, 0]         |
+///        : vector<8xi64> into vector<1x2x3x8xi32> ─┘
+///
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
+public:
+  UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
+                  PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), targetRank(targetRank){};
+
+  LogicalResult matchAndRewrite(vector::BitCastOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+    if (!unrollIterator)
+      return failure();
+
+    // TODO: Support the scalable vector cases. It is not supported because
+    // the final rank could be values other than `targetRank`. It makes creating
+    // the result type of new vector.bitcast ops much harder.
+    if (resultType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "unrolling vector.bitcast on scalable vectors is NIY");
+
+    SmallVector<int64_t> shape(resultType.getShape().take_back(targetRank));
+    auto bitcastResType = VectorType::get(shape, resultType.getElementType());
+
+    Location loc = op.getLoc();
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    for (auto position : *unrollIterator) {
+      Value extract =
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
+      Value bitcast =
+          rewriter.create<vector::BitCastOp>(loc, bitcastResType, extract);
+      result =
+          rewriter.create<vector::InsertOp>(loc, bitcast, result, position);
+    }
+
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+
+private:
+  int64_t targetRank = 1;
+};
+
+} // namespace
+
+void mlir::vector::populateVectorBitCastLoweringPatterns(
+    RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
+  patterns.add<UnrollBitCastOp>(targetRank, patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 245edb6789d30..12121ea0dd70e 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2564,3 +2564,13 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
     %0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
     return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @vector_bitcast_2d
+// CHECK:         llvm.bitcast
+// CHECK-NOT:     vector.bitcast
+func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
+  %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
+  return %0 : vector<2x2xi64>
+}
diff --git a/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
new file mode 100644
index 0000000000000..e8c529dcacc75
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
+  %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64>
+  return %0 : vector<2x2xi64>
+}
+// CHECK-LABEL: func.func @vector_bitcast_2d
+// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
+// CHECK:         %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
+// CHECK:         %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
+// CHECK:         %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
+// CHECK:         %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
+// CHECK:         %[[V2:.+]] = vector.extract %[[IN]][1] : vector<4xi32> from vector<2x4xi32>
+// CHECK:         %[[B2:.+]] = vector.bitcast %[[V2]] : vector<4xi32> to vector<2xi64>
+// CHECK:         %[[R2:.+]] = vector.insert %[[B2]], %[[R1]] [1]
+// CHECK:         return %[[R2]]
+
+func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) -> vector<1x2x[3]x8xi32> {
+  %0 = vector.bitcast %arg0 : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+  return %0 : vector<1x2x[3]x8xi32>
+}
+// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
+// CHECK:         vector.bitcast {{.+}} : vector<1x2x[3]x4xi64> to vector<1x2x[3]x8xi32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+    %f = transform.structured.match ops{["func.func"]} in %module_op
+      : (!transform.any_op) -> !transform.any_op
+
+    transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_bitcast
+    } : !transform.any_op
+    transform.yield
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/94064


More information about the Mlir-commits mailing list