[Mlir-commits] [mlir] 31a346c - [MLIR][Vector] Add support for TupleGetOp folding through InsertSlicesOp and ExtractSlicesOp.

Andy Davis llvmlistbot at llvm.org
Tue Mar 31 08:39:24 PDT 2020


Author: Andy Davis
Date: 2020-03-31T08:39:17-07:00
New Revision: 31a346cc35c829e207fbf24d8f8f1ee22ca9f140

URL: https://github.com/llvm/llvm-project/commit/31a346cc35c829e207fbf24d8f8f1ee22ca9f140
DIFF: https://github.com/llvm/llvm-project/commit/31a346cc35c829e207fbf24d8f8f1ee22ca9f140.diff

LOG: [MLIR][Vector] Add support for TupleGetOp folding through InsertSlicesOp and ExtractSlicesOp.

Summary:
Add support for TupleGetOp folding through InsertSlicesOp and ExtractSlicesOp.
Vector-to-vector transformations for unrolling and lowering to hardware vectors
can generate chains of structured vector operations (InsertSlicesOp,
ExtractSlicesOp and ShapeCastOp) between the producer of a hardware vector
value and its consumer. Because InsertSlicesOp, ExtractSlicesOp and ShapeCastOp
are structured, we can track the location (tuple index and vector offsets) of
the consumer vector value through the chain of structured operations to the
producer, enabling a much more powerful producer-consumer fowarding of values
through structured ops and tuple, which in turn enables a more powerful
TupleGetOp folding transformation.

Reviewers: nicolasvasilache, aartbik

Reviewed By: aartbik

Subscribers: grosul1, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76889

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorUtils.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/test/Dialect/Vector/vector-transforms.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 4bc03e4943fd..35527542d639 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -31,6 +31,9 @@ class VectorType;
 SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
                                        ArrayRef<int64_t> sizes);
 
+/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
+int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
+
 /// Given the slice strides together with a linear index in the dimension
 /// space, returns the vector-space offsets in each dimension for a
 /// de-linearized index.

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 2716d62a4ff1..dbb0bf443700 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -69,15 +69,6 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
   return res;
 }
 
-/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
-static int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
-  assert(offsets.size() == basis.size());
-  int64_t linearIndex = 0;
-  for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
-    linearIndex += offsets[idx] * basis[idx];
-  return linearIndex;
-}
-
 // Clones `op` into a new operations that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
@@ -683,6 +674,99 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
   }
 };
 
+/// Returns the producer Value of the same type as 'consumerValue', by tracking
+/// the tuple index and offsets of the consumer vector value through the
+/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp)
+/// from consumer to producer. Each operation in the chain is structured, and
+/// so the tuple index and offsets can be mapped from result to input, while
+/// visiting each operation in the chain.
+/// Returns nullptr on failure.
+static Value getProducerValue(Value consumerValue) {
+  auto consumerVectorType = consumerValue.getType().cast<VectorType>();
+  // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type.
+  int64_t tupleIndex = -1;
+  SmallVector<int64_t, 4> offsets(consumerVectorType.getRank(), 0);
+  auto *op = consumerValue.getDefiningOp();
+  while (op != nullptr) {
+    if (auto tupleGetOp = dyn_cast<vector::TupleGetOp>(op)) {
+      assert(tupleIndex == -1 && "TupleGetOp must have vector result type");
+
+      // Update 'tupleIndex' and next defining 'op' to visit.
+      tupleIndex = tupleGetOp.getIndex();
+      op = tupleGetOp.vectors().getDefiningOp();
+    } else if (auto extractSlicesOp = dyn_cast<vector::ExtractSlicesOp>(op)) {
+      assert(tupleIndex >= 0);
+
+      // Compute slice strides for 'extractSlicesOp'.
+      SmallVector<int64_t, 4> sizes;
+      extractSlicesOp.getSizes(sizes);
+      auto sliceStrides = computeStrides(
+          extractSlicesOp.getSourceVectorType().getShape(), sizes);
+
+      // Compute 'elementOffsets' into 'extractSlicesOp' input vector type,
+      // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'.
+      auto vectorOffsets = delinearize(sliceStrides, tupleIndex);
+      auto elementOffsets =
+          computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
+
+      // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative
+      // to the 'extractSlicesOp' input vector type.
+      assert(offsets.size() == elementOffsets.size());
+      for (unsigned i = 0, e = offsets.size(); i < e; ++i)
+        offsets[i] += elementOffsets[i];
+
+      // Clear 'tupleIndex' and update next defining 'op' to visit.
+      tupleIndex = -1;
+      op = extractSlicesOp.vector().getDefiningOp();
+    } else if (auto insertSlicesOp = dyn_cast<vector::InsertSlicesOp>(op)) {
+      assert(tupleIndex == -1);
+
+      // Compute slice strides for 'insertSlicesOp'.
+      SmallVector<int64_t, 4> sizes;
+      insertSlicesOp.getSizes(sizes);
+      auto sliceStrides = computeStrides(
+          insertSlicesOp.getResultVectorType().getShape(), sizes);
+
+      // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice,
+      // of 'insertSlicesOp' result vector type at 'offsets'.
+      SmallVector<int64_t, 4> vectorOffsets(offsets.size());
+      assert(offsets.size() == sizes.size());
+      for (unsigned i = 0, e = offsets.size(); i < e; ++i)
+        vectorOffsets[i] = offsets[i] / sizes[i];
+
+      // Compute the source tuple element index.
+      tupleIndex = linearize(vectorOffsets, sliceStrides);
+
+      // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now
+      // relative to input tuple element vector type at 'tupleIndex'.
+      auto elementOffsets =
+          computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
+      assert(offsets.size() == elementOffsets.size());
+      for (unsigned i = 0, e = offsets.size(); i < e; ++i) {
+        offsets[i] -= elementOffsets[i];
+        assert(offsets[i] >= 0);
+      }
+
+      // Update next defining 'op' to visit.
+      op = insertSlicesOp.vectors().getDefiningOp();
+    } else if (auto tupleOp = dyn_cast<vector::TupleOp>(op)) {
+      assert(tupleIndex >= 0);
+
+      // Return tuple element 'value' at 'tupleIndex' if it matches type.
+      auto value = tupleOp.getOperand(tupleIndex);
+      if (value.getType() == consumerVectorType)
+        return value;
+
+      // Update 'tupleIndex' and next defining 'op' to visit.
+      tupleIndex = -1;
+      op = value.getDefiningOp();
+    } else {
+      break;
+    }
+  }
+  return nullptr;
+}
+
 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
 //
 // Example:
@@ -740,28 +824,11 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
 
   LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
                                 PatternRewriter &rewriter) const override {
-    // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
-    auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
-        tupleGetOp.vectors().getDefiningOp());
-    if (!extractSlicesOp)
-      return failure();
-
-    // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
-    auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
-        extractSlicesOp.vector().getDefiningOp());
-    if (!insertSlicesOp)
-      return failure();
-
-    // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
-    auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
-        insertSlicesOp.vectors().getDefiningOp());
-    if (!tupleOp)
-      return failure();
-
-    // Forward Value from 'tupleOp' at 'tupleGetOp.index'.
-    Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
-    rewriter.replaceOp(tupleGetOp, tupleValue);
-    return success();
+    if (auto producer = getProducerValue(tupleGetOp.getResult())) {
+      rewriter.replaceOp(tupleGetOp, producer);
+      return success();
+    }
+    return failure();
   }
 };
 

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index f929dddd6d8d..9038b7ad6617 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -28,10 +28,10 @@
 
 using llvm::SetVector;
 
-namespace mlir {
+using namespace mlir;
 
-SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
-                                       ArrayRef<int64_t> sizes) {
+SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
+                                             ArrayRef<int64_t> sizes) {
   int64_t rank = shape.size();
   // Compute the count for each dimension.
   SmallVector<int64_t, 4> sliceDimCounts(rank);
@@ -45,8 +45,16 @@ SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
   return sliceStrides;
 }
 
-SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> sliceStrides,
-                                    int64_t index) {
+int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
+  assert(offsets.size() == basis.size());
+  int64_t linearIndex = 0;
+  for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
+    linearIndex += offsets[idx] * basis[idx];
+  return linearIndex;
+}
+
+SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
+                                          int64_t index) {
   int64_t rank = sliceStrides.size();
   SmallVector<int64_t, 4> vectorOffsets(rank);
   for (int64_t r = 0; r < rank; ++r) {
@@ -57,16 +65,15 @@ SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> sliceStrides,
   return vectorOffsets;
 }
 
-SmallVector<int64_t, 4>
-computeElementOffsetsFromVectorSliceOffsets(ArrayRef<int64_t> sizes,
-                                            ArrayRef<int64_t> vectorOffsets) {
+SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
+    ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
   return functional::zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
                             vectorOffsets, sizes);
 }
 
-SmallVector<int64_t, 4> computeSliceSizes(ArrayRef<int64_t> shape,
-                                          ArrayRef<int64_t> sizes,
-                                          ArrayRef<int64_t> elementOffsets) {
+SmallVector<int64_t, 4>
+mlir::computeSliceSizes(ArrayRef<int64_t> shape, ArrayRef<int64_t> sizes,
+                        ArrayRef<int64_t> elementOffsets) {
   int64_t rank = shape.size();
   SmallVector<int64_t, 4> sliceSizes(rank);
   for (unsigned r = 0; r < rank; ++r)
@@ -74,8 +81,8 @@ SmallVector<int64_t, 4> computeSliceSizes(ArrayRef<int64_t> shape,
   return sliceSizes;
 }
 
-Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
-                                             ArrayRef<int64_t> subShape) {
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
+                                                   ArrayRef<int64_t> subShape) {
   if (superShape.size() < subShape.size()) {
     return Optional<SmallVector<int64_t, 4>>();
   }
@@ -114,8 +121,8 @@ Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
   return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
 }
 
-Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
-                                             VectorType subVectorType) {
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
+                                                   VectorType subVectorType) {
   assert(superVectorType.getElementType() == subVectorType.getElementType() &&
          "vector types must be of the same elemental type");
   return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
@@ -201,9 +208,9 @@ static SetVector<Operation *> getEnclosingforOps(Operation *op) {
   return getParentsOfType<AffineForOp>(op);
 }
 
-AffineMap
-makePermutationMap(Operation *op, ArrayRef<Value> indices,
-                   const DenseMap<Operation *, unsigned> &loopToVectorDim) {
+AffineMap mlir::makePermutationMap(
+    Operation *op, ArrayRef<Value> indices,
+    const DenseMap<Operation *, unsigned> &loopToVectorDim) {
   DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
   auto enclosingLoops = getEnclosingforOps(op);
   for (auto *forInst : enclosingLoops) {
@@ -212,7 +219,7 @@ makePermutationMap(Operation *op, ArrayRef<Value> indices,
       enclosingLoopToVectorDim.insert(*it);
     }
   }
-  return makePermutationMap(indices, enclosingLoopToVectorDim);
+  return ::makePermutationMap(indices, enclosingLoopToVectorDim);
 }
 
 bool matcher::operatesOnSuperVectorsOf(Operation &op,
@@ -275,4 +282,3 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
   return true;
 }
 
-} // namespace mlir

diff  --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 7582758384ce..082afbaff0b0 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -313,6 +313,95 @@ func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
   return %1 : vector<8xf32>
 }
 
+// CHECK-LABEL: func @tuple_get_producer_consumer
+// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
+//      CHECK: return %[[A7]] : vector<2x4xf32>
+
+func @tuple_get_producer_consumer(
+  %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
+  %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
+  %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
+  %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
+  %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
+    : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+      vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
+  // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
+  %1 = vector.insert_slices %0, [2, 4], [1, 1]
+    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+            vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+      into vector<4x16xf32>
+  // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
+  %2 = vector.extract_slices %1, [4, 8], [1, 1]
+    : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
+  %3 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %3 at tupleIndex = -1, offsets = [2, 4]
+  %4 = vector.extract_slices %3, [2, 4], [1, 1]
+    : vector<4x8xf32> into
+      tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+  // %arg7 == %4 at tupleIndex = 3, offsets = [0, 0]
+  %5 = vector.tuple_get %4, 3
+    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+  // %arg7 == %5
+  return %5 : vector<2x4xf32>
+}
+
+// CHECK-LABEL: func @tuple_get_producer_consumer_swizzle
+// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
+//      CHECK: return %[[A7]] : vector<2x4xf32>
+
+func @tuple_get_producer_consumer_swizzle(
+  %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
+  %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
+  %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
+  %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
+  %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
+    : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+      vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
+  // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
+  %1 = vector.insert_slices %0, [2, 4], [1, 1]
+    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+            vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+      into vector<4x16xf32>
+  // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
+  %2 = vector.extract_slices %1, [4, 8], [1, 1]
+    : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
+
+  // Extract tuple elements.
+  %3 = vector.tuple_get %2, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  %4 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %4 at tupleIndex = -1, offsets = [2, 4]
+
+  // Swizzle tuple elements.
+  %5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32>
+  // %arg7 == %5 at tupleIndex = 0, offsets = [2, 4]
+  %6 = vector.tuple_get %5, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %6 at tupleIndex = -1, offsets = [2, 4]
+  %7 = vector.extract_slices %6, [2, 4], [1, 1]
+    : vector<4x8xf32> into
+      tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+  // %arg7 == %7 at tupleIndex = 3, offsets = [0, 0]
+  %8 = vector.tuple_get %7, 3
+    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+  // %arg7 == %8
+  return %8 : vector<2x4xf32>
+}
+
 // CHECK-LABEL: func @vector_transfers_vector_element_type
 //      CHECK: %[[C0:.*]] = constant 0 : index
 //      CHECK: %[[C1:.*]] = constant 1 : index


        


More information about the Mlir-commits mailing list