[Mlir-commits] [mlir] d3e1398 - [mlir][Vector] Prevent vector-to-scalar xfer patterns from triggering on sub-vectors
Diego Caballero
llvmlistbot at llvm.org
Thu Jun 1 15:23:29 PDT 2023
Author: Diego Caballero
Date: 2023-06-01T22:22:16Z
New Revision: d3e1398bef44774e0f7965b8933d20f824460825
URL: https://github.com/llvm/llvm-project/commit/d3e1398bef44774e0f7965b8933d20f824460825
DIFF: https://github.com/llvm/llvm-project/commit/d3e1398bef44774e0f7965b8933d20f824460825.diff
LOG: [mlir][Vector] Prevent vector-to-scalar xfer patterns from triggering on sub-vectors
Patterns that convert extract(transfer_read) into a scalar load where
incorrectly triggering for cases where a sub-vector instead of a scalar
was extracted.
Reviewed By: nicolasvasilache, hanchung, awarzynski
Differential Revision: https://reviews.llvm.org/D151862
Added:
Modified:
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index af0fcd097028d..8d97bbfb72572 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -581,6 +581,9 @@ class RewriteScalarExtractOfTransferReadBase
extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
return failure();
+ // Check that we are extracting a scalar and not a sub-vector.
+ if (isa<VectorType>(extractOp.getResult().getType()))
+ return failure();
// If multiple uses are not allowed, check if xfer has a single use.
if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
return failure();
@@ -622,6 +625,7 @@ class RewriteScalarExtractElementOfTransferRead
void rewrite(vector::ExtractElementOp extractOp,
PatternRewriter &rewriter) const override {
// Construct scalar load.
+ auto loc = extractOp.getLoc();
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
@@ -629,13 +633,13 @@ class RewriteScalarExtractElementOfTransferRead
AffineExpr sym0, sym1;
bindSymbols(extractOp.getContext(), sym0, sym1);
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, extractOp.getLoc(), sym0 + sym1,
+ rewriter, loc, sym0 + sym1,
{newIndices[newIndices.size() - 1], extractOp.getPosition()});
if (ofr.is<Value>()) {
newIndices[newIndices.size() - 1] = ofr.get<Value>();
} else {
newIndices[newIndices.size() - 1] =
- rewriter.create<arith::ConstantIndexOp>(extractOp.getLoc(),
+ rewriter.create<arith::ConstantIndexOp>(loc,
*getConstantIntValue(ofr));
}
}
diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
index 7029dc717ca4d..ca74a77477f63 100644
--- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
+++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir
@@ -136,3 +136,20 @@ func.func @transfer_read_multi_use(%m: memref<?xf32>, %idx: index) -> (f32, f32)
return %1, %2 : f32, f32
}
+// -----
+
+// Check that patterns don't trigger for an sub-vector (not scalar) extraction.
+// CHECK-LABEL: func @subvector_extract(
+// CHECK-SAME: %[[m:.*]]: memref<?x?xf32>, %[[idx:.*]]: index
+// CHECK-NOT: memref.load
+// CHECK: %[[r:.*]] = vector.transfer_read %[[m]][%[[idx]], %[[idx]]]
+// CHECK: %[[e0:.*]] = vector.extract %[[r]][0]
+// CHECK: return %[[e0]]
+
+func.func @subvector_extract(%m: memref<?x?xf32>, %idx: index) -> vector<16xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %m[%idx, %idx], %cst {in_bounds = [true, true]} : memref<?x?xf32>, vector<8x16xf32>
+ %1 = vector.extract %0[0] : vector<8x16xf32>
+ return %1 : vector<16xf32>
+}
+
More information about the Mlir-commits
mailing list