[Mlir-commits] [mlir] b971515 - [mlir][linalg] lower index operations during linalg to vector lowering.

Tobias Gysi llvmlistbot at llvm.org
Tue Apr 20 04:56:54 PDT 2021


Author: Tobias Gysi
Date: 2021-04-20T11:55:44Z
New Revision: b9715156ff909fb38725893afb1d18709cb7f1bd

URL: https://github.com/llvm/llvm-project/commit/b9715156ff909fb38725893afb1d18709cb7f1bd
DIFF: https://github.com/llvm/llvm-project/commit/b9715156ff909fb38725893afb1d18709cb7f1bd.diff

LOG: [mlir][linalg] lower index operations during linalg to vector lowering.

The patch extends the vectorization pass to lower linalg index operations to vector code. It allocates constant 1d vectors that enumerate the indexes along the iteration dimensions and broadcasts/transposes these 1d vectors to the iteration space.

Differential Revision: https://reviews.llvm.org/D100373

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/IR/Builders.h
    mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/lib/IR/Builders.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 6c3e86c6d2f14..0512b351650e9 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1242,11 +1242,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     /// appear in the operands.
     SmallVector<Value, 4> createFlatListOfOperandDims(OpBuilder &, Location);
 
+    /// Return the flat list of all operands' static dimension sizes in the
+    /// order they appear in the operands. All operand dimension sizes have to
+    /// be statically known.
+    SmallVector<int64_t, 4> createFlatListOfOperandStaticDims();
+
     /// Create the loop ranges to materialize the computation over the current
     /// operands. This is done by applying `getShapesToLoopsMap` to
     /// `createFlatListOfOperandDims`.
     SmallVector<Range, 4> createLoopRanges(OpBuilder &b, Location loc);
 
+    /// Compute the static loop sizes necessary to vectorize the computation.
+    /// This is done by applying `getShapesToLoopsMap` to
+    /// `createFlatListOfOperandStaticDims`.
+    SmallVector<int64_t, 4> computeStaticLoopSizes();
+
     /// Returns all the operands past the inputs, output_buffers and
     /// init_tensors operands. Asserts that these operands are value types to
     /// allow transformations like tiling to just use the values when cloning

diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index f8b119cf962a0..1e0863c7a7a42 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -124,6 +124,7 @@ class Builder {
   DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values);
   DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
   DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values);
+  DenseIntElementsAttr getIndexVectorAttr(ArrayRef<int64_t> values);
 
   /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty.
   /// These are generally preferable for representing general lists of integers

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f1bf22cf3d698..1c45467afbb4d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -193,6 +193,16 @@ SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
   return res;
 }
 
+SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
+  SmallVector<int64_t, 4> res;
+  for (Value v : getShapedOperands()) {
+    ShapedType t = v.getType().template cast<ShapedType>();
+    assert(t.hasStaticShape() && "expected operands to have static shapes");
+    llvm::append_range(res, t.getShape());
+  }
+  return res;
+}
+
 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
   AffineMap map = getLoopsToShapesMap();
   unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
@@ -211,6 +221,19 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
   return res;
 }
 
+SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
+  AffineMap map = getLoopsToShapesMap();
+  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
+  SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
+  SmallVector<int64_t, 4> res(numDims, 0);
+  for (unsigned idx = 0; idx < numRes; ++idx) {
+    auto result = map.getResult(idx);
+    if (auto d = result.dyn_cast<AffineDimExpr>())
+      res[d.getPosition()] = allShapeSizes[idx];
+  }
+  return res;
+}
+
 /// Visitor to check if any of the given set of positions from AffineDimExprs
 /// are used within an AffineExpr.
 struct HasAffineDimExprVisitor

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 55402a737cbb1..2e8b1580c5c73 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -462,8 +462,7 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  // TODO: remove hasIndexSemantics check once index ops are supported.
-  if (!linalgOp || linalgOp.hasIndexSemantics())
+  if (!linalgOp)
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c14a3b3628ba2..14ef418ed5911 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -166,6 +166,42 @@ vectorizeLinalgYield(OpBuilder &builder, Operation *op,
   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
 }
 
+/// Helper function to vectorize the index operations of a `linalgOp`. Return
+/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
+/// should map the produced operations. This function is meant to be used as a
+/// CustomVectorizationHook.
+static VectorizationResult
+vectorizeLinalgIndex(OpBuilder &builder, Operation *op, LinalgOp linalgOp) {
+  IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
+  if (!indexOp)
+    return VectorizationResult{VectorizationStatus::Failure, nullptr};
+  auto loc = indexOp.getLoc();
+  // Compute the static loop sizes of the index op.
+  auto targetShape = linalgOp.computeStaticLoopSizes();
+  // Compute a one-dimensional index vector for the index op dimension.
+  SmallVector<int64_t> constantSeq(
+      llvm::seq<int64_t>(0, targetShape[indexOp.dim()]));
+  ConstantOp constantOp =
+      builder.create<ConstantOp>(loc, builder.getIndexVectorAttr(constantSeq));
+  // Return the one-dimensional index vector if it lives in the trailing
+  // dimension of the iteration space since the vectorization algorithm in this
+  // case can handle the broadcast.
+  if (indexOp.dim() == targetShape.size() - 1)
+    return VectorizationResult{VectorizationStatus::NewOp, constantOp};
+  // Otherwise permute the targetShape to move the index dimension last,
+  // broadcast the one-dimensional index vector to the permuted shape, and
+  // finally transpose the broadcasted index vector to undo the permutation.
+  std::swap(targetShape[indexOp.dim()], targetShape.back());
+  auto broadCastOp = builder.create<vector::BroadcastOp>(
+      loc, VectorType::get(targetShape, builder.getIndexType()), constantOp);
+  SmallVector<int64_t> transposition(
+      llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
+  std::swap(transposition.back(), transposition[indexOp.dim()]);
+  auto transposeOp =
+      builder.create<vector::TransposeOp>(loc, broadCastOp, transposition);
+  return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
+}
+
 /// Generic vectorization for a single operation `op`, given already vectorized
 /// operands carried by `bvm`. Vectorization occurs as follows:
 ///   1. Try to apply any of the `customVectorizationHooks` and return its
@@ -245,7 +281,7 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
   if (!llvm::hasSingleElement(r))
     return false;
   for (Operation &op : r.front()) {
-    if (!(isa<ConstantOp, linalg::YieldOp>(op) ||
+    if (!(isa<ConstantOp, linalg::YieldOp, linalg::IndexOp>(op) ||
           OpTrait::hasElementwiseMappableTraits(&op)) ||
         llvm::any_of(op.getResultTypes(),
                      [](Type type) { return !type.isIntOrIndexOrFloat(); }))
@@ -293,7 +329,9 @@ static AffineMap getTransferReadMap(LinalgOp linalgOp, unsigned argIndex) {
 ///   3. Each region argument is vectorized into a vector.transfer_read (or 0-d
 ///   load).
 ///   TODO: Reuse opportunities for RAR dependencies.
-///   4. Register CustomVectorizationHook for YieldOp to capture the results.
+///   4a. Register CustomVectorizationHook for YieldOp to capture the results.
+///   4b. Register CustomVectorizationHook for IndexOp to access the iteration
+///   indices.
 ///   5. Iteratively call vectorizeOneOp on the region operations.
 LogicalResult vectorizeAsLinalgGeneric(
     OpBuilder &builder, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults,
@@ -333,16 +371,23 @@ LogicalResult vectorizeAsLinalgGeneric(
     bvm.map(vectorArg, vectorRead);
   }
 
-  // 4. Register CustomVectorizationHook for yieldOp.
+  auto hooks = llvm::to_vector<4>(customVectorizationHooks);
+  // 4a. Register CustomVectorizationHook for yieldOp.
   CustomVectorizationHook vectorizeYield =
       [&](Operation *op,
           const BlockAndValueMapping &bvm) -> VectorizationResult {
     return vectorizeLinalgYield(builder, op, bvm, linalgOp, newResults);
   };
-  // Append the vectorizeYield hook.
-  auto hooks = llvm::to_vector<4>(customVectorizationHooks);
   hooks.push_back(vectorizeYield);
 
+  // 4b. Register CustomVectorizationHook for indexOp.
+  CustomVectorizationHook vectorizeIndex =
+      [&](Operation *op,
+          const BlockAndValueMapping &bvm) -> VectorizationResult {
+    return vectorizeLinalgIndex(builder, op, linalgOp);
+  };
+  hooks.push_back(vectorizeIndex);
+
   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
   for (Operation &op : block.getOperations()) {
     VectorizationResult result = vectorizeOneOp(builder, &op, bvm, hooks);
@@ -401,9 +446,6 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
       return failure();
-  // TODO: remove once index ops are supported.
-  if (linalgOp.hasIndexSemantics())
-    return failure();
   if (isElementwise(op))
     return success();
   return success(isaContractionOpInterface(linalgOp));

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d1ab3795d00ec..4f8aa9e820757 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -120,6 +120,12 @@ DenseIntElementsAttr Builder::getI64VectorAttr(ArrayRef<int64_t> values) {
       values);
 }
 
+DenseIntElementsAttr Builder::getIndexVectorAttr(ArrayRef<int64_t> values) {
+  return DenseIntElementsAttr::get(
+      VectorType::get(static_cast<int64_t>(values.size()), getIndexType()),
+      values);
+}
+
 DenseIntElementsAttr Builder::getI32TensorAttr(ArrayRef<int32_t> values) {
   return DenseIntElementsAttr::get(
       RankedTensorType::get(static_cast<int64_t>(values.size()),

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index faaadcf94a5c0..c18bf5b5cd8b9 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -174,6 +174,49 @@ func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
 
 // -----
 
+// CHECK-LABEL: func @test_vectorize_trailing_index
+  //  CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>)
+func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) {
+  //   CHECK-DAG:   %[[CST0:.*]] = constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+  //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  outs(%arg0: memref<1x2x4x8xindex>) {
+  ^bb0(%arg1: index):
+  //       CHECK:   %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<8xindex> to vector<1x2x4x8xindex>
+  //       CHECK:   vector.transfer_write %[[BCST]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex>
+    %0 = linalg.index 3 : index
+    linalg.yield %0 : index
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @test_vectorize_inner_index
+  //  CHECK-SAME: (%[[ARG0:.*]]: memref<1x2x4x8xindex>)
+func @test_vectorize_inner_index(%arg0: memref<1x2x4x8xindex>) {
+  //   CHECK-DAG:   %[[CST0:.*]] = constant dense<[0, 1]> : vector<2xindex>
+  //   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+  linalg.generic {
+    indexing_maps = [
+      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+    iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  outs(%arg0: memref<1x2x4x8xindex>) {
+  ^bb0(%arg1: index):
+  //       CHECK:   %[[BCST:.*]] = vector.broadcast %[[CST0]] : vector<2xindex> to vector<1x8x4x2xindex>
+  //       CHECK:   %[[TRAN:.*]] = vector.transpose %[[BCST]], [0, 3, 2, 1] : vector<1x8x4x2xindex> to vector<1x2x4x8xindex>
+  //       CHECK:   vector.transfer_write %[[TRAN]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {{.*}} : vector<1x2x4x8xindex>, memref<1x2x4x8xindex>
+    %0 = linalg.index 1 : index
+    linalg.yield %0 : index
+  }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @generic_vectorize
   //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x256xf32>, %[[ARG1:.*]]: memref<4x256xf32>,
   //  CHECK-SAME:  %[[ARG2:.*]]: memref<256xf32>, %[[ARG3:.*]]: f32)
@@ -252,7 +295,6 @@ func @generic_vectorize(%arg0: memref<4x256xf32>,
   return
 }
 
-
 // -----
 
 // CHECK-LABEL: func @generic_vectorize_tensor
@@ -469,19 +511,3 @@ func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
     } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
   return %0 : tensor<6x?x?x?xf32>
 }
-
-// -----
-
-// CHECK-LABEL: @index_op
-//       CHECK:   linalg.generic
-func @index_op(%arg0: memref<4x8xindex>) {
-  linalg.generic {
-    indexing_maps = [affine_map<(i, j) -> (i, j)>],
-    iterator_types = ["parallel", "parallel"]}
-  outs(%arg0 : memref<4x8xindex>) {
-  ^bb0(%arg1: index):   // no predecessors
-    %0 = linalg.index 1 : index
-    linalg.yield %0 : index
-  }
-  return
-}


        


More information about the Mlir-commits mailing list