[Mlir-commits] [mlir] e547b1e - [mlir] Rank reducing subview conversion to LLVM

Jakub Lichman llvmlistbot at llvm.org
Thu Oct 8 06:47:54 PDT 2020


Author: Jakub Lichman
Date: 2020-10-08T13:47:22Z
New Revision: e547b1e2431f9b6175470ff703cf6e1988031cda

URL: https://github.com/llvm/llvm-project/commit/e547b1e2431f9b6175470ff703cf6e1988031cda
DIFF: https://github.com/llvm/llvm-project/commit/e547b1e2431f9b6175470ff703cf6e1988031cda.diff

LOG: [mlir] Rank reducing subview conversion to LLVM

This commit adjusts SubViewOp lowering to take rank reduction into account.

Differential Revision: https://reviews.llvm.org/D88883

Added: 
    mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
    mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
index 747a83414a08..d87869270c21 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
@@ -301,6 +301,18 @@ class DmaWaitOp
   LogicalResult verify();
 };
 
+/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
+/// `originalShape` with some `1` entries erased, return the vector of booleans
+/// that specifies which of the entries of `originalShape` are keep to obtain
+/// `reducedShape`. The returned mask can be applied as a projection to
+/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
+/// which dimensions must be kept when e.g. compute MemRef strides under
+/// rank-reducing operations. Return None if reducedShape cannot be obtained
+/// by dropping only `1` entries in `originalShape`.
+llvm::Optional<SmallVector<bool, 4>>
+computeRankReductionMask(ArrayRef<int64_t> originalShape,
+                         ArrayRef<int64_t> reducedShape);
+
 /// Prints dimension and symbol list.
 void printDimAndSymbolList(Operation::operand_iterator begin,
                            Operation::operand_iterator end, unsigned numDims,

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir b/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
new file mode 100644
index 000000000000..9157f9abaab9
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/matmul-vs-matvec.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @print_memref_f32(memref<*xf32>)
+
+func @matmul(%A: memref<?x?xf32>, %B: memref<?x?xf32>) -> (memref<?x?xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %f0 = constant 0.0 : f32
+  %x = dim %A, %c0 : memref<?x?xf32>
+  %y = dim %B, %c1 : memref<?x?xf32>
+  %C = alloc(%x, %y) : memref<?x?xf32>
+  linalg.fill(%C, %f0) : memref<?x?xf32>, f32
+  linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
+                outs(%C: memref<?x?xf32>)
+  return %C : memref<?x?xf32>
+}
+
+func @matvec(%A: memref<?x?xf32>, %B: memref<?x?xf32>) -> (memref<?x?xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %f0 = constant 0.0 : f32
+  %m = dim %A, %c0 : memref<?x?xf32>
+  %x = dim %A, %c1 : memref<?x?xf32>
+  %n = dim %B, %c1 : memref<?x?xf32>
+  %C = alloc(%m, %n) : memref<?x?xf32>
+  linalg.fill(%C, %f0) : memref<?x?xf32>, f32
+  scf.for %i = %c0 to %n step %c1 {
+    %b = subview %B[0, %i][%x, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
+    %c = subview %C[0, %i][%m, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
+    linalg.matvec ins(%A, %b: memref<?x?xf32>, memref<?xf32, offset: ?, strides: [?]>)
+                  outs(%c: memref<?xf32, offset: ?, strides: [?]>)
+  }
+  return %C : memref<?x?xf32>
+}
+
+func @main() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %m = constant 5 : index
+  %x = constant 3 : index
+  %n = constant 2 : index
+  %val1 = constant 13.0 : f32
+  %val2 = constant 17.0 : f32
+  %A = alloc(%m, %x) : memref<?x?xf32>
+  %B = alloc(%x, %n) : memref<?x?xf32>
+  linalg.fill(%A, %val1) : memref<?x?xf32>, f32
+  linalg.fill(%B, %val2) : memref<?x?xf32>, f32
+  store %val1, %B[%c0, %c0] : memref<?x?xf32>
+  %C1 = call @matmul(%A, %B) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
+  %C2 = call @matvec(%A, %B) : (memref<?x?xf32>, memref<?x?xf32>) -> memref<?x?xf32>
+  scf.for %i = %c0 to %m step %c1 {
+    scf.for %j = %c0 to %n step %c1 {
+      %e1 = load %C1[%i, %j] : memref<?x?xf32>
+      %e2 = load %C2[%i, %j] : memref<?x?xf32>
+      %c = cmpf "oeq", %e1, %e2 : f32
+      assert %c, "Matmul does not produce same output as matvec"
+    }
+  }
+  %C2_ = memref_cast %C2 : memref<?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%C2_) : (memref<*xf32>) -> ()
+  return
+}
+
+// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [5, 2] strides = [2, 1] data =
+// CHECK-NEXT:      [
+// CHECK-SAME:  [611,   663],
+// CHECK-NEXT:  [611,   663],
+// CHECK-NEXT:  [611,   663],
+// CHECK-NEXT:  [611,   663],
+// CHECK-NEXT:  [611,   663]
+// CHECK-SAME: ]

diff  --git a/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir b/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir
new file mode 100644
index 000000000000..ceffd6f79d23
--- /dev/null
+++ b/mlir/integration_test/Dialect/Linalg/CPU/rank-reducing-subview.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm | \
+// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
+// RUN: | FileCheck %s
+
+func @print_memref_f32(memref<*xf32>)
+
+func @main() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %f0 = constant 0.0 : f32
+  %f1 = constant 1.0 : f32
+  %f2 = constant 2.0 : f32
+  %f3 = constant 3.0 : f32
+  %A = alloc(%c2, %c2) : memref<?x?xf32>
+  store %f0, %A[%c0, %c0] : memref<?x?xf32>
+  store %f1, %A[%c0, %c1] : memref<?x?xf32>
+  store %f2, %A[%c1, %c0] : memref<?x?xf32>
+  store %f3, %A[%c1, %c1] : memref<?x?xf32>
+  %B = subview %A[%c1, 0][1, %c2][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [1]>
+  %C = subview %A[0, %c1][%c2, 1][1, 1] : memref<?x?xf32> to memref<?xf32, offset: ?, strides: [?]>
+  %A_ = memref_cast %A : memref<?x?xf32> to memref<*xf32>
+  call @print_memref_f32(%A_) : (memref<*xf32>) -> ()
+  %B_ = memref_cast %B : memref<?xf32, offset: ?, strides: [1]> to memref<*xf32>
+  call @print_memref_f32(%B_) : (memref<*xf32>) -> ()
+  %C_ = memref_cast %C : memref<?xf32, offset: ?, strides: [?]> to memref<*xf32>
+  call @print_memref_f32(%C_) : (memref<*xf32>) -> ()
+  return
+}
+
+// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [2, 2] strides = [2, 1] data =
+// CHECK-NEXT:      [
+// CHECK-SAME:  [0,   1],
+// CHECK-NEXT:  [2,   3]
+// CHECK-SAME: ]
+// CHECK: [2,  3]
+// CHECK: [1,  3]

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 93bdd1d89d93..e042fc3d1c4e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2928,6 +2928,14 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
   }
 };
 
+/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
+static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
+  return llvm::to_vector<4>(
+      llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
+        return a.cast<IntegerAttr>().getInt();
+      }));
+}
+
 /// Conversion pattern that transforms a subview op into:
 ///   1. An `llvm.mlir.undef` operation to create a memref descriptor
 ///   2. Updates to the descriptor to introduce the data ptr, offset, size
@@ -2948,6 +2956,12 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
             .dyn_cast_or_null<LLVM::LLVMType>();
 
     auto viewMemRefType = subViewOp.getType();
+    auto inferredType = SubViewOp::inferResultType(
+                            subViewOp.getSourceType(),
+                            extractFromI64ArrayAttr(subViewOp.static_offsets()),
+                            extractFromI64ArrayAttr(subViewOp.static_sizes()),
+                            extractFromI64ArrayAttr(subViewOp.static_strides()))
+                            .cast<MemRefType>();
     auto targetElementTy =
         typeConverter.convertType(viewMemRefType.getElementType())
             .dyn_cast<LLVM::LLVMType>();
@@ -2959,7 +2973,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
     // Extract the offset and strides from the type.
     int64_t offset;
     SmallVector<int64_t, 4> strides;
-    auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
+    auto successStrides = getStridesAndOffset(inferredType, strides, offset);
     if (failed(successStrides))
       return failure();
 
@@ -2983,10 +2997,17 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
         extracted);
     targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
 
+    auto shape = viewMemRefType.getShape();
+    auto inferredShape = inferredType.getShape();
+    size_t inferredShapeRank = inferredShape.size();
+    size_t resultShapeRank = shape.size();
+    SmallVector<bool, 4> mask =
+        computeRankReductionMask(inferredShape, shape).getValue();
+
     // Extract strides needed to compute offset.
     SmallVector<Value, 4> strideValues;
-    strideValues.reserve(viewMemRefType.getRank());
-    for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i)
+    strideValues.reserve(inferredShapeRank);
+    for (unsigned i = 0; i < inferredShapeRank; ++i)
       strideValues.push_back(sourceMemRef.stride(rewriter, loc, i));
 
     // Offset.
@@ -2995,7 +3016,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
       targetMemRef.setConstantOffset(rewriter, loc, offset);
     } else {
       Value baseOffset = sourceMemRef.offset(rewriter, loc);
-      for (unsigned i = 0, e = viewMemRefType.getRank(); i < e; ++i) {
+      for (unsigned i = 0; i < inferredShapeRank; ++i) {
         Value offset =
             subViewOp.isDynamicOffset(i)
                 ? operands[subViewOp.getIndexOfDynamicOffset(i)]
@@ -3009,14 +3030,18 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
     }
 
     // Update sizes and strides.
-    for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
+    for (int i = inferredShapeRank - 1, j = resultShapeRank - 1;
+         i >= 0 && j >= 0; --i) {
+      if (!mask[i])
+        continue;
+
       Value size =
           subViewOp.isDynamicSize(i)
               ? operands[subViewOp.getIndexOfDynamicSize(i)]
               : rewriter.create<LLVM::ConstantOp>(
                     loc, llvmIndexType,
                     rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i)));
-      targetMemRef.setSize(rewriter, loc, i, size);
+      targetMemRef.setSize(rewriter, loc, j, size);
       Value stride;
       if (!ShapedType::isDynamicStrideOrOffset(strides[i])) {
         stride = rewriter.create<LLVM::ConstantOp>(
@@ -3030,7 +3055,8 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
                       rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i)));
         stride = rewriter.create<LLVM::MulOp>(loc, stride, strideValues[i]);
       }
-      targetMemRef.setStride(rewriter, loc, i, stride);
+      targetMemRef.setStride(rewriter, loc, j, stride);
+      j--;
     }
 
     rewriter.replaceOp(op, {targetMemRef});

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index f445a0cce242..82058fdcc03c 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2823,6 +2823,30 @@ static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
       }));
 }
 
+llvm::Optional<SmallVector<bool, 4>>
+mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
+                               ArrayRef<int64_t> reducedShape) {
+  size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
+  SmallVector<bool, 4> mask(originalRank);
+  unsigned reducedIdx = 0;
+  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
+    // Skip matching dims greedily.
+    mask[originalIdx] =
+        (reducedIdx < reducedRank) &&
+        (originalShape[originalIdx] == reducedShape[reducedIdx]);
+    if (mask[originalIdx])
+      reducedIdx++;
+    // 1 is the only non-matching allowed.
+    else if (originalShape[originalIdx] != 1)
+      return {};
+  }
+
+  if (reducedIdx != reducedRank)
+    return {};
+
+  return mask;
+}
+
 enum SubViewVerificationResult {
   Success,
   RankTooLarge,
@@ -2859,20 +2883,10 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
   if (reducedRank > originalRank)
     return SubViewVerificationResult::RankTooLarge;
 
-  unsigned reducedIdx = 0;
-  SmallVector<bool, 4> keepMask(originalRank);
-  for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
-    // -2 is never used as a dim size so it will never match.
-    int reducedVal = reducedIdx < reducedRank ? reducedShape[reducedIdx] : -2;
-    // Skip matching dims greedily.
-    if ((keepMask[originalIdx] = originalShape[originalIdx] == reducedVal))
-      reducedIdx++;
-    // 1 is the only non-matching allowed.
-    else if (originalShape[originalIdx] != 1)
-      return SubViewVerificationResult::SizeMismatch;
-  }
-  // Must match the reduced rank.
-  if (reducedIdx != reducedRank)
+  auto optionalMask = computeRankReductionMask(originalShape, reducedShape);
+
+  // Sizes cannot be matched in case empty vector is returned.
+  if (!optionalMask.hasValue())
     return SubViewVerificationResult::SizeMismatch;
 
   // We are done for the tensor case.
@@ -2885,12 +2899,13 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
   MLIRContext *c = original.getContext();
   int64_t originalOffset, reducedOffset;
   SmallVector<int64_t, 4> originalStrides, reducedStrides, keepStrides;
+  SmallVector<bool, 4> keepMask = optionalMask.getValue();
   getStridesAndOffset(original, originalStrides, originalOffset);
   getStridesAndOffset(reduced, reducedStrides, reducedOffset);
 
   // Filter strides based on the mask and check that they are the same
   // as reduced ones.
-  reducedIdx = 0;
+  unsigned reducedIdx = 0;
   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
     if (keepMask[originalIdx]) {
       if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])


        


More information about the Mlir-commits mailing list