[Mlir-commits] [mlir] bd87c6b - [mlir][Vector] Add custom slt / SCF.if folding to VectorToSCF
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Jul 6 05:24:24 PDT 2020
Author: Nicolas Vasilache
Date: 2020-07-06T08:21:21-04:00
New Revision: bd87c6bce1c30cc089ffdea5e0f3cf5407ed37c5
URL: https://github.com/llvm/llvm-project/commit/bd87c6bce1c30cc089ffdea5e0f3cf5407ed37c5
DIFF: https://github.com/llvm/llvm-project/commit/bd87c6bce1c30cc089ffdea5e0f3cf5407ed37c5.diff
LOG: [mlir][Vector] Add custom slt / SCF.if folding to VectorToSCF
scf.if currently lacks folding on true / false conditionals.
Such foldings are a bit more involved than can be addressed immediately.
This revision introduces an eager folding for lowering vector.transfer operations in the presence of unrolling.
Differential revision: https://reviews.llvm.org/D83146
Added:
Modified:
mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cf3d9653d7df..c7b4db1d5ce3 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -174,6 +174,27 @@ void NDTransferOpHelper<ConcreteOp>::emitLoops(Lambda loopBodyBuilder) {
}
}
+static Optional<int64_t> extractConstantIndex(Value v) {
+ if (auto cstOp = v.getDefiningOp<ConstantIndexOp>())
+ return cstOp.getValue();
+ if (auto affineApplyOp = v.getDefiningOp<AffineApplyOp>())
+ if (affineApplyOp.getAffineMap().isSingleConstant())
+ return affineApplyOp.getAffineMap().getSingleConstantResult();
+ return None;
+}
+
+// Missing foldings of scf.if make it necessary to perform poor man's folding
+// eagerly, especially in the case of unrolling. In the future, this should go
+// away once scf.if folds properly.
+static Value onTheFlyFoldSLT(Value v, Value ub) {
+ using namespace mlir::edsc::op;
+ auto maybeCstV = extractConstantIndex(v);
+ auto maybeCstUb = extractConstantIndex(ub);
+ if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb)
+ return Value();
+ return slt(v, ub);
+}
+
template <typename ConcreteOp>
Value NDTransferOpHelper<ConcreteOp>::emitInBoundsCondition(
ValueRange majorIvs, ValueRange majorOffsets,
@@ -187,9 +208,11 @@ Value NDTransferOpHelper<ConcreteOp>::emitInBoundsCondition(
using namespace mlir::edsc::op;
majorIvsPlusOffsets.push_back(iv + off);
if (xferOp.isMaskedDim(leadingRank + idx)) {
- Value inBounds = slt(majorIvsPlusOffsets.back(), ub);
- inBoundsCondition =
- (inBoundsCondition) ? (inBoundsCondition && inBounds) : inBounds;
+ Value inBoundsCond = onTheFlyFoldSLT(majorIvsPlusOffsets.back(), ub);
+ if (inBoundsCond)
+ inBoundsCondition = (inBoundsCondition)
+ ? (inBoundsCondition && inBoundsCond)
+ : inBoundsCond;
}
++idx;
}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir
index f79b07d00d22..b8c27b51b469 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir
@@ -383,3 +383,20 @@ func @transfer_write_progressive_not_masked(%A : memref<?x?xf32>, %base: index,
vector<3x15xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// FULL-UNROLL-LABEL: transfer_read_simple
+func @transfer_read_simple(%A : memref<2x2xf32>) -> vector<2x2xf32> {
+ %c0 = constant 0 : index
+ %f0 = constant 0.0 : f32
+ // FULL-UNROLL-DAG: %[[VC0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
+ // FULL-UNROLL-DAG: %[[C0:.*]] = constant 0 : index
+ // FULL-UNROLL-DAG: %[[C1:.*]] = constant 1 : index
+ // FULL-UNROLL: %[[V0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]]
+ // FULL-UNROLL: %[[RES0:.*]] = vector.insert %[[V0]], %[[VC0]] [0] : vector<2xf32> into vector<2x2xf32>
+ // FULL-UNROLL: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[C1]], %[[C0]]]
+ // FULL-UNROLL: %[[RES1:.*]] = vector.insert %[[V1]], %[[RES0]] [1] : vector<2xf32> into vector<2x2xf32>
+ %0 = vector.transfer_read %A[%c0, %c0], %f0 : memref<2x2xf32>, vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
More information about the Mlir-commits
mailing list