[Mlir-commits] [mlir] [mlir][vector] Add mask elimination transform (PR #99314)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Jul 22 06:04:52 PDT 2024


================
@@ -0,0 +1,129 @@
+//===- VectorMaskElimination.cpp - Eliminate Vector Masks -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+namespace {
+
+/// If `value` is a constant multiple of `vector.vscale` return the multiplier.
+std::optional<int64_t> getConstantVscaleMultiplier(Value value) {
+  if (value.getDefiningOp<vector::VectorScaleOp>())
+    return 1;
+  auto mul = value.getDefiningOp<arith::MulIOp>();
+  if (!mul)
+    return {};
+  auto lhs = mul.getLhs();
+  auto rhs = mul.getRhs();
+  if (lhs.getDefiningOp<vector::VectorScaleOp>())
+    return getConstantIntValue(rhs);
+  if (rhs.getDefiningOp<vector::VectorScaleOp>())
+    return getConstantIntValue(lhs);
----------------
MacDue wrote:

I'm not sure how it does? :sweat_smile: 

`int * int` is not a valid case nor is `vscale * vscale`. 

https://github.com/llvm/llvm-project/pull/99314


More information about the Mlir-commits mailing list