[Mlir-commits] [mlir] 3aeb27d - [mlir][Vector] Fix 0-D tensor vectorization in Linalg

Diego Caballero llvmlistbot at llvm.org
Fri Jun 16 16:46:20 PDT 2023


Author: Diego Caballero
Date: 2023-06-16T23:45:03Z
New Revision: 3aeb27d69a16e7d8aa3fe684c33d506cdda97b78

URL: https://github.com/llvm/llvm-project/commit/3aeb27d69a16e7d8aa3fe684c33d506cdda97b78
DIFF: https://github.com/llvm/llvm-project/commit/3aeb27d69a16e7d8aa3fe684c33d506cdda97b78.diff

LOG: [mlir][Vector] Fix 0-D tensor vectorization in Linalg

It looks like scalable vector support broke vectorization for 0-D
tensors and we didn't have any test coverting that case. This patch
provides a fix and a test.

Differential Revision: https://reviews.llvm.org/D153181

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 685567d1631be..bbcde44f08618 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1199,38 +1199,43 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   //   a. Get the first max ranked shape.
   VectorType firstMaxRankedType;
   for (Value operand : op->getOperands()) {
-    auto vecType = dyn_cast<VectorType>(bvm.lookup(operand).getType());
+    auto vecOperand = bvm.lookup(operand);
+    assert(vecOperand && "Vector operand couldn't be found");
+
+    auto vecType = dyn_cast<VectorType>(vecOperand.getType());
     if (vecType && (!firstMaxRankedType ||
                     firstMaxRankedType.getRank() < vecType.getRank()))
       firstMaxRankedType = vecType;
   }
   //   b. Broadcast each op if needed.
-  SmallVector<Value> vectorizedOperands;
+  SmallVector<Value> vecOperands;
   for (Value scalarOperand : op->getOperands()) {
-    Value vectorizedOperand = bvm.lookup(scalarOperand);
-    auto vecType =
-        VectorType::get(firstMaxRankedType.getShape(),
-                        getElementTypeOrSelf(vectorizedOperand.getType()),
-                        firstMaxRankedType.getNumScalableDims());
-    vectorizedOperands.push_back(
-        !firstMaxRankedType
-            ? vectorizedOperand
-            : broadcastIfNeeded(rewriter, vectorizedOperand, vecType));
+    Value vecOperand = bvm.lookup(scalarOperand);
+    assert(vecOperand && "Vector operand couldn't be found");
+
+    if (firstMaxRankedType) {
+      auto vecType = VectorType::get(firstMaxRankedType.getShape(),
+                                     getElementTypeOrSelf(vecOperand.getType()),
+                                     firstMaxRankedType.getNumScalableDims());
+      vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
+    } else {
+      vecOperands.push_back(vecOperand);
+    }
   }
   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
   SmallVector<Type> resultTypes;
   for (Type resultType : op->getResultTypes()) {
     resultTypes.push_back(
-        !firstMaxRankedType
-            ? resultType
-            : VectorType::get(firstMaxRankedType.getShape(), resultType,
-                              firstMaxRankedType.getNumScalableDims()));
+        firstMaxRankedType
+            ? VectorType::get(firstMaxRankedType.getShape(), resultType,
+                              firstMaxRankedType.getNumScalableDims())
+            : resultType);
   }
   //   d. Build and return the new op.
   return VectorizationResult{
       VectorizationStatus::NewOp,
-      rewriter.create(op->getLoc(), op->getName().getIdentifier(),
-                      vectorizedOperands, resultTypes, op->getAttrs())};
+      rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
+                      resultTypes, op->getAttrs())};
 }
 
 /// Generic vectorization function that rewrites the body of a `linalgOp` into

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 404c3492470fa..130c6bcc11abb 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1719,3 +1719,35 @@ transform.sequence failures(propagate) {
   %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op
   %2 = transform.structured.vectorize %1  { vectorize_padding } : (!transform.any_op) -> !transform.any_op
 }
+
+// -----
+
+func.func @zero_dim_tensor(%input: tensor<f32>, %output: tensor<f32>) -> tensor<f32>
+{
+  %0 = linalg.generic { indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ],
+                        iterator_types = [] }
+                        ins(%input : tensor<f32>)
+                        outs(%output : tensor<f32>) {
+    ^bb0(%arg0: f32, %arg1: f32):
+      %2 = arith.addf %arg0, %arg1 : f32
+      linalg.yield %2 : f32
+    } -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op
+  %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op
+}
+
+// CHECK-LABEL: func @zero_dim_tensor
+//       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
+//       CHECK:     vector.extractelement
+//       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
+//       CHECK:     vector.extractelement
+//       CHECK:     arith.addf {{.*}} : f32
+//       CHECK:     vector.broadcast %{{.*}} : f32 to vector<f32>
+//       CHECK:     vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
+


        


More information about the Mlir-commits mailing list