[Mlir-commits] [mlir] f0fe380 - [mlir][vector] Add fold pattern to constant-fold InsertStridedSliceOp

Jakub Kuderski llvmlistbot at llvm.org
Mon Nov 28 11:28:49 PST 2022


Author: Jakub Kuderski
Date: 2022-11-28T14:25:28-05:00
New Revision: f0fe38035c92bae9ccc827097ab55b37e2427a29

URL: https://github.com/llvm/llvm-project/commit/f0fe38035c92bae9ccc827097ab55b37e2427a29
DIFF: https://github.com/llvm/llvm-project/commit/f0fe38035c92bae9ccc827097ab55b37e2427a29.diff

LOG: [mlir][vector] Add fold pattern to constant-fold InsertStridedSliceOp

Fold InsertStridedOp(ConstantOp into ConstantOp) -> ConstantOp.

This pattern comes with vector size threshold to make sure we do not
introduce too many large constants.

This help clean up code created by the Wide Integer Emulation pass.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D138739

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9e1b6307b75f2..8802930b46767 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -31,6 +31,7 @@
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
@@ -38,6 +39,7 @@
 #include "llvm/ADT/bit.h"
 
 #include <cassert>
+#include <cstdint>
 #include <numeric>
 
 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
@@ -212,6 +214,28 @@ bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
   return isDisjointTransferIndices(transferA, transferB);
 }
 
+// Helper to iterate over n-D vector slice elements. Calculate the next
+// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
+// Modifies the `position` in place. Returns a failure when `position` becomes
+// the end position.
+static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
+                                      ArrayRef<int64_t> shape,
+                                      ArrayRef<int64_t> offsets) {
+  assert(position.size() == shape.size());
+  assert(position.size() == offsets.size());
+  for (auto [posInDim, dimSize, offsetInDim] :
+       llvm::reverse(llvm::zip(position, shape, offsets))) {
+    ++posInDim;
+    if (posInDim < dimSize + offsetInDim)
+      return success();
+
+    // Carry the overflow to the next loop iteration.
+    posInDim = offsetInDim;
+  }
+
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // CombiningKindAttr
 //===----------------------------------------------------------------------===//
@@ -2392,12 +2416,88 @@ class FoldInsertStridedSliceOfExtract final
   }
 };
 
+// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
+// ConstantOp.
+class InsertStridedSliceConstantFolder final
+    : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  // Do not create constants with more than `vectorSizeFoldThreashold` elements,
+  // unless the source vector constant has a single use.
+  static constexpr int64_t vectorSizeFoldThreshold = 256;
+
+  LogicalResult matchAndRewrite(InsertStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    // Return if 'InsertOp' operand is not defined by a compatible vector
+    // ConstantOp.
+    TypedValue<VectorType> destVector = op.getDest();
+    Attribute vectorDestCst;
+    if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
+      return failure();
+
+    VectorType destTy = destVector.getType();
+    if (destTy.isScalable())
+      return failure();
+
+    // Make sure we do not create too many large constants.
+    if (destTy.getNumElements() > vectorSizeFoldThreshold &&
+        !destVector.hasOneUse())
+      return failure();
+
+    auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
+
+    TypedValue<VectorType> sourceValue = op.getSource();
+    Attribute sourceCst;
+    if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
+      return failure();
+
+    // TODO: Handle non-unit strides when they become available.
+    if (op.hasNonUnitStrides())
+      return failure();
+
+    VectorType sliceVecTy = sourceValue.getType();
+    ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
+    int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
+    SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
+    SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
+
+    // Calcualte the destination element indices by enumerating all slice
+    // positions within the destination and linearizing them. The enumeration
+    // order is lexicographic which yields a sequence of monotonically
+    // increasing linearized position indices.
+    // Because the destination may have higher dimensionality then the slice,
+    // we keep track of two overlapping sets of positions and offsets.
+    auto denseSlice = sourceCst.cast<DenseElementsAttr>();
+    auto sliceValuesIt = denseSlice.value_begin<Attribute>();
+    auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
+    SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
+    MutableArrayRef<int64_t> currSlicePosition(
+        currDestPosition.begin() + rankDifference, currDestPosition.end());
+    ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
+                                   offsets.end());
+    do {
+      int64_t linearizedPosition = linearize(currDestPosition, destStrides);
+      assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
+      assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
+             "Invalid slice element");
+      newValues[linearizedPosition] = *sliceValuesIt;
+      ++sliceValuesIt;
+    } while (succeeded(
+        incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
+
+    auto newAttr = DenseElementsAttr::get(destTy, newValues);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract>(
-      context);
+  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
+              InsertStridedSliceConstantFolder>(context);
 }
 
 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
@@ -2855,7 +2955,8 @@ class StridedSliceNonSplatConstantFolder final
       assert(linearizedPosition < sourceVecTy.getNumElements() &&
              "Invalid index");
       sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
-    } while (succeeded(incPosition(currSlicePosition, sliceShape, offsets)));
+    } while (
+        succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
 
     assert(static_cast<int64_t>(sliceValues.size()) ==
                sliceVecTy.getNumElements() &&
@@ -2865,28 +2966,6 @@ class StridedSliceNonSplatConstantFolder final
                                                    newAttr);
     return success();
   }
-
-private:
-  // Calculate the next `position` in the n-D vector of size `shape`,
-  // applying an offset `offsets`. Modifies the `position` in place.
-  // Returns a failure when `position` becomes the end position.
-  static LogicalResult incPosition(MutableArrayRef<int64_t> position,
-                                   ArrayRef<int64_t> shape,
-                                   ArrayRef<int64_t> offsets) {
-    assert(position.size() == shape.size());
-    assert(position.size() == offsets.size());
-    for (auto [posInDim, dimSize, offsetInDim] :
-         llvm::reverse(llvm::zip(position, shape, offsets))) {
-      ++posInDim;
-      if (posInDim < dimSize + offsetInDim)
-        return success();
-
-      // Carry the overflow to the next loop iteration.
-      posInDim = offsetInDim;
-    }
-
-    return failure();
-  }
 };
 
 // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 872767c7f8440..2addc78f96dbc 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1983,6 +1983,57 @@ func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf3
 
 // -----
 
+// CHECK-LABEL: func.func @insert_strided_1d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<[4, 1, 2]> : vector<3xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<[0, 1, 4]> : vector<3xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<[5, 6, 2]> : vector<3xi32>
+//   CHECK-DAG: %[[DCST:.*]] = arith.constant dense<[0, 5, 6]> : vector<3xi32>
+//   CHECK-DAG: %[[ECST:.*]] = arith.constant dense<[7, 8, 9]> : vector<3xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]], %[[ECST]]
+func.func @insert_strided_1d_constant() ->
+  (vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>) {
+  %vcst = arith.constant dense<[0, 1, 2]> : vector<3xi32>
+  %cst_1 = arith.constant dense<4> : vector<1xi32>
+  %cst_2 = arith.constant dense<[5, 6]> : vector<2xi32>
+  %cst_3 = arith.constant dense<[7, 8, 9]> : vector<3xi32>
+  %a = vector.insert_strided_slice %cst_1, %vcst {offsets = [0], strides = [1]} : vector<1xi32> into vector<3xi32>
+  %b = vector.insert_strided_slice %cst_1, %vcst {offsets = [2], strides = [1]} : vector<1xi32> into vector<3xi32>
+  %c = vector.insert_strided_slice %cst_2, %vcst {offsets = [0], strides = [1]} : vector<2xi32> into vector<3xi32>
+  %d = vector.insert_strided_slice %cst_2, %vcst {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
+  %e = vector.insert_strided_slice %cst_3, %vcst {offsets = [0], strides = [1]} : vector<3xi32> into vector<3xi32>
+  return %a, %b, %c, %d, %e : vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>, vector<3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @insert_strided_2d_constant
+//   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<{{\[\[0, 1\], \[9, 3\], \[4, 5\]\]}}> : vector<3x2xi32>
+//   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\], \[4, 9\]\]}}> : vector<3x2xi32>
+//   CHECK-DAG: %[[CCST:.*]] = arith.constant dense<{{\[\[18, 19\], \[2, 3\], \[4, 5\]\]}}> : vector<3x2xi32>
+//   CHECK-DAG: %[[DCST:.*]] = arith.constant dense<{{\[\[0, 1\], \[18, 19\], \[4, 5\]\]}}> : vector<3x2xi32>
+//   CHECK-DAG: %[[ECST:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\], \[18, 19\]\]}}> : vector<3x2xi32>
+//   CHECK-DAG: %[[FCST:.*]] = arith.constant dense<{{\[\[28, 29\], \[38, 39\], \[4, 5\]\]}}> : vector<3x2xi32>
+//   CHECK-DAG: %[[GCST:.*]] = arith.constant dense<{{\[\[0, 1\], \[28, 29\], \[38, 39\]\]}}> : vector<3x2xi32>
+//  CHECK-NEXT: return %[[ACST]], %[[BCST]], %[[CCST]], %[[DCST]], %[[ECST]], %[[FCST]], %[[GCST]]
+func.func @insert_strided_2d_constant() ->
+  (vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>) {
+  %vcst = arith.constant dense<[[0, 1], [2, 3], [4, 5]]> : vector<3x2xi32>
+  %cst_1 = arith.constant dense<9> : vector<1xi32>
+  %cst_2 = arith.constant dense<[18, 19]> : vector<2xi32>
+  %cst_3 = arith.constant dense<[[28, 29], [38, 39]]> : vector<2x2xi32>
+  %a = vector.insert_strided_slice %cst_1, %vcst {offsets = [1, 0], strides = [1]} : vector<1xi32> into vector<3x2xi32>
+  %b = vector.insert_strided_slice %cst_1, %vcst {offsets = [2, 1], strides = [1]} : vector<1xi32> into vector<3x2xi32>
+  %c = vector.insert_strided_slice %cst_2, %vcst {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<3x2xi32>
+  %d = vector.insert_strided_slice %cst_2, %vcst {offsets = [1, 0], strides = [1]} : vector<2xi32> into vector<3x2xi32>
+  %e = vector.insert_strided_slice %cst_2, %vcst {offsets = [2, 0], strides = [1]} : vector<2xi32> into vector<3x2xi32>
+  %f = vector.insert_strided_slice %cst_3, %vcst {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<3x2xi32>
+  %g = vector.insert_strided_slice %cst_3, %vcst {offsets = [1, 0], strides = [1, 1]} : vector<2x2xi32> into vector<3x2xi32>
+  return %a, %b, %c, %d, %e, %f, %g :
+    vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>, vector<3x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @shuffle_splat
 //  CHECK-SAME:   (%[[ARG:.*]]: i32)
 //  CHECK-NEXT:   %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>


        


More information about the Mlir-commits mailing list