[Mlir-commits] [mlir] e547b1e - [mlir] Rank reducing subview conversion to LLVM
Jakub Lichman
llvmlistbot at llvm.org
Thu Oct 8 06:47:54 PDT 2020
Author: Jakub Lichman
Date: 2020-10-08T13:47:22Z
New Revision: e547b1e2431f9b6175470ff703cf6e1988031cda
URL: https://github.com/llvm/llvm-project/commit/e547b1e2431f9b6175470ff703cf6e1988031cda
DIFF: https://github.com/llvm/llvm-project/commit/e547b1e2431f9b6175470ff703cf6e1988031cda.diff
LOG: [mlir] Rank reducing subview conversion to LLVM
This commit adjusts SubViewOp lowering to take rank reduction into account.
Differential Revision: https://reviews.llvm.org/D88883
Added:
mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 747a83414a08..d87869270c21 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -301,6 +301,18 @@ class DmaWaitOp
LogicalResult verify();
};
+/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
+/// `originalShape` with some `1` entries erased, return the vector of booleans
+/// that specifies which of the entries of `originalShape` are keep to obtain
+/// `reducedShape`. The returned mask can be applied as a projection to
+/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
+/// which dimensions must be kept when e.g. compute MemRef strides under
+/// rank-reducing operations. Return None if reducedShape cannot be obtained
+/// by dropping only `1` entries in `originalShape`.
+llvm::Optional<SmallVector<bool, 4>>
+computeRankReductionMask(ArrayRef<int64_t> originalShape,
+ ArrayRef<int64_t> reducedShape);
+
/// Prints dimension and symbol list.
void printDimAndSymbolList(Operation::operand_iterator begin,
Operation::operand_iterator end, unsigned numDims,
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir b/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
new file mode 100644
index 000000000000..9157f9abaab9
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @print_memref_f32(memref<*xf32>)
+
+func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>) -> (memref<?x?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %f0 = constant 0.0 : f32
+ %x = dim %A, %c0 : memref<?x?xf32>
+ %y = dim %B, %c1 : memref<?x?xf32>
+ %C = alloc(%x, %y) : memref<?x?xf32>
+ linalg.fill(%C, %f0) : memref<?x?xf32>, f32
+ linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+ outs(%C: memref<?x?xf32>)
+ return %C : memref<?x?xf32>
+}
+
+func @matvec(%A: memref<?x?xf32>, %B: memref<?x?xf32>) -> (memref<?x?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %f0 = constant 0.0 : f32
+ %m = dim %A, %c0 : memref<?x?xf32>
+ %x = dim %A, %c1 : memref<?x?xf32>
+ %n = dim %B, %c1 : memref<?x?xf32>
+ %C = alloc(%m, %n) : memref<?x?xf32>
+ linalg.fill(%C, %f0) : memref<?x?xf32>, f32
+ scf.for %i = %c0 to %n step %c1 {
+ %b = subview %B[0, %i][%x, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
+ %c = subview %C[0, %i][%m, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
+ linalg.matvec ins(%A, %b: memref<?x?xf32>, memref<?xf32, offset: ?, strides: [?]>)
+ outs(%c: memref<?xf32, offset: ?, strides: [?]>)
+ }
+ return %C : memref<?x?xf32>
+}
+
+func @main() {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %m = constant 5 : index
+ %x = constant 3 : index
+ %n = constant 2 : index
+ %val1 = constant 13.0 : f32
+ %val2 = constant 17.0 : f32
+ %A = alloc(%m, %x) : memref<?x?xf32>
+ %B = alloc(%x, %n) : memref<?x?xf32>
+ linalg.fill(%A, %val1) : memref<?x?xf32>, f32
+ linalg.fill(%B, %val2) : memref<?x?xf32>, f32
+ store %val1, %B[%c0, %c0] : memref<?x?xf32>
+ %C1 = call @matmul(%A, %B) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
+ %C2 = call @matvec(%A, %B) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
+ scf.for %i = %c0 to %m step %c1 {
+ scf.for %j = %c0 to %n step %c1 {
+ %e1 = load %C1[%i, %j] : memref<?x?xf32>
+ %e2 = load %C2[%i, %j] : memref<?x?xf32>
+ %c = cmpf "oeq", %e1, %e2 : f32
+ assert %c, "Matmul does not produce same output as matvec"
+ }
+ }
+ %C2_ = memref_cast %C2 : memref<?x?xf32> to memref<*xf32>
+ call @print_memref_f32(%C2_) : (memref<*xf32>) -> ()
+ return
+}
+
+// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [5, 2] strides = [2, 1] data =
+// CHECK-NEXT: [
+// CHECK-SAME: [611, 663],
+// CHECK-NEXT: [611, 663],
+// CHECK-NEXT: [611, 663],
+// CHECK-NEXT: [611, 663],
+// CHECK-NEXT: [611, 663]
+// CHECK-SAME: ]
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir b/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir
new file mode 100644
index 000000000000..ceffd6f79d23
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @print_memref_f32(memref<*xf32>)
+
+func @main() {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %f0 = constant 0.0 : f32
+ %f1 = constant 1.0 : f32
+ %f2 = constant 2.0 : f32
+ %f3 = constant 3.0 : f32
+ %A = alloc(%c2, %c2) : memref<?x?xf32>
+ store %f0, %A[%c0, %c0] : memref<?x?xf32>
+ store %f1, %A[%c0, %c1] : memref<?x?xf32>
+ store %f2, %A[%c1, %c0] : memref<?x?xf32>
+ store %f3, %A[%c1, %c1] : memref<?x?xf32>
+ %B = subview %A[%c1, 0][1, %c2][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [1]>
+ %C = subview %A[0, %c1][%c2, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
+ %A_ = memref_cast %A : memref<?x?xf32> to memref<*xf32>
+ call @print_memref_f32(%A_) : (memref<*xf32>) -> ()
+ %B_ = memref_cast %B : memref<?xf32, offset: ?, strides: [1]> to memref<*xf32>
+ call @print_memref_f32(%B_) : (memref<*xf32>) -> ()
+ %C_ = memref_cast %C : memref<?xf32, offset: ?, strides: [?]> to memref<*xf32>
+ call @print_memref_f32(%C_) : (memref<*xf32>) -> ()
+ return
+}
+
+// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [2, 2] strides = [2, 1] data =
+// CHECK-NEXT: [
+// CHECK-SAME: [0, 1],
+// CHECK-NEXT: [2, 3]
+// CHECK-SAME: ]
+// CHECK: [2, 3]
+// CHECK: [1, 3]
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 93bdd1d89d93..e042fc3d1c4e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2928,6 +2928,14 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
}
};
+/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
+static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+ return llvm::to_vector<4>(
+ llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+ return a.cast<IntegerAttr>().getInt();
+ }));
+}
+
/// Conversion pattern that transforms a subview op into:
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
/// 2. Updates to the descriptor to introduce the data ptr, offset, size
@@ -2948,6 +2956,12 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
.dyn_cast_or_null<LLVM::LLVMType>();
auto viewMemRefType = subViewOp.getType();
+ auto inferredType = SubViewOp::inferResultType(
+ subViewOp.getSourceType(),
+ extractFromI64ArrayAttr(subViewOp.static_offsets()),
+ extractFromI64ArrayAttr(subViewOp.static_sizes()),
+ extractFromI64ArrayAttr(subViewOp.static_strides()))
+ .cast<MemRefType>();
auto targetElementTy =
typeConverter.convertType(viewMemRefType.getElementType())
.dyn_cast<LLVM::LLVMType>();
@@ -2959,7 +2973,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
// Extract the offset and strides from the type.
int64_t offset;
SmallVector<int64_t, 4> strides;
- auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
+ auto successStrides = getStridesAndOffset(inferredType, strides, offset);
if (failed(successStrides))
return failure();
@@ -2983,10 +2997,17 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
extracted);
targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
+ auto shape = viewMemRefType.getShape();
+ auto inferredShape = inferredType.getShape();
+ size_t inferredShapeRank = inferredShape.size();
+ size_t resultShapeRank = shape.size();
+ SmallVector<bool, 4> mask =
+ computeRankReductionMask(inferredShape, shape).getValue();
+
// Extract strides needed to compute offset.
SmallVector<Value, 4> strideValues;
- strideValues.reserve(viewMemRefType.getRank());
- for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
+ strideValues.reserve(inferredShapeRank);
+ for (unsigned i = 0; i < inferredShapeRank; ++i)
strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
// Offset.
@@ -2995,7 +3016,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
targetMemRef.setConstantOffset(rewriter, loc, offset);
} else {
Value baseOffset = sourceMemRef.offset(rewriter, loc);
- for (unsigned i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
+ for (unsigned i = 0; i < inferredShapeRank; ++i) {
Value offset =
subViewOp.isDynamicOffset(i)
? operands[subViewOp.getIndexOfDynamicOffset(i)]
@@ -3009,14 +3030,18 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
}
// Update sizes and strides.
- for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
+ for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
+ i >= 0 && j >= 0; --i) {
+ if (!mask[i])
+ continue;
+
Value size =
subViewOp.isDynamicSize(i)
? operands[subViewOp.getIndexOfDynamicSize(i)]
: rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexType,
rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
- targetMemRef.setSize(rewriter, loc, i, size);
+ targetMemRef.setSize(rewriter, loc, j, size);
Value stride;
if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
stride = rewriter.create<LLVM::ConstantOp>(
@@ -3030,7 +3055,8 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i)));
stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
}
- targetMemRef.setStride(rewriter, loc, i, stride);
+ targetMemRef.setStride(rewriter, loc, j, stride);
+ j--;
}
rewriter.replaceOp(op, {targetMemRef});
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index f445a0cce242..82058fdcc03c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2823,6 +2823,30 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
}));
}
+llvm::Optional<SmallVector<bool, 4>>
+mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
+ ArrayRef<int64_t> reducedShape) {
+ size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
+ SmallVector<bool, 4> mask(originalRank);
+ unsigned reducedIdx = 0;
+ for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
+ // Skip matching dims greedily.
+ mask[originalIdx] =
+ (reducedIdx < reducedRank) &&
+ (originalShape[originalIdx] == reducedShape[reducedIdx]);
+ if (mask[originalIdx])
+ reducedIdx++;
+ // 1 is the only non-matching allowed.
+ else if (originalShape[originalIdx] != 1)
+ return {};
+ }
+
+ if (reducedIdx != reducedRank)
+ return {};
+
+ return mask;
+}
+
enum SubViewVerificationResult {
Success,
RankTooLarge,
@@ -2859,20 +2883,10 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
if (reducedRank > originalRank)
return SubViewVerificationResult::RankTooLarge;
- unsigned reducedIdx = 0;
- SmallVector<bool, 4> keepMask(originalRank);
- for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
- // -2 is never used as a dim size so it will never match.
- int reducedVal = reducedIdx < reducedRank ? reducedShape[reducedIdx] : -2;
- // Skip matching dims greedily.
- if ((keepMask[originalIdx] = originalShape[originalIdx] == reducedVal))
- reducedIdx++;
- // 1 is the only non-matching allowed.
- else if (originalShape[originalIdx] != 1)
- return SubViewVerificationResult::SizeMismatch;
- }
- // Must match the reduced rank.
- if (reducedIdx != reducedRank)
+ auto optionalMask = computeRankReductionMask(originalShape, reducedShape);
+
+ // Sizes cannot be matched in case empty vector is returned.
+ if (!optionalMask.hasValue())
return SubViewVerificationResult::SizeMismatch;
// We are done for the tensor case.
@@ -2885,12 +2899,13 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
MLIRContext *c = original.getContext();
int64_t originalOffset, reducedOffset;
SmallVector<int64_t, 4> originalStrides, reducedStrides, keepStrides;
+ SmallVector<bool, 4> keepMask = optionalMask.getValue();
getStridesAndOffset(original, originalStrides, originalOffset);
getStridesAndOffset(reduced, reducedStrides, reducedOffset);
// Filter strides based on the mask and check that they are the same
// as reduced ones.
- reducedIdx = 0;
+ unsigned reducedIdx = 0;
for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
if (keepMask[originalIdx]) {
if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])
More information about the Mlir-commits
mailing list