[Mlir-commits] [mlir] 688d650 - [mlir][vector] Add scalable vectors support to OuterProductOp

Prabhdeep Singh Soni llvmlistbot at llvm.org
Mon Jan 16 08:50:16 PST 2023


Author: Frank (Fang) Gao
Date: 2023-01-16T11:49:11-05:00
New Revision: 688d6507c7e2f49668ab0d1f71a1f86f933f99f1

URL: https://github.com/llvm/llvm-project/commit/688d6507c7e2f49668ab0d1f71a1f86f933f99f1
DIFF: https://github.com/llvm/llvm-project/commit/688d6507c7e2f49668ab0d1f71a1f86f933f99f1.diff

LOG: [mlir][vector] Add scalable vectors support to OuterProductOp

This will probably be the first in a series of patches that tries to
enable code generation for ARM SME (extension of SVE).

Since SME's core operation is the outer product instruction, I figured
that it would probably be a good idea to enable the outer product
operation to properly accept and generate scalable vectors.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D138718

Added: 
    mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d565ecdd8735d..933945233c885 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2664,10 +2664,18 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
   if (!vLHS)
     return parser.emitError(parser.getNameLoc(),
                             "expected vector type for operand #1");
-  VectorType resType =
-      vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
-                             vLHS.getElementType())
-           : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
+
+  unsigned numScalableDims = vLHS.getNumScalableDims();
+  VectorType resType;
+  if (vRHS) {
+    numScalableDims += vRHS.getNumScalableDims();
+    resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
+                              vLHS.getElementType(), numScalableDims);
+  } else {
+    // Scalar RHS operand
+    resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
+                              numScalableDims);
+  }
 
   if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
     result.attributes.append(
@@ -2703,6 +2711,9 @@ LogicalResult OuterProductOp::verify() {
       return emitOpError("expected #1 operand dim to match result dim #1");
     if (vRHS.getDimSize(0) != vRES.getDimSize(1))
       return emitOpError("expected #2 operand dim to match result dim #2");
+    if (vRHS.isScalable() != vLHS.isScalable())
+      return emitOpError("expected either all or none of vector operands #1 "
+                         "and #2 to be scalable");
   } else {
     // An AXPY operation.
     if (vRES.getRank() != 1)

diff  --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
new file mode 100644
index 0000000000000..8e010d2183fd7
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt
+
+func.func @scalable_outerproduct(%src : memref<?xf32>) {
+  %idx = arith.constant 0 : index
+  %cst = arith.constant 1.0 : f32
+  %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+  %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+
+  %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<[4]xf32>
+  vector.store %op, %src[%idx] : memref<?xf32>, vector<[4x4]xf32>
+
+  %op2 = vector.outerproduct %0, %cst : vector<[4]xf32>, f32
+  vector.store %op2, %src[%idx] : memref<?xf32>, vector<[4]xf32>
+  return
+}
+
+// -----
+
+func.func @invalid_outerproduct(%src : memref<?xf32>) {
+  %idx = arith.constant 0 : index
+  %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+  %1 = vector.load %src[%idx] : memref<?xf32>, vector<4xf32>
+
+  // expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}}
+  %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32>
+}
+// -----
+
+func.func @invalid_outerproduct1(%src : memref<?xf32>) {
+  %idx = arith.constant 0 : index
+  %0 = vector.load %src[%idx] : memref<?xf32>, vector<[4x4]xf32>
+  %1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
+
+  // expected-error @+1 {{expected 1-d vector for operand #1}}
+  %op = vector.outerproduct %0, %1 : vector<[4x4]xf32>, vector<[4]xf32>
+}


        


More information about the Mlir-commits mailing list