[Mlir-commits] [mlir] 93f9922 - [mlir][linalg] adding operation to access the iteration index of enclosing linalg ops.

Tobias Gysi llvmlistbot at llvm.org
Mon Apr 12 06:39:46 PDT 2021


Author: Tobias Gysi
Date: 2021-04-12T13:37:17Z
New Revision: 93f9922d65f8a89b8a9299eeab61511ce8baa3bc

URL: https://github.com/llvm/llvm-project/commit/93f9922d65f8a89b8a9299eeab61511ce8baa3bc
DIFF: https://github.com/llvm/llvm-project/commit/93f9922d65f8a89b8a9299eeab61511ce8baa3bc.diff

LOG: [mlir][linalg] adding operation to access the iteration index of enclosing linalg ops.

The `linalg.index` operation provides access to the iteration indexes of immediately enclosing linalg operations. It takes a dimension `dim` attribute and returns the iteration index in the given dimension. Having `linalg.index` allows us to unify `linalg.generic` and `linalg.indexed_generic` and also enables index access in named operations.

Differential Revision: https://reviews.llvm.org/D100292

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
    mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
    mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
    mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
    mlir/test/Dialect/Linalg/fusion-tensor.mlir
    mlir/test/Dialect/Linalg/invalid.mlir
    mlir/test/Dialect/Linalg/loop-order.mlir
    mlir/test/Dialect/Linalg/roundtrip.mlir
    mlir/test/Dialect/Linalg/tile.mlir
    mlir/test/Dialect/Linalg/vectorization.mlir
    mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 95e008aacc451..6c3e86c6d2f14 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1026,7 +1026,16 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return $_op.getLibraryCallName();
       }]
     >,
-
+    InterfaceMethod<
+      /*desc=*/[{
+         Return whether the op accesses the iteration indices.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasIndexSemantics",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/""
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index b0a93f36ab758..931eada61ac4a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -618,5 +618,51 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
   let hasFolder = 1;
 }
 
+def Linalg_IndexOp : Linalg_Op<"index", [NoSideEffect]>,
+    Arguments<(ins Confined<I64Attr, [IntMinValue<0>]>:$dim)>,
+    Results<(outs Index:$result)> {
+  let summary = "linalg index operation";
+  let description = [{
+    The `linalg.index` operation returns the iteration index of the immediately
+    enclosing linalg structured operation for the iteration dimension `dim`. The
+    `dim` attribute specifies the position of the accessed dimension in the
+    indexing map domain.
+
+    Example:
+
+    ```mlir
+    #map = affine_map<(i, j) -> (i, j)>
+    linalg.generic {indexing_maps = [#map, #map],
+                    iterator_types = ["parallel", "parallel"]}
+      outs(%I, %J : memref<?x?xindex>, memref<?x?xindex>) {
+      ^bb0(%arg0 : index, %arg1 : index):
+      // Access the outer iteration dimension i
+      %i = linalg.index 0 : index
+      // Access the inner iteration dimension j
+      %j = linalg.index 1 : index
+      linalg.yield %i, %j : index, index
+    }
+    ```
+
+    This may lower to IR resembling:
+
+    ```mlir
+    %0 = dim %I, %c0 : memref<?x?xindex>
+    %1 = dim %I, %c1 : memref<?x?xindex>
+    scf.for %i = %c0 to %0 step %c1 {
+      scf.for %j = %c0 to %1 step %c1 {
+        store %i, %I[%i, %j] : memref<?x?xindex>
+        store %j, %J[%i, %j] : memref<?x?xindex>
+      }
+    }
+    ```
+  }];
+  let builders = [
+    OpBuilder<(ins "int64_t":$dim),
+      [{ build($_builder, $_state, $_builder.getIndexType(), dim); }]>
+  ];
+
+  let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
+}
 
 #endif // LINALG_OPS

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index f80d6759172c1..b158cec529d99 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -35,6 +35,14 @@ class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
       return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
     }
 
+    // Return whether the op accesses the iteration indices.
+    bool hasIndexSemantics() {
+      Operation *op = this->getOperation();
+      if(op->getNumRegions() == 0 || op->getRegion(0).empty())
+        return false;
+      return !op->getRegion(0).front().getOps<IndexOp>().empty();
+    }
+
     LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
         SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
       return cast<LinalgOp>(getOperation()).reifyReturnTypeShapesPerResultDim(b,

diff  --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index 5260c503cce92..d385b46848df9 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -106,6 +106,10 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite(
   if (isa<CopyOp>(op) || isa<IndexedGenericOp>(op))
     return failure();
 
+  // TODO: remove once index ops are supported.
+  if (op.hasIndexSemantics())
+    return failure();
+
   auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
   if (!libraryCallName)
     return failure();

diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d8b512cdeea06..e93e13cb7e192 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2047,6 +2047,21 @@ LogicalResult TiledLoopOp::fold(ArrayRef<Attribute>,
   return foldMemRefCast(*this);
 }
 
+//===----------------------------------------------------------------------===//
+// IndexOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(IndexOp op) {
+  auto linalgOp = dyn_cast<LinalgOp>(op->getParentOp());
+  if (!linalgOp)
+    return op.emitOpError("expected parent op with LinalgOp interface");
+  if (linalgOp.getNumLoops() <= op.dim())
+    return op.emitOpError("expected dim (")
+           << op.dim() << ") to be lower than the number of loops ("
+           << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
+  return success();
+}
+
 /////// Operations corresponding to library calls defined with Tablegen ////////
 
 template <typename LinalgPoolingOp>

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index b3af82c80ceea..0059178d30e5f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -177,6 +177,10 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(GenericOpTy op,
                                 PatternRewriter &rewriter) const override {
+    // TODO: remove once index ops are supported.
+    if (op.hasIndexSemantics())
+      return failure();
+
     SmallVector<AffineMap, 4> indexingMaps = op.getIndexingMaps();
     if (indexingMaps.empty())
       return failure();
@@ -321,6 +325,10 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
   LogicalResult matchAndRewrite(GenericOpTy op,
                                 PatternRewriter &rewriter) const override {
+    // TODO: remove once index ops are supported.
+    if (op.hasIndexSemantics())
+      return failure();
+
     if (!op.hasTensorSemantics())
       return failure();
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 34eac4bdfcaaf..713de7b22c4b7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -28,6 +28,10 @@ using namespace mlir::linalg;
 /// Implementation of fusion of generic ops and indexed_generic ops.
 static bool areElementwiseOpsFusable(LinalgOp producer, LinalgOp consumer,
                                      unsigned consumerIdx) {
+  // TODO: remove once index ops are supported.
+  if (producer.hasIndexSemantics() || consumer.hasIndexSemantics())
+    return false;
+
   // Producer and consumer must have tensor semantics.
   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
     return false;

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index a6e296b0ea112..8b2b0cde8f9a5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -527,7 +527,9 @@ class LinalgRewritePattern : public RewritePattern {
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
-    if (!isa<LinalgOp>(op))
+    auto linalgOp = dyn_cast<LinalgOp>(op);
+    // TODO: remove hasIndexSemantics check once index ops are supported.
+    if (!linalgOp || linalgOp.hasIndexSemantics())
       return failure();
     if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector))
       return failure();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e7095a9f0b34e..00d6a3a1b95f4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -246,7 +246,8 @@ mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
+  // TODO: remove hasIndexSemantics check once index ops are supported.
+  if (!linalgOp || linalgOp.hasIndexSemantics())
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
@@ -314,7 +315,8 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
+  // TODO: remove hasIndexSemantics check once index ops are supported.
+  if (!linalgOp || linalgOp.hasIndexSemantics())
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
@@ -407,7 +409,8 @@ mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
+  // TODO: remove hasIndexSemantics check once index ops are supported.
+  if (!linalgOp || linalgOp.hasIndexSemantics())
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
@@ -465,7 +468,8 @@ mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
+  // TODO: remove hasIndexSemantics check once index ops are supported.
+  if (!linalgOp || linalgOp.hasIndexSemantics())
     return failure();
   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 10562d68a9e0a..e2c069ea7f477 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -402,6 +402,9 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
       return failure();
+  // TODO: remove once index ops are supported.
+  if (linalgOp.hasIndexSemantics())
+    return failure();
   if (isElementwise(op))
     return success();
   return success(isaContractionOpInterface(linalgOp));

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5c8866662d2ba..e65a408a0a0e2 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -566,3 +566,19 @@ func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x1xf32
 // CHECK-SAME:     outs(%[[FILL]] : tensor<?xf32>)
 //      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]]
 //      CHECK:   return %[[RESULT_RESHAPE]]
+
+// -----
+
+//  CHECK: #{{.+}} = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: @index_op
+func @index_op(%arg0: memref<1x8xindex>) {
+  linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]}
+  outs(%arg0 : memref<1x8xindex>) {
+  ^bb0(%arg1: index):   // no predecessors
+    %0 = linalg.index 1 : index
+    linalg.yield %0 : index
+  }
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
index 76e9148cdff0c..2cb7ef1274b8d 100644
--- a/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-indexed-generic.mlir
@@ -188,3 +188,53 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,
 // CHECK:          {{.*}} = index_cast [[j_new]] : index to i32
 // CHECK:      linalg.generic
 // CHECK:          addf
+
+// -----
+
+#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
+#id_2d = affine_map<(d0, d1) -> (d0, d1)>
+#pointwise_2d_trait = {
+  indexing_maps = [#id_2d],
+  iterator_types = ["parallel", "parallel"]
+}
+func @index_op(%A: memref<?x?xindex>,
+               %B: memref<?x?xindex>) {
+  linalg.generic #pointwise_2d_trait
+   outs(%B : memref<?x?xindex>) {
+  ^bb0(%arg6: index):   // no predecessors
+    %2 = constant 0 : index
+    linalg.yield %2 : index
+  }
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c25 = constant 25 : index
+  %c10 = constant 10 : index
+  %0 = memref.dim %A, %c0 : memref<?x?xindex>
+  %1 = memref.dim %A, %c1 : memref<?x?xindex>
+  %2 = memref.dim %B, %c0 : memref<?x?xindex>
+  %3 = memref.dim %B, %c1 : memref<?x?xindex>
+  scf.for %arg2 = %c0 to %0 step %c10 {
+    scf.for %arg3 = %c0 to %1 step %c25 {
+      %4 = memref.subview %A[%arg2, %arg3][%c10, %c25][%c1, %c1] :
+          memref<?x?xindex> to memref<?x?xindex, #map>
+      %5 = memref.subview %B[%arg2, %arg3][%c10, %c25][%c1, %c1] :
+          memref<?x?xindex> to memref<?x?xindex, #map>
+      linalg.generic {
+        indexing_maps = [#id_2d, #id_2d],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%4 : memref<?x?xindex, #map>)
+       outs(%5 : memref<?x?xindex, #map>) {
+      ^bb0(%arg6: index, %arg7: index):
+        %6 = linalg.index 0 : index
+        linalg.yield %6 : index
+      }
+    }
+  }
+  return
+}
+// CHECK-LABEL: func @index_op
+// CHECK:  linalg.generic
+// CHECK:  scf.for
+// CHECK:    scf.for
+// CHECK-NOT:  scf.for
+// CHECK:      linalg.generic

diff  --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
index b0a006398c991..7983fe19a95a3 100644
--- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir
@@ -652,3 +652,29 @@ func @generic_index_op2(%arg0: tensor<1x8xf64>, %arg1: tensor<1x8xi32>) -> tenso
   // CHECK-NEXT:   return %[[R]] : tensor<1x8xi32>
   return %1 : tensor<1x8xi32>
 }
+
+// -----
+
+// CHECK-LABEL: func @index_op(
+// CHECK-COUNT-2: linalg.generic
+func @index_op(%arg0: tensor<1x8xindex>, %arg1: tensor<1x8xindex>) -> tensor<1x8xindex> {
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]}
+  outs(%arg0 : tensor<1x8xindex>) {
+  ^bb0(%a: index):   // no predecessors
+    %2 = linalg.index 1 : index
+    linalg.yield %2 : index
+  } -> tensor<1x8xindex>
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]}
+  ins(%0 : tensor<1x8xindex>)
+  outs(%arg1 : tensor<1x8xindex>) {
+  ^bb0(%a: index, %b: index):   // no predecessors
+    %2 = linalg.index 0 : index
+    %3 = addi %2, %a : index
+    linalg.yield %3 : index
+  } -> tensor<1x8xindex>
+  return %1 : tensor<1x8xindex>
+}

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index bdaf0ea351aa9..bb17c9bf1399c 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -24,6 +24,41 @@ func @yield_parent(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>) {
 
 // -----
 
+func @index_parent() {
+  // expected-error @+1 {{op expected parent op with LinalgOp interface}}
+  linalg.index 0 : index
+}
+
+// -----
+
+func @index_dim_lower_than_number_of_loops(%arg0: memref<f32>) {
+  // expected-error @+6 {{op expected dim (2) to be lower than the number of loops (0) of the enclosing LinalgOp}}
+  linalg.generic {
+      indexing_maps =  [ affine_map<() -> ()> ],
+      iterator_types = []}
+      outs(%arg0 : memref<f32>) {
+    ^bb(%0: f32):
+      linalg.index 2 : index
+      linalg.yield %0 : f32
+  }
+}
+
+// -----
+
+func @index_dim_negative(%arg0: memref<f32>) {
+  // expected-error @+6 {{op attribute 'dim' failed to satisfy constraint: 64-bit signless integer attribute whose minimum value is 0}}
+  linalg.generic {
+      indexing_maps =  [ affine_map<() -> ()> ],
+      iterator_types = []}
+      outs(%arg0 : memref<f32>) {
+    ^bb(%0: f32):
+      linalg.index -1 : index
+      linalg.yield %0 : f32
+  }
+}
+
+// -----
+
 func @generic_no_region(%arg0: memref<f32>) {
   // expected-error @+5 {{expected '{' to begin a region}}
   linalg.generic {

diff  --git a/mlir/test/Dialect/Linalg/loop-order.mlir b/mlir/test/Dialect/Linalg/loop-order.mlir
index d1ff47977c351..968ffdc5e7478 100644
--- a/mlir/test/Dialect/Linalg/loop-order.mlir
+++ b/mlir/test/Dialect/Linalg/loop-order.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=LOOP %s
-// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=PARALLEL %s
-// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=AFFINE %s
+// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=LOOP %s
+// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=PARALLEL %s
+// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=AFFINE %s
 
 func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
   linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32>
@@ -22,3 +22,24 @@ func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
 // AFFINE:       affine.for %{{.*}} = 0 to 2
 // AFFINE:         affine.for %{{.*}} = 0 to 3
 
+// -----
+
+func @index_op(%arg0: memref<4x8xindex>) {
+  linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]}
+  outs(%arg0 : memref<4x8xindex>) {
+  ^bb0(%arg1: index):   // no predecessors
+    %0 = linalg.index 1 : index
+    linalg.yield %0 : index
+  }
+  return
+}
+// LOOP-LABEL: @index_op
+//      LOOP:   linalg.generic
+
+// PARALLEL-LABEL: @index_op
+//      PARALLEL:   linalg.generic
+
+// AFFINE-LABEL: @index_op
+//      AFFINE:   linalg.generic

diff  --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 084b8a339c0de..d03161541e2f6 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -525,6 +525,9 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
       outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
       attrs = {foo = 1} {
     ^bb(%a: vector<3x4xi4>, %b: f32) :
+      %0 = linalg.index 0 : index
+      %1 = linalg.index 1 : index
+      %2 = linalg.index 2 : index
       linalg.yield %b : f32
   }
   return
@@ -538,6 +541,9 @@ func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1
 //  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
 //  CHECK-SAME:     attrs = {foo = 1 : i64} {
 //       CHECK:  ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
+//       CHECK:    %{{.*}} = linalg.index 0 : index
+//       CHECK:    %{{.*}} = linalg.index 1 : index
+//       CHECK:    %{{.*}} = linalg.index 2 : index
 //       CHECK:    linalg.yield %{{.*}} : f32
 
 func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,

diff  --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
index c761bd6cd57e0..d8b20904f751a 100644
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -377,3 +377,18 @@ func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memre
 //       TILE-234:     for
 //   TILE-234-NOT:   for
 //       TILE-234:       linalg.generic
+
+// TILE-2-LABEL: func @index_op
+//   TILE-2-NOT:   for
+//       TILE-2:   linalg.generic
+func @index_op(%arg0: memref<?x?xindex>) {
+  linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]}
+  outs(%arg0 : memref<?x?xindex>) {
+  ^bb0(%arg1: index):   // no predecessors
+    %0 = linalg.index 1 : index
+    linalg.yield %0 : index
+  }
+  return
+}

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 3daab904d0323..faaadcf94a5c0 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -469,3 +469,19 @@ func @pad_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: index,
     } : tensor<1x2x2x?xf32> to tensor<6x?x?x?xf32>
   return %0 : tensor<6x?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @index_op
+//       CHECK:   linalg.generic
+func @index_op(%arg0: memref<4x8xindex>) {
+  linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>],
+    iterator_types = ["parallel", "parallel"]}
+  outs(%arg0 : memref<4x8xindex>) {
+  ^bb0(%arg1: index):   // no predecessors
+    %0 = linalg.index 1 : index
+    linalg.yield %0 : index
+  }
+  return
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 3ef6ed5e4b4ba..6931f07a8ab4c 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -126,6 +126,10 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
   // Save original Linalg ops, we only want to make a pass over those.
   SmallVector<LinalgOp, 8> linalgOps;
   f.walk([&](LinalgOp op) {
+    // TODO: remove hasIndexSemantics check once index ops are supported.
+    if (op.hasIndexSemantics())
+      return;
+
     // TODO: support multi-results.
     if (op->getNumResults() <= 1)
       linalgOps.push_back(op);


        


More information about the Mlir-commits mailing list