[Mlir-commits] [mlir] 067bd7d - [mlir][vector] Use optional for outerproduct accumulator instead of variadic

Cullen Rhodes llvmlistbot at llvm.org
Thu Aug 31 23:03:55 PDT 2023


Author: Cullen Rhodes
Date: 2023-09-01T05:50:01Z
New Revision: 067bd7d0512be3b4c0ad307ad6855da021194269

URL: https://github.com/llvm/llvm-project/commit/067bd7d0512be3b4c0ad307ad6855da021194269
DIFF: https://github.com/llvm/llvm-project/commit/067bd7d0512be3b4c0ad307ad6855da021194269.diff

LOG: [mlir][vector] Use optional for outerproduct accumulator instead of variadic

This was introduced before the Optional directive and uses Variadic, but
it's really optional.

Reviewed By: nicolasvasilache, benmxwl-arm, dcaballe

Differential Revision: https://reviews.llvm.org/D159259

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index e23ed9258f3102..bf42b4053ac05b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -897,7 +897,7 @@ def Vector_OuterProductOp :
                 TCresVTEtIsSameAsOpBase<0, 1>>,
     DeclareOpInterfaceMethods<MaskableOpInterface>]>,
     Arguments<(ins AnyVector:$lhs, AnyType:$rhs,
-               Variadic<AnyVector>:$acc,
+               Optional<AnyVector>:$acc,
                DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,
     Results<(outs AnyVector)> {
   let summary = "vector outerproduct with optional fused add";
@@ -961,9 +961,9 @@ def Vector_OuterProductOp :
       return getRhs().getType();
     }
     VectorType getOperandVectorTypeACC() {
-      return getAcc().empty()
-        ? VectorType()
-        : ::llvm::cast<VectorType>((*getAcc().begin()).getType());
+      return getAcc()
+        ? ::llvm::cast<VectorType>(getAcc().getType())
+        : VectorType();
     }
     VectorType getResultVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f4dcadd59cd3b9..a2975ac468c05d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2756,7 +2756,7 @@ void OuterProductOp::build(OpBuilder &builder, OperationState &result,
 
 void OuterProductOp::print(OpAsmPrinter &p) {
   p << " " << getLhs() << ", " << getRhs();
-  if (!getAcc().empty()) {
+  if (getAcc()) {
     p << ", " << getAcc();
     p.printOptionalAttrDict((*this)->getAttrs());
   }

diff  --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 1b3d617a79edb7..b66077372164e7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1128,7 +1128,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
     VectorType resType = op.getResultVectorType();
     Type eltType = resType.getElementType();
     bool isInt = isa<IntegerType, IndexType>(eltType);
-    Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
+    Value acc = op.getAcc();
     vector::CombiningKind kind = op.getKind();
 
     // Vector mask setup.


        


More information about the Mlir-commits mailing list