[Mlir-commits] [mlir] [mlir][vector] Add support for unrolling vector.bitcast ops. (PR #94064)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Jun 3 03:00:17 PDT 2024


================
@@ -0,0 +1,94 @@
+//===- LowerVectorBitCast.cpp - Lower 'vector.bitcast' operation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.bitcast' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
+
+#define DEBUG_TYPE "vector-bitcast-lowering"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+/// A one-shot unrolling of vector.bitcast to the `targetRank`.
+///
+/// Example:
+///
+///   vector.bitcast %a, %b : vector<1x2x3x4xi64> to vector<1x2x3x8xi32>
+///
+/// Would be unrolled to:
+///
+/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
+/// %0 = vector.extract %a[0, 0, 0]                 ─┐
+///        : vector<4xi64> from vector<1x2x3x4xi64>  |
+/// %1 = vector.bitcast %0                           | - Repeated 6x for
+///        : vector<4xi64> to vector<8xi32>          |   all leading positions
+/// %2 = vector.insert %1, %result [0, 0, 0]         |
+///        : vector<8xi64> into vector<1x2x3x8xi32> ─┘
+///
+/// Note: If any leading dimension before the `targetRank` is scalable the
+/// unrolling will stop before the scalable dimension.
+class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
+public:
+  UnrollBitCastOp(int64_t targetRank, MLIRContext *context,
+                  PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), targetRank(targetRank) {};
+
+  LogicalResult matchAndRewrite(vector::BitCastOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+    if (!unrollIterator)
+      return failure();
+
+    // TODO: Support the scalable vector cases. It is not supported because
+    // the final rank could be values other than `targetRank`. It makes creating
+    // the result type of new vector.bitcast ops much harder.
+    if (resultType.isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "unrolling vector.bitcast on scalable vectors is NIY");
----------------
MacDue wrote:

Not yet implemented (I guess) 

https://github.com/llvm/llvm-project/pull/94064


More information about the Mlir-commits mailing list