[Mlir-commits] [mlir] [mlir][vector] `vector.fma` is not `ElementwiseMappable` (PR #132611)

Ivan Butygin llvmlistbot at llvm.org
Sun Mar 23 06:49:53 PDT 2025


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/132611

`ElementwiseMappable` implies `Scalarizable` and `Tensorizable` but `vector.fma` only supports vector inputs.

>From 1fbced6f62232476a08523062fd2dc18e1ede827 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Sun, 23 Mar 2025 14:48:00 +0100
Subject: [PATCH] [mlir][vector] `vector.fma` is not `ElementwiseMappable`

`ElementwiseMappable` implies `Scalarizable` and `Tensorizable` but `vector.fma` only supports vector inputs.
---
 mlir/include/mlir/Dialect/Vector/IR/VectorOps.td    | 5 +++--
 mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp | 3 ++-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..c895e32839acd 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -759,8 +759,9 @@ def Vector_ExtractOp :
 def Vector_FMAOp :
   Op<Vector_Dialect, "fma", [
        Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
-       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
-     ] # ElementwiseMappable.traits>,
+       DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+       Elementwise, Vectorizable
+     ] >,
     Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
                    VectorOfAnyRankOf<[AnyFloat]>:$rhs,
                    VectorOfAnyRankOf<[AnyFloat]>:$acc)>,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 04c38f9f7b2e3..2d3febffddc74 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -434,7 +434,8 @@ struct UnrollElementwisePattern : public RewritePattern {
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+    if (!op->hasTrait<OpTrait::Elementwise>() ||
+        !op->hasTrait<OpTrait::Vectorizable>() || op->getNumResults() != 1)
       return failure();
     auto targetShape = getTargetShape(options, op);
     if (!targetShape)



More information about the Mlir-commits mailing list