[Mlir-commits] [mlir] 432341d - [mlir] Handle cases where transfer_read should turn into a scalar load

Stephen Neuendorffer llvmlistbot at llvm.org
Tue Aug 3 23:11:54 PDT 2021


Author: Stephen Neuendorffer
Date: 2021-08-03T22:53:40-07:00
New Revision: 432341d8a81afb95f12ca8e91fbbb4a4b526856f

URL: https://github.com/llvm/llvm-project/commit/432341d8a81afb95f12ca8e91fbbb4a4b526856f
DIFF: https://github.com/llvm/llvm-project/commit/432341d8a81afb95f12ca8e91fbbb4a4b526856f.diff

LOG: [mlir] Handle cases where transfer_read should turn into a scalar load

The existing vector transforms reduce the dimension of transfer_read
ops.  However, beyond a certain point, the vector op actually has
to be reduced to a scalar load, since we can't load a zero-dimension
vector.  This handles this case.

Note that in the longer term, it may be preferaby to support
zero-dimension vectors.  see
https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.

Differential Revision: https://reviews.llvm.org/D103432

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 6e7f644d3aeb7..278cd6d639cfa 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2828,6 +2828,18 @@ struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
     // with broadasting. Otherwise we first want to permute the map.
     if (!newMap.isMinorIdentityWithBroadcasting())
       return failure();
+
+    // TODO: support zero-dimension vectors natively.  See:
+    // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
+    // In the meantime, lower these to a scalar load when they pop up.
+    if (reducedShapeRank == 0) {
+      Value newRead = rewriter.create<memref::LoadOp>(
+          op.getLoc(), originalVecType.getElementType(), op.source(),
+          op.indices());
+      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
+                                                       newRead);
+      return success();
+    }
     SmallVector<int64_t> newShape = llvm::to_vector<4>(
         originalVecType.getShape().take_back(reducedShapeRank));
     // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
index 910100d61af4f..825d7468dc952 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir
@@ -228,6 +228,7 @@ func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index)
 #map3 = affine_map<(d0, d1) -> (d1, d0, 0, 0)>
 #map4 = affine_map<(d0, d1) -> (0, d1, 0, d0)>
 #map5 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>
+#map6 = affine_map<(d0, d1) -> (0)>
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, 0, 0)>
 // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>
@@ -235,7 +236,7 @@ func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index)
 // CHECK-LABEL: func @transfer_read_permutations
 func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?xf32>)
     -> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
-       vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>) {
+       vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<8xf32>) {
 // CHECK-DAG: %[[CF0:.*]] = constant 0.000000e+00 : f32
 // CHECK-DAG: %[[C0:.*]] = constant 0 : index
   %cst = constant 0.000000e+00 : f32
@@ -275,9 +276,13 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
 // CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[CF0]] : memref<?x?x?x?xf32>, vector<16x14x7x8xf32>
 // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>
 
-  return %0, %1, %2, %3, %4, %5 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
+  %6 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map6} : memref<?x?xf32>, vector<8xf32>
+// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>
+// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32>
+
+  return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
          vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
-         vector<7x14x8x16xf32>
+         vector<7x14x8x16xf32>, vector<8xf32>
 }
 
 // -----


        


More information about the Mlir-commits mailing list