[Mlir-commits] [mlir] bd20756 - [mlir] Support tensor types in unrolled VectorToSCF
Matthias Springer
llvmlistbot at llvm.org
Tue Jun 1 18:44:22 PDT 2021
Author: Matthias Springer
Date: 2021-06-02T10:44:04+09:00
New Revision: bd20756d2c583002de862cb2aa41d54c8e9bc3d0
URL: https://github.com/llvm/llvm-project/commit/bd20756d2c583002de862cb2aa41d54c8e9bc3d0
DIFF: https://github.com/llvm/llvm-project/commit/bd20756d2c583002de862cb2aa41d54c8e9bc3d0.diff
LOG: [mlir] Support tensor types in unrolled VectorToSCF
Differential Revision: https://reviews.llvm.org/D102668
Added:
mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir
Modified:
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 7637c22d17bd5..9993c823d6d6f 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -866,7 +866,7 @@ struct UnrollTransferReadConversion
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
- if (xferOp.getShapedType().template isa<RankedTensorType>())
+ if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
if (xferOp.getVectorType().getElementType() !=
@@ -988,7 +988,7 @@ struct UnrollTransferWriteConversion
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
- if (xferOp.getShapedType().template isa<RankedTensorType>())
+ if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
if (xferOp.getVectorType().getElementType() !=
@@ -998,15 +998,19 @@ struct UnrollTransferWriteConversion
auto vec = getDataVector(xferOp);
auto xferVecType = xferOp.getVectorType();
int64_t dimSize = xferVecType.getShape()[0];
+ auto source = xferOp.source(); // memref or tensor to be written to.
+ auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
// Generate fully unrolled loop of transfer ops.
Location loc = xferOp.getLoc();
for (int64_t i = 0; i < dimSize; ++i) {
Value iv = rewriter.create<ConstantIndexOp>(loc, i);
- generateInBoundsCheck(
+ auto updatedSource = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp),
- /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
+ isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
+ /*inBoundsCase=*/
+ [&](OpBuilder &b, Location loc) {
// Indices for the new transfer op.
SmallVector<Value, 8> xferIndices;
getXferIndices(b, xferOp, iv, xferIndices);
@@ -1019,17 +1023,29 @@ struct UnrollTransferWriteConversion
auto extracted =
b.create<vector::ExtractOp>(loc, vec, extractionIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
-
auto newXferOp = b.create<vector::TransferWriteOp>(
- loc, Type(), extracted, xferOp.source(), xferIndices,
+ loc, sourceType, extracted, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
maybeAssignMask(b, xferOp, newXferOp, i);
+
+ return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
+ },
+ /*outOfBoundsCase=*/
+ [&](OpBuilder &b, Location loc) {
+ return isTensorOp(xferOp) ? source : Value();
});
+
+ if (isTensorOp(xferOp))
+ source = updatedSource;
}
- rewriter.eraseOp(xferOp);
+ if (isTensorOp(xferOp))
+ rewriter.replaceOp(xferOp, source);
+ else
+ rewriter.eraseOp(xferOp);
+
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir b/mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir
new file mode 100644
index 0000000000000..443d67eba53da
--- /dev/null
+++ b/mlir/test/Conversion/VectorToSCF/unrolled-tensor-transfer-ops.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-tensors=true' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: func @transfer_read_2d(
+// CHECK: %[[V_INIT:.*]] = constant dense<-4.200000e+01> : vector<4x9xf32>
+// CHECK: %[[V0:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
+// CHECK: %[[I0:.*]] = vector.insert %[[V0]], %[[V_INIT]] [0] : vector<9xf32> into vector<4x9xf32>
+// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
+// CHECK: %[[I1:.*]] = vector.insert %[[V1]], %[[I0]] [1] : vector<9xf32> into vector<4x9xf32>
+// CHECK: %[[V2:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
+// CHECK: %[[I2:.*]] = vector.insert %[[V2]], %[[I1]] [2] : vector<9xf32> into vector<4x9xf32>
+// CHECK: %[[V3:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
+// CHECK: %[[I3:.*]] = vector.insert %[[V3]], %[[I2]] [3] : vector<9xf32> into vector<4x9xf32>
+// CHECK: return %[[I3]] : vector<4x9xf32>
+func @transfer_read_2d(%A : tensor<?x?xf32>, %base1 : index, %base2 : index)
+ -> (vector<4x9xf32>){
+ %p = constant -42.0: f32
+ %f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]}
+ : tensor<?x?xf32>, vector<4x9xf32>
+ return %f : vector<4x9xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @transfer_write_2d(
+// CHECK: %[[V0:.*]] = vector.extract %{{.*}}[0] : vector<2x3xf32>
+// CHECK: %[[T0:.*]] = vector.transfer_write %[[V0]], %{{.*}}[{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor<?x?xf32>
+// CHECK: %[[V1:.*]] = vector.extract %{{.*}}[1] : vector<2x3xf32>
+// CHECK: %[[T1:.*]] = vector.transfer_write %[[V1]], %[[T0]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor<?x?xf32>
+// CHECK: return %[[T1]] : tensor<?x?xf32>
+func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
+ %base1 : index, %base2 : index) -> (tensor<?x?xf32>) {
+ %t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]}
+ : vector<2x3xf32>, tensor<?x?xf32>
+ return %t : tensor<?x?xf32>
+}
+
More information about the Mlir-commits
mailing list