[Mlir-commits] [mlir] [mlir][vector] `vector.fma` is not `ElementwiseMappable` (PR #132611)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Mar 23 06:50:25 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
`ElementwiseMappable` implies `Scalarizable` and `Tensorizable` but `vector.fma` only supports vector inputs.
---
Full diff: https://github.com/llvm/llvm-project/pull/132611.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+3-2)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+2-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index fbbf817ecff98..c895e32839acd 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -759,8 +759,9 @@ def Vector_ExtractOp :
def Vector_FMAOp :
Op<Vector_Dialect, "fma", [
Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
- DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
- ] # ElementwiseMappable.traits>,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
+ Elementwise, Vectorizable
+ ] >,
Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
VectorOfAnyRankOf<[AnyFloat]>:$rhs,
VectorOfAnyRankOf<[AnyFloat]>:$acc)>,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 04c38f9f7b2e3..2d3febffddc74 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -434,7 +434,8 @@ struct UnrollElementwisePattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
- if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+ if (!op->hasTrait<OpTrait::Elementwise>() ||
+ !op->hasTrait<OpTrait::Vectorizable>() || op->getNumResults() != 1)
return failure();
auto targetShape = getTargetShape(options, op);
if (!targetShape)
``````````
</details>
https://github.com/llvm/llvm-project/pull/132611
More information about the Mlir-commits
mailing list