[Mlir-commits] [mlir] 5571cc7 - [mlir][linalg] Add helpers for linalg.tiled_loop [nfc].

Alexander Belyaev llvmlistbot at llvm.org
Tue Apr 6 09:17:50 PDT 2021


Author: Alexander Belyaev
Date: 2021-04-06T18:17:37+02:00
New Revision: 5571cc7deed6dc01f4764adbbf3d668866f22173

URL: https://github.com/llvm/llvm-project/commit/5571cc7deed6dc01f4764adbbf3d668866f22173
DIFF: https://github.com/llvm/llvm-project/commit/5571cc7deed6dc01f4764adbbf3d668866f22173.diff

LOG: [mlir][linalg] Add helpers for linalg.tiled_loop [nfc].

Differential Revision: https://reviews.llvm.org/D99968

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index a2c3bb7b0279..b0a93f36ab75 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -584,9 +584,33 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
   ];
 
   let extraClassDeclaration = [{
+    /// Number of operands controlling the loop: lbs, ubs, steps
+    unsigned getNumControlOperands() { return 3 * getNumLoops(); }
+
     ValueRange getInductionVars() {
       return getBody()->getArguments();
     }
+
+   /// Result that corresponds to the `outputs` argument of tensor type.
+   OpResult getTiedOpResult(OpOperand& opOperand) {
+      // No result can correspond to a memref argument.
+      if (opOperand.get().getType().isa<MemRefType>()) return OpResult();
+
+      // Check whether the operand index is in bounds of `outputs()` arg.
+      int operandIndex = opOperand.getOperandNumber();
+      int outputIndexStart =
+          getNumControlOperands() + inputs().size();
+      int outputIndexEnd = outputIndexStart + outputs().size();
+      if (operandIndex < outputIndexStart || operandIndex >= outputIndexEnd)
+        return OpResult();
+
+      // Count tensor arguments in `outputs` to compute the result index.
+      int tensorId = -1;
+      for (int i = outputIndexStart; i <= operandIndex; ++i)
+        tensorId += getOperand(i).getType().isa<RankedTensorType>();
+      return getOperation()->getResult(tensorId);
+    }
+
     unsigned getNumLoops() { return step().size(); }
   }];
 


        


More information about the Mlir-commits mailing list