[Mlir-commits] [mlir] d3e1398 - [mlir][Vector] Prevent vector-to-scalar xfer patterns from triggering on sub-vectors

Diego Caballero llvmlistbot at llvm.org
Thu Jun 1 15:23:29 PDT 2023


Author: Diego Caballero
Date: 2023-06-01T22:22:16Z
New Revision: d3e1398bef44774e0f7965b8933d20f824460825

URL: https://github.com/llvm/llvm-project/commit/d3e1398bef44774e0f7965b8933d20f824460825
DIFF: https://github.com/llvm/llvm-project/commit/d3e1398bef44774e0f7965b8933d20f824460825.diff

LOG: [mlir][Vector] Prevent vector-to-scalar xfer patterns from triggering on sub-vectors

Patterns that convert extract(transfer_read) into a scalar load where
incorrectly triggering for cases where a sub-vector instead of a scalar
was extracted.

Reviewed By: nicolasvasilache, hanchung, awarzynski

Differential Revision: https://reviews.llvm.org/D151862

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index af0fcd097028d..8d97bbfb72572 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -581,6 +581,9 @@ class RewriteScalarExtractOfTransferReadBase
         extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
     if (!xferOp)
       return failure();
+    // Check that we are extracting a scalar and not a sub-vector.
+    if (isa<VectorType>(extractOp.getResult().getType()))
+      return failure();
     // If multiple uses are not allowed, check if xfer has a single use.
     if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
       return failure();
@@ -622,6 +625,7 @@ class RewriteScalarExtractElementOfTransferRead
   void rewrite(vector::ExtractElementOp extractOp,
                PatternRewriter &rewriter) const override {
     // Construct scalar load.
+    auto loc = extractOp.getLoc();
     auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
     SmallVector<Value> newIndices(xferOp.getIndices().begin(),
                                   xferOp.getIndices().end());
@@ -629,13 +633,13 @@ class RewriteScalarExtractElementOfTransferRead
       AffineExpr sym0, sym1;
       bindSymbols(extractOp.getContext(), sym0, sym1);
       OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, extractOp.getLoc(), sym0 + sym1,
+          rewriter, loc, sym0 + sym1,
           {newIndices[newIndices.size() - 1], extractOp.getPosition()});
       if (ofr.is<Value>()) {
         newIndices[newIndices.size() - 1] = ofr.get<Value>();
       } else {
         newIndices[newIndices.size() - 1] =
-            rewriter.create<arith::ConstantIndexOp>(extractOp.getLoc(),
+            rewriter.create<arith::ConstantIndexOp>(loc,
                                                     *getConstantIntValue(ofr));
       }
     }

diff  --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 7029dc717ca4d..ca74a77477f63 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -136,3 +136,20 @@ func.func @transfer_read_multi_use(%m: memref<?xf32>, %idx: index) -> (f32, f32)
   return %1, %2 : f32, f32
 }
 
+// -----
+
+// Check that patterns don't trigger for an sub-vector (not scalar) extraction.
+// CHECK-LABEL: func @subvector_extract(
+//  CHECK-SAME:   %[[m:.*]]: memref<?x?xf32>, %[[idx:.*]]: index
+//   CHECK-NOT:   memref.load
+//       CHECK:   %[[r:.*]] = vector.transfer_read %[[m]][%[[idx]], %[[idx]]]
+//       CHECK:   %[[e0:.*]] = vector.extract %[[r]][0]
+//       CHECK:   return %[[e0]]
+
+func.func @subvector_extract(%m: memref<?x?xf32>, %idx: index) -> vector<16xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32>
+  %1 = vector.extract %0[0] : vector<8x16xf32>
+  return %1 : vector<16xf32>
+}
+


        


More information about the Mlir-commits mailing list