[Mlir-commits] [mlir] 688d650 - [mlir][vector] Add scalable vectors support to OuterProductOp
Prabhdeep Singh Soni
llvmlistbot at llvm.org
Mon Jan 16 08:50:16 PST 2023
Author: Frank (Fang) Gao
Date: 2023-01-16T11:49:11-05:00
New Revision: 688d6507c7e2f49668ab0d1f71a1f86f933f99f1
URL: https://github.com/llvm/llvm-project/commit/688d6507c7e2f49668ab0d1f71a1f86f933f99f1
DIFF: https://github.com/llvm/llvm-project/commit/688d6507c7e2f49668ab0d1f71a1f86f933f99f1.diff
LOG: [mlir][vector] Add scalable vectors support to OuterProductOp
This will probably be the first in a series of patches that tries to
enable code generation for ARM SME (extension of SVE).
Since SME's core operation is the outer product instruction, I figured
that it would probably be a good idea to enable the outer product
operation to properly accept and generate scalable vectors.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D138718
Added:
mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d565ecdd8735d..933945233c885 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2664,10 +2664,18 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
if (!vLHS)
return parser.emitError(parser.getNameLoc(),
"expected vector type for operand #1");
- VectorType resType =
- vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
- vLHS.getElementType())
- : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
+
+ unsigned numScalableDims = vLHS.getNumScalableDims();
+ VectorType resType;
+ if (vRHS) {
+ numScalableDims += vRHS.getNumScalableDims();
+ resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
+ vLHS.getElementType(), numScalableDims);
+ } else {
+ // Scalar RHS operand
+ resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
+ numScalableDims);
+ }
if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
result.attributes.append(
@@ -2703,6 +2711,9 @@ LogicalResult OuterProductOp::verify() {
return emitOpError("expected #1 operand dim to match result dim #1");
if (vRHS.getDimSize(0) != vRES.getDimSize(1))
return emitOpError("expected #2 operand dim to match result dim #2");
+ if (vRHS.isScalable() != vLHS.isScalable())
+ return emitOpError("expected either all or none of vector operands #1 "
+ "and #2 to be scalable");
} else {
// An AXPY operation.
if (vRES.getRank() != 1)
diff --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
new file mode 100644
index 0000000000000..8e010d2183fd7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt
+
+func.func @scalable_outerproduct(%src : memref<?xf32>) {
+ %idx = arith.constant 0 : index
+ %cst = arith.constant 1.0 : f32
+ %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+ %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+
+ %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<[4]xf32>
+ vector.store %op, %src[%idx] : memref<?xf32>, vector<[4x4]xf32>
+
+ %op2 = vector.outerproduct %0, %cst : vector<[4]xf32>, f32
+ vector.store %op2, %src[%idx] : memref<?xf32>, vector<[4]xf32>
+ return
+}
+
+// -----
+
+func.func @invalid_outerproduct(%src : memref<?xf32>) {
+ %idx = arith.constant 0 : index
+ %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+ %1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
+
+ // expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
+ %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
+}
+// -----
+
+func.func @invalid_outerproduct1(%src : memref<?xf32>) {
+ %idx = arith.constant 0 : index
+ %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4x4]xf32>
+ %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+
+ // expected-error @+1 {{expected 1-d vector for operand #1}}
+ %op = vector.outerproduct %0, %1 : vector<[4x4]xf32>, vector<[4]xf32>
+}
More information about the Mlir-commits
mailing list