[Mlir-commits] [mlir] 9328412 - [mlir][vector] Fix TransferOpReduceRank for 0-D tensors

Lei Zhang llvmlistbot at llvm.org
Mon Nov 22 09:34:30 PST 2021


Author: Lei Zhang
Date: 2021-11-22T12:30:46-05:00
New Revision: 93284120f28c82503138f3e594358349ed0ab37f

URL: https://github.com/llvm/llvm-project/commit/93284120f28c82503138f3e594358349ed0ab37f
DIFF: https://github.com/llvm/llvm-project/commit/93284120f28c82503138f3e594358349ed0ab37f.diff

LOG: [mlir][vector] Fix TransferOpReduceRank for 0-D tensors

We cannot unconditionally generate memref.load ops for such cases;
need to check the source's type.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D114376

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
    mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
index 3f5c3127a286c..a27ebfc8e5c62 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferPermutationMapRewritePatterns.cpp
@@ -224,9 +224,15 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
     // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
     // In the meantime, lower these to a scalar load when they pop up.
     if (reducedShapeRank == 0) {
-      Value newRead = rewriter.create<memref::LoadOp>(
-          op.getLoc(), originalVecType.getElementType(), op.source(),
-          op.indices());
+      Value newRead;
+      if (op.getShapedType().isa<TensorType>()) {
+        newRead = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.source(),
+                                                     op.indices());
+      } else {
+        newRead = rewriter.create<memref::LoadOp>(
+            op.getLoc(), originalVecType.getElementType(), op.source(),
+            op.indices());
+      }
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
                                                        newRead);
       return success();

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 866d791c7c19f..a5c0cb584b11b 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
 
-// CHECK-LABEL: func @vector_transfer_ops_0d(
+// CHECK-LABEL: func @vector_transfer_ops_0d_memref(
 //  CHECK-SAME:   %[[MEM:.*]]: memref<f32>
 //  CHECK-SAME:   %[[VV:.*]]: vector<1x1x1xf32>
-func @vector_transfer_ops_0d(%M: memref<f32>, %v: vector<1x1x1xf32>) {
+func @vector_transfer_ops_0d_memref(%M: memref<f32>, %v: vector<1x1x1xf32>) {
     %f0 = arith.constant 0.0 : f32
 
 //  CHECK-NEXT:   %[[V:.*]] = memref.load %[[MEM]][] : memref<f32>
@@ -23,6 +23,22 @@ func @vector_transfer_ops_0d(%M: memref<f32>, %v: vector<1x1x1xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @vector_transfer_ops_0d_tensor(
+//  CHECK-SAME:   %[[SOURCE:.*]]: tensor<f32>
+func @vector_transfer_ops_0d_tensor(%M: tensor<f32>) -> vector<1xf32> {
+    %f0 = arith.constant 0.0 : f32
+
+//  CHECK-NEXT:   %[[S:.*]] = tensor.extract %[[SOURCE]][] : tensor<f32>
+//  CHECK-NEXT:   %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<1xf32>
+    %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} :
+      tensor<f32>, vector<1xf32>
+
+//  CHECK-NEXT:   return %[[V]]
+    return %0: vector<1xf32>
+}
+
+// -----
+
 // transfer_read/write are lowered to vector.load/store
 // CHECK-LABEL:   func @transfer_to_load(
 // CHECK-SAME:                                %[[MEM:.*]]: memref<8x8xf32>,


        


More information about the Mlir-commits mailing list